인공지능/기타

Backpropagation for a Linear Layer

NickTop 2024. 10. 31. 00:55

벡터 미분이 항상 생각이 안나서 정리합니다

 

Chain rule

스칼라 f, 벡터 x,y가 주어져있습니다

$\frac{\partial f}{\partial \textbf{x}} = \left[ \frac{\partial f}{\partial x_1} \;\; \frac{\partial f}{\partial x_2}  \right]$ 
   
$\frac{\partial f}{\partial \textbf{y}} = \left[ \frac{\partial f}{\partial y_1} \;\; \frac{\partial f}{\partial y_2} \;\; \frac{\partial f}{\partial y_3}  \right]$  
  
${\frac{\partial \textbf{y}}{\partial \textbf{x}} = 
\begin{bmatrix}
\frac{\partial y_1}{\partial x_1} & \frac{\partial y_1}{\partial x_2} \\ 
\frac{\partial y_2}{\partial x_1} & \frac{\partial y_2}{\partial x_2} \\ 
\frac{\partial y_3}{\partial x_1} & \frac{\partial y_3}{\partial x_2}
\end{bmatrix}
}$  
  
${
\frac{\partial f}{\partial x_1}
=
\frac{\partial f}{\partial y_1}\frac{\partial y_1}{\partial x_1} + 
\frac{\partial f}{\partial y_2}\frac{\partial y_2}{\partial x_1} + 
\frac{\partial f}{\partial y_3}\frac{\partial y_3}{\partial x_1}
}$  
  
${
\frac{\partial f}{\partial x_2}
=
\frac{\partial f}{\partial y_1}\frac{\partial y_1}{\partial x_2} + 
\frac{\partial f}{\partial y_2}\frac{\partial y_2}{\partial x_2} + 
\frac{\partial f}{\partial y_3}\frac{\partial y_3}{\partial x_2}
}$ 
  
이기 때문에  
  
${
\frac{\partial f}{\partial \textbf{x}} = \frac{\partial f}{\partial \textbf{y}}\frac{\partial \textbf{y}}{\partial \textbf{x}}
}$
  
곱셈 순서는 꼭 지켜야 합니다

 

찾아보니 위 예시에서는 inner product로 곱셈을 했지만, 두 행렬간 inner product말고 다른 곱셈을 하는 경우도 있습니다

matrix chain rule이라고 검색해보니 일반화된 규칙은 없는것같습니다 (교수님도 직접 계산을 해봐야 아는거라고 합니다)

 

 

Backpropagation for a Linear Layer

backpropagation 시 Y = XW에서 미분 계산하는 방법을 보겠습니다

X는 (N * D)의 shape을 가지고 있고 W는 (D * M)의 shape를 가지고 있습니다

 

$\frac{\partial L}{\partial Y}$(upstream gradient)가 이미 주어져있다고 가정해봅시다 (L은 스칼라)

 

구하고 싶은 값은 $\frac{\partial L}{\partial X}$와 $\frac{\partial L}{\partial X}$입니다

 

체인룰에 따라,

 

$\frac{\partial L}{\partial X} = \frac{\partial Y}{\partial X}  \frac{\partial L}{\partial Y}$
$\frac{\partial L}{\partial W} = \frac{\partial Y}{\partial W}  \frac{\partial L}{\partial Y}$

 

X,Y,Z가 각각 matrix입니다

따라서 $ \frac{\partial Y}{\partial X} $의 shape은 (N*M*N*D)입니다

이를 직접 계산하게 되면, 메모리를 너무 많이 차지하게 됩니다

 

X가 (2*2) W가 (2*3)이라고 합시다

 

 

$\frac{\partial L}{\partial x_{1,1}}$를 먼저 구해봅시다 (결과값은 스칼라)

$\frac{\partial L}{\partial x_{1,1}} = \sum_{i=1}^{2} \sum_{j=1}^{3} \frac{\partial L}{\partial y_{i,j}}  \frac{\partial y_{i,j}}{\partial x_{1,1}} = \frac{\partial L}{\partial Y} \cdot  \frac{\partial Y}{\partial x_{1,1}}$

 

Y는 XW이므로

