Skip to content

High-Level APIs

AutoModel

AutoModel Variant API
AutoModelForCausalLM liger_kernel.transformers.AutoLigerKernelForCausalLM

This API extends the implementation of the AutoModelForCausalLM within the transformers library from Hugging Face.

liger_kernel.transformers.AutoLigerKernelForCausalLM

Bases: AutoModelForCausalLM

This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model if applicable.

Source code in src/liger_kernel/transformers/auto_model.py
class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
    """
    This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
    if applicable.
    """

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)

        # Determine the model type and apply the Liger Kernel if applicable
        # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
        model_type = model_config.model_type

        _apply_liger_kernel(model_type, **kwargs)

        # Filter out kwargs that were passed to the apply_liger_* function, which will cause
        # model initialization errors otherwise
        apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
        apply_fn_signature = inspect.signature(apply_fn)

        applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}

        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)

Try it Out

You can experiment as shown in this example here.


Patching

You can also use the Patching APIs to use the kernels for a specific model architecture.

Model API Supported Operations
LLaMA 2 & 3 liger_kernel.transformers.apply_liger_kernel_to_llama RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
LLaMA 3.2-Vision liger_kernel.transformers.apply_liger_kernel_to_mllama RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Mistral liger_kernel.transformers.apply_liger_kernel_to_mistral RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Mixtral liger_kernel.transformers.apply_liger_kernel_to_mixtral RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Gemma1 liger_kernel.transformers.apply_liger_kernel_to_gemma RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Gemma2 liger_kernel.transformers.apply_liger_kernel_to_gemma2 RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Qwen2, Qwen2.5, & QwQ liger_kernel.transformers.apply_liger_kernel_to_qwen2 RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Qwen2-VL liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Phi3 & Phi3.5 liger_kernel.transformers.apply_liger_kernel_to_phi3 RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy

Function Signatures

liger_kernel.transformers.apply_liger_kernel_to_llama

apply_liger_kernel_to_llama(rope=True, cross_entropy=False, fused_linear_cross_entropy=True, rms_norm=True, swiglu=True, model=None)

Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)

Parameters:

Name Type Description Default
rope bool

Whether to apply Liger's rotary position embedding. Default is True.

True
cross_entropy bool

Whether to apply Liger's cross entropy loss. Default is False.

False
fused_linear_cross_entropy bool

Whether to apply Liger's fused linear cross entropy loss. Default is True. cross_entropy and fused_linear_cross_entropy cannot both be True. If fused_linear_cross_entropy is True, the logits will not be materialized but more memory efficient.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
swiglu bool

Whether to apply Liger's SwiGLU MLP. Default is True.

True
model PreTrainedModel

The model instance to apply Liger kernels to, if the model has already been

