Skip to content

Commit

Permalink
Pairwise and colwise with convenience functions (#224)
Browse files Browse the repository at this point in the history
* Pairwise and colwise with convenienc functions

* Add tests

* Bump version
  • Loading branch information
devmotion authored Sep 4, 2021
1 parent 4f98bc5 commit a43d76b
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 25 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
name = "Distances"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.3"
version = "0.10.4"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[compat]
julia = "1"
StatsAPI = "1"
julia = "1"

[extras]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand Down
4 changes: 2 additions & 2 deletions src/bhattacharyya.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ end

# Bhattacharyya distance
(::BhattacharyyaDist)(a, b) = -log(bhattacharyya_coeff(a, b))
bhattacharyya(a, b) = BhattacharyyaDist()(a, b)
const bhattacharyya = BhattacharyyaDist()

# Hellinger distance
(::HellingerDist)(a, b) = sqrt(1 - bhattacharyya_coeff(a, b))
hellinger(a, b) = HellingerDist()(a, b)
const hellinger = HellingerDist()
2 changes: 1 addition & 1 deletion src/haversine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ function (dist::SphericalAngle)(x, y)
2 * asin( min(a, one(a)) ) # take care of floating point errors
end

spherical_angle(x, y) = SphericalAngle()(x, y)
const spherical_angle = SphericalAngle()

result_type(::Union{Haversine, SphericalAngle}, ::Type, ::Type) = Float64
40 changes: 20 additions & 20 deletions src/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ end
# Euclidean
@inline eval_op(::Euclidean, ai, bi) = abs2(ai - bi)
eval_end(::Euclidean, s) = sqrt(s)
euclidean(a, b) = Euclidean()(a, b)
const euclidean = Euclidean()

# Weighted Euclidean
@inline eval_op(::WeightedEuclidean, ai, bi, wi) = abs2(ai - bi) * wi
Expand All @@ -350,15 +350,15 @@ peuclidean(a, b, p) = PeriodicEuclidean(p)(a, b)

# SqEuclidean
@inline eval_op(::SqEuclidean, ai, bi) = abs2(ai - bi)
sqeuclidean(a, b) = SqEuclidean()(a, b)
const sqeuclidean = SqEuclidean()

# Weighted Squared Euclidean
@inline eval_op(::WeightedSqEuclidean, ai, bi, wi) = abs2(ai - bi) * wi
wsqeuclidean(a, b, w) = WeightedSqEuclidean(w)(a, b)

# Cityblock
@inline eval_op(::Cityblock, ai, bi) = abs(ai - bi)
cityblock(a, b) = Cityblock()(a, b)
const cityblock = Cityblock()

# Weighted City Block
@inline eval_op(::WeightedCityblock, ai, bi, wi) = abs((ai - bi) * wi)
Expand All @@ -367,14 +367,14 @@ wcityblock(a, b, w) = WeightedCityblock(w)(a, b)
# Total variation
@inline eval_op(::TotalVariation, ai, bi) = abs(ai - bi)
eval_end(::TotalVariation, s) = s / 2
totalvariation(a, b) = TotalVariation()(a, b)
const totalvariation = TotalVariation()

# Chebyshev
@inline eval_op(::Chebyshev, ai, bi) = abs(ai - bi)
@inline eval_reduce(::Chebyshev, s1, s2) = max(s1, s2)
# if only NaN, will output NaN
Base.@propagate_inbounds eval_start(::Chebyshev, a, b) = abs(first(a) - first(b))
chebyshev(a, b) = Chebyshev()(a, b)
const chebyshev = Chebyshev()

# Minkowski
@inline eval_op(dist::Minkowski, ai, bi) = abs(ai - bi)^dist.p
Expand All @@ -390,7 +390,7 @@ wminkowski(a, b, w, p::Real) = WeightedMinkowski(w, p)(a, b)
result_type(::Hamming, ::Type, ::Type) = Int # fallback for Hamming
eval_start(d::Hamming, a, b) = 0
@inline eval_op(::Hamming, ai, bi) = ai != bi ? 1 : 0
hamming(a, b) = Hamming()(a, b)
const hamming = Hamming()

# WeightedHamming
@inline eval_op(::WeightedHamming, ai, bi, wi) = ai != bi ? wi : zero(eltype(wi))
Expand All @@ -409,27 +409,27 @@ function eval_end(::CosineDist, s)
ab, a2, b2 = s
max(1 - ab / (sqrt(a2) * sqrt(b2)), 0)
end
cosine_dist(a, b) = CosineDist()(a, b)
const cosine_dist = CosineDist()

# CorrDist
_centralize(x) = x .- mean(x)
(::CorrDist)(a, b) = CosineDist()(_centralize(a), _centralize(b))
(::CorrDist)(a::Number, b::Number) = CosineDist()(zero(mean(a)), zero(mean(b)))
corr_dist(a, b) = CorrDist()(a, b)
const corr_dist = CorrDist()

# ChiSqDist
@inline eval_op(::ChiSqDist, ai, bi) = (d = abs2(ai - bi) / (ai + bi); ifelse(ai != bi, d, zero(d)))
chisq_dist(a, b) = ChiSqDist()(a, b)
const chisq_dist = ChiSqDist()

# KLDivergence
@inline eval_op(dist::KLDivergence, ai, bi) =
ai > 0 ? ai * log(ai / bi) : zero(eval_op(dist, oneunit(ai), bi))
kl_divergence(a, b) = KLDivergence()(a, b)
const kl_divergence = KLDivergence()

# GenKLDivergence
@inline eval_op(dist::GenKLDivergence, ai, bi) =
ai > 0 ? ai * log(ai / bi) - ai + bi : oftype(eval_op(dist, oneunit(ai), bi), bi)
gkl_divergence(a, b) = GenKLDivergence()(a, b)
const gkl_divergence = GenKLDivergence()

# RenyiDivergence
Base.@propagate_inbounds function eval_start(::RenyiDivergence, a, b)
Expand Down Expand Up @@ -494,7 +494,7 @@ end
tu = u > 0 ? u * log(u) : zero(log(one(T)))
ta + tb - tu
end
js_divergence(a, b) = JSDivergence()(a, b)
const js_divergence = JSDivergence()

# SpanNormDist

Expand All @@ -517,7 +517,7 @@ end

eval_end(::SpanNormDist, s) = s[2] - s[1]
(::SpanNormDist)(a::Number, b::Number) = zero(promote_type(typeof(a), typeof(b)))
spannorm_dist(a, b) = SpanNormDist()(a, b)
const spannorm_dist = SpanNormDist()

# Jaccard

Expand All @@ -537,7 +537,7 @@ end
@inbounds v = 1 - (a[1] / a[2])
return v
end
jaccard(a, b) = Jaccard()(a, b)
const jaccard = Jaccard()

# BrayCurtis

Expand All @@ -557,7 +557,7 @@ end
@inbounds v = a[1] / a[2]
return v
end
braycurtis(a, b) = BrayCurtis()(a, b)
const braycurtis = BrayCurtis()

# Tanimoto

Expand All @@ -583,24 +583,24 @@ end
@inbounds denominator = a[1] + a[4] + 2(a[2] + a[3])
numerator / denominator
end
rogerstanimoto(a, b) = RogersTanimoto()(a, b)
const rogerstanimoto = RogersTanimoto()

# Deviations

(::MeanAbsDeviation)(a, b) = cityblock(a, b) / length(a)
meanad(a, b) = MeanAbsDeviation()(a, b)
const meanad = MeanAbsDeviation()

(::MeanSqDeviation)(a, b) = sqeuclidean(a, b) / length(a)
msd(a, b) = MeanSqDeviation()(a, b)
const msd = MeanSqDeviation()

(::RMSDeviation)(a, b) = sqrt(MeanSqDeviation()(a, b))
rmsd(a, b) = RMSDeviation()(a, b)
const rmsd = RMSDeviation()

function (::NormRMSDeviation)(a, b)
amin, amax = extrema(a)
return RMSDeviation()(a, b) / (amax - amin)
end
nrmsd(a, b) = NormRMSDeviation()(a, b)
const nrmsd = NormRMSDeviation()


###########################################################
Expand Down
57 changes: 57 additions & 0 deletions test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,31 +517,50 @@ end
end

test_colwise(SqEuclidean(), X, Y, T)
test_colwise(sqeuclidean, X, Y, T)
test_colwise(Euclidean(), X, Y, T)
test_colwise(euclidean, X, Y, T)
test_colwise(Cityblock(), X, Y, T)
test_colwise(cityblock, X, Y, T)
test_colwise(TotalVariation(), X, Y, T)
test_colwise(totalvariation, X, Y, T)
test_colwise(Chebyshev(), X, Y, T)
test_colwise(chebyshev, X, Y, T)
test_colwise(Minkowski(2.5), X, Y, T)
test_colwise(Hamming(), A, B, T)
test_colwise(hamming, A, B, T)
test_colwise(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), X, Y, T);

