Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it an option to use TransformerEngine activation function in FFN block #1233

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

guyueh1
Copy link
Contributor

@guyueh1 guyueh1 commented Oct 21, 2024

[draft]
Add a new config parameter use_te_activation_func to control if we want to use TE custom kernels for activation function in MLP.

Guyue Huang added 2 commits October 21, 2024 11:57
* Add activation_func field in MLPSubmodules
* In extensions/transformer_engine.py, add TEActivationOp, TEActivationOpFp8,
and TERowParallelLinearOp which all have type te.ops.Sequential(te.ops.<OP_NAME>)
* Add specs of the new classes to get_layer_specs.py when instantiating mlp

Signed-off-by: Guyue Huang <[email protected]>
'transformer engine. Consider setting use_te_activation_func=False')
return te_ops.Sequential(instance)

class TEActivationOpFp8:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?
In case of fp8 recipe, the fp8 quantization should be handled by linear module, right?

Also, te_ops.Quantize doesn't cast using scale_factor from amax history, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module was what I designed to enable cast fusion. But I have decided to make this PR focused on enabling TE activation, and enable the cast fusion in a new PR, so we can discuss API design once I create the new PR. I have removed this class.

@@ -523,6 +575,181 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
)


class TERowParallelLinearOp(te_ops.Sequential):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need TERowParallelLinearOp module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl;dr: this module is no longer necessary due to a recent refact at TE side. I have removed it;
This module was designed to make RowParallelLinear a TE op. Previously there were two code paths to implement row parallel linear, one is legacy and the other is TE operation-based API. The latter is what I needed. However, recent refact of TE has unified the two code paths and there is no longer legacy layer, so my wrapper class here is not necessary. I have removed this.

Conflicts:
	megatron/core/transformer/mlp.py
	megatron/core/transformer/transformer_config.py
* Remove wrapping TE activation with TE sequential, directly use TE op class
* Remove the TE activation class dedicated for fp8, we will enable cast fusion in a new PR
* Remove the TE linear op class because TE has refacted its linear class to use op so mcore
doesn't need to
* Fix bug
* Remove unused file megatron/core/transformer/te_activation_func_utils.py

Signed-off-by: Guyue Huang <[email protected]>
@guyueh1 guyueh1 changed the title Make it an option to use te activation function in MLP Make it an option to use TransformerEngine activation function in FFN block Nov 19, 2024
@guyueh1 guyueh1 marked this pull request as ready for review November 19, 2024 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants