diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index f9efd689..3440d890 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -30,6 +30,9 @@ struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: Continuous α0::Ts lmnB::Ts function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N} + if !all(x -> x > zero(x), alpha) + throw(ArgumentError("All elements of the alpha tensor should be positive")) + end alpha0 = sum(alpha; dims = 1) lmnB = sum(loggamma, alpha; dims = 1) - loggamma.(alpha0) new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0, lmnB)