인공지능/TVM

TVM - Tensor Expression Optimization

NickTop 2024. 5. 15. 13:14

최적화를 할 때 relay를 TE로 가져와 TE를 직접 다루지는 않는 것 같습니다.

TE에서 어떠한 방식으로 최적화가 이루어지는지 살펴봅시다

 

https://tvm.apache.org/docs/how_to/optimize_operators/opt_gemm.html#sphx-glr-how-to-optimize-operators-opt-gemm-py

 

How to optimize GEMM on CPU — tvm 0.17.dev0 documentation

Table of Contents Docs > How To Guides > Optimize Tensor Operators > How to optimize GEMM on CPU Edit on GitHub Note This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally. How to optimize GEM

tvm.apache.org

 

두 행렬을 naive 하게 곱하면 다음과 같습니다

for y in range(1024):
    for x in range(1024):
        C[y][x] = 0
        for k in range(1024):
            C[y][x]+=A[y][k]*B[k][x]

TE로 스케줄링한 결과를 봅시다

import tvm
from tvm import te

# The size of the matrix
# (M, K) x (K, N)
M = 1024
K = 1024
N = 1024

# The default tensor type in tvm
dtype = "float32"
target = "llvm"
dev = tvm.device(target, 0)

# Algorithm
k = te.reduce_axis((0, K), "k")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C")

# Default schedule
s = te.create_schedule(C.op)
print(tvm.lower(s, [A, B, C], simple_mode=True))

아래는 결과입니다

참고로 for m, n in T.grid(1024, 1024)for m in range(1024): for n in range(1024)와 같은 뜻입니다

1048576은 행렬을 flatten해서 그렇습니다.

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for m, n in T.grid(1024, 1024):
            C_1 = T.Buffer((1048576,), data=C.data)
            C_1[m * 1024 + n] = T.float32(0)
            for k in range(1024):
                cse_var_2: T.int32 = m * 1024
                cse_var_1: T.int32 = cse_var_2 + n
                A_1 = T.Buffer((1048576,), data=A.data)
                B_1 = T.Buffer((1048576,), data=B.data)
                C_1[cse_var_1] = C_1[cse_var_1] + A_1[cse_var_2 + k] * B_1[k * 1024 + n]

 

위를 어떻게 최적화 할 수 있는지 알아봅시다

 

Blocking

작은 inner loop를 만들어 cache hit rate를 높이는 방법입니다.

for yo in range(128):
    for xo in range(128):
        C[yo*8:yo*8+8][xo*8:xo*8+8] = 0
        for ko in range(128):
            for ki in range(8):
                for yi in range(8):
                    for xi in range(8):
                        C[yo*8+yi][xo*8+xi] += A[yo*8+yi][ko*8+ki] * B[ko*8+ki][xo*8+xi]

스케줄을 재정의 합니다

bn = 32 # x_inner, y_inner 정의
kfactor = 4 # k_inner 정의

# Blocking by loop tiling
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(kaxis,) = s[C].op.reduce_axis
ko, ki = s[C].split(kaxis, factor=kfactor)

# Hoist reduction domain outside the blocking loop
s[C].reorder(mo, no, ko, ki, mi, ni)

print(tvm.lower(s, [A, B, C], simple_mode=True))
@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for m_outer, n_outer in T.grid(32, 32):
            C_1 = T.Buffer((1048576,), data=C.data)
            for m_inner_init, n_inner_init in T.grid(32, 32):
                C_1[m_outer * 32768 + m_inner_init * 1024 + n_outer * 32 + n_inner_init] = T.float32(0)
            for k_outer, k_inner, m_inner, n_inner in T.grid(256, 4, 32, 32):
                cse_var_3: T.int32 = n_outer * 32
                cse_var_2: T.int32 = m_outer * 32768 + m_inner * 1024
                cse_var_1: T.int32 = cse_var_2 + cse_var_3 + n_inner
                A_1 = T.Buffer((1048576,), data=A.data)
                B_1 = T.Buffer((1048576,), data=B.data)
                C_1[cse_var_1] = C_1[cse_var_1] + A_1[cse_var_2 + k_outer * 4 + k_inner] * B_1[k_outer * 4096 + k_inner * 1024 + cse_var_3 + n_inner]

inner loop가 하나 생겼습니다

 

Vectorization

memory access pattern이 같으면 컴파일러가 SIMD를 쓸 수 있습니다. (동일한 operation에 thread가 서로 다른 data 처리)

SIMD 설명 : https://johnnysswlab.com/crash-course-introduction-to-parallelism-simd-parallelism/

 

# Vectorization
s[C].vectorize(ni)

