From 9e2380993857d9028c7106e15fb3108cfe05153d Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 30 Oct 2021 21:01:04 +0200 Subject: [PATCH] Add specialized `pairwise` methods to `*msd` (#232) --- Project.toml | 2 +- src/metrics.jl | 51 +++++++++++++++++++++++++++++++++++++++------- test/test_dists.jl | 6 ++++++ 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 2bf7023..2da6062 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Distances" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.5" +version = "0.10.6" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/metrics.jl b/src/metrics.jl index c479605..f593b68 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -731,10 +731,50 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE r end +# MeanSqDeviation, RMSDeviation, NormRMSDeviation +function _pairwise!(r::AbstractMatrix, dist::MeanSqDeviation, a::AbstractMatrix, b::AbstractMatrix) + _pairwise!(r, SqEuclidean(), a, b) + # TODO: Replace by rdiv!(r, size(a, 1)) once julia compat ≥v1.2 + s = size(a, 1) + @simd for I in eachindex(r) + @inbounds r[I] /= s + end + return r +end +_pairwise!(r::AbstractMatrix, dist::RMSDeviation, a::AbstractMatrix, b::AbstractMatrix) = + sqrt!(_pairwise!(r, MeanSqDeviation(), a, b)) +function _pairwise!(r::AbstractMatrix, dist::NormRMSDeviation, a::AbstractMatrix, b::AbstractMatrix) + _pairwise!(r, RMSDeviation(), a, b) + @views for (i, j) in zip(axes(r, 1), axes(a, 2)) + amin, amax = extrema(a[:,j]) + r[i,:] ./= amax - amin + end + return r +end + +function _pairwise!(r::AbstractMatrix, dist::MeanSqDeviation, a::AbstractMatrix) + _pairwise!(r, SqEuclidean(), a) + # TODO: Replace by rdiv!(r, size(a, 1)) once julia compat ≥v1.2 + s = size(a, 1) + @simd for I in eachindex(r) + @inbounds r[I] /= s + end + return r +end +_pairwise!(r::AbstractMatrix, dist::RMSDeviation, a::AbstractMatrix) = + sqrt!(_pairwise!(r, MeanSqDeviation(), a)) +function _pairwise!(r::AbstractMatrix, dist::NormRMSDeviation, a::AbstractMatrix) + _pairwise!(r, RMSDeviation(), a) + @views for (i, j) in zip(axes(r, 1), axes(a, 2)) + amin, amax = extrema(a[:,j]) + r[i,:] ./= amax - amin + end + return r +end + # CosineDist -function _pairwise!(r::AbstractMatrix, ::CosineDist, - a::AbstractMatrix, b::AbstractMatrix) +function _pairwise!(r::AbstractMatrix, ::CosineDist, a::AbstractMatrix, b::AbstractMatrix) require_one_based_indexing(r, a, b) m, na, nb = get_pairwise_dims(r, a, b) inplace = promote_type(eltype(r), typeof(oneunit(eltype(a))'oneunit(eltype(b)))) === eltype(r) @@ -772,10 +812,7 @@ end # 2. pre-calculated `_centralize_colwise` avoids four times of redundant computations # of `_centralize` -- ~4x speed up _centralize_colwise(x::AbstractMatrix) = x .- mean(x, dims=1) -function _pairwise!(r::AbstractMatrix, ::CorrDist, - a::AbstractMatrix, b::AbstractMatrix) +_pairwise!(r::AbstractMatrix, ::CorrDist, a::AbstractMatrix, b::AbstractMatrix) = _pairwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b)) -end -function _pairwise!(r::AbstractMatrix, ::CorrDist, a::AbstractMatrix) +_pairwise!(r::AbstractMatrix, ::CorrDist, a::AbstractMatrix) = _pairwise!(r, CosineDist(), _centralize_colwise(a)) -end diff --git a/test/test_dists.jl b/test/test_dists.jl index 086f3a0..9dd6dad 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -686,6 +686,12 @@ end test_pairwise(cityblock, X, Y, T) test_pairwise(TotalVariation(), X, Y, T) test_pairwise(totalvariation, X, Y, T) + test_pairwise(MeanSqDeviation(), X, Y, T) + test_pairwise(msd, X, Y, T) + test_pairwise(RMSDeviation(), X, Y, T) + test_pairwise(rmsd, X, Y, T) + test_pairwise(NormRMSDeviation(), X, Y, T) + test_pairwise(nrmsd, X, Y, T) test_pairwise(Chebyshev(), X, Y, T) test_pairwise(chebyshev, X, Y, T) test_pairwise(Minkowski(2.5), X, Y, T)