Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.10'
- '1'
os:
- ubuntu-latest
Expand All @@ -29,7 +29,7 @@ jobs:
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
- uses: julia-actions/cache@v2
env:
cache-name: cache-artifacts
with:
Expand Down
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJMultivariateStatsInterface"
uuid = "1b6a4a23-ba22-4f51-9698-8599985d3728"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>", "Thibaut Lienart <thibaut.lienart@gmail.com>", "Okon Samuel <okonsamuel50@gmail.com>"]
version = "0.5.3"
version = "0.6.0"

[deps]
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
Expand All @@ -12,19 +12,20 @@ MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
CategoricalDistributions = "0.1.9"
CategoricalDistributions = "0.2"
Distances = "0.9,0.10"
MLJModelInterface = "1.4"
MultivariateStats = "0.10"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.6"
julia = "1.10"

[extras]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Dates", "MLJBase", "Random", "StableRNGs", "Test"]
test = ["Dates", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "Test"]
23 changes: 13 additions & 10 deletions src/models/discriminant_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ const ERR_LONE_TARGET_CLASS = ArgumentError(
)

function _check_lda_data(model, X, y)
pool = MMI.classes(y[1]) # Class list containing entries in pool of `y`.
pool = CategoricalDistributions.levels(y[1]) # Class list containing entries in pool
# of `y`.
classes_seen = unique(y) # Class list of actual entries in seen in `y`.
nc = length(classes_seen) # Number of actual classes seen in `y`.

Expand Down Expand Up @@ -109,7 +110,7 @@ function MMI.predict(m::LDA, (core_res, classes_seen, pool), Xnew)
Pr .*= -1
# apply a softmax transformation
softmax!(Pr)
return MMI.UnivariateFinite(classes_seen, Pr, pool=pool)
return MMI.UnivariateFinite(classes_seen, Pr)
end