None
Source code in src/liger_kernel/transformers/monkey_patch.py
def apply_liger_kernel_to_llama(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    swiglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """

    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.llama import modeling_llama
    from transformers.models.llama.modeling_llama import LlamaModel

    if rope:
        modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
    if rms_norm:
        modeling_llama.LlamaRMSNorm = LigerRMSNorm
    if swiglu:
        modeling_llama.LlamaMLP = LigerSwiGLUMLP

    if cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            from transformers.loss.loss_utils import nn

            nn.functional.cross_entropy = liger_cross_entropy
        else:
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss

    if fused_linear_cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            if model is not None:
                model.forward = MethodType(llama_lce_forward, model)
            else:
                modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
        else:  # if version < 4.46.1
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            if model is not None:
                model.forward = MethodType(llama_lce_forward_deprecated, model)
            else:
                modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)

        # get the base model from the model instance
        base_model: LlamaModel = getattr(model, model.base_model_prefix, model)

        if rms_norm:
            _patch_rms_norm_module(base_model.norm)

        for decoder_layer in base_model.layers:
            if swiglu:
                _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
            if rms_norm:
                _patch_rms_norm_module(decoder_layer.input_layernorm)
                _patch_rms_norm_module(decoder_layer.post_attention_layernorm)

liger_kernel.transformers.apply_liger_kernel_to_mllama

apply_liger_kernel_to_mllama(rope=True, cross_entropy=False, fused_linear_cross_entropy=True, layer_norm=True, rms_norm=True, swiglu=True, model=None)

Apply Liger kernels to replace original implementation in HuggingFace MLlama models. NOTE: MLlama is not available in transformers<4.45.0

Parameters:

Name Type Description Default
rope bool

Whether to apply Liger's rotary position embedding. Default is True.

True
cross_entropy bool

Whether to apply Liger's cross entropy loss. Default is False.

False
fused_linear_cross_entropy bool

Whether to apply Liger's fused linear cross entropy loss. Default is True. cross_entropy and fused_linear_cross_entropy cannot both be True. If fused_linear_cross_entropy is True, the logits will not be materialized but more memory efficient.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
swiglu bool

Whether to apply Liger's SwiGLU MLP. Default is True.

True
model PreTrainedModel

The model instance to apply Liger kernels to, if the model has already been

None
Source code in src/liger_kernel/transformers/monkey_patch.py
def apply_liger_kernel_to_mllama(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    layer_norm: bool = True,
    rms_norm: bool = True,
    swiglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
    NOTE: MLlama is not available in transformers<4.45.0

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """

    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.mllama import modeling_mllama
    from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
    from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
    from transformers.models.mllama.modeling_mllama import MllamaTextModel
    from transformers.models.mllama.modeling_mllama import MllamaVisionModel

    from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
    from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated

    if rope:
        modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
    if layer_norm and model is None:
        modeling_mllama.nn.LayerNorm = LigerLayerNorm
    if rms_norm:
        modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
    if swiglu:
        modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
    if cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            from transformers.loss.loss_utils import nn

            nn.functional.cross_entropy = liger_cross_entropy
        else:
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
    if fused_linear_cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            if model is not None:
                model.forward = MethodType(mllama_lce_forward, model)
            else:
                modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
        else:  # if version < 4.46.1
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            if model is not None:
                model.forward = MethodType(mllama_lce_forward_deprecated, model)
            else:
                modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules

        if isinstance(model, MllamaForConditionalGeneration):
            language_model: MllamaForCausalLM = model.language_model
            vision_model: MllamaVisionModel = model.vision_model
            if isinstance(language_model, MllamaForCausalLM):
                text_model: MllamaTextModel = language_model.model
            else:
                text_model = language_model
        elif isinstance(model, MllamaForCausalLM):
            text_model = model.model
            vision_model = None
        elif isinstance(model, MllamaTextModel):
            text_model = model
            vision_model = None

        else:
            raise ValueError(f"Unsupported Mllama model type: {type(model)}")

        if text_model:
            if rms_norm:
                _patch_rms_norm_module(text_model.norm)
            for decoder_layer in text_model.layers:
                if swiglu:
                    _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
                if rms_norm:
                    _patch_rms_norm_module(decoder_layer.input_layernorm)
                    _patch_rms_norm_module(decoder_layer.post_attention_layernorm)

        if vision_model:
            _patch_layer_norm_module(vision_model.layernorm_pre)
            _patch_layer_norm_module(vision_model.layernorm_post)

            for layer in vision_model.transformer.layers:
                if layer_norm:
                    _patch_layer_norm_module(layer.input_layernorm)
                    _patch_layer_norm_module(layer.post_attention_layernorm)

            for layer in vision_model.global_transformer.layers:
                if layer_norm:
                    _patch_layer_norm_module(layer.input_layernorm)
                    _patch_layer_norm_module(layer.post_attention_layernorm)

liger_kernel.transformers.apply_liger_kernel_to_mistral

apply_liger_kernel_to_mistral(rope=True, cross_entropy=False, fused_linear_cross_entropy=True, rms_norm=True, swiglu=True, model=None)

Apply Liger kernels to replace original implementation in HuggingFace Mistral models

Parameters:

Name Type Description Default
rope bool

Whether to apply Liger's rotary position embedding. Default is False.

True
cross_entropy bool

Whether to apply Liger's cross entropy loss. Default is True.

False
fused_linear_cross_entropy bool

Whether to apply Liger's fused linear cross entropy loss. Default is True. cross_entropy and fused_linear_cross_entropy cannot both be True. If fused_linear_cross_entropy is True, the logits will not be materialized but more memory efficient.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
swiglu bool

Whether to apply Liger's SwiGLU MLP. Default is True.

True
model PreTrainedModel

The model instance to apply Liger kernels to, if the model has already been

None
Source code in src/liger_kernel/transformers/monkey_patch.py
def apply_liger_kernel_to_mistral(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    swiglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Mistral models

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """
    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.mistral import modeling_mistral
    from transformers.models.mistral.modeling_mistral import MistralModel

    if rope:
        modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
    if rms_norm:
        modeling_mistral.MistralRMSNorm = LigerRMSNorm
    if cross_entropy:
        modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
    if fused_linear_cross_entropy:
        if transformer_version >= version.parse("4.49.0"):
            if model is not None:
                model.forward = MethodType(mistral_lce_forward, model)
            else:
                modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
        else:
            logger.warning(
                "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
            )
            logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")

    if swiglu:
        modeling_mistral.MistralMLP = LigerSwiGLUMLP

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules

        # get the base model from the model instance
        base_model: MistralModel = getattr(model, model.base_model_prefix, model)

        if rms_norm:
            _patch_rms_norm_module(base_model.norm)

        for decoder_layer in base_model.layers:
            if swiglu:
                _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
            if rms_norm:
                _patch_rms_norm_module(decoder_layer.input_layernorm)
                _patch_rms_norm_module(decoder_layer.post_attention_layernorm)

liger_kernel.transformers.apply_liger_kernel_to_mixtral

apply_liger_kernel_to_mixtral(rope=True, cross_entropy=False, fused_linear_cross_entropy=True, rms_norm=True, swiglu=True, model=None)

Apply Liger kernels to replace original implementation in HuggingFace Mixtral models

Parameters:

Name Type Description Default
rope bool

Whether to apply Liger's rotary position embedding. Default is True.

True
cross_entropy bool

Whether to apply Liger's cross entropy loss. Default is False.

False
fused_linear_cross_entropy bool

Whether to apply Liger's fused linear cross entropy loss. Default is True. cross_entropy and fused_linear_cross_entropy cannot both be True. If fused_linear_cross_entropy is True, the logits will not be materialized but more memory efficient.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
swiglu bool

Whether to apply Liger's SwiGLU MLP. Default is True.

True
model PreTrainedModel

The model instance to apply Liger kernels to, if the model has already been

None
Source code in src/liger_kernel/transformers/monkey_patch.py
def apply_liger_kernel_to_mixtral(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    swiglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Mixtral models

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """

    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.mixtral import modeling_mixtral
    from transformers.models.mixtral.modeling_mixtral import MixtralModel

    if rope:
        modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
    if rms_norm:
        modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
    if cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            from transformers.loss.loss_utils import nn

            nn.functional.cross_entropy = liger_cross_entropy
        else:
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss

    if fused_linear_cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            if model is not None:
                model.forward = MethodType(mixtral_lce_forward, model)
            else:
                modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
        else:  # if version < 4.46.1
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            if model is not None:
                model.forward = MethodType(mixtral_lce_forward_deprecated, model)
            else:
                modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
    if swiglu:
        modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules

        # get the base model from the model instance
        base_model: MixtralModel = getattr(model, model.base_model_prefix, model)

        if rms_norm:
            _patch_rms_norm_module(base_model.norm)

        for decoder_layer in base_model.layers:
            if swiglu:
                for expert in decoder_layer.block_sparse_moe.experts:
                    _patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
            if rms_norm:
                _patch_rms_norm_module(decoder_layer.input_layernorm)
                _patch_rms_norm_module(decoder_layer.post_attention_layernorm)

liger_kernel.transformers.apply_liger_kernel_to_gemma

apply_liger_kernel_to_gemma(rope=True, cross_entropy=False, fused_linear_cross_entropy=True, rms_norm=True, geglu=True, model=None)

Apply Liger kernels to replace original implementation in HuggingFace Gemma (Gemma 1 and 1.1 supported, for Gemma2 please use apply_liger_kernel_to_gemma2 ) to make GPU go burrr.

Parameters:

Name Type Description Default
rope bool

Whether to apply Liger's rotary position embedding. Default is True.

True
cross_entropy bool

Whether to apply Liger's cross entropy loss. Default is False.

False
fused_linear_cross_entropy bool

Whether to apply Liger's fused linear cross entropy loss. Default is True. cross_entropy and fused_linear_cross_entropy cannot both be True. If fused_linear_cross_entropy is True, the logits will not be materialized but more memory efficient.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
geglu bool

Whether to apply Liger's GeGLU MLP. Default is True.

True
model PreTrainedModel

The model instance to apply Liger kernels to, if the model has already been

None
Source code in src/liger_kernel/transformers/monkey_patch.py
def apply_liger_kernel_to_gemma(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    geglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Gemma
    (Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """
    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.gemma import modeling_gemma
    from transformers.models.gemma.modeling_gemma import GemmaModel

    from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma

    _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)

    if rope:
        modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
    if rms_norm:
        modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
    if cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            from transformers.loss.loss_utils import nn

            nn.functional.cross_entropy = liger_cross_entropy
        else:
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
    if geglu:
        modeling_gemma.GemmaMLP = LigerGEGLUMLP
    if fused_linear_cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            if model is not None:
                model.forward = MethodType(gemma_lce_forward, model)
            else:
                modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
        else:  # if version < 4.46.1
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            if model is not None:
                model.forward = MethodType(gemma_lce_forward_deprecated, model)
            else:
                modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules

        # get the base model from the model instance
        base_model: GemmaModel = getattr(model, model.base_model_prefix, model)

        if rms_norm:
            _patch_rms_norm_module_for_gemma(base_model.norm)

        for decoder_layer in base_model.layers:
            if geglu:
                _patch_geglu_module(decoder_layer.mlp)
            if rms_norm:
                _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
                _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)

liger_kernel.transformers.apply_liger_kernel_to_gemma2

apply_liger_kernel_to_gemma2(rope=True, cross_entropy=False, fused_linear_cross_entropy=True, rms_norm=True, geglu=True, model=None)

Apply Liger kernels to replace original implementation in HuggingFace Gemma2 (for Gemma1 please use apply_liger_kernel_to_gemma) to make GPU go burrr.

Parameters:

Name Type Description Default
rope bool

Whether to apply Liger's rotary position embedding. Default is True.

True
cross_entropy bool

Whether to apply Liger's cross entropy loss. Default is False.

False
fused_linear_cross_entropy bool

Whether to apply Liger's fused linear cross entropy loss. Default is True. cross_entropy and fused_linear_cross_entropy cannot both be True. If fused_linear_cross_entropy is True, the logits will not be materialized but more memory efficient.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
geglu bool

Whether to apply Liger's GeGLU MLP. Default is True.

True
model PreTrainedModel

The model instance to apply Liger kernels to, if the model has already been

None
Source code in src/liger_kernel/transformers/monkey_patch.py
def apply_liger_kernel_to_gemma2(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    geglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Gemma2
    (for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """
    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.gemma2 import modeling_gemma2
    from transformers.models.gemma2.modeling_gemma2 import Gemma2Model

    from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2

    _patch_rms_norm_module_for_gemma2 = partial(
        _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
    )

    if rope:
        modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
    if rms_norm:
        # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
        modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
    if cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            from transformers.loss.loss_utils import nn

            nn.functional.cross_entropy = liger_cross_entropy
        else:
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
    if fused_linear_cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            if model is not None:
                model.forward = MethodType(gemma2_lce_forward, model)
            else:
                modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
        else:
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            if model is not None:
                model.forward = MethodType(gemma2_lce_forward_deprected, model)
            else:
                modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
    if geglu:
        modeling_gemma2.Gemma2MLP = LigerGEGLUMLP

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules

        # get the base model from the model instance
        base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)

        if rms_norm:
            _patch_rms_norm_module_for_gemma2(base_model.norm)

        for decoder_layer in base_model.layers:
            if geglu:
                _patch_geglu_module(decoder_layer.mlp)
            if rms_norm:
                _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
                _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
                _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
                _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)

liger_kernel.transformers.apply_liger_kernel_to_qwen2

apply_liger_kernel_to_qwen2(rope=True, cross_entropy=False, fused_linear_cross_entropy=True, rms_norm=True, swiglu=True, model=None)

Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models

Parameters:

Name Type Description Default
rope bool

Whether to apply Liger's rotary position embedding. Default is True.

True
cross_entropy bool

Whether to apply Liger's cross entropy loss. Default is False.

False
fused_linear_cross_entropy bool

Whether to apply Liger's fused linear cross entropy loss. Default is True. cross_entropy and fused_linear_cross_entropy cannot both be True. If fused_linear_cross_entropy is True, the logits will not be materialized but more memory efficient.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
swiglu bool

Whether to apply Liger's SwiGLU MLP. Default is True.

True
model PreTrainedModel

The model instance to apply Liger kernels to, if the model has already been

None
Source code in src/liger_kernel/transformers/monkey_patch.py
def apply_liger_kernel_to_qwen2(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    swiglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """
    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.qwen2 import modeling_qwen2
    from transformers.models.qwen2.modeling_qwen2 import Qwen2Model

    if rope:
        modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
    if rms_norm:
        modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm

    if cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            from transformers.loss.loss_utils import nn

            nn.functional.cross_entropy = liger_cross_entropy
        else:
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss

    if fused_linear_cross_entropy:
        if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
            if model is not None:
                model.forward = MethodType(qwen2_lce_forward, model)
            else:
                modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
        else:  # if version < 4.46.1
            logger.warning(TRANSFORMER_DEPRECATION_WARNING)
            if model is not None:
                model.forward = MethodType(qwen2_lce_forward_deprecated, model)
            else:
                modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated

    if swiglu:
        modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules

        # get the base model from the model instance
        base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)

        if rms_norm:
            _patch_rms_norm_module(base_model.norm)

        for decoder_layer in base_model.layers:
            if swiglu:
                _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
            if rms_norm:
                _patch_rms_norm_module(decoder_layer.input_layernorm)
                _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
    print("Applied Liger kernels to Qwen2")

liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl

apply_liger_kernel_to_qwen2_vl(rope=True, cross_entropy=False, fused_linear_cross_entropy=True, rms_norm=True, layer_norm=True, swiglu=True, model=None)

Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models. NOTE: Qwen2-VL is not supported in transformers<4.52.4

Parameters:

Name Type Description Default
cross_entropy bool

Whether to apply Liger's cross entropy loss. Default is False.

False
fused_linear_cross_entropy bool

Whether to apply Liger's fused linear cross entropy loss. Default is True. cross_entropy and fused_linear_cross_entropy cannot both be True. If fused_linear_cross_entropy is True, the logits will not be materialized but more memory efficient.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
layer_norm bool

Whether to apply Liger's LayerNorm. Default is True.

True
swiglu bool

Whether to apply Liger's SwiGLU MLP. Default is True.

True
model PreTrainedModel

The model instance to apply Liger kernels to, if the model has already been

None
Source code in src/liger_kernel/transformers/monkey_patch.py
def apply_liger_kernel_to_qwen2_vl(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    layer_norm: bool = True,
    swiglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
    NOTE: Qwen2-VL is not supported in transformers<4.52.4

    Args:
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
        swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """
    if transformer_version < version.parse("4.52.4"):
        logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
        return

    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.qwen2_vl import modeling_qwen2_vl
    from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
    from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
    from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
    from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel

    from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward

    if rope:
        modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
    if rms_norm:
        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
        modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
    if layer_norm and model is None:
        modeling_qwen2_vl.LayerNorm = LigerLayerNorm
    if cross_entropy:
        modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
    if fused_linear_cross_entropy:
        if model is not None:
            model.forward = MethodType(qwen2_vl_lce_forward, model)
        else:
            modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
    if swiglu:
        modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules

        if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
            # Note: language_model and visual properties can be accessed throught conditional class for BC.
            # Not sure if it is subject to changes in the future.
            # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
            text_model: Qwen2VLTextModel = model.language_model
            vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
        elif isinstance(model, Qwen2VLTextModel):
            text_model: Qwen2VLTextModel = model
            vision_model = None
        else:
            # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
            raise TypeError(
                f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
            )

        # Patch Qwen2VisionTransformerPretrainedModel
        if vision_model is not None:
            for vision_block in vision_model.blocks:
                if layer_norm:
                    _patch_layer_norm_module(vision_block.norm1)
                    _patch_layer_norm_module(vision_block.norm2)

        # Patch Qwen2VisionTextModel
        if text_model is not None:
            if rms_norm:
                _patch_rms_norm_module(text_model.norm)
            for decoder_layer in text_model.layers:
                if swiglu:
                    _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
                if rms_norm:
                    _patch_rms_norm_module(decoder_layer.input_layernorm)
                    _patch_rms_norm_module(decoder_layer.post_attention_layernorm)

liger_kernel.transformers.apply_liger_kernel_to_phi3

apply_liger_kernel_to_phi3(rope=True, cross_entropy=False, fused_linear_cross_entropy=True, rms_norm=True, swiglu=True, model=None)

Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.

Parameters:

Name Type Description Default
rope bool

Whether to apply Liger's rotary position embedding. Default is True.

True
cross_entropy bool

Whether to apply Liger's cross entropy loss. Default is False.

False
fused_linear_cross_entropy bool

Whether to apply Liger's fused linear cross entropy loss. Default is True. cross_entropy and fused_linear_cross_entropy cannot both be True. If fused_linear_cross_entropy is True, the logits will not be materialized but more memory efficient.

True
rms_norm bool

Whether to apply Liger's RMSNorm. Default is True.

True
swiglu bool

Whether to apply Liger's SwiGLU Phi3MLP. Default is True.

True
model PreTrainedModel

The model instance to apply Liger kernels to, if the model has already been

None
Source code in src/liger_kernel/transformers/monkey_patch.py
def apply_liger_kernel_to_phi3(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    swiglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """
    assert not (cross_entropy and fused_linear_cross_entropy), (
        "cross_entropy and fused_linear_cross_entropy cannot both be True."
    )

    from transformers.models.phi3 import modeling_phi3
    from transformers.models.phi3.modeling_phi3 import Phi3Model

    if rope:
        modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb  # Same as Gemma
    if rms_norm:
        modeling_phi3.Phi3RMSNorm = LigerRMSNorm  # Same as Llama
    if swiglu:
        modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
    if cross_entropy:
        from transformers.loss.loss_utils import nn

        nn.functional.cross_entropy = liger_cross_entropy
    if fused_linear_cross_entropy:
        if model is not None:
            model.forward = MethodType(phi3_lce_forward, model)
        else:
            modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules

        # get the base model from the model instance
        base_model: Phi3Model = getattr(model, model.base_model_prefix, model)

        if rms_norm:
            _patch_rms_norm_module(base_model.norm)

        for decoder_layer in base_model.layers:
            if swiglu:
                _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
            if rms_norm:
                _patch_rms_norm_module(decoder_layer.input_layernorm)
                _patch_rms_norm_module(decoder_layer.post_attention_layernorm)