Skip to content

Commit

Permalink
optimize permutedim and pass chain_diff
Browse files Browse the repository at this point in the history
  • Loading branch information
billera committed Dec 31, 2024
1 parent 8f7fa5b commit fbb84d9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
14 changes: 12 additions & 2 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ function (ipa::Union{IPCrossA, IPA})(
end

if customgrad
return ipa_customgrad(ipa, TiL, siL, zij, mask; rope)
return ipa_customgrad(ipa, TiL, siL, zij, mask, rope = rope, chain_diffs = chain_diffs)
end

if !isnothing(zij)
Expand Down Expand Up @@ -268,7 +268,14 @@ function (ipa::Union{IPCrossA, IPA})(
return si
end

function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,AbstractArray}, S::AbstractArray, zij::AbstractArray, mask::AbstractArray, rope::Union{IPARoPE, Nothing} = nothing, chain_diffs = 1)
function ipa_customgrad(
ipa::Union{IPCrossA, IPA},
Ti::Tuple{AbstractArray,AbstractArray},
S::AbstractArray,
zij::Union{AbstractArray, Nothing},
mask::AbstractArray;
rope = nothing,
chain_diffs = 1)
# Get relevant parameters from our ipa struct.
l = ipa.layers
dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings
Expand Down Expand Up @@ -312,6 +319,8 @@ function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,Abstr
vhp = reshape(l.proj_vhp(siL),(3,N_head*N_point_values,N_frames_L,:))

if !isnothing(rope)
# @show size(qh)
#@show size(kh)
qhTkh = dotproducts(rope, qh, kh; chain_diffs)
else
qhTkh = dotproducts(qh, kh)
Expand Down Expand Up @@ -638,6 +647,7 @@ function expand(
layer.ipa_linear(o), cache
end


sumdrop(x; dims) = dropdims(sum(x; dims); dims)
sumdrop(f, x; dims) = dropdims(sum(f, x; dims); dims)

Expand Down
20 changes: 13 additions & 7 deletions src/rope.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,24 @@ function RoPEdotproducts(iparope::IPARoPE, q, k; chain_diffs = nothing)
chain_diffs is either nothing or a array of 0's and 1's describing the ij-pair as pertaining to the same chain if the entry at ij is 1, else 0.
"""
function dotproducts(iparope::IPARoPE, qh::AbstractArray{T, 4}, kh::AbstractArray{T, 4}; chain_diffs = 1) where T<: Real
# O(N) permutedims, shouldn't be too bad.
qropshape = permutedims(qh, (1,3,2,4))
kropshape = permutedims(kh, (1,3,2,4))
rotq, rotk = permutedims(iparope.rope(qropshape), (1,3,2,4)), permutedims(iparope.rope(kropshape), (1,3,2,4))
rotqTrotk = dotproducts(rotq, rotk)
kropshape = permutedims(kh, (1,3,2,4))
rotq, rotk = permutedims(iparope.rope(qropshape), (2,1,3,4)), iparope.rope(kropshape)
rotqTrotk = permutedims(batched_mul(
rotq,
rotk
), (3,1,2,4))

# when things are from different chain, we rotate only the queries by a fixed amount
if chain_diffs != 1
#return qropshape
rotq2 = permutedims(iparope.fixed_rope(qropshape), (1,3,2,4))
rotq2Trotk2 = dotproducts(rotq2, kh)
rotq2 = permutedims(iparope.fixed_rope(qropshape), (2,1,3,4))
rotq2Trotk2 = permutedims(batched_mul(
rotq2,
kropshape
), (3,1,2,4))
# unsqueeze chain diffs to shape 1, framesR, framesL
rotqTrotk = unsqueeze(chain_diffs, 1) .* rotqTrotk .+ (1 .- unsqueeze(chain_diffs, 1) .* rotq2Trotk2)
rotqTrotk = Flux.unsqueeze(chain_diffs, 1) .* rotqTrotk .+ (1 .- Flux.unsqueeze(chain_diffs, 1) .* rotq2Trotk2)
end
return rotqTrotk
end
Expand Down

0 comments on commit fbb84d9

Please sign in to comment.