From 0c8b88efee5ff3f0bf98e3bb1885c44313bd2184 Mon Sep 17 00:00:00 2001 From: billera Date: Fri, 31 Jan 2025 10:53:16 +0100 Subject: [PATCH] runtest expand --- test/runtests.jl | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index ba6275a..7066fc0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,37 @@ using ChainRulesTestUtils @testset "InvariantPointAttention.jl" begin + @testset "Rope Expand" begin + dims = 8 + settings = IPA_settings(dims) + + # generate random data + L = 10 + R = 10 + B = 1 + siL = randn(Float32, dims, L, B) + siR = siL + #zij = randn(Float32, c_z, R, L, B) + TiL = (get_rotation(L, B), get_translation(L, B)) + TiR = TiL + + # Left and right equal for self attention + TiL == TiR + siL == siR + + # Extend the cache along both left and right + ipa = IPCrossA(settings) + cache = InvariantPointAttention.IPACache(settings, B) + + rope = IPARoPE(ipa.settings.c, 100) + siRs = [] + for i in 1:10 + si, cache = InvariantPointAttention.expand(ipa, cache, TiL, siL, 1, TiR, siR, 1, rope= rope.rope[i:i]) + push!(siRs, si) + end + cat(siRs..., dims = 2) ≈ ipa(TiL, siL, TiR, siR; mask = right_to_left_mask(10), rope = rope[1:10]) + end + @testset "IPAsoftmax_invariance" begin batch_size = 3 framesL = 100