print(tvm.lower(s, [A, B, C], simple_mode=True))

 

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for m_outer, n_outer in T.grid(32, 32):
            C_1 = T.Buffer((1048576,), data=C.data)
            for m_inner_init in range(32):
                C_1[m_outer * 32768 + m_inner_init * 1024 + n_outer * 32:m_outer * 32768 + m_inner_init * 1024 + n_outer * 32 + 32] = T.Broadcast(T.float32(0), 32)
            for k_outer, k_inner, m_inner in T.grid(256, 4, 32):
                cse_var_3: T.int32 = n_outer * 32
                cse_var_2: T.int32 = m_outer * 32768 + m_inner * 1024
                cse_var_1: T.int32 = cse_var_2 + cse_var_3
                A_1 = T.Buffer((1048576,), data=A.data)
                B_1 = T.Buffer((1048576,), data=B.data)
                C_1[cse_var_1:cse_var_1 + 32] = C_1[cse_var_1:cse_var_1 + 32] + T.Broadcast(A_1[cse_var_2 + k_outer * 4 + k_inner], 32) * B_1[k_outer * 4096 + k_inner * 1024 + cse_var_3:k_outer * 4096 + k_inner * 1024 + cse_var_3 + 32]

T.Broadcast 가 vectorize된 부분입니다

 

Loop Permutation

row를 순서대로 access해야 cache friendly 합니다. Blocking 예시에서 A가 cache friendly 하지 않으므로 inner loop 순서를 변경합니다

for yo in range(128):
    for xo in range(128):
        C[yo*8:yo*8+8][xo*8:xo*8+8] = 0
        for ko in range(128):
            for yi in range(8): # 서로 순서바꿈
                for ki in range(8): # 서로 순서바꿈
                    for xi in range(8):
                        C[yo*8+yi][xo*8+xi] += A[yo*8+yi][ko*8+ki] * B[ko*8+ki][xo*8+xi]

이제는 A도 row를 순서대로 access합니다. TE에서 reorder의 순서를 바꿔줍니다

 

# Loop permutation
s[C].reorder(mo, no, ko, mi, ki, ni)

# Vectorization
s[C].vectorize(ni)

print(tvm.lower(s, [A, B, C], simple_mode=True))

 

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        for m_outer, n_outer in T.grid(32, 32):
            C_1 = T.Buffer((1048576,), data=C.data)
            for m_inner_init in range(32):
                C_1[m_outer * 32768 + m_inner_init * 1024 + n_outer * 32:m_outer * 32768 + m_inner_init * 1024 + n_outer * 32 + 32] = T.Broadcast(T.float32(0), 32)
            for k_outer, m_inner, k_inner in T.grid(256, 32, 4):
                cse_var_3: T.int32 = n_outer * 32
                cse_var_2: T.int32 = m_outer * 32768 + m_inner * 1024
                cse_var_1: T.int32 = cse_var_2 + cse_var_3
                A_1 = T.Buffer((1048576,), data=A.data)
                B_1 = T.Buffer((1048576,), data=B.data)
                C_1[cse_var_1:cse_var_1 + 32] = C_1[cse_var_1:cse_var_1 + 32] + T.Broadcast(A_1[cse_var_2 + k_outer * 4 + k_inner], 32) * B_1[k_outer * 4096 + k_inner * 1024 + cse_var_3:k_outer * 4096 + k_inner * 1024 + cse_var_3 + 32]

 

 

Array Packing

Array Packing

inner block에 access하는 순서를 보면, A는 sequential하게 access하는 반면 B는 그렇지 못합니다.

B의 모양을 변경해 Sequential하게 access 할수있도록 합니다

B를 [N/bn][K][bn]로 변경해주면 됩니다

