Xintong Li

Mathmetics

Online Softmax

Vanilla Softmax

Given logits: $x_1, x_2, \dots, x_n$,

\[\mathrm{softmax}(x)_i = \frac{e^{x_1}}{\sum _{j=1} ^n e^{x_j}}\]

However, if $x_j$ were very large, $e^{x_j}$ could overflow. A common trick to mitigate this issue is to subtract $m=\max_j x_j$, so that

\[\mathrm{softmax}(x)_i = \frac{e^{x_i - m}}{\sum _{j=1} ^n e^{x_j - m}}\]

But this requires finding $m$ from all the logics, which could be a lot. Online softmax is to enable block-wise compuation of softmax.

Online Softmax Derivation

Assuming only $t$ logits $x_1,\dots,x_t$ are available, we have

\[m_t = max(x_1,\dots,x_t)\] \[d_t = \sum _{j=1} ^t e^{x_j - m_t}\]

For $x_{t+1}$, we could compute $m$ and $d$ incrementally as

\[m_t = max(m_t, x_{t+1})\] \[\begin{equation*} \begin{array}{rcl} d_t & = & \sum \limits _{j=1} ^{t+1} e^{x_j - m_{t+1}} \\ & = & \sum \limits _{j=1} ^{t} e^{x_j - m_{t}} + e^{x_{t+1} - m_{t+1}} \\ & = & \sum \limits _{j=1} ^{t} e^{x_j - m_{t}} \cdot e^{m_t - m_{t+1}} + e^{x_{t+1} - m_{t+1}} \\ & = & d_t \cdot e^{m_t - m_{t+1}} + e^{x_{t+1} - m_{t+1}} \end{array} \end{equation*}\]

So, when $t=n$,

\[\mathrm{softmax}(x)_i = \frac{e^{x_i - m_n}}{d_n}\]

WLOG, merging two blocks is as

\[m = max(m_A, m_B)\] \[d = d_A e^{m_A-m} + d_B e^{m_B-m}\]

This is the foundation of flash attention, where it does

for each K/V block:
  compute block-wise Q @ K^T
  get block-wise m_B and d_B
  merge to global m and d

Reference