메뉴
HN
Hacker News 13일 전

트랜스포머의 자기회귀 예측과 KV 캐시

IMP
7/10
핵심 요약

대규모 언어 모델(LLM)이 토큰을 생성하는 핵심 원리인 '자기회귀 다음 토큰 예측'과 추론 속도를 획기적으로 높이는 'KV 캐시' 최적화 기법을 설명합니다. 이 과정을 통해 모델이 텍스트를 벡터로 변환하여 디코더 블록을 거치고 다음 단어를 예측하며, 이후 반복적인 연산을 줄여 긴 문장을 빠르게 생성할 수 있게 되는 원리를 이해할 수 있습니다.

번역된 본문

트랜스포머에서의 자기회귀 다음 토큰 예측 및 KV 캐시 | Frederik vom Lehn | 7분 읽기 · 2일 전

LLM에서 토큰 생성 속도를 높이는 최적화 기술을 이해해 봅시다.

큰 그림 (The Big Picture) 어텐션 헤드(attention heads), KV 캐시, 그리고 텍스트 생성의 메커니즘을 자세히 파고들기 전에, 자기회귀 언어 모델(autoregressive language model)이 실제로 무엇인지 한발짝 물러서서 전체적으로 살펴보는 것이 좋습니다.

프롬프트는 "How are you?"라는 일반 텍스트로 입력됩니다. 토크나이저(tokenizer)는 이를 어휘 ID(여기서는 3, 7, 1, 9)로 자르고, 맨 앞에는 BOS(시작 토큰, "beginning of sequence")를 추가합니다. 각 ID는 단순히 룩업 테이블(lookup table)을 가리키는 정수입니다. 이 테이블은 (어휘 크기, c) 모양을 가진 학습된 행렬로, 각 행은 어휘에 있는 하나의 토큰에 대한 임베딩 벡터(embedding vector)입니다. 입력받은 5개의 ID에 해당하는 행을 선택하면 (5, 4) 크기의 행렬 X가 생성됩니다. 이는 5개의 토큰이 각각 4차원 임베딩 공간에 존재함을 의미합니다. 바로 이 지점에서 텍스트는 기호의 세계를 떠나 벡터의 세계로 들어갑니다. (여기서는 예시를 위해 가상의 작은 차원을 사용합니다.)

이제부터 X는 디코더 블록(decoder blocks) 스택을 통과합니다. 각 블록은 동일한 구조(멀티 헤드 셀프 어텐션 + MLP)를 가지며, 입력을 동일한 형태의 정제된 (5, 4) 표현으로 변환합니다. 깊은 트랜스포머를 학습 가능하게 만드는 핵심은 모든 블록을 감싸는 잔차 연결(residual connection)입니다. 각 블록은 입력을 대체하는 것이 아니라 그 위에 더하게 됩니다 (X₁ = X + block_output). 정보는 각 레이어가 덮어쓰는 대신 편집하는 연속적인 '잔차 스트림(residual stream)'을 따라 흐릅니다. 이러한 블록 3개를 쌓으면 최종 은닉 상태(hidden state)인 X₃을 얻게 됩니다.

마지막 단계는 첫 번째 단계를 역전시키는 것입니다. 언임베딩 행렬(unembedding matrix, 입력과 출력 어휘가 동일하므로 보통 룩업 테이블의 전치 행렬 사용)은 X₃의 각 행을 다시 어휘 공간으로 투영하여 (5, 12) 크기의 로짓(logits) 행렬을 생성합니다. 이는 모든 위치에서 모든 어휘 토큰에 대한 점수입니다. '다음 토큰 생성(next-token generation)'을 위해서는 마지막 행만 중요합니다. 이 행렬의 argmax(최댓값 인덱스)가 모델이 다음에 말하고자 하는 토큰입니다. 여기서는 토큰 ID 5가 해당됩니다.

이것이 전방 패스(forward pass)가 고차원적으로 일어나는 전체 과정입니다. 이 글의 나머지 부분에서는 이러한 디코더 블록 중 하나 내부에서 발생하는 일과, 긴 시퀀스 생성을 가능하게 하는 최적화 기법인 'KV 캐싱'에 대해 자세히 알아보겠습니다.

이제 한 발 더 들어가서 단일 디코딩 레이어 내의 첫 번째 전방 패스 동안 한 레이어 내부에서 어떤 일이 일어나는지 살펴보겠습니다.