# We have to re-write the algorithm slightly.
packedB = te.compute(
    (N / bn, K, bn), lambda bigN, k, littleN: B[k, bigN * bn + littleN], name="packedB"
)
C = te.compute(
    (M, N),
    lambda m, n: te.sum(A[m, k] * packedB[n // bn, k, tvm.tir.indexmod(n, bn)], axis=k),
    name="C",
)

s = te.create_schedule(C.op)

mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(kaxis,) = s[C].op.reduce_axis
ko, ki = s[C].split(kaxis, factor=kfactor)

s[C].reorder(mo, no, ko, mi, ki, ni)
s[C].vectorize(ni)

bigN, _, littleN = s[packedB].op.axis
s[packedB].vectorize(littleN)
s[packedB].parallel(bigN)

print(tvm.lower(s, [A, B, C], simple_mode=True))
@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        packedB = T.allocate([32768], "float32x32", "global")
        packedB_1 = T.Buffer((32768,), "float32x32", data=packedB)
        for bigN in T.parallel(32):
            for k in range(1024):
                B_1 = T.Buffer((1048576,), data=B.data)
                packedB_1[bigN * 1024 + k] = B_1[k * 1024 + bigN * 32:k * 1024 + bigN * 32 + 32]
        for m_outer, n_outer in T.grid(32, 32):
            C_1 = T.Buffer((1048576,), data=C.data)
            for m_inner_init in range(32):
                C_1[m_outer * 32768 + m_inner_init * 1024 + n_outer * 32:m_outer * 32768 + m_inner_init * 1024 + n_outer * 32 + 32] = T.Broadcast(T.float32(0), 32)
            for k_outer, m_inner, k_inner in T.grid(256, 32, 4):
                cse_var_3: T.int32 = m_outer * 32768 + m_inner * 1024
                cse_var_2: T.int32 = k_outer * 4
                cse_var_1: T.int32 = cse_var_3 + n_outer * 32
                A_1 = T.Buffer((1048576,), data=A.data)
                C_1[cse_var_1:cse_var_1 + 32] = C_1[cse_var_1:cse_var_1 + 32] + T.Broadcast(A_1[cse_var_3 + cse_var_2 + k_inner], 32) * packedB_1[n_outer * 1024 + cse_var_2 + k_inner]

packedB가 새로 생겼습니다

 

Write cache for blocks

blocking으로 인해 C에 write하는 패턴이 sequential하지 않습니다.

결과를 임시로 저장했다가 C에 write합니다

 

s = te.create_schedule(C.op)

# Allocate write cache
CC = s.cache_write(C, "global")

mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)

# Write cache is computed at no
s[CC].compute_at(s[C], no)

# New inner axes
mc, nc = s[CC].op.axis

(kaxis,) = s[CC].op.reduce_axis
ko, ki = s[CC].split(kaxis, factor=kfactor)
s[CC].reorder(ko, mc, ki, nc)
s[CC].vectorize(nc)

# TODO: Add separate optimization step to discuss loop unrolling
# unrolling is a loop optimization strategy which can reduce branch
# prediction failures and increases the chance of concurrent execution
# unroll kfactor loops
s[CC].unroll(ki)

bigN, _, littleN = s[packedB].op.axis
s[packedB].vectorize(littleN)
s[packedB].parallel(bigN)

print(tvm.lower(s, [A, B, C], simple_mode=True))
@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        packedB = T.allocate([32768], "float32x32", "global")
        C_global = T.allocate([1024], "float32", "global")
        packedB_1 = T.Buffer((32768,), "float32x32", data=packedB)
        for bigN in T.parallel(32):
            for k in range(1024):
                B_1 = T.Buffer((1048576,), data=B.data)
                packedB_1[bigN * 1024 + k] = B_1[k * 1024 + bigN * 32:k * 1024 + bigN * 32 + 32]
        for m_outer, n_outer in T.grid(32, 32):
            C_global_1 = T.Buffer((1024,), data=C_global)
            for m_c_init in range(32):
                C_global_1[m_c_init * 32:m_c_init * 32 + 32] = T.Broadcast(T.float32(0), 32)
            for k_outer, m_c in T.grid(256, 32):
                cse_var_4: T.int32 = k_outer * 4
                cse_var_3: T.int32 = m_c * 32
                cse_var_2: T.int32 = n_outer * 1024 + cse_var_4
                cse_var_1: T.int32 = m_outer * 32768 + m_c * 1024 + cse_var_4
                A_1 = T.Buffer((1048576,), data=A.data)
                C_global_1[cse_var_3:cse_var_3 + 32] = C_global_1[cse_var_3:cse_var_3 + 32] + T.Broadcast(A_1[cse_var_1], 32) * packedB_1[cse_var_2]
                C_global_1[cse_var_3:cse_var_3 + 32] = C_global_1[cse_var_3:cse_var_3 + 32] + T.Broadcast(A_1[cse_var_1 + 1], 32) * packedB_1[cse_var_2 + 1]
                C_global_1[cse_var_3:cse_var_3 + 32] = C_global_1[cse_var_3:cse_var_3 + 32] + T.Broadcast(A_1[cse_var_1 + 2], 32) * packedB_1[cse_var_2 + 2]
                C_global_1[cse_var_3:cse_var_3 + 32] = C_global_1[cse_var_3:cse_var_3 + 32] + T.Broadcast(A_1[cse_var_1 + 3], 32) * packedB_1[cse_var_2 + 3]
            for m_inner, n_inner in T.grid(32, 32):
                C_1 = T.Buffer((1048576,), data=C.data)
                C_1[m_outer * 32768 + m_inner * 1024 + n_outer * 32 + n_inner] = C_global_1[m_inner * 32 + n_inner]

CC에 저장했다가 C에 다시 저장합니다

 

Parallel

멀티코어 프로세서일 경우 thread level parallel을 수행합니다

# parallel
s[C].parallel(mo)

print(tvm.lower(s, [A, B, C], simple_mode=True))
@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        packedB = T.allocate([32768], "float32x32", "global")
        packedB_1 = T.Buffer((32768,), "float32x32", data=packedB)
        for bigN in T.parallel(32):
            for k in range(1024):
                B_1 = T.Buffer((1048576,), data=B.data)
                packedB_1[bigN * 1024 + k] = B_1[k * 1024 + bigN * 32:k * 1024 + bigN * 32 + 32]
        for m_outer in T.parallel(32):
            C_global = T.allocate([1024], "float32", "global")
            for n_outer in range(32):
                C_global_1 = T.Buffer((1024,), data=C_global)
                for m_c_init in range(32):
                    C_global_1[m_c_init * 32:m_c_init * 32 + 32] = T.Broadcast(T.float32(0), 32)
                for k_outer, m_c in T.grid(256, 32):
                    cse_var_4: T.int32 = k_outer * 4
                    cse_var_3: T.int32 = m_c * 32
                    cse_var_2: T.int32 = n_outer * 1024 + cse_var_4
                    cse_var_1: T.int32 = m_outer * 32768 + m_c * 1024 + cse_var_4
                    A_1 = T.Buffer((1048576,), data=A.data)
                    C_global_1[cse_var_3:cse_var_3 + 32] = C_global_1[cse_var_3:cse_var_3 + 32] + T.Broadcast(A_1[cse_var_1], 32) * packedB_1[cse_var_2]
                    C_global_1[cse_var_3:cse_var_3 + 32] = C_global_1[cse_var_3:cse_var_3 + 32] + T.Broadcast(A_1[cse_var_1 + 1], 32) * packedB_1[cse_var_2 + 1]
                    C_global_1[cse_var_3:cse_var_3 + 32] = C_global_1[cse_var_3:cse_var_3 + 32] + T.Broadcast(A_1[cse_var_1 + 2], 32) * packedB_1[cse_var_2 + 2]
                    C_global_1[cse_var_3:cse_var_3 + 32] = C_global_1[cse_var_3:cse_var_3 + 32] + T.Broadcast(A_1[cse_var_1 + 3], 32) * packedB_1[cse_var_2 + 3]
                for m_inner, n_inner in T.grid(32, 32):
                    C_1 = T.Buffer((1048576,), data=C.data)
                    C_1[m_outer * 32768 + m_inner * 1024 + n_outer * 32 + n_inner] = C_global_1[m_inner * 32 + n_inner]

 

 

GPU

https://tvm.d2l.ai/chapter_gpu_schedules/matmul.html

https://tvm.apache.org/docs/how_to/optimize_operators/opt_conv_cuda.html

https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/

http://cuda-programming.blogspot.com/2013/02/bank-conflicts-in-shared-memory-in-cuda.html

GPU에서 최적화 방법을 봅시다. 최적화 기법이라기보다는 개인적인 느낌으로 GPU에서 matrix 연산을 하는 방법론같은 느낌입니다.

GPU memory hierachy

shared memory에 어떤 데이터를 올릴 지 사용자에서 조작가능합니다.

shared memory는 on chip에 있기 때문에 속도가 매우 빠릅니다. shared memory는 스레드 블록 단위로 할당됩니다. 각 블록은 다른 블록들과 독립적으로 실행할 수 있습니다.

shared memory를 잘 활용하는 것이 중요합니다

W*A = B 라는 연산을 수행할 때 A와 B를 shared memory에 올려 수행합니다. shared memory의 한계가 있어 쪼개서 계산을 하게 되는데 이를 Blocking이라고 합니다. B의 각 점을 계산할때마다 하나의 thread가 맡아서 수행합니다.

GPU blocking

shared memory는 bank 단위로 나누어져 있습니다. 같은 clock동안 하나의 bank에는 하나의 thread만 접근가능합니다. 만약 두개의 thread가 하나의 bank를 접근하려한다면 하나의 thread는 대기를 해야합니다. 이를 memory bank conflict이라고 하고, 이를 막기 위해 block을 여러개의 sub-block으로 쪼개는 Virtual Thread Split를 합니다

 

bank conflict

 

virtual thread split

또한 shared memory에 있는 값들이 모두 계산된 이후 모든 thread들이 병렬적으로 block에 메모리에 올리게 되는데 이를 cooperative fetching이라고 합니다

'인공지능 > TVM' 카테고리의 다른 글

TVM - autoTVM  (0) 2024.05.20
TVM - relay, Graph level Optimization  (0) 2024.05.06
TVM introduction  (0) 2024.05.06