${Y=XW=\begin{pmatrix}
x_{1,1}w_{1,1} + x_{1,2}w_{2,1} & x_{1,1}w_{1,2} + x_{1,2}w_{2,2} & x_{1,1}w_{1,3} + x_{1,2}w_{2,3} \\
x_{2,1}w_{1,1} + x_{2,2}w_{2,1} & x_{2,1}w_{1,2} + x_{2,2}w_{2,2} & x_{2,1}w_{1,3} + x_{2,2}w_{2,3} \\
\end{pmatrix}}$

 

${\frac{\partial Y}{\partial x_{1,1}} = \begin{pmatrix}
w_{1,1} & w_{1,2} & w_{1,3} \\
0 & 0 & 0 \\
\end{pmatrix}}$

 

${\begin{align}
  \frac {\partial L}{x_{1,1}} &= \frac { \partial  L}{ \partial  Y}\cdot \frac { \partial  Y}{ \partial  x_{1,1}}
  \\ &= \begin{pmatrix}
         \frac{ \partial  L}{ \partial  y_{1,1}} & \frac { \partial  L}{ \partial  y_{1,2}} & \frac { \partial  L}{ \partial  y_{1,3}}
    \\ \frac { \partial  L}{ \partial  y_{2,1}} & \frac { \partial  L}{ \partial  y_{2,2}} & \frac { \partial  L}{ \partial  y_{2,3}}
  \end{pmatrix}\cdot
  \begin{pmatrix}
    w_{1,1} & w_{1,2} & w_{1,3}
    \\ 0 & 0 & 0 
  \end{pmatrix}
  \\ &= \frac { \partial  L}{ \partial y_{1,1}}w_{1,1} + \frac { \partial  L}{ \partial y_{1,2}}w_{1,2} + \frac { \partial  L}{ \partial y_{1,3}}w_{1,3}
  \label{eq:dLdX11}
\end{align}}$

 

비슷한 방법으로 하면

$ \frac {\partial L}{x_{1,2}} = \frac { \partial  L}{ \partial  Y}\cdot \frac { \partial  Y}{ \partial  x_{1,2}} = \frac { \partial  L}{ \partial y_{2,1}}w_{2,1} + \frac { \partial  L}{ \partial y_{2,2}}w_{2,2} + \frac { \partial  L}{ \partial y_{2,3}}w_{2,3}$

$ \frac {\partial L}{x_{2,1}} = \frac { \partial  L}{ \partial  Y}\cdot \frac { \partial  Y}{ \partial  x_{2,1}} = \frac { \partial  L}{ \partial y_{2,1}}w_{1,1} + \frac { \partial  L}{ \partial y_{2,2}}w_{1,2} + \frac { \partial  L}{ \partial y_{2,3}}w_{1,3}$

$ \frac {\partial L}{x_{2,2}} = \frac { \partial  L}{ \partial  Y}\cdot \frac { \partial  Y}{ \partial  x_{2,2}} = \frac { \partial  L}{ \partial y_{2,1}}w_{2,1} + \frac { \partial  L}{ \partial y_{2,2}}w_{2,2} + \frac { \partial  L}{ \partial y_{2,3}}w_{2,3}$

 

식을 조합하면

${
\begin{align}
  \ \frac{\partial L}{\partial X}  &= \begin{pmatrix}
         \frac{ \partial  L}{ \partial  y_{1,1}} & \frac { \partial  L}{ \partial  y_{1,2}} & \frac { \partial  L}{ \partial  y_{1,3}}
    \\ \frac { \partial  L}{ \partial  y_{2,1}} & \frac { \partial  L}{ \partial  y_{2,2}} & \frac { \partial  L}{ \partial  y_{2,3}}
  \end{pmatrix}
  \begin{pmatrix}
    w_{1,1} & w_{2,1} \\

    w_{1,2} & w_{2,2} \\

    w_{1,3} & w_{2,3} 
  \end{pmatrix}

 \\ &= \frac{\partial L}{\partial Y}W^T
\end{align}
}$

 

또한 같은 방법으로 $ \frac{\partial L}{\partial W}$ 를 계산하면

$ \frac{\partial L}{\partial W}= X^T \frac{\partial L}{\partial Y}$ 

 

https://web.eecs.umich.edu/~justincj/teaching/eecs498/FA2020/linear-backprop.html