https://github.com/jhong92-pro/llm-batch-invariance/blob/main/llm-batch-invariance.ipynb
결합법칙
LLM의 대부분 연산은 matmul / reduction인데, GPU에서는 병렬 합산이 일어나며 합치는 순서가 바뀌면 반올림 오차가 달라질 수 있다
import torch
a = torch.Tensor([1e16])
b = torch.Tensor([-1e16])
c = torch.Tensor([1])
print("(a + b) + c :" , (a + b) + c)
print("a + (b + c) :" , a + (b + c))
(a + b) + c : tensor([1.])
a + (b + c) : tensor([0.])
수학적으로는 둘 다 1이어야 할 것 같지만, floating-point에서는 반올림 때문에 결과가 달라질 수 있음
Batch invariance
M, K, N = 512, 512, 512
A0 = torch.randn(M, K, device=device, dtype=dtype)
B0 = torch.randn(K, N, device=device, dtype=dtype)
C_single = A0 @ B0
def make_batched_inputs(batch_size: int):
A = torch.randn(batch_size, M, K, device=device, dtype=dtype)
B = torch.randn(batch_size, K, N, device=device, dtype=dtype)
A[0].copy_(A0)
B[0].copy_(B0)
return A, B
batch_sizes = [1,2,3,4,5,6,7,8,9,10,100,1000]
for B in batch_sizes:
A, Bmat = make_batched_inputs(B)
C_batch = torch.bmm(A, Bmat)
diff = C_batch[0] - C_single
max_abs = diff.abs().max().item()
l2 = torch.norm(diff).item()
same = torch.allclose(C_batch[0], C_single, atol=1e-5, rtol=1e-5)
print(f"Batch {B}: max_abs_diff={max_abs}, l2_diff={l2}, allclose={same}")
Batch 1: max_abs_diff=0.0, l2_diff=0.0, allclose=True
Batch 2: max_abs_diff=0.0001068115234375, l2_diff=0.005174573510885239, allclose=False
Batch 3: max_abs_diff=0.0001068115234375, l2_diff=0.005174573510885239, allclose=False
...
Batch 1000: max_abs_diff=0.0001068115234375, l2_diff=0.005174573510885239, allclose=False
Batch=1 이후로는 결과가 약간 다르다
커널/연산 전략 차이로 인한 FP 오차차이가 생기기 때문이다
torch.profiler로 operation을 확인하면 아래와 같다
[single (bmm)] GEMM-related ops: ['aten::bmm', 'ampere_sgemm_32x32_sliced1x4_nn']
[batched (bmm, B=2)] GEMM-related ops: ['aten::bmm', 'ampere_sgemm_128x128_nn']
[batched (bmm, B=200)] GEMM-related ops: ['aten::bmm', 'ampere_sgemm_128x128_nn']
간단한 transformer를 만들어서 실험해봐도 layer output이 달라진다(깃허브 코드 참고)
B=1: max_abs=4.882812e-03, l2=1.297607e-01 (worst layer_3)
B=4: max_abs=5.859375e-03, l2=1.132202e-01 (worst layer_3)
B=8, 16: exact match (all zeros)
B=32: max_abs=3.906250e-03, l2=7.080078e-02 (worst layer_1)
꼭, 배치가 커질수록 크게 달라지지는 않는다
GPT-2 테스트
- 모델: gpt2-medium
- decoding: greedy(argmax)
- baseline: batch=2
- 비교: batch=3,4,7,8,16,32
- 배치 구성: 0번은 동일 prompt, 나머지는 distractor(다른 질문들)
def build_batch(batch_size: int, prompt: str):
"""index0 : prompt / rest : DISTRACTORS"""
if batch_size == 1:
return [prompt]
need = batch_size - 1
return [prompt] + DISTRACTORS[:need]
"Tell me about yourself.\nanswer : " input에서 바로 flip이 관측된다
=== prompt 1/103, Flip checks (vs baseline) ===
Baseline(B=2) : I'm a professional artist. I'm a professional ...
batch=3: FLIPS at steps: [ 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18]...
batch=4: FLIPS at steps: [ 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18]...
batch=7: FLIPS at steps: [ 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18]...
...
[Baseline B=2 top-k]
'\xa0': -90.499222
' young': -90.501602
' writer': -90.504623
' woman': -90.577217
' student': -90.678535
' man': -90.699829
' guy': -90.841461
' professional': -90.907951
' twenty': -91.086029
' 22': -91.121254
[Batch B=3 top-k]
' young': -90.602028
' writer': -90.617401
'\xa0': -90.664429
' woman': -90.695312
' student': -90.778191
' man': -90.795197
' guy': -90.942856
' professional': -91.005234
' twenty': -91.189301
' 22': -91.211235
...
분류 모델도 애매한 샘플에서는 top-1과 top-2의 차이가 작아질 수 있지만, 보통 결정은 한 번만 내리고 그 결과가 다음 입력으로 피드백되지 않는다. 반면 LLM은 토큰을 여러 번 연속으로 선택해야 하고, next-token 분포에서 top-1/top-2 마진이 작은 순간이 자주 나타난다. 그래서 아주 작은 수치 차이(배치/커널/반올림)가 특정 스텝에서 top-1을 바꾸면 토큰 flip이 발생하고, 그 flip이 이후 생성 과정 전체로 누적 및 증폭되기 쉽다.
실무적으로 왜 중요할까?
- 요청 단위로 caching할 때, 항상 같은 답을 기대하면 깨질 수 있음
- golden set 평가를 배치로 돌릴 때, 배치 크기나 배치 구성(정렬/섞임)에 따라 결과가 미세하게 달라짐
- 실제로는 수치 오차/커널 선택 문제지만, 관측되는 현상은 '남이 뭐 넣었냐에 따라 내 답이 바뀐다'로 보임
'인공지능' 카테고리의 다른 글
| 사내 교육용 자연어처리 ppt (0) | 2025.10.21 |
|---|---|
| DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification (0) | 2025.05.18 |
| Kernel Ridge Regression (0) | 2025.05.06 |
| Dataset Distillation (0) | 2025.05.01 |
| 모델 경량화 프루닝 (Pruning) - Structured (0) | 2025.03.30 |