logo
|
Blog
    deeplearning

    KV Caching for LLM inference speed

    논문은 아니고 메모
    김호진's avatar
    김호진
    May 26, 2025
    KV Caching for LLM inference speed

    KV caching

    LLM의 구현 방식이 Transformer 구조, Auto-regressive next token prediction, 토큰간에 영향을 주지않는 Norm, Positional encoding 이라고 가정한다.
     
    1. Transformer 앞뒤의 연산, FeedForward, LayerNorm, token to embedding 등의 연산은 각 토큰별로 계산된다 - 캐싱이 가능한지 고려할 때 크게 신경쓰지 않아도 된다.
    1. Decoder-only causal transformer로 auto-regressive하게 생성하는 경우, 만약 입력으로 “I got it and”가 들어왔다고 가정해보자. (K/V cacing을 사용하지 않을 때)
    1. 그럼 self-attention에서 1번째 index의 Q, K, V 값은 이전 레이어의 1st index 임베딩 값에만 의존한다. (단순 nn.Linear)
      1. 그리고 attention 연산이 적용된 후 출력되는 값은 causal masking에 의해서 0번째, 1번째 index의 Q, K, V에만 의존한다 (정확히는 이전 토큰들의 Q, K로 계산한 attention score)
        다시말해 모든 레이어에서 i-th index의 출력 임베딩은 causal 한 특성에 의해 현재와 이전 토큰들에만 의존한다.
        notion image
    1. 그럼 만약 한 토큰을 생성한 뒤 “I got it and then”이 다시 모델에 들어왔을 때
      1. 마지막 토큰 이전의 토큰들의 레이어별 output은
        어짜피 새로 추가된 토큰에 영향을 받지 않기 때문에, 이전에 사용한 값을 그대로 사용할 수 있다.
    1. 따라서 각 레이어 별로 → O0..4O_{0..4}O0..4​을 캐싱해둔다면, O4O_4O4​만 추가로 이전 레이어의 4th index 출력값을 사용해서 연산하면 된다. (output은 OiO_iOi​라고 하겠습니다.)
      1. 즉, 이전 레이어의 입장에서도 N2N^2N2를 모두 계산해야하는게 아니라, NNN만 계산해서 다음 레이어로 넘기면 된다. 어짜피 다음 토큰 예측에도 그것만 사용한다.
    1. 여기서 O(S∗d2)O(S *d^2)O(S∗d2)를 구하기위한 self attention 연산은?
      1. O(1∗d2)O(1*d^2)O(1∗d2)와 O(S∗8∗d2)O(S * 8 * d^2)O(S∗8∗d2) 간의 mult, softmax, divice 연산 → attention score
      2. attention score를 이용한 O(1∗8∗d2)O(1 * 8 *d^2)O(1∗8∗d2) weighed sum →
      3. → 를 구하기 위해 필요한 query는 i-th quey일 뿐이다.
        그럼 K,V를 캐싱했다고 했을 때 첫번째 토큰이 아닌 다음 토큰 생성 부터는 마지막 토큰에 대한 embedding만 다음 레이어로 넘기면서 연산에 사용하면 된다(시간이 에서 이 된다). 첫번째 토큰을 생성할 때는 K/V가 필요하기 때문에 모든 토큰의 임베딩이 다 필요하지만.
     
    따라서 K/V를 캐싱함으로써 연산이 줄어드는 부분은 크게 2가지 이다.
    Sequence length = N, dimension = d
    1. K와 V행렬의 Linear Porjection, Feed-Forward 및 메모리 접근
      1. 각 레이어에서 Q, K, V를 계산하기 위해 필요한 Linear projection 연산
        →
         
        Feedforward 생략 : (hidden dimension이 mult 4라고 가정)
        →
    1. Self-Attention 연산
        • 연산 : →
        • Weighted Sum 연산 : →
        즉 N^2에서 N으로 줄어든다. context가 dimension보다 훨씬 긴 경우 여기가 더 중요
     
    결국 연산량이 꽤 줄어든다. GPT api 비용이 input token이 output token대비 절반일 수 있는 이유가 된다.
     
    KV caching을 설명하는 한 영상에서 어짜피 한번 토큰을 생성한 뒤라면 그 이후로는 마지막 index의 embedding (1, d)만 모델의 입력으로 사용된다고 설명을 시작해서 K,V를 caching 할 수 있다는걸 봤는데, 이해가 잘 되지 않는다. KV caching을 적용하기 때문에 마지막 토큰만 모델에 흘려보낼 수 있는거라 순서가 바뀐거 같은데
     
     
     
     
     
     
    Share article

    Kim Hojin

    RSS·Powered by Inblog