metadata_model(
Expand Down Expand Up @@ -160,7 +161,7 @@ function _check_prob01(priors)
end

@inline function _check_lda_priors(priors::UnivariateFinite, classes_seen, pool)
if MMI.classes(priors) != pool
if CategoricalDistributions.levels(priors) != pool
throw(
ArgumentError(
"UnivariateFinite `priors` must have common pool with training target."
Expand Down Expand Up @@ -236,7 +237,7 @@ function MMI.fitted_params(::BayesianLDA, (core_res, classes_seen, pool, priors
return (
classes = classes_seen,
projection_matrix=MS.projection(core_res),
priors=MMI.UnivariateFinite(classes_seen, priors, pool=pool)
priors=MMI.UnivariateFinite(classes_seen, priors)
)
end

Expand All @@ -261,7 +262,7 @@ function MMI.predict(m::BayesianLDA, (core_res, classes_seen, pool, priors, n),

# apply a softmax transformation to convert Pr to a probability matrix
softmax!(Pr)
return MMI.UnivariateFinite(classes_seen, Pr, pool=pool)
return MMI.UnivariateFinite(classes_seen, Pr)
end

function MMI.transform(m::T, (core_res, ), X) where T<:Union{LDA, BayesianLDA}
Expand Down Expand Up @@ -353,7 +354,7 @@ function MMI.predict(m::SubspaceLDA, (core_res, outdim, classes_seen, pool), Xne
Pr .*= -1
# apply a softmax transformation
softmax!(Pr)
return MMI.UnivariateFinite(classes_seen, Pr, pool=pool)
return MMI.UnivariateFinite(classes_seen, Pr)
end

metadata_model(
Expand Down Expand Up @@ -430,7 +431,7 @@ function MMI.fitted_params(
return (
classes = classes_seen,
projection_matrix=core_res.projw * view(core_res.projLDA, :, 1:outdim),
priors=MMI.UnivariateFinite(classes_seen, priors, pool=pool)
priors=MMI.UnivariateFinite(classes_seen, priors)
)
end

Expand Down Expand Up @@ -470,7 +471,7 @@ function MMI.predict(

# apply a softmax transformation to convert Pr to a probability matrix
softmax!(Pr)
return MMI.UnivariateFinite(classes_seen, Pr, pool=pool)
return MMI.UnivariateFinite(classes_seen, Pr)
end

function MMI.transform(
Expand Down Expand Up @@ -724,7 +725,8 @@ The fields of `fitted_params(mach)` are:
section below).

- `priors`: The class priors for classification. As inferred from training target `y`, if
not user-specified. A `UnivariateFinite` object with levels consistent with `levels(y)`.
not user-specified. A `UnivariateFinite` object with levels (classes) consistent with
`levels(y)`.

# Report

Expand Down Expand Up @@ -954,7 +956,8 @@ The fields of `fitted_params(mach)` are:
section below).

- `priors`: The class priors for classification. As inferred from training target `y`, if
not user-specified. A `UnivariateFinite` object with levels consistent with `levels(y)`.
not user-specified. A `UnivariateFinite` object with levels (classes) consistent with
`levels(y)`.

# Report

Expand Down
10 changes: 5 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# internal method essentially the same as Base.replace!(y, (z .=> r)...)
# but more efficient.
# Similar to the behaviour of `Base.replace!` if `z` contain repetions of values in
# Similar to the behaviour of `Base.replace!` if `z` contain repetions of values in
# `y` then only the transformation corresponding to the first occurence is performed
# i.e `_replace!([1,5,3], [1,4], 4:5)` would return `[4,5,3]` rather than `[5,5,3]`
# (which replaces `1=>4` and then `4=>5`)
function _replace!(y::AbstractVector, z::AbstractVector, r::AbstractVector)
length(r) == length(z) ||
length(r) == length(z) ||
throw(DimensionMismatch("`z` and `r` has to be of the same length"))
@inbounds for i in eachindex(y)
for j in eachindex(z)
for j in eachindex(z)
isequal(z[j], y[i]) && (y[i] = r[j]; break)
end
end
Expand All @@ -35,7 +35,7 @@ Implementation taken from NNlib.jl.
"""
function softmax!(X::AbstractMatrix{<:Real})
max_ = maximum(X, dims=2)
X .= exp.(X .- max_)
X .= exp.(X .- max_)
X ./= sum(X, dims=2)
return X
return X
end
20 changes: 10 additions & 10 deletions test/models/discriminant_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ end
ytest = selectrows(y, test)

BLDA_model = BayesianLDA(regcoef=0)

## Check model `fit`
fitresult, cache, report = fit(BLDA_model, 1, Xtrain, ytrain)
classes_seen, projection_matrix, priors = fitted_params(BLDA_model, fitresult)
@test classes(priors) == classes(y)
@test levels(priors) == levels(y)
@test pdf.(priors, support(priors)) == [491/998, 507/998]
@test classes_seen == ["Up", "Down"]
@test round.((report.class_means)', sigdigits = 3) == [-0.0395 -0.0313; 0.0428 0.0339] #[0.0428 0.0339; -0.0395 -0.0313]

## Check model `predict`
preds = predict(BLDA_model, fitresult, Xtest)
mce = cross_entropy(preds, ytest) |> mean
Expand All @@ -94,7 +94,7 @@ end
fitresult1, cache1, report1 = fit(BLDA_model1, 1, Xtrain, ytrain)
classes_seen1, projection_matrix1, priors1 = fitted_params(BLDA_model1, fitresult1)
BLDA_model2 = BayesianLDA(
regcoef=0, priors=UnivariateFinite(classes(ytrain), [491/998, 507/998])
regcoef=0, priors=UnivariateFinite(levels(ytrain), [491/998, 507/998])
)
fitresult2, cache2, report2 = fit(BLDA_model2, 1, Xtrain, ytrain)
classes_seen2, projection_matrix2, priors2 = fitted_params(BLDA_model2, fitresult2)
Expand Down Expand Up @@ -156,7 +156,7 @@ end
LDA_model1, fitresult1
)
LDA_model2 = BayesianSubspaceLDA(
priors=UnivariateFinite(classes(y), [1/3, 1/3, 1/3])
priors=UnivariateFinite(levels(y), [1/3, 1/3, 1/3])
)
fitresult2, cache2, report2 = fit(LDA_model2, 1, X, y)
classes_seen2, projection_matrix2, priors2 = fitted_params(
Expand Down Expand Up @@ -231,24 +231,24 @@ end
y2 = y[[1,2,1,2]]
@test_throws ArgumentError fit(model, 1, X, y2)

## Check to make sure error is thrown if UnivariateFinite `priors` doesn't have
## Check to make sure error is thrown if UnivariateFinite `priors` doesn't have
## common pool with target vector used in training.
model = BayesianLDA(priors=UnivariateFinite([0.1, 0.5, 0.4], pool=missing))
@test_throws ArgumentError fit(model, 1, X, y)

## Check to make sure error is thrown if keys used in `priors` dictionary are in pool
## Check to make sure error is thrown if keys used in `priors` dictionary are in pool
## of training target used in training.
model = BayesianLDA(priors=Dict("apples" => 0.1, "oranges"=>0.5, "bannana"=>0.4))
@test_throws ArgumentError fit(model, 1, X, y)

## Check to make sure error is thrown if sum(`priors`) isn't approximately equal to 1.
model = BayesianLDA(priors=UnivariateFinite(classes(y), [0.1, 0.5, 0.4, 0.2]))
model = BayesianLDA(priors=UnivariateFinite(levels(y), [0.1, 0.5, 0.4, 0.2]))
@test_throws ArgumentError fit(model, 1, X, y)

## Check to make sure error is thrown if `priors .< 0` or `priors .> 1`.
model = BayesianLDA(priors=Dict(classes(y) .=> [-0.1, 0.0, 1.0, 0.1]))
model = BayesianLDA(priors=Dict(levels(y) .=> [-0.1, 0.0, 1.0, 0.1]))
@test_throws ArgumentError fit(model, 1, X, y)
model = BayesianLDA(priors=Dict(classes(y) .=> [1.1, 0.0, 0.0, -0.1]))
model = BayesianLDA(priors=Dict(levels(y) .=> [1.1, 0.0, 0.0, -0.1]))
@test_throws ArgumentError fit(model, 1, X, y)

X2 = (x=rand(5),)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ import Dates
import MLJMultivariateStatsInterface: _replace!
import MultivariateStats
import Random
import CategoricalDistributions.levels

using LinearAlgebra
using MLJBase
using StatisticalMeasures
using MLJMultivariateStatsInterface
using StableRNGs
using Test
Expand Down
Loading