Skip to content

Commit

Permalink
Merge pull request #61 from tjjarvinen/mlj
Browse files Browse the repository at this point in the history
MLJ extension
  • Loading branch information
wcwitt authored Aug 9, 2023
2 parents 24d656f + 6f5d589 commit 2c61394
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 1 deletion.
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@ SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJScikitLearnInterface = "5ae90465-5518-4432-b9d2-8a1def2f0cab"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[extensions]
ACEfit_PythonCall_ext = "PythonCall"
ACEfit_MLJLinearModels_ext = [ "MLJ", "MLJLinearModels" ]
ACEfit_MLJScikitLearnInterface_ext = ["MLJScikitLearnInterface", "PythonCall", "MLJ"]

[compat]
julia = "1.9"
IterativeSolvers = "0.9.2"
MLJ = "0.19"
MLJLinearModels = "0.9"
MLJScikitLearnInterface = "0.5"
LowRankApprox = "0.5.3"
Optim = "1.7"
ParallelDataTransfer = "0.5.0"
Expand Down
29 changes: 29 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,35 @@ using ACEfit
using PythonCall
```

## MLJ solvers

To use [MLJ](https://github.com/alan-turing-institute/MLJ.jl) solvers you need to load MLJ in addition to ACEfit

```julia
using ACEfit
using MLJ
```

After that you need to load an appropriate MLJ solver. Take a look on available MLJ [solvers](https://alan-turing-institute.github.io/MLJ.jl/dev/model_browser/). Note that only [MLJScikitLearnInterface.jl](https://github.com/JuliaAI/MLJScikitLearnInterface.jl) and [MLJLinearModels.jl](https://github.com/JuliaAI/MLJLinearModels.jl) have extension available. To use other MLJ solvers please file an issue.

You need to load the solver and then create a solver structure

```julia
# Load ARD solver
ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface

# Create the solver itself and give it parameters
solver = ARDRegressor(
n_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
)
```

After this you can use the MLJ solver like any other solver.

## Index

```@index
```

Expand Down
53 changes: 53 additions & 0 deletions ext/ACEfit_MLJLinearModels_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module ACEfit_MLJLinearModels_ext

using MLJ
using ACEfit
using MLJLinearModels

"""
ACEfit.solve(solver, A, y)
Overloads `ACEfit.solve` to use MLJLinearModels solvers,
when `solver` is [MLJLinearModels](https://github.com/JuliaAI/MLJLinearModels.jl) solver.
# Example
```julia
using MLJ
using ACEfit
# Load Lasso solver
LassoRegressor = @load LassoRegressor pkg=MLJLinearModels
# Create the solver itself and give it parameters
solver = LassoRegressor(
lambda = 0.2,
fit_intercept = false
# insert more fit params
)
# fit ACE model
linear_fit(training_data, basis, solver)
# or lower level
ACEfit.fit(solver, A, y)
```
"""
function ACEfit.solve(solver::Union{
MLJLinearModels.ElasticNetRegressor,
MLJLinearModels.HuberRegressor,
MLJLinearModels.LADRegressor,
MLJLinearModels.LassoRegressor,
MLJLinearModels.LinearRegressor,
MLJLinearModels.QuantileRegressor,
MLJLinearModels.RidgeRegressor,
MLJLinearModels.RobustRegressor,
},
A, y)
Atable = MLJ.table(A)
mach = machine(solver, Atable, y)
MLJ.fit!(mach)
params = fitted_params(mach)
return Dict{String, Any}("C" => map( x->x.second, params.coefs) )
end

end
46 changes: 46 additions & 0 deletions ext/ACEfit_MLJScikitLearnInterface_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module ACEfit_MLJScikitLearnInterface_ext

using ACEfit
using MLJ
using MLJScikitLearnInterface
using PythonCall


"""
ACEfit.solve(solver, A, y)
Overloads `ACEfit.solve` to use scikitlearn solvers from MLJ.
# Example
```julia
using MLJ
using ACEfit
# Load ARD solver
ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface
# Create the solver itself and give it parameters
solver = ARDRegressor(
n_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
# more params
)
# fit ACE model
linear_fit(training_data, basis, solver)
# or lower level
ACEfit.fit(solver, A, y)
```
"""
function ACEfit.solve(solver, A, y)
Atable = MLJ.table(A)
mach = machine(solver, Atable, y)
MLJ.fit!(mach)
params = fitted_params(mach)
c = params.coef
return Dict{String, Any}("C" => pyconvert(Array, c) )
end

end
6 changes: 5 additions & 1 deletion src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ struct SKLEARN_BRR
n_iter::Integer
end

SKLEARN_BRR(; tol = 1e-3, n_iter = 300) = SKLEARN_BRR(tol, n_iter)
function SKLEARN_BRR(; tol = 1e-3, n_iter = 300)
@warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this."
SKLEARN_BRR(tol, n_iter)
end

# solve(solver::SKLEARN_BRR, ...) is implemented in ext/

Expand All @@ -140,6 +143,7 @@ struct SKLEARN_ARD
end

function SKLEARN_ARD(; n_iter = 300, tol = 1e-3, threshold_lambda = 10000)
@warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this."
SKLEARN_ARD(n_iter, tol, threshold_lambda)
end

Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJScikitLearnInterface = "5ae90465-5518-4432-b9d2-8a1def2f0cab"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ using Test
@testset "Bayesian Linear" begin include("test_bayesianlinear.jl") end

@testset "Linear Solvers" begin include("test_linearsolvers.jl") end

@testset "MLJ Solvers" begin include("test_mlj.jl") end
end
42 changes: 42 additions & 0 deletions test/test_mlj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using ACEfit
using LinearAlgebra
using MLJ
using MLJScikitLearnInterface

@info("Test MLJ interface on overdetermined system")
Nobs = 10_000
Nfeat = 100
A = randn(Nobs, Nfeat) / sqrt(Nobs)
y = randn(Nobs)
P = Diagonal(1.0 .+ rand(Nfeat))


@info(" ... MLJLinearModels LinearRegressor")
LinearRegressor = @load LinearRegressor pkg=MLJLinearModels
solver = LinearRegressor()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)


@info(" ... MLJLinearModels LassoRegressor")
LassoRegressor = @load LassoRegressor pkg=MLJLinearModels
solver = LassoRegressor()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)


@info(" ... MLJ SKLearn ARD")
ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface
solver = ARDRegressor(
n_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)

0 comments on commit 2c61394

Please sign in to comment.