- At Meta, we are constantly pushing the boundaries of LLM inference systems to power applications such as the Meta AI App.
- We’re sharing how we developed and implemented advanced parallelism techniques to optimize key performance metrics related to resource efficiency, throughput, and latency.
The rapid evolution of large language models (LLMs) has ushered in a new era of AI-powered applications, from conversational agents to advanced content generation. However, deploying these massive models at scale for real-time inference presents significant challenges, particularly in achieving high throughput, low latency, and better resource efficiency.
Our primary goal is to optimize key performance metrics:
- Resource efficiency: Maximizing GPU utilization to improve operational efficiency.
- Throughput (queries/s): Serving more users by processing a higher volume of requests.
- Latency: Minimizing response times for a seamless user experience. This includes:
- Time-to-first-token (TTFT) for prefill: The time it takes for the first part of the response to appear, ideally under 350ms.
- Time-to-incremental-token (TTIT) for decoding: The latency between subsequent words, targeting less than 25ms.
These metrics highlight the distinct computational demands of LLM inference: Prefill is compute-intensive, while decoding is memory bandwidth-intensive. To address these challenges and enable the deployment of large models, we have developed and implemented advanced parallelism techniques.
The Two Stages of LLM Inference
A typical LLM generative-inference task unfolds in two stages:
- Prefill stage: This stage processes the input prompt (which can be thousands of tokens long) to generate a key-value (KV) cache for each transformer layer of the LLM. Prefill is compute-bound, because the attention mechanism scales quadratically with sequence length.
- Decoding Stage: This stage utilizes and incrementally updates the KV cache to generate tokens (words) one by one. Decoding is memory-bound, as the I/O time of reading memory dominates attention time, with model weights and the KV cache occupying the majority of memory.
Addressing Bottlenecks With Parallelism
To scale LLM inference effectively, especially for handling long contexts and massive models, we employ three main types of inference parallelism:
1. Tensor parallelism (TP), which improves fitting large models across multiple GPUs and achieving high throughput that a single device cannot provide. It involves sharding individual layers of the model, such as attention blocks and multi-layer perceptron (MLP) layers, into smaller, independent blocks that can be executed on different devices.
A challenge in tensor parallelism is the “allreduce” communication operation, which can contribute up to 30% of end-to-end latency. To mitigate this, we developed direct data access (DDA) algorithms:
- DDA flat algorithm: Improves small message-size allreduce latency by allowing each rank to directly load memory from other ranks and perform local reduce operations. This reduces latency from O(N) to O(1) by increasing the amount of data exchange from O(n) to O(n^2).
- DDA tree algorithm: Breaks the allreduce into two phases (reduce-scatter and all-gather) and uses direct data access in each step. This moves the same amount of data as the ring algorithm but reduces latency to a constant factor, making it suitable for slightly larger message sizes.
Our DDA solutions demonstrate significant speedups against baselines such as NCCL (NVIDIA Collective Communications Library) and RCCL (ROCm Communication Collectives Library for AMD GPUs). For instance, with AMD MI300X, we achieved overall performance parity with Nvidia H100, with DDA outperforming RCCL baseline by 10-50% for decode (small message sizes) and yielding 10-30% speedup for prefill, resulting in approximately 10% reduction in TTIT.
2. Context parallelism (CP), which facilitates managing and processing extremely long contexts, such as the 1M/10M token capabilities introduced with Llama 4. Long-context inference presents unique challenges:
- Compute: Dense attention FLOPs scale quadratically with context length, leading to attention-compute dominating.
- Memory: The KV cache grows linearly with context.
- Communication: Communication latency increases when parallelizing across multiple hosts.
We have implemented two variants of context parallelism in the attention module, often referred to as “ring attention”:
- Pass-KV: In this approach, input tokens are split across multiple CP ranks. Each rank calculates its portion of query, key, and value tensors. Then, key and value tensors are exchanged between ranks to enable attention interactions across the full context.
- Pass-Q: Similar to Pass-KV, but query tensors are exchanged between ranks.
Our context parallelism optimizations, combined with a fast-attention kernel, have enabled remarkable performance for long-context capabilities. We achieved less than one minute for one million tokens on a single H100 host and less than one minute for 10 million tokens using distributed inference across multiple H100 hosts (e.g., 32 H100 hosts). With Llama 3 405B, we demonstrated near-linear scaling, achieving 128K token prefill in 3.8 seconds with CP over 16 nodes, and 1M-token prefill in 77 seconds.
3. Expert parallelism (EP), which helps with scaling mixture-of-experts (MoE) models, where a large number of “experts” (neural network modules) make it impossible to fit the entire model onto a single host. In EP-based inference, we utilize a two-shot, all-to-all communication pattern to exchange tokens between data parallelism and expert parallelism ranks based on routing.
The all-to-all communication can contribute 10-30% to end-to-end latency, especially for decode messages (100KB to 2MB). To optimize this, we are exploring solutions including:
- Dynamic all-to-all: Sending sub-chunks of data to remote neighbors.
- Persistent all-to-all: Addressing slowdowns primarily caused by memory-handle exchange, network-load balancing, and CPU overhead.
Looking Ahead: Disaggregated Inference and Future Challenges
To further optimize LLM inference, we are moving towards N-D parallelism (CP, PP, EP, TP across nodes, with separate DP) and disaggregating prefill and decoding tiers. This allows for better resource balancing and the potential to use heterogeneous hardware, where compute-heavy hardware is used for prefill and memory bandwidth-heavy hardware for decoding. This multi-dimensional parallelism can help unblock the serving and evaluation of colossal models.
Future challenges in this space include:
- Cloud fabric design: Optimizing the underlying cloud infrastructure for LLM workloads.
- Communication going to kernel (fused kernel): Integrating communication operations directly into computational kernels for greater efficiency.
- Device-initiated kernel: Enabling devices to initiate operations directly, reducing CPU overhead.
These advancements in parallelization and system-level improvements have helped enable the next generation of AI applications and push the boundaries of what LLMs can achieve. We are committed to continuous innovation to ensure efficient and scalable LLM inference for millions of users worldwide.