메뉴
BL
MarkTechPost 7일 전

엔비디아, 기억 수정 최적화 선형 어텐션 모델 공개

IMP
8/10
핵심 요약

엔비디아가 기존 선형 어텐션 모델들의 한계였던 '기억 덮어쓰기' 문제를 해결한 'Gated DeltaNet-2'를 발표했습니다. 이 모델은 기존의 단일 게이트를 키(Key) 축의 '삭제 게이트'와 값(Value) 축의 '쓰기 게이트'로 분리하여 메모리를 더 정교하게 제어합니다. 그 결과, Mamba-2, Mamba-3 등 기존 최고 성능 모델들을 벤치마크에서 모두 능가하는 우수한 성능을 입증했습니다.

번역된 본문

인공지능(AI) 인프라 기술 뉴스, 편집자 추천, 대규모 언어 모델(LLM), 머신러닝, 오픈소스, 소프트웨어 엔지니어링 등

선형 어텐션(Linear Attention)은 소프트맥스 어텐션(Softmax Attention)의 무한한 크기의 KV 캐시를 고정 크기의 순환 상태(Recurrent State)로 대체합니다. 이를 통해 시퀀스 혼합(Sequence mixing) 시간을 선형 시간으로 줄이고, 디코딩 시 메모리 사용량을 상수(O(1))로 낮춥니다. 하지만 가장 어려운 문제는 '무엇을 잊을 것인가'가 아니라, 기존에 학습된 연관성을 파괴하지 않고 압축된 메모리를 어떻게 수정할 것인가입니다.

엔비디아는 이러한 병목 현상을 해결하기 위해 'Gated DeltaNet-2'라는 선형 어텐션 레이어를 공개했습니다. 이 모델은 능동적인 메모리 편집(Active memory edit) 과정을 두 개의 채널 단위 게이트(Channel-wise gate)로 분리합니다. 100B(천억) 개의 FineWeb-Edu 토큰으로 1.3B(13억) 파라미터 모델을 학습시켰으며, 연구진의 벤치마크 스위트 전반에서 Mamba-2, Gated DeltaNet, KDA(Kimi Delta Attention), Mamba-3를 모두 능가하는 성능을 보였습니다.

델타 규칙(Delta Rule) 모델의 스칼라 게이트(Scalar Gate) 문제 순환 선형 어텐션 레이어는 행렬 상태 S_t를 저장하고 쿼리(Query)를 통해 이를 읽어옵니다. DeltaNet은 현재 키(Key)에 연결된 기존 값을 빼는 방식으로 능동적인 편집을 추가합니다. 이때 스칼라 스텝 크기 β_t를 사용하여 얼마나 덮어쓸지를 결정합니다. Mamba-2는 전역적인 망각(Global forgetting)을 위해 데이터 종속적인 스칼라 감쇠(Decay) α_t를 추가했습니다.

Gated DeltaNet은 이 두 가지 연산을 결합했지만, 여전히 헤드 당 스칼라 게이트 방식을 사용했습니다. KDA는 감쇠 측면을 개선하여 스칼라 α_t를 채널 단위 벡터로 교체했습니다. 하지만 KDA 역시 능동적 편집을 위해 단일 스칼라 β_t를 유지합니다.

이 단일 스칼라는 두 가지 다른 역할을 동시에 제어합니다. 하나는 키(Key) 축에서 얼마나 많은 기존 콘텐츠를 지울지 결정하는 것이고, 다른 하나는 값(Value) 축에서 얼마나 많은 새로운 콘텐츠를 기록할지 결정하는 것입니다. 이 두 결정은 상태의 다른 축에 영향을 미칩니다. 이를 하나로 묶는 것은 델타 규칙의 속성이 아니라 모델링의 제약 조건일 뿐입니다.

게이트 델타 규칙-2 (Gated Delta Rule-2): 하나 대신 두 개의 게이트 Gated DeltaNet-2는 'Gated Delta Rule-2'를 통해 이 두 가지 결정을 분리합니다. 이 규칙은 키 축에 대해 채널 단위의 삭제 게이트(Erase gate) b_t ∈ [0,1]^(d_k)를 도입합니다. 또한 값 축에 대해 채널 단위의 쓰기 게이트(Write gate) w_t ∈ [0,1]^(d_v)를 도입합니다. 두 게이트는 모두 토큰 표현(Token representation)의 시그모이드(Sigmoid) 프로젝션을 통해 생성됩니다.

업데이트 시 능동적 편집 전에 감쇠(Decay)를 적용합니다. 간결하게 표현하면 순환 식은 다음과 같습니다: S_t = (I − k_t (b_t ⊙ k_t)^⊤) D_t S_{t−1} + k_t (w_t ⊙ v_t)^⊤

여기서 D_t = Diag(α_t)는 KDA에서 가져온 채널 단위 감쇠입니다. 삭제 행렬(Erase matrix)의 왼쪽 요인은 k_t로 유지되어 델타 규칙의 쓰기 방향을 보존합니다. 오른쪽 요인은 b_t ⊙ k_t가 되어 읽기 방향을 채널 선택적으로 만듭니다. 쓰기 항 k_t z_t^⊤는 z_t = w_t ⊙ v_t를 사용하여 값 업데이트를 채널 선택적으로 만듭니다.

