From 2aa166dada0303269ab4fd707dcc85c80d9cd5c2 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Thu, 9 Jan 2025 13:36:42 +0100 Subject: [PATCH] Introduce error msg for negative alpha --- src/distributions/tensor_dirichlet.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index f9efd689..3440d890 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -30,6 +30,9 @@ struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: Continuous α0::Ts lmnB::Ts function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N} + if !all(x -> x > zero(x), alpha) + throw(ArgumentError("All elements of the alpha tensor should be positive")) + end alpha0 = sum(alpha; dims = 1) lmnB = sum(loggamma, alpha; dims = 1) - loggamma.(alpha0) new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0, lmnB)