test_colwise(CosineDist(), X, Y, T)
test_colwise(cosine_dist, X, Y, T)
test_colwise(CorrDist(), X, Y, T)
test_colwise(corr_dist, X, Y, T)

test_colwise(ChiSqDist(), X, Y, T)
test_colwise(chisq_dist, X, Y, T)
test_colwise(KLDivergence(), P, Q, T)
test_colwise(kl_divergence, P, Q, T)
test_colwise(GenKLDivergence(), P, Q, T)
test_colwise(gkl_divergence, P, Q, T)
test_colwise(RenyiDivergence(0.0), P, Q, T)
test_colwise(RenyiDivergence(1.0), P, Q, T)
test_colwise(RenyiDivergence(Inf), P, Q, T)
test_colwise(RenyiDivergence(0.5), P, Q, T)
test_colwise(RenyiDivergence(2), P, Q, T)
test_colwise(RenyiDivergence(10), P, Q, T)
test_colwise(JSDivergence(), P, Q, T)
test_colwise(js_divergence, P, Q, T)
test_colwise(SpanNormDist(), X, Y, T)
test_colwise(spannorm_dist, X, Y, T)

test_colwise(BhattacharyyaDist(), X, Y, T)
test_colwise(bhattacharyya, X, Y, T)
test_colwise(HellingerDist(), X, Y, T)
test_colwise(hellinger, X, Y, T)
test_colwise(BrayCurtis(), X, Y, T)
test_colwise(braycurtis, X, Y, T)
test_colwise(Jaccard(), X, Y, T)
test_colwise(jaccard, X, Y, T)