두 게이트가 동일한 스칼라 β_t로 축소되면 업데이트는 KDA와 정확히 일치합니다. 감쇠 α_t도 스칼라로 축소되면 기존 Gated DeltaNet을 복원합니다. 두 이전 모델은 모두 새로운 업데이트의 제약된 부분 공간(Tied subspace)으로 보존됩니다.

Fast-weight 관점에서 Gated Delta Rule-2는 국소 회귀 손실(Local regression loss)에 대한 하나의 온라인 경사 하강법(Online gradient step) 단계입니다. 감쇠된 상태는 메모리에 가깝게 유지되고, 잔차 편집은 게이트가 적용된 읽기 및 쓰기 타겟을 사용합니다.

청크 단위 학습(Chunkwise Training) 및 게이트 인식 역전파(Backward Pass) 이 순환 구조는 KDA에 사용된 구조와 일치하는 청크 단위 WY 형식을 허용합니다. 누적 채널 감쇠는 각 계수-1(Rank-one) 삭제의 두 요인에 흡수됩니다. 청크별 업데이트는 I − k̄_r ē_r^⊤ 형태의 비대칭 행렬의 곱이 됩니다. 구현에는 융합된 Triton 커널을 사용하여 청크 크기 C = 64로 설정되었습니다.

역전파(Backward pass)의 경우, KDA에서 사용하던 스칼라 지름길(Shortcut)을 더 이상 사용할 수 없습니다. 쓰기 측에는 값 채널에 대한 다른 대각선 게이트가 포함되어 있고, 삭제 측에는 키 채널에 대한 다른 대각선 게이트가 포함되어 있습니다. 따라서 그래디언트를 누적하는 내적(Dot product) 내부에 게이트 요인이 반드시 나타나야 합니다. 논문에서는 이 게이트 인식 벡터-야코비안 곱(Vector-Jacobian product)을 명시적으로 도출했습니다. Hopper GPU에서 융합된 WY 역방향 커널은 Triton WGMMA 레이아웃 어설션 오류를 피하기 위해 2개 또는 4개의 워프(Warp)로 제한됩니다.

블록 설계 및 하이브리드 모델 Gated DeltaNet-2는 표준 트랜스포머(Transformer) 아키텍처에서 순환 토큰 믹서(Recurrent token mixer)로 사용됩니다.

