Liger-Kernel

Liger Kernel: Efficient Triton Kernels for LLM Training

Stable Nightly Discord
Downloads (Stable) PyPI - Version Downloads (Nightly) PyPI - Version Join Our Discord

Installation Getting Started Examples High-level APIs Low-level APIs Cite our work
Latest News 🔥 - [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/) - [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)! - [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training) - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision! - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989 - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks! - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel) - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)

Liger Kernel is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. The kernel works out of the box with Flash Attention, PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.

We’ve also added optimized Post-Training kernels that deliver up to 80% memory savings for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out how we optimize the memory.

You can view the documentation site for additional installation, usage examples, and API references:https://linkedin.github.io/Liger-Kernel/

Supercharge Your Model with Liger Kernel

Banner

With one line of code, Liger Kernel can increase throughput by more than 20% and reduce memory usage by 60%, thereby enabling longer context lengths, larger batch sizes, and massive vocabularies.

Speed Up Memory Reduction
Speed up Memory

Note:

Optimize Post Training with Liger Kernel

Post Training

We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.

from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
orpo_loss = LigerFusedLinearORPOLoss()
y = orpo_loss(lm_head.weight, x, target)

Examples

Use Case Description
Hugging Face Trainer Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP
Lightning Trainer Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3
Medusa Multi-head LLM (Retraining Phase) Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP
Vision-Language Model SFT Finetune Qwen2-VL on image-text data using 4 A100s with FSDP
Liger ORPO Trainer Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction

Key Features

Installation

Dependencies

CUDA

ROCm

# Need to pass the url when installing
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2

Optional Dependencies

Note: Our kernels inherit the full spectrum of hardware compatibility offered by Triton.

To install the stable version:

$ pip install liger-kernel

To install the nightly version:

$ pip install liger-kernel-nightly

To install from source:

git clone https://github.com/linkedin/Liger-Kernel.git
cd Liger-Kernel

# Install Default Dependencies
# Setup.py will detect whether you are using AMD or NVIDIA
pip install -e .

# Setup Development Dependencies
pip install -e ".[dev]"

Getting Started

There are a couple of ways to apply Liger kernels, depending on the level of customization required.

1. Use AutoLigerKernelForCausalLM

Using the AutoLigerKernelForCausalLM is the simplest approach, as you don’t have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings.

from liger_kernel.transformers import AutoLigerKernelForCausalLM

# This AutoModel wrapper class automatically monkey-patches the
# model with the optimized Liger kernels if the model is supported.
model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")

2. Apply Model-Specific Patching APIs

Using the patching APIs, you can swap Hugging Face models with optimized Liger Kernels.

import transformers
from liger_kernel.transformers import apply_liger_kernel_to_llama

# 1a. Adding this line automatically monkey-patches the model with the optimized Liger kernels
apply_liger_kernel_to_llama()

# 1b. You could alternatively specify exactly which kernels are applied
apply_liger_kernel_to_llama(
  rope=True,
  swiglu=True,
  cross_entropy=True,
  fused_linear_cross_entropy=False,
  rms_norm=False
)

# 2. Instantiate patched model
model = transformers.AutoModelForCausalLM("path/to/llama/model")

3. Compose Your Own Model

You can take individual kernels to compose your models.

from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
import torch.nn as nn
import torch

model = nn.Linear(128, 256).cuda()

# fuses linear + cross entropy layers together and performs chunk-by-chunk computation to reduce memory
loss_fn = LigerFusedLinearCrossEntropyLoss()

input = torch.randn(4, 128, requires_grad=True, device="cuda")
target = torch.randint(256, (4, ), device="cuda")

loss = loss_fn(model.weight, input, target)
loss.backward()

High-level APIs

AutoModel

AutoModel Variant API
AutoModelForCausalLM liger_kernel.transformers.AutoLigerKernelForCausalLM

Patching

Model API Supported Operations
LLaMA 2 & 3 liger_kernel.transformers.apply_liger_kernel_to_llama RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
LLaMA 3.2-Vision liger_kernel.transformers.apply_liger_kernel_to_mllama RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Mistral liger_kernel.transformers.apply_liger_kernel_to_mistral RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Mixtral liger_kernel.transformers.apply_liger_kernel_to_mixtral RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Gemma1 liger_kernel.transformers.apply_liger_kernel_to_gemma RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Gemma2 liger_kernel.transformers.apply_liger_kernel_to_gemma2 RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Gemma3 (Text) liger_kernel.transformers.apply_liger_kernel_to_gemma3_text RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Gemma3 (Multimodal) liger_kernel.transformers.apply_liger_kernel_to_gemma3 LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Paligemma, Paligemma2, & Paligemma2 Mix liger_kernel.transformers.apply_liger_kernel_to_paligemma LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Qwen2, Qwen2.5, & QwQ liger_kernel.transformers.apply_liger_kernel_to_qwen2 RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Qwen2-VL, & QVQ liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Qwen2.5-VL liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Qwen3 liger_kernel.transformers.apply_liger_kernel_to_qwen3 RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Qwen3 MoE liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Phi3 & Phi3.5 liger_kernel.transformers.apply_liger_kernel_to_phi3 RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Granite 3.0 & 3.1 liger_kernel.transformers.apply_liger_kernel_to_granite RoPE, RMSNorm, SwiGLU, CrossEntropyLoss
OLMo2 liger_kernel.transformers.apply_liger_kernel_to_olmo2 RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
GLM-4 liger_kernel.transformers.apply_liger_kernel_to_glm4 RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy

Low-level APIs

Model Kernels

Kernel API
RMSNorm liger_kernel.transformers.LigerRMSNorm
LayerNorm liger_kernel.transformers.LigerLayerNorm
RoPE liger_kernel.transformers.liger_rotary_pos_emb
SwiGLU liger_kernel.transformers.LigerSwiGLUMLP
GeGLU liger_kernel.transformers.LigerGEGLUMLP
CrossEntropy liger_kernel.transformers.LigerCrossEntropyLoss
Fused Linear CrossEntropy liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss
Sparsemax liger_kernel.transformers.LigerSparsemax

Alignment Kernels

Kernel API
Fused Linear CPO Loss liger_kernel.chunked_loss.LigerFusedLinearCPOLoss
Fused Linear DPO Loss liger_kernel.chunked_loss.LigerFusedLinearDPOLoss
Fused Linear ORPO Loss liger_kernel.chunked_loss.LigerFusedLinearORPOLoss
Fused Linear SimPO Loss liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss
Fused Linear KTO Loss liger_kernel.chunked_loss.LigerFusedLinearKTOLoss

Distillation Kernels

Kernel API
KLDivergence liger_kernel.transformers.LigerKLDIVLoss
JSD liger_kernel.transformers.LigerJSD
Fused Linear JSD liger_kernel.transformers.LigerFusedLinearJSD
TVD liger_kernel.transformers.LigerTVDLoss

Experimental Kernels

Kernel API
Embedding liger_kernel.transformers.experimental.LigerEmbedding
Matmul int2xint8 liger_kernel.transformers.experimental.matmul

Contributing, Acknowledgements, and License

Sponsorship and Collaboration

CI status

Build

Contact

Cite this work

Biblatex entry:

@article{hsu2024ligerkernelefficienttriton,
      title={Liger Kernel: Efficient Triton Kernels for LLM Training},
      author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
      year={2024},
      eprint={2410.10989},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2410.10989},
      journal={arXiv preprint arXiv:2410.10989},
}

Star History

Star History Chart

↑ Back to Top ↑