Skip to content

Commit

Permalink
Refactor fallbacks (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
juliohm authored Jan 28, 2025
1 parent 37f1e06 commit 65e6df4
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 14 deletions.
16 changes: 14 additions & 2 deletions src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ Predict one or multiple variables `vars` at geometry `gₒ` with
given geostatistical `model`.
"""
predict(model::FittedGeoStatsModel, var::AbstractString, gₒ) = predict(model, Symbol(var), gₒ)
predict(model::FittedGeoStatsModel, vars, gₒ) = [predict(model, var, gₒ) for var in vars]
function predict(model::FittedGeoStatsModel, vars, gₒ)
if length(vars) > 1
throw(ArgumentError("cannot use univariate model to predict multiple variables"))
else
[predict(model, first(vars), gₒ)]
end
end

"""
predictprob(model, vars, gₒ)
Expand All @@ -42,7 +48,13 @@ Predict distribution of one or multiple variables `vars` at
geometry `gₒ` with given geostatistical `model`.
"""
predictprob(model::FittedGeoStatsModel, var::AbstractString, gₒ) = predictprob(model, Symbol(var), gₒ)
predictprob(model::FittedGeoStatsModel, vars, gₒ) = [predictprob(model, var, gₒ) for var in vars]
function predictprob(model::FittedGeoStatsModel, vars, gₒ)
if length(vars) > 1
throw(ArgumentError("cannot use univariate model to predict multiple variables"))
else
[predictprob(model, first(vars), gₒ)]
end
end

"""
status(fitted)
Expand Down
2 changes: 1 addition & 1 deletion test/idw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
@test pred isa Composition
end

@testset "Single/Multiple" begin
@testset "Fallbacks" begin
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
idw = GeoStatsModels.fit(IDW(), d)
pred1 = GeoStatsModels.predict(idw, :z, Point(0.0, 0.0))
Expand Down
2 changes: 1 addition & 1 deletion test/krig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@
@test all((0), var.(dkdist))
end

@testset "Single/Multiple" begin
@testset "Fallbacks" begin
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
γ = GaussianVariogram()
ok = GeoStatsModels.fit(OK(γ), d)
Expand Down
2 changes: 1 addition & 1 deletion test/lwr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
@test pred isa Composition
end

@testset "Single/Multiple" begin
@testset "Fallbacks" begin
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
lwr = GeoStatsModels.fit(LWR(), d)
pred1 = GeoStatsModels.predict(lwr, :z, Point(0.0, 0.0))
Expand Down
5 changes: 2 additions & 3 deletions test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

# fitpredict with IDW
pset = PointSet(rand(rng, Point, 3))
gtb = georef((a=[1, 2, 3], b=[4, 5, 6]), pset)
gtb = georef((z=[1, 2, 3],), pset)
pred = GeoStatsModels.fitpredict(IDW(), gtb, pset, neighbors=false)
@test pred.a == gtb.a
@test pred.b == gtb.b
@test pred.z == gtb.z
@test pred.geometry == gtb.geometry

# also works with views
Expand Down
2 changes: 1 addition & 1 deletion test/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@test pred isa Composition
end

@testset "Single/Multiple" begin
@testset "Fallbacks" begin
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
nn = GeoStatsModels.fit(NN(), d)
pred1 = GeoStatsModels.predict(nn, :z, Point(0.0, 0.0))
Expand Down
9 changes: 4 additions & 5 deletions test/poly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@

# correct schema
rng = StableRNG(42)
d = georef((a=rand(rng, 10), b=rand(rng, 10)), rand(rng, Point, 10))
d = georef((z=rand(rng, 10),), rand(rng, Point, 10))
= fitpredict(Polynomial(), d)
= values(d̄)
@test propertynames(t̄) == (:a, :b)
@test eltype(t̄.a) == Float64
@test eltype(t̄.b) == Float64
@test propertynames(t̄) == (:z,)
@test eltype(t̄.z) == Float64

# latlon coordinates
d = georef((; z=[1, 2, 3]), Point.([LatLon(0, 0), LatLon(0, 1), LatLon(1, 0)]))
Expand All @@ -68,7 +67,7 @@
@test pred isa Composition
end

@testset "Single/Multiple" begin
@testset "Fallbacks" begin
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
poly = GeoStatsModels.fit(Polynomial(), d)
pred1 = GeoStatsModels.predict(poly, :z, Point(0.0, 0.0))
Expand Down

0 comments on commit 65e6df4

Please sign in to comment.