From 9d89985a43b92330afb55cc7fcc4686f739e1366 Mon Sep 17 00:00:00 2001 From: anton083 Date: Tue, 11 Jun 2024 11:46:13 +0200 Subject: [PATCH] Fixes --- src/InvariantPointAttention.jl | 7 +++---- src/layers.jl | 18 +++++++++++++++++- src/rotational_utils.jl | 20 +++++++++++++++----- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/InvariantPointAttention.jl b/src/InvariantPointAttention.jl index fa11fc2..865b4ff 100644 --- a/src/InvariantPointAttention.jl +++ b/src/InvariantPointAttention.jl @@ -9,11 +9,10 @@ include("grads.jl") include("layers.jl") include("masks.jl") -export IPA -export IPAStructureModuleLayer -export BackboneUpdate export IPA_settings -export IPCrossA +export IPA, IPCrossA +export IPAStructureModuleLayer, IPCrossAStructureModuleLayer +export BackboneUpdate export right_to_left_mask export left_to_right_mask export virtual_residues diff --git a/src/layers.jl b/src/layers.jl index 35a55d6..d8bd34e 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -22,6 +22,18 @@ function (backboneupdate::BackboneUpdate)(Ti, si) end """ + IPA_settings( + dims; + c = 16, + N_head = 12, + N_query_points = 4, + N_point_values = 8, + c_z = 0, + Typ = Float32, + use_softmax1 = false, + scaling_qk = :default, + ) + Returns a tuple of the IPA settings, with defaults for everything except dims. This can be passed to the IPA and IPCrossAStructureModuleLayer. """ IPA_settings( @@ -49,7 +61,11 @@ IPA_settings( """ + IPCrossA(settings) + Invariant Point Cross Attention (IPCrossA). Information flows from L (Keys, Values) to R (Queries). + +Get `settings` with [`IPA_settings`](@ref) """ struct IPCrossA settings::NamedTuple @@ -149,7 +165,7 @@ function (ipa::Union{IPCrossA, IPA})( else use_softmax1 = false end - + rot_TiL, translate_TiL = TiL rot_TiR, translate_TiR = TiR diff --git a/src/rotational_utils.jl b/src/rotational_utils.jl index 9cd72c5..2780dd3 100644 --- a/src/rotational_utils.jl +++ b/src/rotational_utils.jl @@ -79,16 +79,26 @@ Applies the SE3 transformations T = (rot,trans) ∈ SE(3)^N to N batches of m points in R3, i.e., mat ∈ R^(3 x m x N) ↦ T(mat) ∈ R^(3 x m x N). Note here that rotations here are represented in matrix form. """ -function T_R3(x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N} - return batched_mul(R, x) .+ t +function T_R3(x::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T + x′ = reshape(x, 3, size(x, 2), :) + R′ = reshape(R, 3, 3, :) + t′ = reshape(t, 3, 1, :) + y′ = batched_mul(R′, x′) .+ t′ + y = reshape(y′, size(x)) + return y end -""" +""" Applies the group inverse of the SE3 transformations T = (R,t) ∈ SE(3)^N to N batches of m points in R3, such that T^-1(T*x) = T^-1(Rx+t) = R^T(Rx+t-t) = x. """ -function T_R3_inv(x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N} - return batched_mul_T1(R, x .- t) +function T_R3_inv(y::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T + y′ = reshape(y, 3, size(y, 2), :) + R′ = reshape(R, 3, 3, :) + t′ = reshape(t, 3, 1, :) + x′ = batched_mul(batched_transpose(R′), y′ .- t′) + x = reshape(x′, size(y)) + return x end """