엔비디아, 기억 수정 최적화 선형 어텐션 모델 공개
엔비디아가 기존 선형 어텐션 모델들의 한계였던 '기억 덮어쓰기' 문제를 해결한 'Gated DeltaNet-2'를 발표했습니다. 이 모델은 기존의 단일 게이트를 키(Key) 축의 '삭제 게이트'와 값(Value) 축의 '쓰기 게이트'로 분리하여 메모리를 더 정교하게 제어합니다. 그 결과, Mamba-2, Mamba-3 등 기존 최고 성능 모델들을 벤치마크에서 모두 능가하는 우수한 성능을 입증했습니다.
인공지능(AI) 인프라 기술 뉴스, 편집자 추천, 대규모 언어 모델(LLM), 머신러닝, 오픈소스, 소프트웨어 엔지니어링 등
선형 어텐션(Linear Attention)은 소프트맥스 어텐션(Softmax Attention)의 무한한 크기의 KV 캐시를 고정 크기의 순환 상태(Recurrent State)로 대체합니다. 이를 통해 시퀀스 혼합(Sequence mixing) 시간을 선형 시간으로 줄이고, 디코딩 시 메모리 사용량을 상수(O(1))로 낮춥니다. 하지만 가장 어려운 문제는 '무엇을 잊을 것인가'가 아니라, 기존에 학습된 연관성을 파괴하지 않고 압축된 메모리를 어떻게 수정할 것인가입니다.
엔비디아는 이러한 병목 현상을 해결하기 위해 'Gated DeltaNet-2'라는 선형 어텐션 레이어를 공개했습니다. 이 모델은 능동적인 메모리 편집(Active memory edit) 과정을 두 개의 채널 단위 게이트(Channel-wise gate)로 분리합니다. 100B(천억) 개의 FineWeb-Edu 토큰으로 1.3B(13억) 파라미터 모델을 학습시켰으며, 연구진의 벤치마크 스위트 전반에서 Mamba-2, Gated DeltaNet, KDA(Kimi Delta Attention), Mamba-3를 모두 능가하는 성능을 보였습니다.
델타 규칙(Delta Rule) 모델의 스칼라 게이트(Scalar Gate) 문제 순환 선형 어텐션 레이어는 행렬 상태 S_t를 저장하고 쿼리(Query)를 통해 이를 읽어옵니다. DeltaNet은 현재 키(Key)에 연결된 기존 값을 빼는 방식으로 능동적인 편집을 추가합니다. 이때 스칼라 스텝 크기 β_t를 사용하여 얼마나 덮어쓸지를 결정합니다. Mamba-2는 전역적인 망각(Global forgetting)을 위해 데이터 종속적인 스칼라 감쇠(Decay) α_t를 추가했습니다.
Gated DeltaNet은 이 두 가지 연산을 결합했지만, 여전히 헤드 당 스칼라 게이트 방식을 사용했습니다. KDA는 감쇠 측면을 개선하여 스칼라 α_t를 채널 단위 벡터로 교체했습니다. 하지만 KDA 역시 능동적 편집을 위해 단일 스칼라 β_t를 유지합니다.
이 단일 스칼라는 두 가지 다른 역할을 동시에 제어합니다. 하나는 키(Key) 축에서 얼마나 많은 기존 콘텐츠를 지울지 결정하는 것이고, 다른 하나는 값(Value) 축에서 얼마나 많은 새로운 콘텐츠를 기록할지 결정하는 것입니다. 이 두 결정은 상태의 다른 축에 영향을 미칩니다. 이를 하나로 묶는 것은 델타 규칙의 속성이 아니라 모델링의 제약 조건일 뿐입니다.
게이트 델타 규칙-2 (Gated Delta Rule-2): 하나 대신 두 개의 게이트 Gated DeltaNet-2는 'Gated Delta Rule-2'를 통해 이 두 가지 결정을 분리합니다. 이 규칙은 키 축에 대해 채널 단위의 삭제 게이트(Erase gate) b_t ∈ [0,1]^(d_k)를 도입합니다. 또한 값 축에 대해 채널 단위의 쓰기 게이트(Write gate) w_t ∈ [0,1]^(d_v)를 도입합니다. 두 게이트는 모두 토큰 표현(Token representation)의 시그모이드(Sigmoid) 프로젝션을 통해 생성됩니다.
업데이트 시 능동적 편집 전에 감쇠(Decay)를 적용합니다. 간결하게 표현하면 순환 식은 다음과 같습니다: S_t = (I − k_t (b_t ⊙ k_t)^⊤) D_t S_{t−1} + k_t (w_t ⊙ v_t)^⊤
여기서 D_t = Diag(α_t)는 KDA에서 가져온 채널 단위 감쇠입니다. 삭제 행렬(Erase matrix)의 왼쪽 요인은 k_t로 유지되어 델타 규칙의 쓰기 방향을 보존합니다. 오른쪽 요인은 b_t ⊙ k_t가 되어 읽기 방향을 채널 선택적으로 만듭니다. 쓰기 항 k_t z_t^⊤는 z_t = w_t ⊙ v_t를 사용하여 값 업데이트를 채널 선택적으로 만듭니다.
두 게이트가 동일한 스칼라 β_t로 축소되면 업데이트는 KDA와 정확히 일치합니다. 감쇠 α_t도 스칼라로 축소되면 기존 Gated DeltaNet을 복원합니다. 두 이전 모델은 모두 새로운 업데이트의 제약된 부분 공간(Tied subspace)으로 보존됩니다.
Fast-weight 관점에서 Gated Delta Rule-2는 국소 회귀 손실(Local regression loss)에 대한 하나의 온라인 경사 하강법(Online gradient step) 단계입니다. 감쇠된 상태는 메모리에 가깝게 유지되고, 잔차 편집은 게이트가 적용된 읽기 및 쓰기 타겟을 사용합니다.
청크 단위 학습(Chunkwise Training) 및 게이트 인식 역전파(Backward Pass) 이 순환 구조는 KDA에 사용된 구조와 일치하는 청크 단위 WY 형식을 허용합니다. 누적 채널 감쇠는 각 계수-1(Rank-one) 삭제의 두 요인에 흡수됩니다. 청크별 업데이트는 I − k̄_r ē_r^⊤ 형태의 비대칭 행렬의 곱이 됩니다. 구현에는 융합된 Triton 커널을 사용하여 청크 크기 C = 64로 설정되었습니다.
역전파(Backward pass)의 경우, KDA에서 사용하던 스칼라 지름길(Shortcut)을 더 이상 사용할 수 없습니다. 쓰기 측에는 값 채널에 대한 다른 대각선 게이트가 포함되어 있고, 삭제 측에는 키 채널에 대한 다른 대각선 게이트가 포함되어 있습니다. 따라서 그래디언트를 누적하는 내적(Dot product) 내부에 게이트 요인이 반드시 나타나야 합니다. 논문에서는 이 게이트 인식 벡터-야코비안 곱(Vector-Jacobian product)을 명시적으로 도출했습니다. Hopper GPU에서 융합된 WY 역방향 커널은 Triton WGMMA 레이아웃 어설션 오류를 피하기 위해 2개 또는 4개의 워프(Warp)로 제한됩니다.
블록 설계 및 하이브리드 모델 Gated DeltaNet-2는 표준 트랜스포머(Transformer) 아키텍처에서 순환 토큰 믹서(Recurrent token mixer)로 사용됩니다.