Stable | Nightly | Discord | ||
---|---|---|---|---|
|
|
|
|
|
Installation | Getting Started | Examples | High-level APIs | Low-level APIs | Cite our work |
Liger Kernel is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. We have implemented Hugging Face Compatible RMSNorm
, RoPE
, SwiGLU
, CrossEntropy
, FusedLinearCrossEntropy
, and more to come. The kernel works out of the box with Flash Attention, PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.
We’ve also added optimized Post-Training kernels that deliver up to 80% memory savings for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out how we optimize the memory.
You can view the documentation site for additional installation, usage examples, and API references:https://linkedin.github.io/Liger-Kernel/
With one line of code, Liger Kernel can increase throughput by more than 20% and reduce memory usage by 60%, thereby enabling longer context lengths, larger batch sizes, and massive vocabularies.
Speed Up | Memory Reduction |
---|---|
![]() |
![]() |
Note:
- Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type =
bf16
, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.- Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
orpo_loss = LigerFusedLinearORPOLoss()
y = orpo_loss(lm_head.weight, x, target)
Use Case | Description |
---|---|
Hugging Face Trainer | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
Lightning Trainer | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
Medusa Multi-head LLM (Retraining Phase) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
Vision-Language Model SFT | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
Liger ORPO Trainer | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |
torch >= 2.1.2
triton >= 2.3.0
torch >= 2.5.0
Install according to the instruction in Pytorch official webpage.triton >= 3.0.0
Install from pypi. (e.g. pip install triton==3.0.0
)# Need to pass the url when installing
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
transformers >= 4.x
: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.Note: Our kernels inherit the full spectrum of hardware compatibility offered by Triton.
To install the stable version:
$ pip install liger-kernel
To install the nightly version:
$ pip install liger-kernel-nightly
To install from source:
git clone https://github.com/linkedin/Liger-Kernel.git
cd Liger-Kernel
# Install Default Dependencies
# Setup.py will detect whether you are using AMD or NVIDIA
pip install -e .
# Setup Development Dependencies
pip install -e ".[dev]"
There are a couple of ways to apply Liger kernels, depending on the level of customization required.
Using the AutoLigerKernelForCausalLM
is the simplest approach, as you don’t have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings.
from liger_kernel.transformers import AutoLigerKernelForCausalLM
# This AutoModel wrapper class automatically monkey-patches the
# model with the optimized Liger kernels if the model is supported.
model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")
Using the patching APIs, you can swap Hugging Face models with optimized Liger Kernels.
import transformers
from liger_kernel.transformers import apply_liger_kernel_to_llama
# 1a. Adding this line automatically monkey-patches the model with the optimized Liger kernels
apply_liger_kernel_to_llama()
# 1b. You could alternatively specify exactly which kernels are applied
apply_liger_kernel_to_llama(
rope=True,
swiglu=True,
cross_entropy=True,
fused_linear_cross_entropy=False,
rms_norm=False
)
# 2. Instantiate patched model
model = transformers.AutoModelForCausalLM("path/to/llama/model")
You can take individual kernels to compose your models.
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
import torch.nn as nn
import torch
model = nn.Linear(128, 256).cuda()
# fuses linear + cross entropy layers together and performs chunk-by-chunk computation to reduce memory
loss_fn = LigerFusedLinearCrossEntropyLoss()
input = torch.randn(4, 128, requires_grad=True, device="cuda")
target = torch.randint(256, (4, ), device="cuda")
loss = loss_fn(model.weight, input, target)
loss.backward()
AutoModel Variant | API |
---|---|
AutoModelForCausalLM | liger_kernel.transformers.AutoLigerKernelForCausalLM |
Model | API | Supported Operations |
---|---|---|
LLaMA 2 & 3 | liger_kernel.transformers.apply_liger_kernel_to_llama |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
LLaMA 3.2-Vision | liger_kernel.transformers.apply_liger_kernel_to_mllama |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Mistral | liger_kernel.transformers.apply_liger_kernel_to_mistral |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Mixtral | liger_kernel.transformers.apply_liger_kernel_to_mixtral |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Gemma1 | liger_kernel.transformers.apply_liger_kernel_to_gemma |
RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Gemma2 | liger_kernel.transformers.apply_liger_kernel_to_gemma2 |
RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Gemma3 (Text) | liger_kernel.transformers.apply_liger_kernel_to_gemma3_text |
RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Gemma3 (Multimodal) | liger_kernel.transformers.apply_liger_kernel_to_gemma3 |
LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Paligemma, Paligemma2, & Paligemma2 Mix | liger_kernel.transformers.apply_liger_kernel_to_paligemma |
LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Qwen2, Qwen2.5, & QwQ | liger_kernel.transformers.apply_liger_kernel_to_qwen2 |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Qwen2-VL, & QVQ | liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl |
RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Qwen2.5-VL | liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl |
RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Qwen3 | liger_kernel.transformers.apply_liger_kernel_to_qwen3 |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Qwen3 MoE | liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Phi3 & Phi3.5 | liger_kernel.transformers.apply_liger_kernel_to_phi3 |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Granite 3.0 & 3.1 | liger_kernel.transformers.apply_liger_kernel_to_granite |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
OLMo2 | liger_kernel.transformers.apply_liger_kernel_to_olmo2 |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
GLM-4 | liger_kernel.transformers.apply_liger_kernel_to_glm4 |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Fused Linear
kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.Kernel | API |
---|---|
RMSNorm | liger_kernel.transformers.LigerRMSNorm |
LayerNorm | liger_kernel.transformers.LigerLayerNorm |
RoPE | liger_kernel.transformers.liger_rotary_pos_emb |
SwiGLU | liger_kernel.transformers.LigerSwiGLUMLP |
GeGLU | liger_kernel.transformers.LigerGEGLUMLP |
CrossEntropy | liger_kernel.transformers.LigerCrossEntropyLoss |
Fused Linear CrossEntropy | liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss |
Sparsemax | liger_kernel.transformers.LigerSparsemax |
Kernel | API |
---|---|
Fused Linear CPO Loss | liger_kernel.chunked_loss.LigerFusedLinearCPOLoss |
Fused Linear DPO Loss | liger_kernel.chunked_loss.LigerFusedLinearDPOLoss |
Fused Linear ORPO Loss | liger_kernel.chunked_loss.LigerFusedLinearORPOLoss |
Fused Linear SimPO Loss | liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss |
Fused Linear KTO Loss | liger_kernel.chunked_loss.LigerFusedLinearKTOLoss |
Kernel | API |
---|---|
KLDivergence | liger_kernel.transformers.LigerKLDIVLoss |
JSD | liger_kernel.transformers.LigerJSD |
Fused Linear JSD | liger_kernel.transformers.LigerFusedLinearJSD |
TVD | liger_kernel.transformers.LigerTVDLoss |
Kernel | API |
---|---|
Embedding | liger_kernel.transformers.experimental.LigerEmbedding |
Matmul int2xint8 | liger_kernel.transformers.experimental.matmul |
Build |
---|
Biblatex entry:
@article{hsu2024ligerkernelefficienttriton,
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
year={2024},
eprint={2410.10989},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.10989},
journal={arXiv preprint arXiv:2410.10989},
}