roofline.devPublic Beta

Matrix LogSumExp

0MEDIUMRankedReport issue

Given matrix AA of shape m×nm \times n, produce ss such that

A^ij=maxijAijs=log(iMjNexp(AijA^ij))+A^ij\hat{A}_{ij} = \max_{ij} A_{ij} \quad s = \log\left(\sum_{i}^{M} \sum_{j}^{N} \exp\left(A_{ij} - \hat{A}_{ij} \right) \right) + \hat{A}_{ij}

NB: The shift by A^ij\hat{A}_{ij} is required for numerical stability. Without it, exp(Aij)\exp(A_{ij}) overflows the f32 range as soon as AijA_{ij} exceeds 88\sim 88.

Input

  • A - input matrix of shape [m][n] stored in row-major order.
  • m - the number of rows in A.
  • n - the number of columns in A.

Output

  • s - the log-sum-exp value computed from the formula above.
Open on a desktop browser to write and submit code.