diff --git a/src/layers.jl b/src/layers.jl index 5e8753a..38de2f6 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -268,7 +268,7 @@ function (ipa::Union{IPCrossA, IPA})( return si end -function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,AbstractArray}, S::AbstractArray, zij::AbstractArray, mask::AbstractArray, rope::Union{IPARoPE, Nothing} = nothing) +function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,AbstractArray}, S::AbstractArray, zij::AbstractArray, mask::AbstractArray, rope::Union{IPARoPE, Nothing} = nothing, chain_diffs = 1) # Get relevant parameters from our ipa struct. l = ipa.layers dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings