diff --git a/stan/math/fwd/meta/operands_and_partials.hpp b/stan/math/fwd/meta/operands_and_partials.hpp index e622d24def8..c771de79287 100644 --- a/stan/math/fwd/meta/operands_and_partials.hpp +++ b/stan/math/fwd/meta/operands_and_partials.hpp @@ -70,11 +70,11 @@ template class operands_and_partials> { public: - internal::ops_partials_edge edge1_; - internal::ops_partials_edge edge2_; - internal::ops_partials_edge edge3_; - internal::ops_partials_edge edge4_; - internal::ops_partials_edge edge5_; + internal::ops_partials_edge> edge1_; + internal::ops_partials_edge> edge2_; + internal::ops_partials_edge> edge3_; + internal::ops_partials_edge> edge4_; + internal::ops_partials_edge> edge5_; using T_return_type = fvar; explicit operands_and_partials(const Op1& o1) : edge1_(o1) {} operands_and_partials(const Op1& o1, const Op2& o2) diff --git a/stan/math/prim/err.hpp b/stan/math/prim/err.hpp index 969fcc37fb4..e38f0fc53ee 100644 --- a/stan/math/prim/err.hpp +++ b/stan/math/prim/err.hpp @@ -48,6 +48,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/err/hmm_check.hpp b/stan/math/prim/err/hmm_check.hpp new file mode 100644 index 00000000000..acdf3ce903e --- /dev/null +++ b/stan/math/prim/err/hmm_check.hpp @@ -0,0 +1,47 @@ +#ifndef STAN_MATH_PRIM_ERR_HMM_CHECK_HPP +#define STAN_MATH_PRIM_ERR_HMM_CHECK_HPP + +#include +#include + +namespace stan { +namespace math { + +/** + * Check arguments for hidden Markov model functions with a discrete + * latent state (lpdf, rng for latent states, and marginal probabilities + * for latent sates). + * + * @tparam T_omega type of the log likelihood matrix + * @tparam T_Gamma type of the transition matrix + * @tparam T_rho type of the initial guess vector + * @param[in] log_omegas log matrix of observational densities. + * @param[in] Gamma transition density between hidden states. + * @param[in] rho initial state + * @param[in] function the name of the function using the arguments. + * @throw `std::invalid_argument` if Gamma is not square + * or if the size of rho is not the number of rows of log_omegas. + * @throw `std::domain_error` if rho is not a simplex or the rows + * of Gamma are not a simplex. + */ +template +inline void hmm_check( + const Eigen::Matrix& log_omegas, + const Eigen::Matrix& Gamma, + const Eigen::Matrix& rho, const char* function) { + int n_states = log_omegas.rows(); + int n_transitions = log_omegas.cols() - 1; + + check_consistent_size(function, "rho", rho, n_states); + check_simplex(function, "rho", rho); + check_square(function, "Gamma", Gamma); + check_nonzero_size(function, "Gamma", Gamma); + check_multiplicable(function, "Gamma", Gamma, "log_omegas", log_omegas); + for (int i = 0; i < Gamma.rows(); ++i) { + check_simplex(function, "Gamma[i, ]", row(Gamma, i + 1)); + } +} + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/prim/functor/coupled_ode_system.hpp b/stan/math/prim/functor/coupled_ode_system.hpp index 004dfb8009c..7971074bdb0 100644 --- a/stan/math/prim/functor/coupled_ode_system.hpp +++ b/stan/math/prim/functor/coupled_ode_system.hpp @@ -127,14 +127,16 @@ struct coupled_ode_system_impl { template struct coupled_ode_system : public coupled_ode_system_impl< - std::is_arithmetic>::value, F, - T_initial, Args...> { + std::is_arithmetic>::value, + F, T_initial, Args...> { coupled_ode_system(const F& f, const Eigen::Matrix& y0, std::ostream* msgs, const Args&... args) : coupled_ode_system_impl< - std::is_arithmetic>::value, F, - T_initial, Args...>(f, y0, msgs, args...) {} + std::is_arithmetic>::value, + F, T_initial, Args...>(f, y0, msgs, args...) {} }; } // namespace math diff --git a/stan/math/prim/functor/integrate_ode_rk45.hpp b/stan/math/prim/functor/integrate_ode_rk45.hpp index 9e3f755bebe..38d7ed19ace 100644 --- a/stan/math/prim/functor/integrate_ode_rk45.hpp +++ b/stan/math/prim/functor/integrate_ode_rk45.hpp @@ -15,7 +15,8 @@ namespace math { */ template -std::vector>> +std::vector>> integrate_ode_rk45(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -28,7 +29,9 @@ integrate_ode_rk45(const F& f, const std::vector& y0, const T_t0& t0, = ode_rk45_tol(f_adapted, to_vector(y0), t0, ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> y_converted; + std::vector>> + y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index 351387a7abf..ee64d97ddb6 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -22,8 +22,10 @@ namespace internal { template struct integrate_ode_std_vector_interface_adapter { const F f_; + const int num_vars__; - integrate_ode_std_vector_interface_adapter(const F& f) : f_(f) {} + integrate_ode_std_vector_interface_adapter(const F& f) + : f_(f), num_vars__(f.num_vars__) {} template auto operator()(const T0& t, const Eigen::Matrix& y, @@ -32,6 +34,62 @@ struct integrate_ode_std_vector_interface_adapter { const std::vector& x_int) const { return to_vector(f_(t, to_array_1d(y), msgs, theta, x, x_int)); } + template + void save_varis(v** p) const { + f_.save_varis(p); + } + void accumulate_adjoints(double* p) const { f_.accumulate_adjoints(p); } + void set_zer_adjoints() const { f_.set_zero_adjoints(); } + + struct DeepCopy_cl__ { + const typename F::DeepCopy__ f_; + const int num_vars__; + struct ValueOf_cl__ { + const typename F::ValueOf__ f_; + const static int num_vars__ = 0; + ValueOf_cl__(const integrate_ode_std_vector_interface_adapter& f) + : f_(f.f_) {} + ValueOf_cl__(const DeepCopy_cl__& f) : f_(f.f_) {} + template + auto operator()(const T0& t, + const Eigen::Matrix& y, + std::ostream* msgs, const std::vector& theta, + const std::vector& x, + const std::vector& x_int) const { + return to_vector(f_(t, to_array_1d(y), msgs, theta, x, x_int)); + } + template + void save_varis(v** p) const { + f_.save_varis(p); + } + void accumulate_adjoints(double* p) const { f_.accumulate_adjoints(p); } + void set_zer_adjoints() const { f_.set_zero_adjoints(); } + using captured_scalar_t__ = double; + using ValueOf__ = ValueOf_cl__; + using DeepCopy__ = ValueOf_cl__; + }; + DeepCopy_cl__(const integrate_ode_std_vector_interface_adapter& f) + : f_(f.f_), num_vars__(f.num_vars__) {} + template + auto operator()(const T0& t, const Eigen::Matrix& y, + std::ostream* msgs, const std::vector& theta, + const std::vector& x, + const std::vector& x_int) const { + return to_vector(f_(t, to_array_1d(y), msgs, theta, x, x_int)); + } + template + void save_varis(v** p) const { + f_.save_varis(p); + } + void accumulate_adjoints(double* p) const { f_.accumulate_adjoints(p); } + void set_zer_adjoints() const { f_.set_zero_adjoints(); } + using captured_scalar_t__ = typename F::captured_scalar_t__; + using ValueOf__ = ValueOf_cl__; + using DeepCopy__ = DeepCopy_cl__; + }; + using captured_scalar_t__ = typename F::captured_scalar_t__; + using DeepCopy__ = DeepCopy_cl__; + using ValueOf__ = typename DeepCopy__::ValueOf__; }; } // namespace internal diff --git a/stan/math/prim/functor/ode_rk45.hpp b/stan/math/prim/functor/ode_rk45.hpp index 9ceb84d8d9d..7fb1cbe1006 100644 --- a/stan/math/prim/functor/ode_rk45.hpp +++ b/stan/math/prim/functor/ode_rk45.hpp @@ -51,7 +51,8 @@ namespace math { */ template -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45_tol(const F& f, const Eigen::Matrix& y0_arg, T_t0 t0, @@ -89,7 +90,8 @@ ode_rk45_tol(const F& f, absolute_tolerance); check_positive("integrate_ode_rk45", "max_num_steps", max_num_steps); - using return_t = return_type_t; + using return_t = return_type_t; // creates basic or coupled system by template specializations coupled_ode_system coupled_system(f, y0, msgs, args...); @@ -172,8 +174,9 @@ ode_rk45_tol(const F& f, */ template -std::vector< - Eigen::Matrix, Eigen::Dynamic, 1>> +std::vector, + Eigen::Dynamic, 1>> ode_rk45(const F& f, const Eigen::Matrix& y0, T_t0 t0, const std::vector& ts, std::ostream* msgs, const Args&... args) { diff --git a/stan/math/prim/functor/ode_store_sensitivities.hpp b/stan/math/prim/functor/ode_store_sensitivities.hpp index 8d394a41eb5..68c3b00886a 100644 --- a/stan/math/prim/functor/ode_store_sensitivities.hpp +++ b/stan/math/prim/functor/ode_store_sensitivities.hpp @@ -23,10 +23,11 @@ namespace math { * @param args Extra arguments passed unmodified through to ODE right hand side * @return ODE state */ -template < - typename F, typename T_y0_t0, typename T_t0, typename T_t, typename... Args, - typename - = require_all_arithmetic_t...>> +template ...>> Eigen::VectorXd ode_store_sensitivities( const F& f, const Eigen::VectorXd& coupled_state, const Eigen::Matrix& y0, T_t0 t0, T_t t, diff --git a/stan/math/prim/meta/as_array_or_scalar.hpp b/stan/math/prim/meta/as_array_or_scalar.hpp index 38deb58f38b..6560b52804e 100644 --- a/stan/math/prim/meta/as_array_or_scalar.hpp +++ b/stan/math/prim/meta/as_array_or_scalar.hpp @@ -17,8 +17,8 @@ namespace math { * @return Same value. */ template > -inline const T& as_array_or_scalar(const T& v) { - return v; +inline T as_array_or_scalar(T&& v) { + return std::forward(v); } /** \ingroup type_trait diff --git a/stan/math/prim/meta/broadcast_array.hpp b/stan/math/prim/meta/broadcast_array.hpp index 3c92f4e139b..d7d5c05e3f1 100644 --- a/stan/math/prim/meta/broadcast_array.hpp +++ b/stan/math/prim/meta/broadcast_array.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -28,7 +29,8 @@ class broadcast_array { */ template void operator=(const Y& m) { - prim_ = m[0]; + ref_type_t m_ref = m; + prim_ = m_ref[0]; } }; diff --git a/stan/math/prim/prob.hpp b/stan/math/prim/prob.hpp index 2fc90e047a8..8d7190e0a15 100644 --- a/stan/math/prim/prob.hpp +++ b/stan/math/prim/prob.hpp @@ -134,7 +134,9 @@ #include #include #include -#include +#include +#include +#include #include #include #include diff --git a/stan/math/prim/prob/bernoulli_cdf.hpp b/stan/math/prim/prob/bernoulli_cdf.hpp index d4decf90969..0c36305f65b 100644 --- a/stan/math/prim/prob/bernoulli_cdf.hpp +++ b/stan/math/prim/prob/bernoulli_cdf.hpp @@ -27,21 +27,22 @@ namespace math { template return_type_t bernoulli_cdf(const T_n& n, const T_prob& theta) { using T_partials_return = partials_return_t; + using T_theta_ref = ref_type_t; static const char* function = "bernoulli_cdf"; - check_finite(function, "Probability parameter", theta); - check_bounded(function, "Probability parameter", theta, 0.0, 1.0); check_consistent_sizes(function, "Random variable", n, "Probability parameter", theta); + T_theta_ref theta_ref = theta; + check_bounded(function, "Probability parameter", theta_ref, 0.0, 1.0); if (size_zero(n, theta)) { return 1.0; } T_partials_return P(1.0); - operands_and_partials ops_partials(theta); + operands_and_partials ops_partials(theta_ref); scalar_seq_view n_vec(n); - scalar_seq_view theta_vec(theta); + scalar_seq_view theta_vec(theta_ref); size_t max_size_seq_view = max_size(n, theta); // Explicit return for extreme values diff --git a/stan/math/prim/prob/bernoulli_lccdf.hpp b/stan/math/prim/prob/bernoulli_lccdf.hpp index 4c05d70a90b..0e681d1547d 100644 --- a/stan/math/prim/prob/bernoulli_lccdf.hpp +++ b/stan/math/prim/prob/bernoulli_lccdf.hpp @@ -30,22 +30,23 @@ namespace math { template return_type_t bernoulli_lccdf(const T_n& n, const T_prob& theta) { using T_partials_return = partials_return_t; + using T_theta_ref = ref_type_t; using std::log; static const char* function = "bernoulli_lccdf"; - check_finite(function, "Probability parameter", theta); - check_bounded(function, "Probability parameter", theta, 0.0, 1.0); check_consistent_sizes(function, "Random variable", n, "Probability parameter", theta); + T_theta_ref theta_ref = theta; + check_bounded(function, "Probability parameter", theta_ref, 0.0, 1.0); if (size_zero(n, theta)) { return 0.0; } T_partials_return P(0.0); - operands_and_partials ops_partials(theta); + operands_and_partials ops_partials(theta_ref); scalar_seq_view n_vec(n); - scalar_seq_view theta_vec(theta); + scalar_seq_view theta_vec(theta_ref); size_t max_size_seq_view = max_size(n, theta); // Explicit return for extreme values diff --git a/stan/math/prim/prob/bernoulli_lcdf.hpp b/stan/math/prim/prob/bernoulli_lcdf.hpp index ac6327aeed4..0d57760837f 100644 --- a/stan/math/prim/prob/bernoulli_lcdf.hpp +++ b/stan/math/prim/prob/bernoulli_lcdf.hpp @@ -30,22 +30,23 @@ namespace math { template return_type_t bernoulli_lcdf(const T_n& n, const T_prob& theta) { using T_partials_return = partials_return_t; + using T_theta_ref = ref_type_t; using std::log; static const char* function = "bernoulli_lcdf"; - check_finite(function, "Probability parameter", theta); - check_bounded(function, "Probability parameter", theta, 0.0, 1.0); check_consistent_sizes(function, "Random variable", n, "Probability parameter", theta); + T_theta_ref theta_ref = theta; + check_bounded(function, "Probability parameter", theta_ref, 0.0, 1.0); if (size_zero(n, theta)) { return 0.0; } T_partials_return P(0.0); - operands_and_partials ops_partials(theta); + operands_and_partials ops_partials(theta_ref); scalar_seq_view n_vec(n); - scalar_seq_view theta_vec(theta); + scalar_seq_view theta_vec(theta_ref); size_t max_size_seq_view = max_size(n, theta); // Explicit return for extreme values diff --git a/stan/math/prim/prob/bernoulli_logit_glm_rng.hpp b/stan/math/prim/prob/bernoulli_logit_glm_rng.hpp index ed48b749be6..54377cad1f2 100644 --- a/stan/math/prim/prob/bernoulli_logit_glm_rng.hpp +++ b/stan/math/prim/prob/bernoulli_logit_glm_rng.hpp @@ -38,35 +38,42 @@ namespace math { */ template inline typename VectorBuilder::type bernoulli_logit_glm_rng( - const T_x &x, const T_alpha &alpha, const T_beta &beta, RNG &rng) { + const T_x& x, const T_alpha& alpha, const T_beta& beta, RNG& rng) { using boost::bernoulli_distribution; using boost::variate_generator; + using T_x_ref = ref_type_t; + using T_alpha_ref = ref_type_t; + using T_beta_ref = ref_type_t; - const size_t N = x.row(0).size(); - const size_t M = x.col(0).size(); + const size_t N = x.cols(); + const size_t M = x.rows(); - static const char *function = "bernoulli_logit_glm_rng"; - check_finite(function, "Matrix of independent variables", x); - check_finite(function, "Weight vector", beta); - check_finite(function, "Intercept", alpha); + static const char* function = "bernoulli_logit_glm_rng"; check_consistent_size(function, "Weight vector", beta, N); check_consistent_size(function, "Vector of intercepts", alpha, M); + T_x_ref x_ref = x; + T_alpha_ref alpha_ref = alpha; + T_beta_ref beta_ref = beta; + check_finite(function, "Matrix of independent variables", x_ref); + check_finite(function, "Weight vector", beta_ref); + check_finite(function, "Intercept", alpha_ref); - scalar_seq_view beta_vec(beta); - Eigen::VectorXd beta_vector(N); - for (int i = 0; i < N; ++i) { - beta_vector[i] = beta_vec[i]; - } + const auto& beta_vector = as_column_vector_or_scalar(beta_ref); - Eigen::VectorXd x_beta = x * beta_vector; + Eigen::VectorXd x_beta; + if (is_vector::value) { + x_beta = x_ref * beta_vector; + } else { + x_beta = (x_ref.array() * forward_as(beta_vector)).rowwise().sum(); + } - scalar_seq_view alpha_vec(alpha); + scalar_seq_view alpha_vec(alpha_ref); VectorBuilder output(M); for (size_t m = 0; m < M; ++m) { double theta_m = alpha_vec[m] + x_beta(m); - variate_generator> bernoulli_rng( + variate_generator> bernoulli_rng( rng, bernoulli_distribution<>(inv_logit(theta_m))); output[m] = bernoulli_rng(); } diff --git a/stan/math/prim/prob/bernoulli_logit_lpmf.hpp b/stan/math/prim/prob/bernoulli_logit_lpmf.hpp index 418aacb23ef..589955fbc97 100644 --- a/stan/math/prim/prob/bernoulli_logit_lpmf.hpp +++ b/stan/math/prim/prob/bernoulli_logit_lpmf.hpp @@ -7,7 +7,9 @@ #include #include #include +#include #include +#include #include namespace stan { @@ -28,12 +30,17 @@ namespace math { template return_type_t bernoulli_logit_lpmf(const T_n& n, const T_prob& theta) { using T_partials_return = partials_return_t; + using T_partials_array = Eigen::Array; using std::exp; + using T_n_ref = ref_type_t; + using T_theta_ref = ref_type_t; static const char* function = "bernoulli_logit_lpmf"; - check_bounded(function, "n", n, 0, 1); - check_not_nan(function, "Logit transformed probability parameter", theta); check_consistent_sizes(function, "Random variable", n, "Probability parameter", theta); + T_n_ref n_ref = n; + T_theta_ref theta_ref = theta; + check_bounded(function, "n", n_ref, 0, 1); + check_not_nan(function, "Logit transformed probability parameter", theta_ref); if (size_zero(n, theta)) { return 0.0; @@ -43,38 +50,44 @@ return_type_t bernoulli_logit_lpmf(const T_n& n, const T_prob& theta) { } T_partials_return logp(0.0); - operands_and_partials ops_partials(theta); + operands_and_partials ops_partials(theta_ref); - scalar_seq_view n_vec(n); - scalar_seq_view theta_vec(theta); - size_t N = max_size(n, theta); + const auto& theta_val = value_of(theta_ref); + const auto& theta_arr = as_array_or_scalar(theta_val); + const auto& n_double = value_of_rec(n_ref); - for (size_t n = 0; n < N; n++) { - const T_partials_return theta_dbl = value_of(theta_vec[n]); - - const int sign = 2 * n_vec[n] - 1; - const T_partials_return ntheta = sign * theta_dbl; - const T_partials_return exp_m_ntheta = exp(-ntheta); + auto signs = to_ref_if::value>( + (2 * as_array_or_scalar(n_double) - 1)); + T_partials_array ntheta; + if (is_vector::value || is_vector::value) { + ntheta = forward_as(signs * theta_arr); + } else { + T_partials_return ntheta_s + = forward_as(signs * theta_arr); + ntheta = T_partials_array::Constant(1, 1, ntheta_s); + } + T_partials_array exp_m_ntheta = exp(-ntheta); + static const double cutoff = 20.0; + logp += sum( + (ntheta > cutoff) + .select(-exp_m_ntheta, + (ntheta < -cutoff).select(ntheta, -log1p(exp_m_ntheta)))); - // Handle extreme values gracefully using Taylor approximations. - static const double cutoff = 20.0; - if (ntheta > cutoff) { - logp -= exp_m_ntheta; - } else if (ntheta < -cutoff) { - logp += ntheta; + if (!is_constant_all::value) { + if (is_vector::value) { + ops_partials.edge1_.partials_ = forward_as( + (ntheta > cutoff) + .select(-exp_m_ntheta, + (ntheta >= -cutoff) + .select(signs * exp_m_ntheta / (exp_m_ntheta + 1), + signs))); } else { - logp -= log1p(exp_m_ntheta); - } - - if (!is_constant_all::value) { - if (ntheta > cutoff) { - ops_partials.edge1_.partials_[n] -= exp_m_ntheta; - } else if (ntheta < -cutoff) { - ops_partials.edge1_.partials_[n] += sign; - } else { - ops_partials.edge1_.partials_[n] - += sign * exp_m_ntheta / (exp_m_ntheta + 1); - } + ops_partials.edge1_.partials_[0] + = sum((ntheta > cutoff) + .select(-exp_m_ntheta, (ntheta >= -cutoff) + .select(signs * exp_m_ntheta + / (exp_m_ntheta + 1), + signs))); } } return ops_partials.build(logp); diff --git a/stan/math/prim/prob/bernoulli_logit_rng.hpp b/stan/math/prim/prob/bernoulli_logit_rng.hpp index 5f3ae9903df..7cc493ef018 100644 --- a/stan/math/prim/prob/bernoulli_logit_rng.hpp +++ b/stan/math/prim/prob/bernoulli_logit_rng.hpp @@ -30,10 +30,11 @@ inline typename VectorBuilder::type bernoulli_logit_rng( const T_t& t, RNG& rng) { using boost::bernoulli_distribution; using boost::variate_generator; + ref_type_t t_ref = t; check_finite("bernoulli_logit_rng", "Logit transformed probability parameter", - t); + t_ref); - scalar_seq_view t_vec(t); + scalar_seq_view t_vec(t_ref); size_t N = stan::math::size(t); VectorBuilder output(N); diff --git a/stan/math/prim/prob/bernoulli_lpmf.hpp b/stan/math/prim/prob/bernoulli_lpmf.hpp index 515c100b0a7..7e749f579a4 100644 --- a/stan/math/prim/prob/bernoulli_lpmf.hpp +++ b/stan/math/prim/prob/bernoulli_lpmf.hpp @@ -29,13 +29,16 @@ namespace math { template return_type_t bernoulli_lpmf(const T_n& n, const T_prob& theta) { using T_partials_return = partials_return_t; + using T_theta_ref = ref_type_t; + using T_n_ref = ref_type_t; using std::log; static const char* function = "bernoulli_lpmf"; - check_bounded(function, "n", n, 0, 1); - check_finite(function, "Probability parameter", theta); - check_bounded(function, "Probability parameter", theta, 0.0, 1.0); check_consistent_sizes(function, "Random variable", n, "Probability parameter", theta); + const T_n_ref n_ref = to_ref(n); + const T_theta_ref theta_ref = to_ref(theta); + check_bounded(function, "n", n_ref, 0, 1); + check_bounded(function, "Probability parameter", theta_ref, 0.0, 1.0); if (size_zero(n, theta)) { return 0.0; @@ -45,10 +48,10 @@ return_type_t bernoulli_lpmf(const T_n& n, const T_prob& theta) { } T_partials_return logp(0.0); - operands_and_partials ops_partials(theta); + operands_and_partials ops_partials(theta_ref); - scalar_seq_view n_vec(n); - scalar_seq_view theta_vec(theta); + scalar_seq_view n_vec(n_ref); + scalar_seq_view theta_vec(theta_ref); size_t N = max_size(n, theta); if (size(theta) == 1) { diff --git a/stan/math/prim/prob/bernoulli_rng.hpp b/stan/math/prim/prob/bernoulli_rng.hpp index 8de2f400bda..9e588953712 100644 --- a/stan/math/prim/prob/bernoulli_rng.hpp +++ b/stan/math/prim/prob/bernoulli_rng.hpp @@ -30,10 +30,10 @@ inline typename VectorBuilder::type bernoulli_rng( using boost::bernoulli_distribution; using boost::variate_generator; static const char* function = "bernoulli_rng"; - check_finite(function, "Probability parameter", theta); - check_bounded(function, "Probability parameter", theta, 0.0, 1.0); + ref_type_t theta_ref = theta; + check_bounded(function, "Probability parameter", theta_ref, 0.0, 1.0); - scalar_seq_view theta_vec(theta); + scalar_seq_view theta_vec(theta_ref); size_t N = stan::math::size(theta); VectorBuilder output(N); diff --git a/stan/math/prim/prob/hmm_hidden_state_prob.hpp b/stan/math/prim/prob/hmm_hidden_state_prob.hpp new file mode 100644 index 00000000000..27fc1e0baa8 --- /dev/null +++ b/stan/math/prim/prob/hmm_hidden_state_prob.hpp @@ -0,0 +1,87 @@ +#ifndef STAN_MATH_PRIM_PROB_HMM_HIDDEN_STATE_PROB_HPP +#define STAN_MATH_PRIM_PROB_HMM_HIDDEN_STATE_PROB_HPP + +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * For a hidden Markov model with observation y, hidden state x, + * and parameters theta, compute the marginal probability + * vector for each x, given y and theta, p(x_i | y, theta). + * In this setting, the hidden states are discrete + * and take values over the finite space {1, ..., K}. + * Hence for each hidden variable x, we compute a simplex with K elements. + * The final result is stored in a K by N matrix, where N is the length of x. + * log_omegas is a matrix of observational densities, where + * the (i, j)th entry corresponds to the density of the ith observation, y_i, + * given x_i = j. + * The transition matrix Gamma is such that the (i, j)th entry is the + * probability that x_n = j given x_{n - 1} = i. The rows of Gamma are + * simplexes. + * This function cannot be used to reconstruct the marginal distributon + * of a state sequence given parameters and an observation sequence, + * p(x | y, theta), + * because it only computes marginals on a state-by-state basis. + * + * @tparam T_omega type of the log likelihood matrix + * @tparam T_Gamma type of the transition matrix + * @tparam T_rho type of the initial guess vector + * @param[in] log_omegas log matrix of observational densities + * @param[in] Gamma transition density between hidden states + * @param[in] rho initial state + * @return the posterior probability for each latent state + * @throw `std::invalid_argument` if Gamma is not square + * or if the size of rho is not the number of rows of log_omegas + * @throw `std::domain_error` if rho is not a simplex and of the rows + * of Gamma are not a simplex + */ +template +inline Eigen::MatrixXd hmm_hidden_state_prob( + const Eigen::Matrix& log_omegas, + const Eigen::Matrix& Gamma, + const Eigen::Matrix& rho) { + int n_states = log_omegas.rows(); + int n_transitions = log_omegas.cols() - 1; + + hmm_check(log_omegas, Gamma, rho, "hmm_hidden_state_prob"); + + Eigen::MatrixXd omegas = value_of(log_omegas).array().exp(); + Eigen::VectorXd rho_dbl = value_of(rho); + Eigen::MatrixXd Gamma_dbl = value_of(Gamma); + + Eigen::MatrixXd alphas(n_states, n_transitions + 1); + alphas.col(0) = omegas.col(0).cwiseProduct(rho_dbl); + alphas.col(0) /= alphas.col(0).maxCoeff(); + + Eigen::MatrixXd Gamma_dbl_transpose = Gamma_dbl.transpose(); + for (int n = 0; n < n_transitions; ++n) + alphas.col(n + 1) + = omegas.col(n + 1).cwiseProduct(Gamma_dbl_transpose * alphas.col(n)); + + // Backward pass with running normalization + Eigen::VectorXd beta = Eigen::VectorXd::Ones(n_states); + + alphas.col(n_transitions) /= alphas.col(n_transitions).sum(); + + for (int n = n_transitions; n-- > 0;) { + beta = Gamma_dbl * omegas.col(n + 1).cwiseProduct(beta); + beta /= beta.maxCoeff(); + + // Reuse alphas to store probabilities + alphas.col(n) = alphas.col(n).cwiseProduct(beta); + alphas.col(n) /= alphas.col(n).sum(); + } + + return alphas; +} + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/prim/prob/hmm_latent_rng.hpp b/stan/math/prim/prob/hmm_latent_rng.hpp new file mode 100644 index 00000000000..fce44a576e4 --- /dev/null +++ b/stan/math/prim/prob/hmm_latent_rng.hpp @@ -0,0 +1,98 @@ +#ifndef STAN_MATH_PRIM_PROB_HMM_LATENT_RNG_HPP +#define STAN_MATH_PRIM_PROB_HMM_LATENT_RNG_HPP + +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * For a hidden Markov model with observation y, hidden state x, + * and parameters theta, generate samples from the posterior distribution + * of the hidden states, x. + * In this setting, the hidden states are discrete + * and takes values over the finite space {1, ..., K}. + * log_omegas is a matrix of observational densities, where + * the (i, j)th entry corresponds to the density of the ith observation, y_i, + * given x_i = j. + * The transition matrix Gamma is such that the (i, j)th entry is the + * probability that x_n = j given x_{n - 1} = i. The rows of Gamma are + * simplexes. + * + * @tparam T_omega type of the log likelihood matrix + * @tparam T_Gamma type of the transition matrix + * @tparam T_rho type of the initial guess vector + * @param[in] log_omegas log matrix of observational densities. + * @param[in] Gamma transition density between hidden states. + * @param[in] rho initial state + * @param[in] rng random number generator + * @return sample from the posterior distribution of the hidden states. + * @throw `std::invalid_argument` if Gamma is not square, when we have + * at least one transition, or if the size of rho is not the + * number of rows of log_omegas. + * @throw `std::domain_error` if rho is not a simplex and of the rows + * of Gamma are not a simplex (when there is at least one transition). + */ +template +inline std::vector hmm_latent_rng( + const Eigen::Matrix& log_omegas, + const Eigen::Matrix& Gamma, + const Eigen::Matrix& rho, RNG& rng) { + int n_states = log_omegas.rows(); + int n_transitions = log_omegas.cols() - 1; + + hmm_check(log_omegas, Gamma, rho, "hmm_latent_rng"); + + Eigen::MatrixXd omegas = value_of(log_omegas).array().exp(); + Eigen::VectorXd rho_dbl = value_of(rho); + Eigen::MatrixXd Gamma_dbl = value_of(Gamma); + + Eigen::MatrixXd alphas(n_states, n_transitions + 1); + alphas.col(0) = omegas.col(0).cwiseProduct(rho_dbl); + alphas.col(0) /= alphas.col(0).maxCoeff(); + + Eigen::MatrixXd Gamma_dbl_transpose = Gamma_dbl.transpose(); + for (int n = 0; n < n_transitions; ++n) { + alphas.col(n + 1) + = omegas.col(n + 1).cwiseProduct(Gamma_dbl_transpose * alphas.col(n)); + alphas.col(n + 1) /= alphas.col(n + 1).maxCoeff(); + } + + Eigen::VectorXd beta = Eigen::VectorXd::Ones(n_states); + + // sample the last hidden state + std::vector hidden_states(n_transitions + 1); + Eigen::VectorXd probs_vec + = alphas.col(n_transitions) / alphas.col(n_transitions).sum(); + std::vector probs(probs_vec.data(), probs_vec.data() + n_states); + boost::random::discrete_distribution<> cat_hidden(probs); + hidden_states[n_transitions] = cat_hidden(rng); + + for (int n = n_transitions; n-- > 0;) { + // sample the nth hidden state conditional on (n + 1)st hidden state + int last_hs = hidden_states[n + 1]; + + probs_vec = alphas.col(n).cwiseProduct(Gamma_dbl.col(last_hs)) + * beta(last_hs) * omegas(last_hs, n + 1); + + probs_vec /= probs_vec.sum(); + std::vector probs(probs_vec.data(), probs_vec.data() + n_states); + boost::random::discrete_distribution<> cat_hidden(probs); + hidden_states[n] = cat_hidden(rng); + + // update backwards state + beta = Gamma_dbl * (omegas.col(n + 1).cwiseProduct(beta)); + beta /= beta.maxCoeff(); + } + + return hidden_states; +} + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/prim/prob/hmm_marginal_lpdf.hpp b/stan/math/prim/prob/hmm_marginal.hpp similarity index 84% rename from stan/math/prim/prob/hmm_marginal_lpdf.hpp rename to stan/math/prim/prob/hmm_marginal.hpp index 67968c6d9b9..acac3d5398d 100644 --- a/stan/math/prim/prob/hmm_marginal_lpdf.hpp +++ b/stan/math/prim/prob/hmm_marginal.hpp @@ -15,17 +15,14 @@ namespace stan { namespace math { -template * = nullptr, - require_all_eigen_col_vector_t* = nullptr, - require_stan_scalar_t* = nullptr, - require_all_vt_same* = nullptr> -inline auto hmm_marginal_lpdf_val(const T_omega& omegas, - const T_Gamma& Gamma_val, - const T_rho& rho_val, T_alphas& alphas, - T_alpha_log_norm& alpha_log_norms, - T_norm& norm_norm) { +template +inline auto hmm_marginal_val( + const Eigen::Matrix& omegas, + const Eigen::Matrix& Gamma_val, + const Eigen::Matrix& rho_val, + Eigen::Matrix& alphas, + Eigen::Matrix& alpha_log_norms, + T_alpha& norm_norm) { const int n_states = omegas.rows(); const int n_transitions = omegas.cols() - 1; alphas.col(0) = omegas.col(0).cwiseProduct(rho_val); @@ -60,7 +57,6 @@ inline auto hmm_marginal_lpdf_val(const T_omega& omegas, * The transition matrix Gamma is such that the (i, j)th entry is the * probability that x_n = j given x_{n - 1} = i. The rows of Gamma are * simplexes. - * The Gamma argument is only checked if there is at least one transition. * * @tparam T_omega type of the log likelihood matrix * @tparam T_Gamma type of the transition matrix @@ -76,7 +72,7 @@ inline auto hmm_marginal_lpdf_val(const T_omega& omegas, * of Gamma are not a simplex (when there is at least one transition). */ template -inline auto hmm_marginal_lpdf( +inline auto hmm_marginal( const Eigen::Matrix& log_omegas, const Eigen::Matrix& Gamma, const Eigen::Matrix& rho) { @@ -87,15 +83,7 @@ inline auto hmm_marginal_lpdf( int n_states = log_omegas.rows(); int n_transitions = log_omegas.cols() - 1; - check_consistent_size("hmm_marginal_lpdf", "rho", rho, n_states); - check_simplex("hmm_marginal_lpdf", "rho", rho); - check_square("hmm_marginal_lpdf", "Gamma", Gamma); - check_nonzero_size("hmm_marginal_lpdf", "Gamma", Gamma); - check_multiplicable("hmm_marginal_lpdf", "Gamma", Gamma, "log_omegas", - log_omegas); - for (int i = 0; i < Gamma.rows(); ++i) { - check_simplex("hmm_marginal_lpdf", "Gamma[i, ]", row(Gamma, i + 1)); - } + hmm_check(log_omegas, Gamma, rho, "hmm_marginal"); operands_and_partials, Eigen::Matrix, @@ -110,7 +98,7 @@ inline auto hmm_marginal_lpdf( const auto& rho_val = to_ref(value_of(rho)); eig_matrix_partial omegas = value_of(log_omegas).array().exp(); T_partial_type norm_norm; - auto log_marginal_density = hmm_marginal_lpdf_val( + auto log_marginal_density = hmm_marginal_val( omegas, Gamma_val, rho_val, alphas, alpha_log_norms, norm_norm); // Variables required for all three Jacobian-adjoint products. diff --git a/stan/math/rev/functor/coupled_ode_system.hpp b/stan/math/rev/functor/coupled_ode_system.hpp index 61205cee3eb..c74f1a098b7 100644 --- a/stan/math/rev/functor/coupled_ode_system.hpp +++ b/stan/math/rev/functor/coupled_ode_system.hpp @@ -73,6 +73,8 @@ struct coupled_ode_system_impl { const size_t N_; std::ostream* msgs_; + using F_copy = typename F::DeepCopy__; + /** * Construct a coupled ode system from the base system function, * initial state of the base system, parameters, and a stream for @@ -123,6 +125,8 @@ struct coupled_ode_system_impl { for (size_t n = 0; n < N_; ++n) y_vars(n) = z(n); + F_copy f_vars_ = f_; + auto local_args_tuple = apply( [&](auto&&... args) { return std::tuple( @@ -130,14 +134,15 @@ struct coupled_ode_system_impl { }, args_tuple_); - Eigen::Matrix f_y_t_vars - = apply([&](auto&&... args) { return f_(t, y_vars, msgs_, args...); }, - local_args_tuple); + Eigen::Matrix f_y_t_vars = apply( + [&](auto&&... args) { return f_vars_(t, y_vars, msgs_, args...); }, + local_args_tuple); check_size_match("coupled_ode_system", "dy_dt", f_y_t_vars.size(), "states", N_); Eigen::VectorXd args_adjoints(args_vars_); + Eigen::VectorXd f_adjoints(f_.num_vars__); for (size_t i = 0; i < N_; i++) { dz_dt(i) = f_y_t_vars(i).val(); f_y_t_vars(i).grad(); @@ -155,6 +160,7 @@ struct coupled_ode_system_impl { } args_adjoints.setZero(); + f_adjoints.setZero(); apply( [&](auto&&... args) { accumulate_adjoints(args_adjoints.data(), args...); @@ -170,6 +176,17 @@ struct coupled_ode_system_impl { dz_dt[N_ + N_ * y0_vars_ + N_ * j + i] = temp_deriv; } + f_vars_.accumulate_adjoints(f_adjoints.data()); + for (size_t j = 0; j < f_.num_vars__; j++) { + double temp_deriv = f_adjoints(j); + for (size_t k = 0; k < N_; k++) { + temp_deriv += z[N_ + N_ * y0_vars_ + N_ * args_vars_ + N_ * j + k] + * y_vars[k].adj(); + } + + dz_dt[N_ + N_ * y0_vars_ + N_ * args_vars_ + N_ * j + i] = temp_deriv; + } + nested.set_zero_all_adjoints(); } } @@ -179,7 +196,9 @@ struct coupled_ode_system_impl { * * @return size of the coupled system. */ - size_t size() const { return N_ + N_ * y0_vars_ + N_ * args_vars_; } + size_t size() const { + return N_ + N_ * y0_vars_ + N_ * args_vars_ + N_ * f_.num_vars__; + } /** * Returns the initial state of the coupled system. diff --git a/stan/math/rev/functor/cvodes_integrator.hpp b/stan/math/rev/functor/cvodes_integrator.hpp index 7a0ccd0651b..82fa63bc91e 100644 --- a/stan/math/rev/functor/cvodes_integrator.hpp +++ b/stan/math/rev/functor/cvodes_integrator.hpp @@ -30,10 +30,14 @@ namespace math { template class cvodes_integrator { - using T_Return = return_type_t; + using T_Return = return_type_t; using T_y0_t0 = return_type_t; + using F_dbl = typename F::ValueOf__; + const F& f_; + const F_dbl f_dbl_; const Eigen::Matrix y0_; const T_t0 t0_; const std::vector& ts_; @@ -103,9 +107,9 @@ class cvodes_integrator { inline void rhs(double t, const double y[], double dy_dt[]) const { const Eigen::VectorXd y_vec = Eigen::Map(y, N_); - Eigen::VectorXd dy_dt_vec - = apply([&](auto&&... args) { return f_(t, y_vec, msgs_, args...); }, - value_of_args_tuple_); + Eigen::VectorXd dy_dt_vec = apply( + [&](auto&&... args) { return f_dbl_(t, y_vec, msgs_, args...); }, + value_of_args_tuple_); check_size_match("cvodes_integrator", "dy_dt", dy_dt_vec.size(), "states", N_); @@ -122,7 +126,7 @@ class cvodes_integrator { Eigen::MatrixXd Jfy; auto f_wrapped = [&](const Eigen::Matrix& y) { - return apply([&](auto&&... args) { return f_(t, y, msgs_, args...); }, + return apply([&](auto&&... args) { return f_dbl_(t, y, msgs_, args...); }, value_of_args_tuple_); }; @@ -146,12 +150,12 @@ class cvodes_integrator { Eigen::VectorXd z(coupled_state_.size()); Eigen::VectorXd dz_dt; std::copy(y, y + N_, z.data()); - for (std::size_t s = 0; s < y0_vars_ + args_vars_; s++) { + for (std::size_t s = 0; s < y0_vars_ + args_vars_ + f_.num_vars__; s++) { std::copy(NV_DATA_S(yS[s]), NV_DATA_S(yS[s]) + N_, z.data() + (s + 1) * N_); } coupled_ode_(z, dz_dt, t); - for (std::size_t s = 0; s < y0_vars_ + args_vars_; s++) { + for (std::size_t s = 0; s < y0_vars_ + args_vars_ + f_.num_vars__; s++) { std::move(dz_dt.data() + (s + 1) * N_, dz_dt.data() + (s + 2) * N_, NV_DATA_S(ySdot[s])); } @@ -185,6 +189,7 @@ class cvodes_integrator { long int max_num_steps, std::ostream* msgs, const T_Args&... args) : f_(f), + f_dbl_(f), y0_(y0.unaryExpr([](const T_y0& val) { return T_y0_t0(val); })), t0_(t0), ts_(ts), @@ -228,10 +233,10 @@ class cvodes_integrator { A_ = SUNDenseMatrix(N_, N_); LS_ = SUNDenseLinearSolver(nv_state_, A_); - if (y0_vars_ + args_vars_ > 0) { - nv_state_sens_ - = N_VCloneVectorArrayEmpty_Serial(y0_vars_ + args_vars_, nv_state_); - for (std::size_t i = 0; i < y0_vars_ + args_vars_; i++) { + if (y0_vars_ + args_vars_ + f_.num_vars__ > 0) { + nv_state_sens_ = N_VCloneVectorArrayEmpty_Serial( + y0_vars_ + args_vars_ + f_.num_vars__, nv_state_); + for (std::size_t i = 0; i < y0_vars_ + args_vars_ + f_.num_vars__; i++) { NV_DATA_S(nv_state_sens_[i]) = &coupled_state_[N_] + i * N_; } } @@ -241,8 +246,9 @@ class cvodes_integrator { SUNLinSolFree(LS_); SUNMatDestroy(A_); N_VDestroy_Serial(nv_state_); - if (y0_vars_ + args_vars_ > 0) { - N_VDestroyVectorArray_Serial(nv_state_sens_, y0_vars_ + args_vars_); + if (y0_vars_ + args_vars_ + f_.num_vars__ > 0) { + N_VDestroyVectorArray_Serial(nv_state_sens_, + y0_vars_ + args_vars_ + f_.num_vars__); } } @@ -291,11 +297,12 @@ class cvodes_integrator { "CVodeSetJacFn"); // initialize forward sensitivity system of CVODES as needed - if (y0_vars_ + args_vars_ > 0) { + if (y0_vars_ + args_vars_ + f_.num_vars__ > 0) { check_flag_sundials( - CVodeSensInit(cvodes_mem, static_cast(y0_vars_ + args_vars_), - CV_STAGGERED, &cvodes_integrator::cv_rhs_sens, - nv_state_sens_), + CVodeSensInit( + cvodes_mem, + static_cast(y0_vars_ + args_vars_ + f_.num_vars__), + CV_STAGGERED, &cvodes_integrator::cv_rhs_sens, nv_state_sens_), "CVodeSensInit"); if (include_sensitivities_in_errors_) { @@ -331,7 +338,7 @@ class cvodes_integrator { CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL), "CVode"); - if (y0_vars_ + args_vars_ > 0) { + if (y0_vars_ + args_vars_ + f_.num_vars__ > 0) { check_flag_sundials( CVodeGetSens(cvodes_mem, &t_init, nv_state_sens_), "CVodeGetSens"); diff --git a/stan/math/rev/functor/integrate_ode_adams.hpp b/stan/math/rev/functor/integrate_ode_adams.hpp index f108ef9a7e0..863f3a9b12c 100644 --- a/stan/math/rev/functor/integrate_ode_adams.hpp +++ b/stan/math/rev/functor/integrate_ode_adams.hpp @@ -14,7 +14,8 @@ namespace math { */ template -std::vector>> +std::vector>> integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -28,7 +29,8 @@ integrate_ode_adams(const F& f, const std::vector& y0, = ode_adams_tol(f_adapted, to_vector(y0), t0, ts, relative_tolerance, absolute_tolerance, max_num_steps, msgs, theta, x, x_int); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); diff --git a/stan/math/rev/functor/ode_adams.hpp b/stan/math/rev/functor/ode_adams.hpp index 15beb0a11af..0126c19dfa4 100644 --- a/stan/math/rev/functor/ode_adams.hpp +++ b/stan/math/rev/functor/ode_adams.hpp @@ -10,7 +10,8 @@ namespace stan { namespace math { template -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_adams_tol_sens_error(const F& f, const Eigen::Matrix& y0, @@ -27,7 +28,8 @@ ode_adams_tol_sens_error(const F& f, template -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_adams_tol_error(const F& f, const Eigen::Matrix& y0, @@ -77,7 +79,8 @@ ode_adams_tol_error(const F& f, */ template -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_adams_tol(const F& f, const Eigen::Matrix& y0, const T_t0& t0, const std::vector& ts, @@ -127,7 +130,8 @@ ode_adams_tol(const F& f, const Eigen::Matrix& y0, */ template -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_adams(const F& f, const Eigen::Matrix& y0, const T_t0& t0, const std::vector& ts, std::ostream* msgs, diff --git a/stan/math/rev/functor/ode_bdf.hpp b/stan/math/rev/functor/ode_bdf.hpp index 7f16b35c417..d7cb40c3450 100644 --- a/stan/math/rev/functor/ode_bdf.hpp +++ b/stan/math/rev/functor/ode_bdf.hpp @@ -77,8 +77,10 @@ ode_bdf_tol_error(const F& f, */ template -std::vector, - Eigen::Dynamic, 1>> +std::vector, + Eigen::Dynamic, 1>> ode_bdf_tol(const F& f, const Eigen::Matrix& y0, const T_t0& t0, const std::vector& ts, double relative_tolerance, double absolute_tolerance, @@ -126,8 +128,10 @@ ode_bdf_tol(const F& f, const Eigen::Matrix& y0, */ template -std::vector, - Eigen::Dynamic, 1>> +std::vector, + Eigen::Dynamic, 1>> ode_bdf(const F& f, const Eigen::Matrix& y0, const T_t0& t0, const std::vector& ts, std::ostream* msgs, const T_Args&... args) { diff --git a/stan/math/rev/functor/ode_store_sensitivities.hpp b/stan/math/rev/functor/ode_store_sensitivities.hpp index 725f5731a9e..1b119c9b5ca 100644 --- a/stan/math/rev/functor/ode_store_sensitivities.hpp +++ b/stan/math/rev/functor/ode_store_sensitivities.hpp @@ -30,33 +30,41 @@ namespace math { * @param args Extra arguments passed unmodified through to ODE right hand side * @return ODE state with scalar type var */ -template ...>> +template < + typename F, typename T_y0_t0, typename T_t0, typename T_t, typename... Args, + typename = require_any_autodiff_t...>> Eigen::Matrix ode_store_sensitivities( const F& f, const Eigen::VectorXd& coupled_state, const Eigen::Matrix& y0, const T_t0& t0, const T_t& t, std::ostream* msgs, const Args&... args) { + using F_dbl = typename F::ValueOf__; const size_t N = y0.size(); const size_t y0_vars = count_vars(y0); const size_t args_vars = count_vars(args...); const size_t t0_vars = count_vars(t0); const size_t t_vars = count_vars(t); + Eigen::Matrix yt(N); Eigen::VectorXd y = coupled_state.head(N); Eigen::VectorXd f_y_t; - if (is_var::value) - f_y_t = f(value_of(t), y, msgs, value_of(args)...); - Eigen::VectorXd f_y0_t0; - if (is_var::value) - f_y0_t0 = f(value_of(t0), value_of(y0), msgs, value_of(args)...); + + if (is_var::value || is_var::value) { + F_dbl f_dbl = f; + + if (is_var::value) + f_y_t = f_dbl(value_of(t), y, msgs, value_of(args)...); + + if (is_var::value) + f_y0_t0 = f_dbl(value_of(t0), value_of(y0), msgs, value_of(args)...); + } for (size_t j = 0; j < N; j++) { - const size_t total_vars = y0_vars + args_vars + t0_vars + t_vars; + const size_t total_vars + = y0_vars + args_vars + t0_vars + t_vars + f.num_vars__; vari** varis = ChainableStack::instance_->memalloc_.alloc_array(total_vars); @@ -82,6 +90,14 @@ Eigen::Matrix ode_store_sensitivities( partials_ptr++; } + f.save_varis(varis_ptr); + varis_ptr += f.num_vars__; + for (std::size_t k = 0; k < f.num_vars__; ++k) { + *partials_ptr + = coupled_state(N + N * y0_vars + N * args_vars + N * k + j); + partials_ptr++; + } + varis_ptr = save_varis(varis_ptr, t0); if (t0_vars > 0) { double dyt_dt0 = 0.0; diff --git a/test/unit/math/prim/err/hmm_check_test.cpp b/test/unit/math/prim/err/hmm_check_test.cpp new file mode 100644 index 00000000000..f2862746199 --- /dev/null +++ b/test/unit/math/prim/err/hmm_check_test.cpp @@ -0,0 +1,74 @@ +#include +#include +#include + +TEST(err, hmm_check) { + using Eigen::MatrixXd; + using Eigen::VectorXd; + using stan::math::hmm_check; + + int n_states = 2; + int n_transitions = 2; + MatrixXd log_omegas(n_states, n_transitions + 1); + MatrixXd Gamma(n_states, n_states); + VectorXd rho(n_states); + + for (int i = 0; i < n_states; i++) + for (int j = 0; j < n_transitions + 1; j++) + log_omegas(i, j) = 1; + + rho(0) = 0.65; + rho(1) = 0.35; + Gamma << 0.8, 0.2, 0.6, 0.4; + + // Gamma is not square. + MatrixXd Gamma_rec(n_states, n_states + 1); + EXPECT_THROW_MSG( + hmm_check(log_omegas, Gamma_rec, rho, "hmm_marginal_lpdf"), + std::invalid_argument, + "hmm_marginal_lpdf: Expecting a square matrix; rows of Gamma (2) " + "and columns of Gamma (3) must match in size") + + // Gamma has a column that is not a simplex. + MatrixXd Gamma_bad = Gamma; + Gamma_bad(0, 0) = Gamma(0, 0) + 1; + EXPECT_THROW_MSG(hmm_check(log_omegas, Gamma_bad, rho, "hmm_marginal_lpdf"), + std::domain_error, + "hmm_marginal_lpdf: Gamma[i, ] is not a valid simplex. " + "sum(Gamma[i, ]) = 2, but should be 1") + + // The size of Gamma is 0, even though there is at least one transition + MatrixXd Gamma_empty(0, 0); + EXPECT_THROW_MSG( + hmm_check(log_omegas, Gamma_empty, rho, "hmm_marginal_lpdf"), + std::invalid_argument, + "hmm_marginal_lpdf: Gamma has size 0, but must have a non-zero size") + + // The size of Gamma is inconsistent with that of log_omega + MatrixXd Gamma_wrong_size(n_states + 1, n_states + 1); + + EXPECT_THROW_MSG( + hmm_check(log_omegas, Gamma_wrong_size, rho, "hmm_marginal_lpdf"), + std::invalid_argument, + "hmm_marginal_lpdf: Columns of Gamma (3)" + " and Rows of log_omegas (2) must match in size") + + // rho is not a simplex. + VectorXd rho_bad = rho; + rho_bad(0) = rho(0) + 1; + EXPECT_THROW_MSG(hmm_check(log_omegas, Gamma, rho_bad, "hmm_marginal_lpdf"), + std::domain_error, + "hmm_marginal_lpdf: rho is not a valid simplex. " + "sum(rho) = 2, but should be 1") + + // The size of rho is inconsistent with that of log_omega + VectorXd rho_wrong_size(n_states + 1); + EXPECT_THROW_MSG( + hmm_check(log_omegas, Gamma, rho_wrong_size, "hmm_marginal_lpdf"), + std::invalid_argument, + "hmm_marginal_lpdf: rho has dimension = 3, expecting dimension = 2;" + " a function was called with arguments of different scalar," + " array, vector, or matrix types, and they were not consistently sized;" + " all arguments must be scalars or multidimensional values of" + " the same shape.") +} diff --git a/test/unit/math/prim/prob/hmm_hidden_state_prob_test.cpp b/test/unit/math/prim/prob/hmm_hidden_state_prob_test.cpp new file mode 100644 index 00000000000..747784ab5b9 --- /dev/null +++ b/test/unit/math/prim/prob/hmm_hidden_state_prob_test.cpp @@ -0,0 +1,71 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TEST_F(hmm_test, hidden_state_single_outcome) { + using stan::math::hmm_hidden_state_prob; + + int n_states = 2; + Eigen::MatrixXd Gamma(n_states, n_states); + Gamma << 1, 0, 1, 0; + Eigen::VectorXd rho(n_states); + rho << 1, 0; + + Eigen::MatrixXd prob = hmm_hidden_state_prob(log_omegas_, Gamma, rho); + + for (int i = 0; i < n_transitions_; i++) { + EXPECT_EQ(prob(0, i), 1); + EXPECT_EQ(prob(1, i), 0); + } +} + +TEST_F(hmm_test, hidden_state_identity_transition) { + // With an identity transition matrix, all latent probabilities + // are equal. Setting the log density to 1 for all states makes + // the initial prob drive the subsequent probabilities. + using stan::math::hmm_hidden_state_prob; + int n_states = 2; + Eigen::MatrixXd Gamma = Eigen::MatrixXd::Identity(n_states, n_states); + Eigen::MatrixXd log_omegas + = Eigen::MatrixXd::Ones(n_states, n_transitions_ + 1); + + Eigen::MatrixXd prob = hmm_hidden_state_prob(log_omegas, Gamma, rho_); + + for (int i = 0; i < n_transitions_; i++) { + EXPECT_FLOAT_EQ(prob(0, i), rho_(0)); + EXPECT_FLOAT_EQ(prob(1, i), rho_(1)); + } +} + +TEST(hmm_test_nonstandard, hidden_state_symmetry) { + // In this two states situation, the latent states are + // symmetric, based on the observational log density, + // and transition matrix. + // The initial conditions introduces an asymmetry in the first + // state. The other hidden states all have probability 0.5. + using stan::math::hmm_hidden_state_prob; + int n_states = 2; + int n_transitions = 2; + Eigen::MatrixXd Gamma(n_states, n_states); + Gamma << 0.5, 0.5, 0.5, 0.5; + Eigen::VectorXd rho(n_states); + rho << 0.3, 0.7; + Eigen::MatrixXd log_omegas + = Eigen::MatrixXd::Ones(n_states, n_transitions + 1); + + Eigen::MatrixXd prob = hmm_hidden_state_prob(log_omegas, Gamma, rho); + + EXPECT_FLOAT_EQ(prob(0, 0), 0.3); + EXPECT_FLOAT_EQ(prob(1, 0), 0.7); + + for (int i = 1; i < n_transitions; i++) { + EXPECT_FLOAT_EQ(prob(0, i), 0.5); + EXPECT_FLOAT_EQ(prob(1, i), 0.5); + } +} diff --git a/test/unit/math/prim/prob/hmm_latent_rng_test.cpp b/test/unit/math/prim/prob/hmm_latent_rng_test.cpp new file mode 100644 index 00000000000..9b88ae09c00 --- /dev/null +++ b/test/unit/math/prim/prob/hmm_latent_rng_test.cpp @@ -0,0 +1,128 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TEST(hmm_rng_test, chiSquareGoodnessFitTest) { + // with identity transition and constant log_omegas, the sampled latent + // states are identifcal and follow a Bernoulli distribution parameterized + // by rho. + using stan::math::hmm_latent_rng; + + int n_states = 2; + int n_transitions = 10; + Eigen::MatrixXd Gamma = Eigen::MatrixXd::Identity(n_states, n_states); + Eigen::VectorXd rho(n_states); + rho << 0.65, 0.35; + Eigen::MatrixXd log_omegas + = Eigen::MatrixXd::Ones(n_states, n_transitions + 1); + + boost::random::mt19937 rng; + int N = 10000; + + std::vector expected; + expected.push_back(N * rho(0)); + expected.push_back(N * rho(1)); + + std::vector counts(2); + std::vector state; + for (int i = 0; i < N; ++i) { + state = hmm_latent_rng(log_omegas, Gamma, rho, rng); + for (int j = 1; j < n_states; ++j) + EXPECT_EQ(state[j], state[0]); + ++counts[state[0]]; + } + + assert_chi_squared(counts, expected, 1e-6); +} + +TEST(hmm_rng_test, chiSquareGoodnessFitTest_symmetric) { + // In this two states situation, the latent states are + // symmetric, based on the observational log density, + // and transition matrix. + // The initial conditions introduces an asymmetry in the first + // state. The other hidden states all have probability 0.5. + // Note that the hidden states are also uncorrelated. + using stan::math::hmm_latent_rng; + + int n_states = 2; + int n_transitions = 1; + Eigen::MatrixXd Gamma(n_states, n_states); + Gamma << 0.5, 0.5, 0.5, 0.5; + Eigen::VectorXd rho(n_states); + rho << 0.3, 0.7; + Eigen::MatrixXd log_omegas + = Eigen::MatrixXd::Ones(n_states, n_transitions + 1); + + boost::random::mt19937 rng; + int N = 10000; + + std::vector expected_0; + expected_0.push_back(N * rho(0)); + expected_0.push_back(N * rho(1)); + + std::vector expected_1; + expected_1.push_back(N * 0.5); + expected_1.push_back(N * 0.5); + + std::vector counts_0(2); + std::vector counts_1(2); + // int product = 0; + std::vector states; + int a = 0, b = 0, c = 0, d = 0; + for (int i = 0; i < N; ++i) { + states = hmm_latent_rng(log_omegas, Gamma, rho, rng); + ++counts_0[states[0]]; + ++counts_1[states[1]]; + // product += states[0] * states[1]; + a += (states[0] == 0 && states[1] == 0); + b += (states[0] == 0 && states[1] == 1); + c += (states[0] == 1 && states[1] == 0); + d += (states[0] == 1 && states[1] == 1); + } + + // Test the marginal probabilities of each variable + assert_chi_squared(counts_0, expected_0, 1e-6); + assert_chi_squared(counts_1, expected_1, 1e-6); + + // Test for independence (0 correlation by construction). + // By independence E(XY) = E(X)E(Y). We compute the R.H.S + // analytically and the L.H.S numerically. + std::vector counts_xy(2); + counts_xy[0] = a; + counts_xy[1] = c; + std::vector expected_xy; + expected_xy.push_back(N * rho(0) * 0.5); + expected_xy.push_back(N * rho(1) * 0.5); + assert_chi_squared(counts_xy, expected_xy, 1e-6); + + // DRAFT -- code for chi-squared independence test. + // (overkill, since we have analytical prob for each cell) + // Test that the two states are independent, using a chi squared + // test for independence. + // Eigen::MatrixXd Expected(n_states, (n_transitions + 1)); + // Expected << (a + b) * (a + c), (a + b) * (b + d), + // (c + d) * (a + c), (c + d) * (b + d); + // Expected = Expected / N; + // + // Eigen::MatrixXd Observed(n_states, (n_transitions + 1)); + // Observed << a, b, c, d; + // double chi = 0; + // + // for (int i = 0; i < n_states; ++i) + // for (int j = 0; j < n_transitions + 1; ++j) + // chi += (Observed(i, j) - Expected(i, j)) + // * (Observed(i, j) - Expected(i, j)) / Expected(i, j); + // + // int nu = 1; + // double p_value = exp(stan::math::chi_square_lcdf(chi, nu)); + // double threshold = 0.1; // CHECK -- what is an appropriate threshold? + // EXPECT_TRUE(p_value > threshold); +} diff --git a/test/unit/math/prim/prob/hmm_marginal_test.cpp b/test/unit/math/prim/prob/hmm_marginal_test.cpp index 137a6eb65bc..3f9b653eb31 100644 --- a/test/unit/math/prim/prob/hmm_marginal_test.cpp +++ b/test/unit/math/prim/prob/hmm_marginal_test.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include #include #include @@ -7,117 +8,13 @@ #include #include -/** - * Wrapper around hmm_marginal_density which passes rho and - * Gamma without the last element of each column. We recover - * the last element using the fact each column sums to 1. - * The purpose of this function is to do finite diff benchmarking, - * without breaking the simplex constraint. - */ -template -inline stan::return_type_t hmm_marginal_test_wrapper( - const Eigen::Matrix& log_omegas, - const Eigen::Matrix& - Gamma_unconstrained, - const std::vector& rho_unconstrained) { - using stan::math::row; - using stan::math::sum; - int n_states = log_omegas.rows(); - - Eigen::Matrix Gamma(n_states, - n_states); - for (int i = 0; i < n_states; i++) { - Gamma(i, n_states - 1) = 1 - sum(row(Gamma_unconstrained, i + 1)); - for (int j = 0; j < n_states - 1; j++) { - Gamma(i, j) = Gamma_unconstrained(i, j); - } - } - - Eigen::Matrix rho(n_states); - rho(1) = 1 - sum(rho_unconstrained); - for (int i = 0; i < n_states - 1; i++) - rho(i) = rho_unconstrained[i]; - - return stan::math::hmm_marginal_lpdf(log_omegas, Gamma, rho); -} - -/** - * In the proposed example, the latent state x determines - * the observational distribution: - * 0: normal(mu, sigma) - * 1: normal(-mu, sigma) - */ -double state_lpdf(double y, double abs_mu, double sigma, int state) { - int x = state == 0 ? 1 : -1; - double chi = (y - x * abs_mu) / sigma; - return -0.5 * chi * chi - 0.5 * std::log(2 * M_PI) - std::log(sigma); -} - -class hmm_marginal_lpdf_test : public ::testing::Test { - protected: - void SetUp() override { - n_states_ = 2; - p1_init_ = 0.65; - gamma1_ = 0.7; - gamma2_ = 0.45; - n_transitions_ = 10; - abs_mu_ = 1; - sigma_ = 1; - - Eigen::VectorXd rho(n_states_); - rho << p1_init_, 1 - p1_init_; - rho_ = rho; - - Eigen::MatrixXd Gamma(n_states_, n_states_); - Gamma << gamma1_, 1 - gamma1_, gamma2_, 1 - gamma2_; - Gamma_ = Gamma; - - Eigen::VectorXd obs_data(n_transitions_ + 1); - obs_data << -0.3315914, -0.1655340, -0.7984021, 0.2364608, -0.4489722, - 2.1831438, -1.4778675, 0.8717423, -1.0370874, 0.1370296, 1.9786208; - obs_data_ = obs_data; - - Eigen::MatrixXd log_omegas(n_states_, n_transitions_ + 1); - for (int n = 0; n < n_transitions_ + 1; n++) { - log_omegas.col(n)[0] = state_lpdf(obs_data[n], abs_mu_, sigma_, 0); - log_omegas.col(n)[1] = state_lpdf(obs_data[n], abs_mu_, sigma_, 1); - } - log_omegas_ = log_omegas; - log_omegas_zero_ = log_omegas.block(0, 0, n_states_, 1); - - std::vector rho_unconstrained(n_states_ - 1); - for (int i = 0; i < rho.size() - 1; i++) - rho_unconstrained[i] = rho(i); - rho_unconstrained_ = rho_unconstrained; - - Gamma_unconstrained_ = Gamma.block(0, 0, n_states_, n_states_ - 1); - } - - int n_states_, n_transitions_; - double abs_mu_, sigma_, p1_init_, gamma1_, gamma2_; - - Eigen::VectorXd rho_; - Eigen::MatrixXd Gamma_; - Eigen::VectorXd obs_data_; - Eigen::MatrixXd log_omegas_; - Eigen::MatrixXd log_omegas_zero_; - - // Construct "unconstrained" versions of rho and Gamma, without - // the final element which can be determnied using the fact - // the columns sum to 1. This allows us to do finite diff tests, - // without violating the simplex constraint of rho and Gamma. - std::vector rho_unconstrained_; - Eigen::MatrixXd Gamma_unconstrained_; - stan::test::ad_tolerances tols_; -}; - // For evaluation of the density, the C++ code is benchmarked against // a forward algorithm written in R. // TODO(charlesm93): Add public repo link with R script. -TEST_F(hmm_marginal_lpdf_test, ten_transitions) { - using stan::math::hmm_marginal_lpdf; +TEST_F(hmm_test, ten_transitions) { + using stan::math::hmm_marginal; - EXPECT_FLOAT_EQ(-18.37417, hmm_marginal_lpdf(log_omegas_, Gamma_, rho_)); + EXPECT_FLOAT_EQ(-18.37417, hmm_marginal(log_omegas_, Gamma_, rho_)); // Differentiation tests auto hmm_functor = [](const auto& log_omegas, const auto& Gamma_unconstrained, @@ -130,10 +27,10 @@ TEST_F(hmm_marginal_lpdf_test, ten_transitions) { rho_unconstrained_); } -TEST_F(hmm_marginal_lpdf_test, zero_transitions) { - using stan::math::hmm_marginal_lpdf; +TEST_F(hmm_test, zero_transitions) { + using stan::math::hmm_marginal; - EXPECT_FLOAT_EQ(-1.520827, hmm_marginal_lpdf(log_omegas_zero_, Gamma_, rho_)); + EXPECT_FLOAT_EQ(-1.520827, hmm_marginal(log_omegas_zero_, Gamma_, rho_)); // Differentiation tests auto hmm_functor = [](const auto& log_omegas, const auto& Gamma_unconstrained, @@ -146,8 +43,8 @@ TEST_F(hmm_marginal_lpdf_test, zero_transitions) { Gamma_unconstrained_, rho_unconstrained_); } -TEST(hmm_marginal_lpdf, one_state) { - using stan::math::hmm_marginal_lpdf; +TEST(hmm_marginal, one_state) { + using stan::math::hmm_marginal; int n_states = 1, p1_init = 1, gamma1 = 1, n_transitions = 10, abs_mu = 1, sigma = 1; Eigen::VectorXd rho(n_states); @@ -161,7 +58,7 @@ TEST(hmm_marginal_lpdf, one_state) { for (int n = 0; n < n_transitions + 1; n++) log_omegas.col(n)[0] = state_lpdf(obs_data[n], abs_mu, sigma, 0); - EXPECT_FLOAT_EQ(-14.89646, hmm_marginal_lpdf(log_omegas, Gamma, rho)); + EXPECT_FLOAT_EQ(-14.89646, hmm_marginal(log_omegas, Gamma, rho)); // Differentiation tests // In the case where we have one state, Gamma and rho @@ -172,17 +69,17 @@ TEST(hmm_marginal_lpdf, one_state) { Eigen::VectorXd rho(1); rho << 1; - return hmm_marginal_lpdf(log_omegas, Gamma, rho); + return hmm_marginal(log_omegas, Gamma, rho); }; stan::test::ad_tolerances tols; stan::test::expect_ad(tols, hmm_functor, log_omegas); } -TEST(hmm_marginal_lpdf, exceptions) { +TEST(hmm_marginal, exceptions) { using Eigen::MatrixXd; using Eigen::VectorXd; - using stan::math::hmm_marginal_lpdf; + using stan::math::hmm_marginal; int n_states = 2; int n_transitions = 2; @@ -200,47 +97,44 @@ TEST(hmm_marginal_lpdf, exceptions) { // Gamma is not square. MatrixXd Gamma_rec(n_states, n_states + 1); - EXPECT_THROW_MSG( - hmm_marginal_lpdf(log_omegas, Gamma_rec, rho), std::invalid_argument, - "hmm_marginal_lpdf: Expecting a square matrix; rows of Gamma (2) " - "and columns of Gamma (3) must match in size"); + EXPECT_THROW_MSG(hmm_marginal(log_omegas, Gamma_rec, rho), + std::invalid_argument, + "hmm_marginal: Expecting a square matrix; rows of Gamma (2) " + "and columns of Gamma (3) must match in size"); // Gamma has a column that is not a simplex. MatrixXd Gamma_bad = Gamma; Gamma_bad(0, 0) = Gamma(0, 0) + 1; - EXPECT_THROW_MSG(hmm_marginal_lpdf(log_omegas, Gamma_bad, rho), - std::domain_error, - "hmm_marginal_lpdf: Gamma[i, ] is not a valid simplex. " + EXPECT_THROW_MSG(hmm_marginal(log_omegas, Gamma_bad, rho), std::domain_error, + "hmm_marginal: Gamma[i, ] is not a valid simplex. " "sum(Gamma[i, ]) = 2, but should be 1") // The size of Gamma is 0, even though there is at least one transition MatrixXd Gamma_empty(0, 0); EXPECT_THROW_MSG( - hmm_marginal_lpdf(log_omegas, Gamma_empty, rho), std::invalid_argument, - "hmm_marginal_lpdf: Gamma has size 0, but must have a non-zero size") + hmm_marginal(log_omegas, Gamma_empty, rho), std::invalid_argument, + "hmm_marginal: Gamma has size 0, but must have a non-zero size") // The size of Gamma is inconsistent with that of log_omega MatrixXd Gamma_wrong_size(n_states + 1, n_states + 1); - EXPECT_THROW_MSG(hmm_marginal_lpdf(log_omegas, Gamma_wrong_size, rho), + EXPECT_THROW_MSG(hmm_marginal(log_omegas, Gamma_wrong_size, rho), std::invalid_argument, - "hmm_marginal_lpdf: Columns of Gamma (3)" + "hmm_marginal: Columns of Gamma (3)" " and Rows of log_omegas (2) must match in size") // rho is not a simplex. VectorXd rho_bad = rho; rho_bad(0) = rho(0) + 1; - EXPECT_THROW_MSG(hmm_marginal_lpdf(log_omegas, Gamma, rho_bad), - std::domain_error, - "hmm_marginal_lpdf: rho is not a valid simplex. " + EXPECT_THROW_MSG(hmm_marginal(log_omegas, Gamma, rho_bad), std::domain_error, + "hmm_marginal: rho is not a valid simplex. " "sum(rho) = 2, but should be 1") // The size of rho is inconsistent with that of log_omega VectorXd rho_wrong_size(n_states + 1); EXPECT_THROW_MSG( - hmm_marginal_lpdf(log_omegas, Gamma, rho_wrong_size), - std::invalid_argument, - "hmm_marginal_lpdf: rho has dimension = 3, expecting dimension = 2;" + hmm_marginal(log_omegas, Gamma, rho_wrong_size), std::invalid_argument, + "hmm_marginal: rho has dimension = 3, expecting dimension = 2;" " a function was called with arguments of different scalar," " array, vector, or matrix types, and they were not consistently sized;" " all arguments must be scalars or multidimensional values of" diff --git a/test/unit/math/prim/prob/hmm_util.hpp b/test/unit/math/prim/prob/hmm_util.hpp new file mode 100644 index 00000000000..e2274076e21 --- /dev/null +++ b/test/unit/math/prim/prob/hmm_util.hpp @@ -0,0 +1,112 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * Wrapper around hmm_marginal_density which passes rho and + * Gamma without the last element of each column. We recover + * the last element using the fact each column sums to 1. + * The purpose of this function is to do finite diff benchmarking, + * without breaking the simplex constraint. + */ +template +inline stan::return_type_t hmm_marginal_test_wrapper( + const Eigen::Matrix& log_omegas, + const Eigen::Matrix& + Gamma_unconstrained, + const std::vector& rho_unconstrained) { + using stan::math::row; + using stan::math::sum; + int n_states = log_omegas.rows(); + + Eigen::Matrix Gamma(n_states, + n_states); + for (int i = 0; i < n_states; i++) { + Gamma(i, n_states - 1) = 1 - sum(row(Gamma_unconstrained, i + 1)); + for (int j = 0; j < n_states - 1; j++) { + Gamma(i, j) = Gamma_unconstrained(i, j); + } + } + + Eigen::Matrix rho(n_states); + rho(1) = 1 - sum(rho_unconstrained); + for (int i = 0; i < n_states - 1; i++) + rho(i) = rho_unconstrained[i]; + + return stan::math::hmm_marginal(log_omegas, Gamma, rho); +} + +/** + * In the proposed example, the latent state x determines + * the observational distribution: + * 0: normal(mu, sigma) + * 1: normal(-mu, sigma) + */ +double state_lpdf(double y, double abs_mu, double sigma, int state) { + int x = state == 0 ? 1 : -1; + double chi = (y - x * abs_mu) / sigma; + return -0.5 * chi * chi - 0.5 * std::log(2 * M_PI) - std::log(sigma); +} + +class hmm_test : public ::testing::Test { + protected: + void SetUp() override { + n_states_ = 2; + p1_init_ = 0.65; + gamma1_ = 0.7; + gamma2_ = 0.45; + n_transitions_ = 10; + abs_mu_ = 1; + sigma_ = 1; + + Eigen::VectorXd rho(n_states_); + rho << p1_init_, 1 - p1_init_; + rho_ = rho; + + Eigen::MatrixXd Gamma(n_states_, n_states_); + Gamma << gamma1_, 1 - gamma1_, gamma2_, 1 - gamma2_; + Gamma_ = Gamma; + + Eigen::VectorXd obs_data(n_transitions_ + 1); + obs_data << -0.3315914, -0.1655340, -0.7984021, 0.2364608, -0.4489722, + 2.1831438, -1.4778675, 0.8717423, -1.0370874, 0.1370296, 1.9786208; + obs_data_ = obs_data; + + Eigen::MatrixXd log_omegas(n_states_, n_transitions_ + 1); + for (int n = 0; n < n_transitions_ + 1; n++) { + log_omegas.col(n)[0] = state_lpdf(obs_data[n], abs_mu_, sigma_, 0); + log_omegas.col(n)[1] = state_lpdf(obs_data[n], abs_mu_, sigma_, 1); + } + log_omegas_ = log_omegas; + log_omegas_zero_ = log_omegas.block(0, 0, n_states_, 1); + + std::vector rho_unconstrained(n_states_ - 1); + for (int i = 0; i < rho.size() - 1; i++) + rho_unconstrained[i] = rho(i); + rho_unconstrained_ = rho_unconstrained; + + Gamma_unconstrained_ = Gamma.block(0, 0, n_states_, n_states_ - 1); + } + + int n_states_, n_transitions_; + double abs_mu_, sigma_, p1_init_, gamma1_, gamma2_; + + Eigen::VectorXd rho_; + Eigen::MatrixXd Gamma_; + Eigen::VectorXd obs_data_; + Eigen::MatrixXd log_omegas_; + Eigen::MatrixXd log_omegas_zero_; + + // Construct "unconstrained" versions of rho and Gamma, without + // the final element which can be determnied using the fact + // the columns sum to 1. This allows us to do finite diff tests, + // without violating the simplex constraint of rho and Gamma. + std::vector rho_unconstrained_; + Eigen::MatrixXd Gamma_unconstrained_; + stan::test::ad_tolerances tols_; +};