벡터 미분이 항상 생각이 안나서 정리합니다
Chain rule
스칼라 f, 벡터 x,y가 주어져있습니다
$\frac{\partial f}{\partial \textbf{x}} = \left[ \frac{\partial f}{\partial x_1} \;\; \frac{\partial f}{\partial x_2} \right]$
$\frac{\partial f}{\partial \textbf{y}} = \left[ \frac{\partial f}{\partial y_1} \;\; \frac{\partial f}{\partial y_2} \;\; \frac{\partial f}{\partial y_3} \right]$
${\frac{\partial \textbf{y}}{\partial \textbf{x}} =
\begin{bmatrix}
\frac{\partial y_1}{\partial x_1} & \frac{\partial y_1}{\partial x_2} \\
\frac{\partial y_2}{\partial x_1} & \frac{\partial y_2}{\partial x_2} \\
\frac{\partial y_3}{\partial x_1} & \frac{\partial y_3}{\partial x_2}
\end{bmatrix}
}$
${
\frac{\partial f}{\partial x_1}
=
\frac{\partial f}{\partial y_1}\frac{\partial y_1}{\partial x_1} +
\frac{\partial f}{\partial y_2}\frac{\partial y_2}{\partial x_1} +
\frac{\partial f}{\partial y_3}\frac{\partial y_3}{\partial x_1}
}$
${
\frac{\partial f}{\partial x_2}
=
\frac{\partial f}{\partial y_1}\frac{\partial y_1}{\partial x_2} +
\frac{\partial f}{\partial y_2}\frac{\partial y_2}{\partial x_2} +
\frac{\partial f}{\partial y_3}\frac{\partial y_3}{\partial x_2}
}$
이기 때문에
${
\frac{\partial f}{\partial \textbf{x}} = \frac{\partial f}{\partial \textbf{y}}\frac{\partial \textbf{y}}{\partial \textbf{x}}
}$
곱셈 순서는 꼭 지켜야 합니다
찾아보니 위 예시에서는 inner product로 곱셈을 했지만, 두 행렬간 inner product말고 다른 곱셈을 하는 경우도 있습니다
matrix chain rule이라고 검색해보니 일반화된 규칙은 없는것같습니다 (교수님도 직접 계산을 해봐야 아는거라고 합니다)
Backpropagation for a Linear Layer
backpropagation 시 Y = XW에서 미분 계산하는 방법을 보겠습니다
X는 (N * D)의 shape을 가지고 있고 W는 (D * M)의 shape를 가지고 있습니다
$\frac{\partial L}{\partial Y}$(upstream gradient)가 이미 주어져있다고 가정해봅시다 (L은 스칼라)
구하고 싶은 값은 $\frac{\partial L}{\partial X}$와 $\frac{\partial L}{\partial X}$입니다
체인룰에 따라,
$\frac{\partial L}{\partial X} = \frac{\partial Y}{\partial X} \frac{\partial L}{\partial Y}$
$\frac{\partial L}{\partial W} = \frac{\partial Y}{\partial W} \frac{\partial L}{\partial Y}$
X,Y,Z가 각각 matrix입니다
따라서 $ \frac{\partial Y}{\partial X} $의 shape은 (N*M*N*D)입니다
이를 직접 계산하게 되면, 메모리를 너무 많이 차지하게 됩니다
X가 (2*2) W가 (2*3)이라고 합시다
$\frac{\partial L}{\partial x_{1,1}}$를 먼저 구해봅시다 (결과값은 스칼라)
$\frac{\partial L}{\partial x_{1,1}} = \sum_{i=1}^{2} \sum_{j=1}^{3} \frac{\partial L}{\partial y_{i,j}} \frac{\partial y_{i,j}}{\partial x_{1,1}} = \frac{\partial L}{\partial Y} \cdot \frac{\partial Y}{\partial x_{1,1}}$
Y는 XW이므로
${Y=XW=\begin{pmatrix}
x_{1,1}w_{1,1} + x_{1,2}w_{2,1} & x_{1,1}w_{1,2} + x_{1,2}w_{2,2} & x_{1,1}w_{1,3} + x_{1,2}w_{2,3} \\
x_{2,1}w_{1,1} + x_{2,2}w_{2,1} & x_{2,1}w_{1,2} + x_{2,2}w_{2,2} & x_{2,1}w_{1,3} + x_{2,2}w_{2,3} \\
\end{pmatrix}}$
${\frac{\partial Y}{\partial x_{1,1}} = \begin{pmatrix}
w_{1,1} & w_{1,2} & w_{1,3} \\
0 & 0 & 0 \\
\end{pmatrix}}$
${\begin{align}
\frac {\partial L}{x_{1,1}} &= \frac { \partial L}{ \partial Y}\cdot \frac { \partial Y}{ \partial x_{1,1}}
\\ &= \begin{pmatrix}
\frac{ \partial L}{ \partial y_{1,1}} & \frac { \partial L}{ \partial y_{1,2}} & \frac { \partial L}{ \partial y_{1,3}}
\\ \frac { \partial L}{ \partial y_{2,1}} & \frac { \partial L}{ \partial y_{2,2}} & \frac { \partial L}{ \partial y_{2,3}}
\end{pmatrix}\cdot
\begin{pmatrix}
w_{1,1} & w_{1,2} & w_{1,3}
\\ 0 & 0 & 0
\end{pmatrix}
\\ &= \frac { \partial L}{ \partial y_{1,1}}w_{1,1} + \frac { \partial L}{ \partial y_{1,2}}w_{1,2} + \frac { \partial L}{ \partial y_{1,3}}w_{1,3}
\label{eq:dLdX11}
\end{align}}$
비슷한 방법으로 하면
$ \frac {\partial L}{x_{1,2}} = \frac { \partial L}{ \partial Y}\cdot \frac { \partial Y}{ \partial x_{1,2}} = \frac { \partial L}{ \partial y_{2,1}}w_{2,1} + \frac { \partial L}{ \partial y_{2,2}}w_{2,2} + \frac { \partial L}{ \partial y_{2,3}}w_{2,3}$
$ \frac {\partial L}{x_{2,1}} = \frac { \partial L}{ \partial Y}\cdot \frac { \partial Y}{ \partial x_{2,1}} = \frac { \partial L}{ \partial y_{2,1}}w_{1,1} + \frac { \partial L}{ \partial y_{2,2}}w_{1,2} + \frac { \partial L}{ \partial y_{2,3}}w_{1,3}$
$ \frac {\partial L}{x_{2,2}} = \frac { \partial L}{ \partial Y}\cdot \frac { \partial Y}{ \partial x_{2,2}} = \frac { \partial L}{ \partial y_{2,1}}w_{2,1} + \frac { \partial L}{ \partial y_{2,2}}w_{2,2} + \frac { \partial L}{ \partial y_{2,3}}w_{2,3}$
식을 조합하면
${
\begin{align}
\ \frac{\partial L}{\partial X} &= \begin{pmatrix}
\frac{ \partial L}{ \partial y_{1,1}} & \frac { \partial L}{ \partial y_{1,2}} & \frac { \partial L}{ \partial y_{1,3}}
\\ \frac { \partial L}{ \partial y_{2,1}} & \frac { \partial L}{ \partial y_{2,2}} & \frac { \partial L}{ \partial y_{2,3}}
\end{pmatrix}
\begin{pmatrix}
w_{1,1} & w_{2,1} \\
w_{1,2} & w_{2,2} \\
w_{1,3} & w_{2,3}
\end{pmatrix}
\\ &= \frac{\partial L}{\partial Y}W^T
\end{align}
}$
또한 같은 방법으로 $ \frac{\partial L}{\partial W}$ 를 계산하면
$ \frac{\partial L}{\partial W}= X^T \frac{\partial L}{\partial Y}$
https://web.eecs.umich.edu/~justincj/teaching/eecs498/FA2020/linear-backprop.html
'인공지능 > 기타' 카테고리의 다른 글
Multimodal Unsupervised Image-to-Image Translation (1) | 2024.12.15 |
---|---|
Deep Image Prior (2) | 2024.12.14 |
라그랑주 승수법 (Lagrange multipliers) (0) | 2024.08.05 |
Transformer: Attention is All you need (2) | 2024.07.27 |
[강화학습 요약] Policy Gradient (0) | 2023.10.15 |