원문 보기
원문 보기 (영어)
Artificial Intelligence AI Infrastructure Tech News AI Paper Summary Applications Technology Editors Pick Language Model Large Language Model Machine Learning New Releases Open Source Physical AI Software Engineering Staff Linear attention replaces the unbounded KV cache of softmax attention with a fixed-size recurrent state. This cuts sequence mixing to linear time and decoding to constant memory. The hard part is not what to forget. It is how to edit a compressed memory without scrambling existing associations. NVIDIA has released Gated DeltaNet-2 , a linear attention layer that targets that bottleneck. The model decouples the active memory edit into two channel-wise gates. It is trained at 1.3B parameters on 100B FineWeb-Edu tokens. It outperforms Mamba-2, Gated DeltaNet, KDA, and Mamba-3 across the researchs benchmark suite. The scalar gate problem in delta-rule models A recurrent linear attention layer stores a matrix state S t and reads it with the query. DeltaNet adds an active edit by subtracting the value currently associated with the current key. It uses a scalar step size β t to control how much to overwrite. Mamba-2 adds a data-dependent scalar decay α t for global forgetting. Gated DeltaNet combined both operations, but both gates remained scalar per head. Kimi Delta Attention (KDA) refines the decay side. It replaces the scalar α t with a channel-wise vector. KDA still keeps a single scalar β t for the active edit. That scalar controls two different things at once. It decides how much old content to erase on the key side. It also decides how much new content to commit on the value side. These two decisions act on different axes of the state. Tying them together is a modeling restriction, not a property of the delta rule. Gated Delta Rule-2: two gates instead of one Gated DeltaNet-2 separates the two decisions through Gated Delta Rule-2. It introduces a channel-wise erase gate b t ∈ [0,1] d k on the key axis. It also introduces a channel-wise write gate w t ∈ [0,1] d v on the value axis. Both gates are produced by sigmoid projections of the token representation. The update applies decay before the active edit. Written compactly, the recurrence is: S t = (I − k t (b t ⊙ k t ) ⊤ ) D t S t−1 + k t (w t ⊙ v t ) ⊤ Here D t = Diag(α t ) is the channel-wise decay carried over from KDA. The left factor of the erase matrix stays k t , preserving the delta-rule write direction. The right factor becomes b t ⊙ k t , making the read direction channel-selective. The write term k t z t ⊤ uses z t = w t ⊙ v t , making the value update channel-selective. When both gates collapse to the same scalar β t , the update recovers KDA exactly. When the decay α t also collapses to a scalar, it recovers Gated DeltaNet. Both prior models are preserved as tied subspaces of the new update. In the fast-weight view, Gated Delta Rule-2 is one online gradient step on a local regression loss. The decayed state stays close to memory, while the residual edit uses gated read and gated write targets. Chunkwise training and gate-aware backward The recurrence admits a chunkwise WY form that matches the structure used by KDA. Cumulative channel-wise decay is absorbed into the two factors of each rank-one erase. The per-chunk update becomes a product of asymmetric matrices of the form I − k̄ r ē r ⊤ . The implementation uses chunk size C = 64 with fused Triton kernels. For the backward pass, the scalar shortcut used by KDA no longer applies. The write side contains a different diagonal gate over value channels. The erase side contains a different diagonal gate over key channels. So the gate factors must appear inside the dot products that accumulate gradients. The paper derives this gate-aware vector-Jacobian product explicitly. On Hopper GPUs, the fused WY backward kernel is restricted to two and four warps to avoid a Triton WGMMA layout assertion. Block design and hybrid model Gated DeltaNet-2 is used as the recurrent token mixer in a standard Transformer-style block. Query and key paths use linear projection, short causal convolution, SiLU, and L2 normalization. The value path uses linear projection, short convolution, and SiLU. The decay α t , erase gate b t , and write gate w t come from separate linear branches. The recurrent output is RMS-normalized, multiplied by a SiLU output gate, and projected back. A hybrid variant inserts Sliding-Window Attention (SWA) after the recurrent mixer. A repeated cell contains Gated DeltaNet-2, an MLP, SWA, and another MLP. SWA handles exact local interactions, while the recurrent mixer compresses long histories. The hybrid retains linear sequence scaling with a bounded attention cache. Results at 1.3B parameters All models are 1.3B parameters trained on 100B FineWeb-Edu tokens. Parameter count and recurrent state size are matched across models. The recurrent state holds 262,144 floats per layer per batch element. Training length is 4K tokens, and hybrid models use a 2K SWA window. The Mamba-3 MIMO baseline uses rank R = 4 . On language modeling and commonsense reasoning, Gated DeltaNet-2 has the best average in both settings. The recurrent model averages 53.11 across LAMBADA and the reasoning suite. That sits above Mamba-3 MIMO at 52.39 and KDA at 52.28. In the hybrid setting, Gated DeltaNet-2 averages 53.97 against Mamba-3 MIMO at 52.72. Since recurrent state size is matched, the gain points to the update rule, not more memory. The clearest gains appear on RULER long-context retrieval. In the recurrent setting, S-NIAH-2 at 4K rises from 89.0 (KDA) to 93.0. S-NIAH-3 at 2K jumps from 63.2 (KDA) to 89.8. MK-NIAH-1 at 4K climbs from 28.0 (KDA) to 37.8. On real-world retrieval (SWDE, SQuAD, FDA, TriviaQA, NQ, DROP), Gated DeltaNet-2 also leads both settings. The recurrent average is 29.88 and the hybrid average is 42.28. Marktechpost’s Visual Explainer Gated DeltaNet-2 · Quickstart 01 / 08 NVIDIA · 2026 Gated DeltaNet-2 Decoupling Erase and Write in Linear Attention. A delta-rule recurrent attention layer with channel-wise erase and write gates. PyTorch Triton kernels 1.3B params 100B FineWeb-Edu tokens Authors Ali Hatamizadeh, Yejin Choi, Jan Kautz Repo github.com/NVlabs/GatedDeltaNet-2 License NVIDIA Source Code License-NC Step 01 · The Idea Two gates instead of one scalar Linear attention compresses an unbounded KV cache into a fixed-size recurrent state. Editing this memory without scrambling existing associations is the hard part. The Problem Prior delta-rule models (Gated DeltaNet, KDA) tie erasing old content and writing new content to one scalar gate β_t . The Fix Split it: a channel-wise erase gate b_t on the key axis, and a channel-wise write gate w_t on the value axis. Erase gate picks which key-side coordinates of the decayed state are read and removed. Write gate picks which value-side coordinates of the new content are committed. Channel-wise decay is inherited from KDA for fine-grained global forgetting. Step 02 · The Update Rule The Gated Delta Rule-2 With erase gate b_t ∈ [0,1]^{d_k} , write gate w_t ∈ [0,1]^{d_v} , and channel-wise decay D_t = Diag(α_t) , the recurrent state evolves as: S_t = (I − k_t (b_t ⊙ k_t) ⊤ ) D_t S_{t−1} + k_t (w_t ⊙ v_t) ⊤ Recovers KDA exactly when both gates collapse to the same scalar. Recovers Gated DeltaNet when the decay also collapses to a scalar. Trains efficiently via a chunkwise WY form with channel-wise decay absorbed into asymmetric erase factors. Step 03 · Get the Code Clone the repo and build the environment The official PyTorch implementation ships with a Dockerfile, training scripts, and the lit_gpt model definitions. git clone https://github.com/NVlabs/GatedDeltaNet-2.git cd GatedDeltaNet-2 # build the environment from the provided Dockerfile docker build -t gdn2 . docker run --gpus all -it —ipc=host -v $PWD:/workspace gdn2 Repo layout lit_gpt/ model code · scripts/ launchers &midd