사전 채우기 전방 패스 (The Prefill Forward Pass) 언어 모델이 단 하나의 새로운 토큰을 생성하기 전에 먼저 프롬프트를 처리해야 합니다. 이 단계(사전 채우기, prefill)는 전체 입력 시퀀스를 네트워크를 통해 하나의 병렬된 전방 패스로 실행합니다. 이 작업의 역할은 두 가지입니다. 첫 번째 예측 토큰을 생성하는 것과, 이후의 디코딩 단계 비용을 낮게 유지하기 위해 KV 캐시를 채우는 것입니다.

은닉 차원 c = 4, 2개의 어텐션 헤드, 12개의 토큰으로 이루어진 작은 모델에서 5개의 토큰으로 된 프롬프트에 어떤 일이 발생하는지 단계별로 살펴보겠습니다.

토큰에서 Q, K, V로 (From tokens to Q, K, V) 입력 X는 (5, 4) 행렬로 들어옵니다. 즉, 5개의 토큰이 각각 룩업 테이블에서 가져온 4차원 임베딩으로 표현됩니다. 학습된 세 개의 투영 행렬(projection matrices)인 Wq, Wk, Wv(각각 (4, 4) 모양)가 X를 쿼리(Query), 키(Key), 밸류(Value) 행렬인 Q, K, V(모두 (5, 4) 모양)로 변환합니다.

헤드가 2개이므로, 각 (5, 4) 행렬은 열 방향으로 두 개의 (5, 2) 조각으로 나뉘며, 헤드당 하나의 조각씩 할당됩니다. 각 헤드는 자신만의 2차원 부분 공간에서 독립적으로 어텐션을 계산합니다.

헤드 내부의 어텐션 (Attention within a head) 단일 헤드 내에서 어텐션은 가중치가 적용된 룩업(lookup) 과정입니다. 해당 헤드의 Q 조각(5, 2)에 K 조각의 전치 행렬을 곱하여 (5, 5) 크기의 어텐션 점수 행렬을 생성합니다. 이는 모든 토큰의 쿼리가 모든 토큰의 키와 내적(dot product)된 결과입니다. 스케일링(scaling)과 소프트맥스(softmax)를 적용한 후(자기회귀 모델이므로 토큰 t는 t보다 큰 토큰을 보지 않도록 하는 인과 마스크, causal mask도 함께 적용됨), 이 행렬의 각 행은 "과거의 어떤 토큰으로부터 정보를 가져와야 하는가?"에 대한 확률 분포가 됩니다. 그런 다음 이 가중치들은 해당 헤드의 V 조각(5, 2)에 곱해져 (5, 2) 모양의 헤드 출력을 생성합니다. 이제 각 토큰은 허용된 위치의 밸류 벡터들이 문맥을 인식한 상태로 혼합된 값을 갖게 됩니다.

연결 (Concatenation)

