-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make it an option to use TransformerEngine activation function in FFN block #1233
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Guyue Huang <[email protected]>
* Add activation_func field in MLPSubmodules * In extensions/transformer_engine.py, add TEActivationOp, TEActivationOpFp8, and TERowParallelLinearOp which all have type te.ops.Sequential(te.ops.<OP_NAME>) * Add specs of the new classes to get_layer_specs.py when instantiating mlp Signed-off-by: Guyue Huang <[email protected]>
'transformer engine. Consider setting use_te_activation_func=False') | ||
return te_ops.Sequential(instance) | ||
|
||
class TEActivationOpFp8: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
In case of fp8 recipe, the fp8 quantization should be handled by linear module
, right?
Also, te_ops.Quantize
doesn't cast using scale_factor from amax history, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This module was what I designed to enable cast fusion. But I have decided to make this PR focused on enabling TE activation, and enable the cast fusion in a new PR, so we can discuss API design once I create the new PR. I have removed this class.
@@ -523,6 +575,181 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): | |||
) | |||
|
|||
|
|||
class TERowParallelLinearOp(te_ops.Sequential): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need TERowParallelLinearOp
module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tl;dr: this module is no longer necessary due to a recent refact at TE side. I have removed it;
This module was designed to make RowParallelLinear a TE op. Previously there were two code paths to implement row parallel linear, one is legacy and the other is TE operation-based API. The latter is what I needed. However, recent refact of TE has unified the two code paths and there is no longer legacy layer, so my wrapper class here is not necessary. I have removed this.
Conflicts: megatron/core/transformer/mlp.py megatron/core/transformer/transformer_config.py
* Remove wrapping TE activation with TE sequential, directly use TE op class * Remove the TE activation class dedicated for fp8, we will enable cast fusion in a new PR * Remove the TE linear op class because TE has refacted its linear class to use op so mcore doesn't need to * Fix bug * Remove unused file megatron/core/transformer/te_activation_func_utils.py Signed-off-by: Guyue Huang <[email protected]>
[draft]
Add a new config parameter
use_te_activation_func
to control if we want to use TE custom kernels for activation function in MLP.