Skip to content

Commit

Permalink
rope expand
Browse files Browse the repository at this point in the history
  • Loading branch information
billera authored Jan 21, 2025
1 parent 065ee34 commit 4514973
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,13 +472,15 @@ function IPACache(settings::NamedTuple, batchsize::Integer)
IPACache(0, 0, batchsize, qh, kh, vh, qhp, khp, vhp)
end


function expand(
ipa::IPCrossA,
cache::IPACache,
TiL::Tuple, siL::AbstractArray, ΔL::Integer,
TiR::Tuple, siR::AbstractArray, ΔR::Integer;
zij = nothing,
mask = 0,
rope = nothing
)
dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings
if haskey(ipa.settings, :use_softmax1) #For compat
Expand All @@ -494,7 +496,11 @@ function expand(
gamma_h = softplus(clamp.(layer.gamma_h,Typ(-100), Typ(100))) #Clamping

Δqh = reshape(calldense(layer.proj_qh, siR[:,R+1:R+ΔR,:]), (c, N_head, ΔR, B))
Δkh = reshape(calldense(layer.proj_kh, siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B))
Δkh = reshape(calldense(layer.proj_kh, siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B))
if !isnothing(rope)
Δqh = rope(Δqh)
Δkh = rope(Δkh)
end
Δvh = reshape(calldense(layer.proj_vh, siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B))

Δqhp = reshape(calldense(layer.proj_qhp, siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, B))
Expand Down

0 comments on commit 4514973

Please sign in to comment.