w = rand(T, m)

Expand Down Expand Up @@ -602,29 +621,50 @@ end
Q = rand(T, m, ny)

test_pairwise(SqEuclidean(), X, Y, T)
test_pairwise(sqeuclidean, X, Y, T)
test_pairwise(Euclidean(), X, Y, T)
test_pairwise(euclidean, X, Y, T)
test_pairwise(Cityblock(), X, Y, T)
test_pairwise(cityblock, X, Y, T)
test_pairwise(TotalVariation(), X, Y, T)
test_pairwise(totalvariation, X, Y, T)
test_pairwise(Chebyshev(), X, Y, T)
test_pairwise(chebyshev, X, Y, T)
test_pairwise(Minkowski(2.5), X, Y, T)
test_pairwise(Hamming(), A, B, T)
test_pairwise(hamming, A, B, T)

test_pairwise(CosineDist(), X, Y, T)
test_pairwise(cosine_dist, X, Y, T)
test_pairwise(CosineDist(), A, B, T)
test_pairwise(cosine_dist, A, B, T)
test_pairwise(CorrDist(), X, Y, T)
test_pairwise(corr_dist, X, Y, T)

test_pairwise(ChiSqDist(), X, Y, T)
test_pairwise(chisq_dist, X, Y, T)
test_pairwise(KLDivergence(), P, Q, T)
test_pairwise(kl_divergence, P, Q, T)
test_pairwise(GenKLDivergence(), P, Q, T)
test_pairwise(gkl_divergence, P, Q, T)
test_pairwise(RenyiDivergence(0.0), P, Q, T)
test_pairwise(RenyiDivergence(1.0), P, Q, T)
test_pairwise(RenyiDivergence(Inf), P, Q, T)
test_pairwise(RenyiDivergence(0.5), P, Q, T)
test_pairwise(RenyiDivergence(2), P, Q, T)
test_pairwise(JSDivergence(), P, Q, T)
test_pairwise(js_divergence, P, Q, T)
test_pairwise(SpanNormDist(), X, Y, T)
test_pairwise(spannorm_dist, X, Y, T)

