From 8e7963a4e2f563a13b0b2c10d2c2efabe2cb37e2 Mon Sep 17 00:00:00 2001 From: "Jachym.Barvinek" Date: Tue, 28 Apr 2026 20:08:42 +0200 Subject: [PATCH 1/5] support row_vector for softmax --- stan/math/fwd/fun/log_softmax.hpp | 9 +++++++- stan/math/fwd/fun/softmax.hpp | 6 ++++++ stan/math/prim/fun/softmax.hpp | 10 ++++----- stan/math/rev/fun/log_softmax.hpp | 9 +++++++- stan/math/rev/fun/softmax.hpp | 4 +++- test/unit/math/mix/fun/softmax_test.cpp | 26 +++++++++++++++++++++++ test/unit/math/prim/fun/softmax_test.cpp | 27 ++++++++++++++++++++++++ 7 files changed, 83 insertions(+), 8 deletions(-) diff --git a/stan/math/fwd/fun/log_softmax.hpp b/stan/math/fwd/fun/log_softmax.hpp index acaf71070cb..d8fefbf0c40 100644 --- a/stan/math/fwd/fun/log_softmax.hpp +++ b/stan/math/fwd/fun/log_softmax.hpp @@ -19,7 +19,14 @@ namespace math { * @return Softmax of the input. * @throw std::domain_error If the input vector is size 0. */ -template * = nullptr> +template * = nullptr, + require_t>>* = nullptr> +inline auto log_softmax(const RowVec& x) { + return log_softmax(x.transpose()).transpose().eval(); +} + +template * = nullptr, + require_not_t>>* = nullptr> inline auto log_softmax(T&& x) { return apply_vector_unary::apply(std::forward(x), [](auto&& alpha) { using T_alpha = decltype(alpha); diff --git a/stan/math/fwd/fun/softmax.hpp b/stan/math/fwd/fun/softmax.hpp index 3625332ddf2..bec25e978e5 100644 --- a/stan/math/fwd/fun/softmax.hpp +++ b/stan/math/fwd/fun/softmax.hpp @@ -10,6 +10,12 @@ namespace stan { namespace math { +template * = nullptr, + require_t>>* = nullptr> +inline auto softmax(const RowVec& alpha) { + return softmax(alpha.transpose()).transpose().eval(); +} + template * = nullptr> inline auto softmax(const ColVec& alpha) { diff --git a/stan/math/prim/fun/softmax.hpp b/stan/math/prim/fun/softmax.hpp index d3221f7ce72..6ab4e89ef1c 100644 --- a/stan/math/prim/fun/softmax.hpp +++ b/stan/math/prim/fun/softmax.hpp @@ -38,20 +38,20 @@ namespace math { * \end{array} * \f$ * - * @tparam ColVec type of elements in the vector + * @tparam Vec type of elements in the vector * @param[in] v Vector to transform. * @return Unit simplex result of the softmax transform of the vector. */ -template * = nullptr> -inline plain_type_t softmax(const ColVec& v) { +template * = nullptr> +inline plain_type_t softmax(const Vec& v) { using std::exp; if (v.size() == 0) { return v; } const auto& v_ref = to_ref(v); const auto theta = (v_ref.array() - v_ref.maxCoeff()).exp().eval(); - return theta.array() / theta.sum(); + return (theta.array() / theta.sum()).matrix(); } } // namespace math diff --git a/stan/math/rev/fun/log_softmax.hpp b/stan/math/rev/fun/log_softmax.hpp index 92650558b65..ceca0538b80 100644 --- a/stan/math/rev/fun/log_softmax.hpp +++ b/stan/math/rev/fun/log_softmax.hpp @@ -53,7 +53,14 @@ class log_softmax_elt_vari : public vari { * @return softmax of the input * @throw std::domain_error if the input size is 0 */ -template * = nullptr> +template * = nullptr, + require_eigen_st* = nullptr> +inline auto log_softmax(const T& x) { + return log_softmax(x.transpose()).transpose().eval(); +} + +template * = nullptr, + require_not_t>>* = nullptr> inline auto log_softmax(const T& x) { const int a_size = x.size(); diff --git a/stan/math/rev/fun/softmax.hpp b/stan/math/rev/fun/softmax.hpp index a1bf786e826..03841e0574e 100644 --- a/stan/math/rev/fun/softmax.hpp +++ b/stan/math/rev/fun/softmax.hpp @@ -31,7 +31,9 @@ inline auto softmax(const Mat& alpha) { return ret_type(alpha); } arena_t alpha_arena = alpha; - arena_t res_val = softmax(value_of(alpha_arena)); + using double_mat_t + = Eigen::Matrix; + arena_t res_val = softmax(value_of(alpha_arena)); arena_t res = res_val; reverse_pass_callback([res_val, res, alpha_arena]() mutable { diff --git a/test/unit/math/mix/fun/softmax_test.cpp b/test/unit/math/mix/fun/softmax_test.cpp index bf748824173..d22fa33e37e 100644 --- a/test/unit/math/mix/fun/softmax_test.cpp +++ b/test/unit/math/mix/fun/softmax_test.cpp @@ -9,6 +9,7 @@ TEST(MathMixMatFun, softmax) { tols.hessian_hessian_ = 1e-2; tols.hessian_fvar_hessian_ = 1e-2; + // Column vectors Eigen::VectorXd a(0); stan::test::expect_ad(tols, f, a); expect_ad_matvar(f, a); @@ -41,4 +42,29 @@ TEST(MathMixMatFun, softmax) { d4 << 0, 3, -1; stan::test::expect_ad(tols, f, d4); expect_ad_matvar(f, d4); + + // Row vectors + Eigen::RowVectorXd ra(0); + stan::test::expect_ad(tols, f, ra); + expect_ad_matvar(f, ra); + + Eigen::RowVectorXd rb(1); + rb << 0; + stan::test::expect_ad(tols, f, rb); + expect_ad_matvar(f, rb); + + Eigen::RowVectorXd rc(2); + rc << -1, 1; + stan::test::expect_ad(tols, f, rc); + expect_ad_matvar(f, rc); + + Eigen::RowVectorXd rd(3); + rd << -1, 1, 10; + stan::test::expect_ad(tols, f, rd); + expect_ad_matvar(f, rd); + + Eigen::RowVectorXd rd2(3); + rd2 << 0.5, -1, 3; + stan::test::expect_ad(tols, f, rd2); + expect_ad_matvar(f, rd2); } diff --git a/test/unit/math/prim/fun/softmax_test.cpp b/test/unit/math/prim/fun/softmax_test.cpp index 8e3c8a13328..356856f4050 100644 --- a/test/unit/math/prim/fun/softmax_test.cpp +++ b/test/unit/math/prim/fun/softmax_test.cpp @@ -28,3 +28,30 @@ TEST(MathMatrixPrimMat, softmax) { EXPECT_FLOAT_EQ(exp(1) / (exp(-1) + exp(1) + exp(10.0)), theta3[1]); EXPECT_FLOAT_EQ(exp(10) / (exp(-1) + exp(1) + exp(10.0)), theta3[2]); } + +TEST(MathMatrixPrimMat, softmax_row_vector) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::softmax; + + Matrix x(1); + x << 0.0; + Matrix theta = softmax(x); + EXPECT_EQ(1, theta.size()); + EXPECT_FLOAT_EQ(1.0, theta[0]); + + Matrix x2(2); + x2 << -1.0, 1.0; + Matrix theta2 = softmax(x2); + EXPECT_EQ(2, theta2.size()); + EXPECT_FLOAT_EQ(exp(-1) / (exp(-1) + exp(1)), theta2[0]); + EXPECT_FLOAT_EQ(exp(1) / (exp(-1) + exp(1)), theta2[1]); + + Matrix x3(3); + x3 << -1.0, 1.0, 10.0; + Matrix theta3 = softmax(x3); + EXPECT_EQ(3, theta3.size()); + EXPECT_FLOAT_EQ(exp(-1) / (exp(-1) + exp(1) + exp(10.0)), theta3[0]); + EXPECT_FLOAT_EQ(exp(1) / (exp(-1) + exp(1) + exp(10.0)), theta3[1]); + EXPECT_FLOAT_EQ(exp(10) / (exp(-1) + exp(1) + exp(10.0)), theta3[2]); +} From bf76b21758db6fa5e7f7624a9ac8d5ebe9d4d283 Mon Sep 17 00:00:00 2001 From: "Jachym.Barvinek" Date: Tue, 28 Apr 2026 21:08:25 +0200 Subject: [PATCH 2/5] add support for matrix --- stan/math/fwd/fun/log_softmax.hpp | 55 +++++++++++-------- stan/math/fwd/fun/softmax.hpp | 49 ++++++++--------- stan/math/prim/fun/log_softmax.hpp | 25 ++++++++- stan/math/prim/fun/softmax.hpp | 22 +++++++- stan/math/rev/fun/log_softmax.hpp | 56 +++++++++++++++++--- stan/math/rev/fun/softmax.hpp | 37 ++++++++++++- test/unit/math/mix/fun/log_softmax_test.cpp | 11 ++++ test/unit/math/mix/fun/softmax_test.cpp | 11 ++++ test/unit/math/prim/fun/log_softmax_test.cpp | 20 +++++++ test/unit/math/prim/fun/softmax_test.cpp | 21 ++++++++ 10 files changed, 247 insertions(+), 60 deletions(-) diff --git a/stan/math/fwd/fun/log_softmax.hpp b/stan/math/fwd/fun/log_softmax.hpp index d8fefbf0c40..2c317016cea 100644 --- a/stan/math/fwd/fun/log_softmax.hpp +++ b/stan/math/fwd/fun/log_softmax.hpp @@ -19,6 +19,28 @@ namespace math { * @return Softmax of the input. * @throw std::domain_error If the input vector is size 0. */ +template * = nullptr, + require_not_eigen_vector_t* = nullptr, + require_t>>* = nullptr> +inline auto log_softmax(const Mat& m) { + check_nonzero_size("log_softmax", "m", m); + const auto& m_ref = to_ref(m); + const auto val = m_ref.val().eval(); + const auto shifted + = (val.array().colwise() - val.rowwise().maxCoeff().array()).eval(); + const auto exp_s = shifted.exp().eval(); + const auto row_sums = exp_s.rowwise().sum().eval(); + const auto lsm_val = (shifted.colwise() - row_sums.log()).matrix().eval(); + // softmax values needed for the tangent: d_in - softmax(x) ⊙ dot(softmax(x), d_in) + const auto s = (exp_s.colwise() / row_sums).eval(); + const auto d_in = m_ref.d().eval(); + const auto dots = (s.array() * d_in.array()).rowwise().sum().eval(); + plain_type_t result(m_ref.rows(), m_ref.cols()); + result.val() = lsm_val; + result.d() = (d_in.array().colwise() - dots.array()).matrix(); + return result; +} + template * = nullptr, require_t>>* = nullptr> inline auto log_softmax(const RowVec& x) { @@ -29,33 +51,20 @@ template * = nullptr, require_not_t>>* = nullptr> inline auto log_softmax(T&& x) { return apply_vector_unary::apply(std::forward(x), [](auto&& alpha) { - using T_alpha = decltype(alpha); + using T_alpha = std::decay_t; using T_fvar = value_type_t; - using T_fvar_inner = typename T_fvar::Scalar; + using T_inner = typename T_fvar::Scalar; auto&& alpha_ref = to_ref(std::forward(alpha)); - Eigen::Matrix alpha_t = alpha_ref.val(); - Eigen::Matrix softmax_alpha_t = softmax(alpha_t); - - Eigen::Matrix log_softmax_alpha(alpha_ref.size()); - log_softmax_alpha.val() = log_softmax(alpha_t); - log_softmax_alpha.d().setZero(); - - for (int m = 0; m < alpha_ref.size(); ++m) { - T_fvar_inner negative_alpha_m_d_times_softmax_alpha_t_m - = -alpha_ref.coeff(m).d_ * softmax_alpha_t(m); - for (int k = 0; k < alpha_ref.size(); ++k) { - if (m == k) { - log_softmax_alpha(k).d_ - += alpha_ref.coeff(m).d_ - + negative_alpha_m_d_times_softmax_alpha_t_m; - } else { - log_softmax_alpha(k).d_ += negative_alpha_m_d_times_softmax_alpha_t_m; - } - } - } + const Eigen::Matrix val = alpha_ref.val(); + const Eigen::Matrix s = softmax(val); + const auto d_in = alpha_ref.d().eval(); + const T_inner dot_sd = s.dot(d_in); - return log_softmax_alpha; + Eigen::Matrix result(alpha_ref.size()); + result.val() = log_softmax(val); + result.d() = (d_in.array() - dot_sd).matrix(); + return result; }); } diff --git a/stan/math/fwd/fun/softmax.hpp b/stan/math/fwd/fun/softmax.hpp index bec25e978e5..6ebe0046453 100644 --- a/stan/math/fwd/fun/softmax.hpp +++ b/stan/math/fwd/fun/softmax.hpp @@ -10,6 +10,21 @@ namespace stan { namespace math { +template * = nullptr, + require_not_eigen_vector_t* = nullptr, + require_t>>* = nullptr> +inline auto softmax(const Mat& m) { + const auto& m_ref = to_ref(m); + const auto s = softmax(m_ref.val()); + const auto d_in = m_ref.d().eval(); + // d/dx softmax(x) applied to tangent: s ⊙ (d_in - s · d_in) (per row) + const auto dots = (s.array() * d_in.array()).rowwise().sum().eval(); + plain_type_t result(m_ref.rows(), m_ref.cols()); + result.val() = s; + result.d() = (s.array() * (d_in.array().colwise() - dots.array())).matrix(); + return result; +} + template * = nullptr, require_t>>* = nullptr> inline auto softmax(const RowVec& alpha) { @@ -26,33 +41,13 @@ inline auto softmax(const ColVec& alpha) { return Matrix, Dynamic, 1>(); } const auto& alpha_ref = to_ref(alpha); - - Matrix softmax_alpha_t = softmax(value_of(alpha_ref)); - - Matrix, Dynamic, 1> softmax_alpha(alpha.size()); - for (int k = 0; k < alpha.size(); ++k) { - softmax_alpha.coeffRef(k).val_ = softmax_alpha_t.coeff(k); - softmax_alpha.coeffRef(k).d_ = 0; - } - - for (int m = 0; m < alpha.size(); ++m) { - T negative_alpha_m_d_times_softmax_alpha_t_m - = -alpha_ref.coeff(m).d_ * softmax_alpha_t.coeff(m); - for (int k = 0; k < alpha.size(); ++k) { - if (m == k) { - softmax_alpha.coeffRef(k).d_ - += softmax_alpha_t.coeff(k) - * (alpha_ref.coeff(m).d_ - + negative_alpha_m_d_times_softmax_alpha_t_m); - } else { - softmax_alpha.coeffRef(k).d_ - += softmax_alpha_t.coeff(k) - * negative_alpha_m_d_times_softmax_alpha_t_m; - } - } - } - - return softmax_alpha; + const Matrix s = softmax(value_of(alpha_ref)); + const auto d_in = alpha_ref.d().eval(); + const T dot_sd = s.dot(d_in); + Matrix, Dynamic, 1> result(alpha.size()); + result.val() = s; + result.d() = (s.array() * (d_in.array() - dot_sd)).matrix(); + return result; } } // namespace math diff --git a/stan/math/prim/fun/log_softmax.hpp b/stan/math/prim/fun/log_softmax.hpp index 876d75a7f09..acdb730a15c 100644 --- a/stan/math/prim/fun/log_softmax.hpp +++ b/stan/math/prim/fun/log_softmax.hpp @@ -40,7 +40,10 @@ namespace math { * @return log unit simplex result of the softmax transform of the vector. */ template * = nullptr, - require_container_t* = nullptr> + require_container_t* = nullptr, + require_not_t>::value + && !is_eigen_vector>::value>>* = nullptr> inline auto log_softmax(Container&& x) { check_nonzero_size("log_softmax", "v", x); return make_holder( @@ -52,6 +55,26 @@ inline auto log_softmax(Container&& x) { to_ref(std::forward(x))); } +/** + * Return the log softmax of the rows of the specified matrix. + * Each row is transformed independently; the result has the same shape + * as the input. + * + * @tparam Mat type of input matrix + * @param[in] m Matrix to transform row-wise. + * @return Log-softmax applied row-wise. + */ +template * = nullptr, + require_not_eigen_vector_t* = nullptr> +inline plain_type_t log_softmax(const Mat& m) { + check_nonzero_size("log_softmax", "m", m); + const auto& m_ref = to_ref(m); + const auto shifted + = (m_ref.array().colwise() - m_ref.rowwise().maxCoeff().array()).eval(); + const auto exp_s = shifted.exp().eval(); + return (shifted.colwise() - exp_s.rowwise().sum().log()).matrix(); +} + } // namespace math } // namespace stan #endif diff --git a/stan/math/prim/fun/softmax.hpp b/stan/math/prim/fun/softmax.hpp index 6ab4e89ef1c..446ea73a741 100644 --- a/stan/math/prim/fun/softmax.hpp +++ b/stan/math/prim/fun/softmax.hpp @@ -45,13 +45,31 @@ namespace math { template * = nullptr> inline plain_type_t softmax(const Vec& v) { - using std::exp; if (v.size() == 0) { return v; } const auto& v_ref = to_ref(v); const auto theta = (v_ref.array() - v_ref.maxCoeff()).exp().eval(); - return (theta.array() / theta.sum()).matrix(); + return (theta / theta.sum()).matrix(); +} + +/** + * Return the softmax of the rows of the specified matrix. + * Each row is transformed independently; the result is a row-stochastic + * matrix whose rows each sum to one. + * + * @tparam Mat type of input matrix + * @param[in] m Matrix to transform row-wise. + * @return Row-stochastic matrix result of applying softmax to each row. + */ +template * = nullptr, + require_not_eigen_vector_t* = nullptr> +inline plain_type_t softmax(const Mat& m) { + const auto& m_ref = to_ref(m); + const auto shifted + = (m_ref.array().colwise() - m_ref.rowwise().maxCoeff().array()).eval(); + const auto exp_s = shifted.exp().eval(); + return (exp_s.colwise() / exp_s.rowwise().sum()).matrix(); } } // namespace math diff --git a/stan/math/rev/fun/log_softmax.hpp b/stan/math/rev/fun/log_softmax.hpp index ceca0538b80..6481c38dd53 100644 --- a/stan/math/rev/fun/log_softmax.hpp +++ b/stan/math/rev/fun/log_softmax.hpp @@ -60,7 +60,7 @@ inline auto log_softmax(const T& x) { } template * = nullptr, - require_not_t>>* = nullptr> + require_eigen_col_vector_t* = nullptr> inline auto log_softmax(const T& x) { const int a_size = x.size(); @@ -103,20 +103,64 @@ inline auto log_softmax(const T& x) { * @return softmax of the input * @throw std::domain_error if the input size is 0 */ -template * = nullptr> +template * = nullptr, + require_t>* = nullptr> inline auto log_softmax(const T& x) { check_nonzero_size("log_softmax", "x", x); - - const auto& theta = (x.val().array() - x.val().maxCoeff()).eval(); - return make_callback_var( - (theta.array() - log(theta.exp().sum())).matrix(), + log_softmax(x.val()).eval(), [x](const auto& res) mutable { + // grad: g - sum(g) * softmax(x), where softmax(x) = exp(log_softmax(x)) x.adj().noalias() += res.adj() - (res.adj().sum() * res.val().array().exp()).matrix(); }); } +/** + * Return the log softmax of the rows of the specified matrix. + * Applied independently to each row. + * + * @tparam T type of input (var_value) + * @param x input matrix + * @return log-softmax applied row-wise + */ +template * = nullptr, + require_t>* = nullptr> +inline auto log_softmax(const T& x) { + check_nonzero_size("log_softmax", "x", x); + return make_callback_var( + Eigen::MatrixXd(log_softmax(x.val())), + [x](const auto& res) mutable { + // grad per row: g - softmax(x) * sum(g), softmax(x) = exp(log_softmax(x)) + const auto row_sums = res.adj().rowwise().sum().eval(); + x.adj().noalias() + += res.adj() + - (res.val().array().exp().colwise() * row_sums.array()) + .matrix(); + }); +} + +/** + * Return the log softmax of the rows of the specified Eigen matrix + * whose entries are vars. Applied independently to each row. + * + * @tparam T type of input (Eigen matrix with var scalar) + * @param x input matrix + * @return log-softmax applied row-wise + */ +template * = nullptr, + require_not_t>>* = nullptr> +inline auto log_softmax(const T& x) { + check_nonzero_size("log_softmax", "x", x); + plain_type_t result(x.rows(), x.cols()); + for (int i = 0; i < x.rows(); ++i) { + result.row(i) = log_softmax(x.row(i)); + } + return result; +} + /** * Return the log softmax of the specified `std::vector` or * `std::vector` of containers. diff --git a/stan/math/rev/fun/softmax.hpp b/stan/math/rev/fun/softmax.hpp index 03841e0574e..cbe0efe6c87 100644 --- a/stan/math/rev/fun/softmax.hpp +++ b/stan/math/rev/fun/softmax.hpp @@ -23,7 +23,9 @@ namespace math { * @return Softmax of the input. * @throw std::domain_error If the input vector is size 0. */ -template * = nullptr> +template * = nullptr, + require_t>* = nullptr> inline auto softmax(const Mat& alpha) { using mat_plain = plain_type_t; using ret_type = return_var_matrix_t; @@ -45,6 +47,39 @@ inline auto softmax(const Mat& alpha) { return ret_type(res); } +/** + * Return the softmax of the rows of the specified matrix. + * Softmax is applied independently to each row, producing a + * row-stochastic matrix. + * + * @param m Unconstrained input matrix. + * @return Row-stochastic matrix result. + */ +template * = nullptr, + require_t>* = nullptr> +inline auto softmax(const Mat& m) { + using mat_plain = plain_type_t; + using ret_type = return_var_matrix_t; + if (m.size() == 0) { + return ret_type(m); + } + arena_t m_arena = m; + using double_mat_t + = Eigen::Matrix; + arena_t res_val = softmax(value_of(m_arena)); + arena_t res = res_val; + + reverse_pass_callback([res_val, res, m_arena]() mutable { + const auto& g = to_ref(res.adj()); + const auto dots = (res_val.array() * g.array()).rowwise().sum().eval(); + m_arena.adj() += (res_val.array() * (g.array().colwise() - dots.array())) + .matrix(); + }); + + return ret_type(res); +} + } // namespace math } // namespace stan #endif diff --git a/test/unit/math/mix/fun/log_softmax_test.cpp b/test/unit/math/mix/fun/log_softmax_test.cpp index 5d8915fea10..e9c69e02da8 100644 --- a/test/unit/math/mix/fun/log_softmax_test.cpp +++ b/test/unit/math/mix/fun/log_softmax_test.cpp @@ -33,6 +33,17 @@ TEST(MathMixMatFun, logSoftmax) { stan::test::expect_ad(f, x3c); stan::test::expect_ad_matvar(f, x3c); + // Matrices (row-wise log_softmax) + Eigen::MatrixXd mx2(2, 2); + mx2 << -1, 1, 0, 2; + stan::test::expect_ad(f, mx2); + stan::test::expect_ad_matvar(f, mx2); + + Eigen::MatrixXd mx3(2, 3); + mx3 << -1, 1, 10, 0.5, -1, 3; + stan::test::expect_ad(f, mx3); + stan::test::expect_ad_matvar(f, mx3); + // Row Vectors Eigen::RowVectorXd rx0(0); // error case stan::test::expect_ad(f, rx0); diff --git a/test/unit/math/mix/fun/softmax_test.cpp b/test/unit/math/mix/fun/softmax_test.cpp index d22fa33e37e..6fc5b27e279 100644 --- a/test/unit/math/mix/fun/softmax_test.cpp +++ b/test/unit/math/mix/fun/softmax_test.cpp @@ -43,6 +43,17 @@ TEST(MathMixMatFun, softmax) { stan::test::expect_ad(tols, f, d4); expect_ad_matvar(f, d4); + // Matrices (row-wise softmax) + Eigen::MatrixXd ma(2, 3); + ma << -1, 1, 10, 0.5, -1, 3; + stan::test::expect_ad(tols, f, ma); + expect_ad_matvar(f, ma); + + Eigen::MatrixXd mb(3, 2); + mb << 0, 1, -1, 2, 3, -2; + stan::test::expect_ad(tols, f, mb); + expect_ad_matvar(f, mb); + // Row vectors Eigen::RowVectorXd ra(0); stan::test::expect_ad(tols, f, ra); diff --git a/test/unit/math/prim/fun/log_softmax_test.cpp b/test/unit/math/prim/fun/log_softmax_test.cpp index 27f682a17ff..18247e0a387 100644 --- a/test/unit/math/prim/fun/log_softmax_test.cpp +++ b/test/unit/math/prim/fun/log_softmax_test.cpp @@ -66,6 +66,26 @@ TEST(MathMatrixPrimMat, log_softmax) { // x3 << -1.0, 1.0, 10.0; // test_log_softmax(x3); } +TEST(MathMatrixPrimMat, log_softmax_matrix) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::log_softmax; + using stan::math::softmax; + + Matrix m(2, 3); + m << -1.0, 1.0, 10.0, 0.5, -1.0, 3.0; + Matrix result = log_softmax(m); + + EXPECT_EQ(m.rows(), result.rows()); + EXPECT_EQ(m.cols(), result.cols()); + // each row matches per-row log_softmax and is consistent with log(softmax) + for (int i = 0; i < result.rows(); ++i) { + Matrix expected = log_softmax(m.row(i)); + for (int j = 0; j < result.cols(); ++j) + EXPECT_FLOAT_EQ(expected(j), result(i, j)); + } +} + TEST(MathMatrixPrimMat, log_softmax_exception) { using stan::math::log_softmax; stan::math::vector_d v0; // size == 0 diff --git a/test/unit/math/prim/fun/softmax_test.cpp b/test/unit/math/prim/fun/softmax_test.cpp index 356856f4050..c61570145e1 100644 --- a/test/unit/math/prim/fun/softmax_test.cpp +++ b/test/unit/math/prim/fun/softmax_test.cpp @@ -29,6 +29,27 @@ TEST(MathMatrixPrimMat, softmax) { EXPECT_FLOAT_EQ(exp(10) / (exp(-1) + exp(1) + exp(10.0)), theta3[2]); } +TEST(MathMatrixPrimMat, softmax_matrix) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::softmax; + + Matrix m(2, 3); + m << -1.0, 1.0, 10.0, 0.5, -1.0, 3.0; + Matrix theta = softmax(m); + + EXPECT_EQ(m.rows(), theta.rows()); + EXPECT_EQ(m.cols(), theta.cols()); + // each row sums to 1 + for (int i = 0; i < theta.rows(); ++i) { + EXPECT_FLOAT_EQ(1.0, theta.row(i).sum()); + // each row matches per-row softmax + Matrix expected = softmax(m.row(i)); + for (int j = 0; j < theta.cols(); ++j) + EXPECT_FLOAT_EQ(expected(j), theta(i, j)); + } +} + TEST(MathMatrixPrimMat, softmax_row_vector) { using Eigen::Dynamic; using Eigen::Matrix; From 5815b2ed2115971fb8e60aa0b3000d006aeb5c95 Mon Sep 17 00:00:00 2001 From: "Jachym.Barvinek" Date: Tue, 28 Apr 2026 21:35:35 +0200 Subject: [PATCH 3/5] additional test for edge cases --- test/unit/math/prim/fun/log_softmax_test.cpp | 30 +++++++++++++++++++ test/unit/math/prim/fun/softmax_test.cpp | 31 ++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/test/unit/math/prim/fun/log_softmax_test.cpp b/test/unit/math/prim/fun/log_softmax_test.cpp index 18247e0a387..de9e6dae847 100644 --- a/test/unit/math/prim/fun/log_softmax_test.cpp +++ b/test/unit/math/prim/fun/log_softmax_test.cpp @@ -1,5 +1,6 @@ #include #include +#include #include inline void test_log_softmax( @@ -66,6 +67,35 @@ TEST(MathMatrixPrimMat, log_softmax) { // x3 << -1.0, 1.0, 10.0; // test_log_softmax(x3); } +TEST(MathMatrixPrimMat, log_softmax_neg_inf) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::log_softmax; + constexpr double neg_inf = -std::numeric_limits::infinity(); + + // -inf in a vector stays -inf in the output; the rest get the + // proper restricted log-softmax. + Matrix v(3); + v << neg_inf, 1.0, 2.0; + Matrix result = log_softmax(v); + const double lse_finite = std::log(exp(1.0) + exp(2.0)); + EXPECT_EQ(neg_inf, result[0]); + EXPECT_FLOAT_EQ(1.0 - lse_finite, result[1]); + EXPECT_FLOAT_EQ(2.0 - lse_finite, result[2]); + + // Row-wise on a matrix. + Matrix m(2, 3); + m << neg_inf, 1.0, 2.0, // + 0.0, neg_inf, 0.0; + Matrix mres = log_softmax(m); + EXPECT_EQ(neg_inf, mres(0, 0)); + EXPECT_FLOAT_EQ(1.0 - lse_finite, mres(0, 1)); + EXPECT_FLOAT_EQ(2.0 - lse_finite, mres(0, 2)); + EXPECT_FLOAT_EQ(-std::log(2.0), mres(1, 0)); + EXPECT_EQ(neg_inf, mres(1, 1)); + EXPECT_FLOAT_EQ(-std::log(2.0), mres(1, 2)); +} + TEST(MathMatrixPrimMat, log_softmax_matrix) { using Eigen::Dynamic; using Eigen::Matrix; diff --git a/test/unit/math/prim/fun/softmax_test.cpp b/test/unit/math/prim/fun/softmax_test.cpp index c61570145e1..ad4fe248afa 100644 --- a/test/unit/math/prim/fun/softmax_test.cpp +++ b/test/unit/math/prim/fun/softmax_test.cpp @@ -1,5 +1,6 @@ #include #include +#include TEST(MathMatrixPrimMat, softmax) { using Eigen::Dynamic; @@ -29,6 +30,36 @@ TEST(MathMatrixPrimMat, softmax) { EXPECT_FLOAT_EQ(exp(10) / (exp(-1) + exp(1) + exp(10.0)), theta3[2]); } +TEST(MathMatrixPrimMat, softmax_neg_inf) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::softmax; + constexpr double neg_inf = -std::numeric_limits::infinity(); + + // -inf in a vector pins that component to exactly 0; the rest renormalize. + Matrix v(3); + v << neg_inf, 1.0, 2.0; + Matrix theta = softmax(v); + EXPECT_FLOAT_EQ(0.0, theta[0]); + EXPECT_FLOAT_EQ(exp(1.0) / (exp(1.0) + exp(2.0)), theta[1]); + EXPECT_FLOAT_EQ(exp(2.0) / (exp(1.0) + exp(2.0)), theta[2]); + EXPECT_FLOAT_EQ(1.0, theta.sum()); + + // Row-wise on a matrix: each row independently handles -inf. + Matrix m(2, 3); + m << neg_inf, 1.0, 2.0, // + 0.0, neg_inf, 0.0; + Matrix result = softmax(m); + EXPECT_FLOAT_EQ(0.0, result(0, 0)); + EXPECT_FLOAT_EQ(exp(1.0) / (exp(1.0) + exp(2.0)), result(0, 1)); + EXPECT_FLOAT_EQ(exp(2.0) / (exp(1.0) + exp(2.0)), result(0, 2)); + EXPECT_FLOAT_EQ(0.5, result(1, 0)); + EXPECT_FLOAT_EQ(0.0, result(1, 1)); + EXPECT_FLOAT_EQ(0.5, result(1, 2)); + EXPECT_FLOAT_EQ(1.0, result.row(0).sum()); + EXPECT_FLOAT_EQ(1.0, result.row(1).sum()); +} + TEST(MathMatrixPrimMat, softmax_matrix) { using Eigen::Dynamic; using Eigen::Matrix; From 3271c41412176c3818a73f9b3c4ee3aeb40cd938 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Tue, 28 Apr 2026 16:28:04 -0400 Subject: [PATCH 4/5] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/fwd/fun/log_softmax.hpp | 3 ++- stan/math/prim/fun/log_softmax.hpp | 6 +++--- stan/math/rev/fun/log_softmax.hpp | 9 ++++----- stan/math/rev/fun/softmax.hpp | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/stan/math/fwd/fun/log_softmax.hpp b/stan/math/fwd/fun/log_softmax.hpp index 2c317016cea..de116e0b219 100644 --- a/stan/math/fwd/fun/log_softmax.hpp +++ b/stan/math/fwd/fun/log_softmax.hpp @@ -31,7 +31,8 @@ inline auto log_softmax(const Mat& m) { const auto exp_s = shifted.exp().eval(); const auto row_sums = exp_s.rowwise().sum().eval(); const auto lsm_val = (shifted.colwise() - row_sums.log()).matrix().eval(); - // softmax values needed for the tangent: d_in - softmax(x) ⊙ dot(softmax(x), d_in) + // softmax values needed for the tangent: d_in - softmax(x) ⊙ dot(softmax(x), + // d_in) const auto s = (exp_s.colwise() / row_sums).eval(); const auto d_in = m_ref.d().eval(); const auto dots = (s.array() * d_in.array()).rowwise().sum().eval(); diff --git a/stan/math/prim/fun/log_softmax.hpp b/stan/math/prim/fun/log_softmax.hpp index acdb730a15c..19266a6c6b4 100644 --- a/stan/math/prim/fun/log_softmax.hpp +++ b/stan/math/prim/fun/log_softmax.hpp @@ -41,9 +41,9 @@ namespace math { */ template * = nullptr, require_container_t* = nullptr, - require_not_t>::value - && !is_eigen_vector>::value>>* = nullptr> + require_not_t>::value + && !is_eigen_vector>::value>>* = nullptr> inline auto log_softmax(Container&& x) { check_nonzero_size("log_softmax", "v", x); return make_holder( diff --git a/stan/math/rev/fun/log_softmax.hpp b/stan/math/rev/fun/log_softmax.hpp index 6481c38dd53..8beee1e3836 100644 --- a/stan/math/rev/fun/log_softmax.hpp +++ b/stan/math/rev/fun/log_softmax.hpp @@ -109,8 +109,7 @@ template * = nullptr, inline auto log_softmax(const T& x) { check_nonzero_size("log_softmax", "x", x); return make_callback_var( - log_softmax(x.val()).eval(), - [x](const auto& res) mutable { + log_softmax(x.val()).eval(), [x](const auto& res) mutable { // grad: g - sum(g) * softmax(x), where softmax(x) = exp(log_softmax(x)) x.adj().noalias() += res.adj() - (res.adj().sum() * res.val().array().exp()).matrix(); @@ -131,9 +130,9 @@ template * = nullptr, inline auto log_softmax(const T& x) { check_nonzero_size("log_softmax", "x", x); return make_callback_var( - Eigen::MatrixXd(log_softmax(x.val())), - [x](const auto& res) mutable { - // grad per row: g - softmax(x) * sum(g), softmax(x) = exp(log_softmax(x)) + Eigen::MatrixXd(log_softmax(x.val())), [x](const auto& res) mutable { + // grad per row: g - softmax(x) * sum(g), softmax(x) = + // exp(log_softmax(x)) const auto row_sums = res.adj().rowwise().sum().eval(); x.adj().noalias() += res.adj() diff --git a/stan/math/rev/fun/softmax.hpp b/stan/math/rev/fun/softmax.hpp index cbe0efe6c87..6f8711fb31d 100644 --- a/stan/math/rev/fun/softmax.hpp +++ b/stan/math/rev/fun/softmax.hpp @@ -73,8 +73,8 @@ inline auto softmax(const Mat& m) { reverse_pass_callback([res_val, res, m_arena]() mutable { const auto& g = to_ref(res.adj()); const auto dots = (res_val.array() * g.array()).rowwise().sum().eval(); - m_arena.adj() += (res_val.array() * (g.array().colwise() - dots.array())) - .matrix(); + m_arena.adj() + += (res_val.array() * (g.array().colwise() - dots.array())).matrix(); }); return ret_type(res); From 008a1310d4d76be2d2e3ad8d40360f06fb2438e2 Mon Sep 17 00:00:00 2001 From: "Jachym.Barvinek" Date: Wed, 29 Apr 2026 08:11:40 +0200 Subject: [PATCH 5/5] fix documentation --- stan/math/fwd/fun/log_softmax.hpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/stan/math/fwd/fun/log_softmax.hpp b/stan/math/fwd/fun/log_softmax.hpp index 2c317016cea..33b2d58080b 100644 --- a/stan/math/fwd/fun/log_softmax.hpp +++ b/stan/math/fwd/fun/log_softmax.hpp @@ -12,12 +12,14 @@ namespace stan { namespace math { /** - * Return the log softmax of the specified vector or container of vectors. + * Return the log softmax of the rows of the specified matrix. + * Each row is transformed independently; the result has the same shape + * as the input. * - * @tparam T Type of input vector or matrix. - * @param[in] x Unconstrained input vector. - * @return Softmax of the input. - * @throw std::domain_error If the input vector is size 0. + * @tparam Mat type of input matrix (Eigen matrix with fvar scalar) + * @param[in] m Matrix to transform row-wise. + * @return Log-softmax applied row-wise. + * @throw std::domain_error If the input matrix is size 0. */ template * = nullptr, require_not_eigen_vector_t* = nullptr,