원문 보기
원문 보기 (영어)
Autoregressive next token prediction & KV Cache in transformers Frederik vom Lehn 7 min read · 2 days ago -- Listen Share Understand the optimization technique in LLMs to speed up token generation The Big Picture Before we dive into attention heads, KV caches, and the mechanics of generation, it helps to zoom out and see what an autoregressive language model actually is at a glance. A prompt enters as plain text: “How are you?”. A tokenizer chops it into vocabulary IDs — here 3, 7, 1, 9 , prefixed with a BOS ("beginning of sequence") token. Each ID is just an integer pointing into a lookup table : a learned matrix of shape (vocab_size, c) where every row is the embedding vector for one token in the vocabulary. Selecting the rows for our 5 input IDs produces X , a (5, 4) matrix, five tokens, each living in a 4-dimensional embedding space. This is where text leaves the world of symbols and enters the world of vectors. We use toy dimensions for our examples here. From here, X flows through a stack of decoder blocks . Each block is the same architecture, multi-head self-attention followed by an MLP, and each block transforms its input into a refined (5, 4) representation of the same shape. The trick that makes deep transformers trainable is the residual connection wrapped around every block: instead of replacing the input, each block adds to it ( X₁ = X + block_output ). Information flows along a continuous "residual stream" that each layer edits rather than overwrites. Stack three of these and you get X₃ , the final hidden state. The last step inverts the first. The unembedding matrix, often the lookup table transposed, since input and output vocabularies are the same, projects each row of X₃ back into vocabulary space, producing a (5, 12) logits matrix: a score for every vocabulary token at every position. For next-token generation, only the last row matters. Its argmax is the token the model wants to say next. Here, that's token ID 5. That’s the whole forward pass at altitude. The rest of this article zooms in on what happens inside one of those decoder blocks and on the optimization, KV caching , that makes generating long sequences feasible at all. Let's zoom in and check what happens inside one layer during the first forward pass inside a single decoding layer. The Prefill Forward Pass Before a language model can generate a single new token, it has to process the prompt. This step ( prefill) runs the entire input sequence through the network in one parallel forward pass. Its job is twofold: produce the first predicted token, and populate the KV cache so that subsequent decode steps stay cheap. Let’s walk through what happens to a 5-token prompt in a tiny model with hidden dimension c = 4 , 2 attention heads, and a vocabulary of 12 tokens. From tokens to Q, K, V The input X arrives as a (5, 4) matrix: 5 tokens, each represented by a 4-dimensional embedding pulled from the lookup table. Three learned projection matrices Wq , Wk , Wv , each of shape (4, 4) , transform X into the query, key, and value matrices Q , K , V , all of shape (5, 4) . Because we have 2 heads, each (5, 4) matrix is split column-wise into two (5, 2) slices, one slice per head. Each head will compute attention independently in its own 2-dimensional subspace. Attention within a head Inside a single head, attention is a weighted lookup. The head’s Q slice (5, 2) is multiplied by the transpose of its K slice to produce a (5, 5) matrix of attention scores — every token's query dotted with every token's key. After scaling and softmax (and a causal mask, since this is an autoregressive model, token t must not see tokens > t ), each row of this matrix becomes a probability distribution over "which past tokens should I pull information from." These weights then multiply the head’s V slice (5, 2) , yielding the head's output of shape (5, 2) : each token now holds a context-aware mix of value vectors from its allowed positions. Concatenation and projection The two heads’ outputs are concatenated back into a (5, 4) matrix, then passed through an output projection (4, 4) . The result, X' , is again (5, 4), same shape as the input, but every row now reflects information gathered from across the sequence. The MLP Each token’s vector is then sent independently through a two-layer MLP. W_up of shape (4, 8) expands each row to 8 dimensions, GeLU adds non-linearity, and W_down of shape (8, 4) projects back down. The output X₁ is (5, 4) and in a real model, this would feed into the next transformer block. Stack a few of these (here, 3 layers) and you have the full forward pass. Lets assume this is the final layer here. Logits and the first prediction After the final layer, the (5, 4) hidden states are multiplied by the unembedding matrix (12, 4).T to produce logits of shape (5, 12) , a score for every vocabulary token at every position. For generation, only the last row matters: it tells us what the model thinks comes after token 5. Argmax (or sampling) over that row gives us the first generated token. In our case token ID 5. What the cache holds onto Here’s the quiet but crucial part: during this single pass, every layer computed K and V of shape (5, 4) for the prompt. Those tensors get stored . They are everything future tokens will ever need to know about the prompt at this layer. The embeddings, the queries, the MLP activations — all discarded. From here on, generation moves into decode mode, processing one new token at a time and reading from this cache instead of redoing the work. So now let’s understand the big picture, what happens when we generate the next token with KV cache. The Decode Step with KV Cache Once prefill is done, the model switches into decode mode . Every subsequent token is generated by a forward pass that looks structurally similar to prefill — but operates on just one row at a time, leaning on the KV cache to remember everything that came before. Let’s continue our example. Prefill predicted token 5, so we now feed token 5 back in as the input for the next step. One token in, one token out The new input X is a single row of shape (1, 4) which is just token 5's embedding, looked up from the same table used during prefill. The previous 5 tokens of the prompt are not re-fed. They don't need to be: everything the model will ever need from them at this layer is already sitting in the cache. Multiplying this (1, 4) row by Wq , Wk , Wv (each still (4, 4) ) yields a fresh Q , K , and V , each of shape (1, 4) . Only the new token gets its query, key, and value computed. Appending to the cache The newly computed K and V rows are appended to the cached K and V matrices from the previous step. The cache, which held (5, 4) after prefill, now holds (6, 4), five rows from the prompt plus one fresh row for token 5. This concatenated tensor is what attention will read against. Attention against the cache Splitting across heads as before, each head now has a query of shape (1, 2) and a full key/value matrix of shape (6, 2) . The dot product Q · K^T produces a (1, 6) score row — token 5's attention weights over all 6 positions, itself included. No causal mask is needed here: every cached position is in the past by construction, so every score is valid. Softmax turns this into a probability distribution, and the weighted sum over V (6, 2) produces a (1, 2) head output. Concatenating both heads gives (1, 4) , and the output projection (4, 4) yields X' of shape (1, 4) . Why this matters Compare the shapes. Prefill processed a (5, 4) input and ran every operation on 5 rows in parallel, which is necessary to populate the cache. Decode processes a (1, 4) input and runs every operation on a single row, with the cache silently providing the historical context where it's needed (inside attention). The MLP, the projections, the unembedding, all do 1/N of the work they'd do in a no-cache forward pass. This is the whole reason long-context