test_pairwise(BhattacharyyaDist(), X, Y, T)
test_pairwise(bhattacharyya, X, Y, T)
test_pairwise(HellingerDist(), X, Y, T)
test_pairwise(hellinger, X, Y, T)
test_pairwise(BrayCurtis(), X, Y, T)
test_pairwise(braycurtis, X, Y, T)
test_pairwise(Jaccard(), X, Y, T)
test_pairwise(jaccard, X, Y, T)
test_pairwise(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), X, Y, T)

w = rand(m)
Expand All @@ -648,6 +688,7 @@ end
Y = rand(T, m, ny)
test_pairwise(Haversine(), X, Y, T)
test_pairwise(SphericalAngle(), X, Y, T)
test_pairwise(spherical_angle, X, Y, T)
end

@testset "pairwise metrics on complex arrays" begin
Expand Down Expand Up @@ -691,18 +732,30 @@ end
a = rand(1:3, nx)
b = rand(1:3, ny)
test_scalar_pairwise(SqEuclidean(), x, y, T)
test_scalar_pairwise(sqeuclidean, x, y, T)
test_scalar_pairwise(Euclidean(), x, y, T)
test_scalar_pairwise(euclidean, x, y, T)
test_scalar_pairwise(Cityblock(), x, y, T)
test_scalar_pairwise(cityblock, x, y, T)
test_scalar_pairwise(TotalVariation(), x, y, T)
test_scalar_pairwise(totalvariation, x, y, T)
test_scalar_pairwise(Chebyshev(), x, y, T)
test_scalar_pairwise(chebyshev, x, y, T)
test_scalar_pairwise(Minkowski(2.5), x, y, T)
test_scalar_pairwise(Hamming(), a, b, T)
test_scalar_pairwise(hamming, a, b, T)
test_scalar_pairwise(CosineDist(), x, y, T)
test_scalar_pairwise(cosine_dist, x, y, T)
test_scalar_pairwise(CosineDist(), a, b, T)
test_scalar_pairwise(cosine_dist, a, b, T)
test_scalar_pairwise(ChiSqDist(), x, y, T)
test_scalar_pairwise(chisq_dist, x, y, T)
test_scalar_pairwise(KLDivergence(), x, y, T)
test_scalar_pairwise(kl_divergence, x, y, T)
test_scalar_pairwise(JSDivergence(), x, y, T)
test_scalar_pairwise(js_divergence, x, y, T)
test_scalar_pairwise(BrayCurtis(), x, y, T)
test_scalar_pairwise(braycurtis, x, y, T)
w = rand(1, 1)
test_scalar_pairwise(WeightedSqEuclidean(w), x, y, T)
test_scalar_pairwise(WeightedEuclidean(w), x, y, T)
Expand All @@ -726,9 +779,13 @@ end
X = [0.3 0.3 + eps()]

@test all(x -> x >= 0, pairwise(SqEuclidean(), X; dims = 2))
@test all(x -> x >= 0, pairwise(sqeuclidean, X; dims = 2))
@test all(x -> x >= 0, pairwise(SqEuclidean(), X, X; dims = 2))
@test all(x -> x >= 0, pairwise(sqeuclidean, X, X; dims = 2))
@test all(x -> x >= 0, pairwise(Euclidean(), X; dims = 2))
@test all(x -> x >= 0, pairwise(euclidean, X; dims = 2))
@test all(x -> x >= 0, pairwise(Euclidean(), X, X; dims = 2))
@test all(x -> x >= 0, pairwise(euclidean, X, X; dims = 2))
@test all(x -> x >= 0, pairwise(WeightedSqEuclidean([1.0]), X; dims = 2))
@test all(x -> x >= 0, pairwise(WeightedSqEuclidean([1.0]), X, X; dims = 2))
@test all(x -> x >= 0, pairwise(SqMahalanobis(ones(1, 1)), X; dims = 2))
Expand Down

2 comments on commit a43d76b

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/44203

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.4 -m "<description of version>" a43d76b23874c10d8ab44909f7d4816ba88ef7db
git push origin v0.10.4

Please sign in to comment.