From fbb84d9e98509e4b5cff2cb9f3ea795de2e3e5d9 Mon Sep 17 00:00:00 2001 From: billera Date: Tue, 31 Dec 2024 21:41:26 +0100 Subject: [PATCH] optimize permutedim and pass chain_diff --- src/layers.jl | 14 ++++++++++++-- src/rope.jl | 20 +++++++++++++------- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index 38de2f6..3f9d21a 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -147,7 +147,7 @@ function (ipa::Union{IPCrossA, IPA})( end if customgrad - return ipa_customgrad(ipa, TiL, siL, zij, mask; rope) + return ipa_customgrad(ipa, TiL, siL, zij, mask, rope = rope, chain_diffs = chain_diffs) end if !isnothing(zij) @@ -268,7 +268,14 @@ function (ipa::Union{IPCrossA, IPA})( return si end -function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,AbstractArray}, S::AbstractArray, zij::AbstractArray, mask::AbstractArray, rope::Union{IPARoPE, Nothing} = nothing, chain_diffs = 1) +function ipa_customgrad( + ipa::Union{IPCrossA, IPA}, + Ti::Tuple{AbstractArray,AbstractArray}, + S::AbstractArray, + zij::Union{AbstractArray, Nothing}, + mask::AbstractArray; + rope = nothing, + chain_diffs = 1) # Get relevant parameters from our ipa struct. l = ipa.layers dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings @@ -312,6 +319,8 @@ function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,Abstr vhp = reshape(l.proj_vhp(siL),(3,N_head*N_point_values,N_frames_L,:)) if !isnothing(rope) + # @show size(qh) + #@show size(kh) qhTkh = dotproducts(rope, qh, kh; chain_diffs) else qhTkh = dotproducts(qh, kh) @@ -638,6 +647,7 @@ function expand( layer.ipa_linear(o), cache end + sumdrop(x; dims) = dropdims(sum(x; dims); dims) sumdrop(f, x; dims) = dropdims(sum(f, x; dims); dims) diff --git a/src/rope.jl b/src/rope.jl index 5f4082b..da28a2d 100644 --- a/src/rope.jl +++ b/src/rope.jl @@ -118,18 +118,24 @@ function RoPEdotproducts(iparope::IPARoPE, q, k; chain_diffs = nothing) chain_diffs is either nothing or a array of 0's and 1's describing the ij-pair as pertaining to the same chain if the entry at ij is 1, else 0. """ function dotproducts(iparope::IPARoPE, qh::AbstractArray{T, 4}, kh::AbstractArray{T, 4}; chain_diffs = 1) where T<: Real - # O(N) permutedims, shouldn't be too bad. qropshape = permutedims(qh, (1,3,2,4)) - kropshape = permutedims(kh, (1,3,2,4)) - rotq, rotk = permutedims(iparope.rope(qropshape), (1,3,2,4)), permutedims(iparope.rope(kropshape), (1,3,2,4)) - rotqTrotk = dotproducts(rotq, rotk) + kropshape = permutedims(kh, (1,3,2,4)) + rotq, rotk = permutedims(iparope.rope(qropshape), (2,1,3,4)), iparope.rope(kropshape) + rotqTrotk = permutedims(batched_mul( + rotq, + rotk + ), (3,1,2,4)) + # when things are from different chain, we rotate only the queries by a fixed amount if chain_diffs != 1 #return qropshape - rotq2 = permutedims(iparope.fixed_rope(qropshape), (1,3,2,4)) - rotq2Trotk2 = dotproducts(rotq2, kh) + rotq2 = permutedims(iparope.fixed_rope(qropshape), (2,1,3,4)) + rotq2Trotk2 = permutedims(batched_mul( + rotq2, + kropshape + ), (3,1,2,4)) # unsqueeze chain diffs to shape 1, framesR, framesL - rotqTrotk = unsqueeze(chain_diffs, 1) .* rotqTrotk .+ (1 .- unsqueeze(chain_diffs, 1) .* rotq2Trotk2) + rotqTrotk = Flux.unsqueeze(chain_diffs, 1) .* rotqTrotk .+ (1 .- Flux.unsqueeze(chain_diffs, 1) .* rotq2Trotk2) end return rotqTrotk end