Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New flow/state abstractions #10

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

New flow/state abstractions #10

wants to merge 6 commits into from

Conversation

AntonOresten
Copy link
Member

@AntonOresten AntonOresten commented Jan 22, 2025

This PR is breaking and makes a few key changes:

  • Flow is now an abstract type AbstractFlow, with a subtype ManifoldFlow parameterized by the type of the manifold. This allows us to have generic code, but still dispatch on the manifold for specialized methods and GPU compatibility. Examples:
    • EuclideanFlow is now ManifoldFlow{<:Manifolds.Euclidean}
    • RotationalFlow is now ManifoldFlow{<:Manifolds.SpecialOrthogonal}
    • RelaxedDiscreteFlow is now Flow{<:Manifolds.ProbabilitySimplex}
  • FlowState{T,N} is now BatchedState{T,N} <: AbstractArray{State{T,N}}, where State{T,N} represents/wraps a point on the manifold represented by an array, and N the array dimensionality, whereas the old N included the batch dimensions.
    • BatchedState is an array of states, but we use ArraysOfArrays.ArrayOfSimilarArrays to make each element a view of the underlying array with an arbitrary number of batch dimension (accessible with flatarray). This allows generic code (e.g. interpolate(::ManifoldFlow, ::BatchedState, ...) to dispatch on another simpler method that only has to worry about the shortest geodesic between two singular states).
  • Types are parameterized by the structure of the manifold, instead of concrete subtypes being created to accommodate specific dimensionalities.
  • statesize exists as a helper function to ensure compatibility between a Flow and State/BatchedState.
    • statesize(Flow(Manifolds.Euclidean(3))) == (3,) since the points in euclidean space are represented as arrays, in this case vectors of length 3.
    • statesize(Flow(Manifolds.SpecialOrthogonal(3))) == (3, 3) since the rotations in the orthogonal group are represented as matrices.
    • statesize(State(rand(3, 3))) == statesize(BatchedState(rand(3, 3, 10))) == (3, 3) since the State is manifold-agnostic.
    • statesize(BatchedState(rand(3, 3, 10, 20, 30), 3)) == (3, 3) since the BatchedState can have multiple batch dimensions (for e.g. 2D images)
    • What's neat is that this abstraction also supports euclidean space with points that don't necessarily have vector embeddings, as would be the case with Manifolds.Euclidean(3, 3).

Does the gaussian noise perturb! method make sense for ProbabilitySimplex (see LinearFlow method), and should Flow{<:ProbabilitySimplex} belong to the LinearFlow union type?

Checklist

  • Write documentation
  • Clean up geometry.jl
  • Review old code that was commented out (ancient scrolls of wisdom)
  • Moar tests

@AntonOresten AntonOresten requested a review from murrellb January 22, 2025 17:00
@AntonOresten AntonOresten marked this pull request as draft January 22, 2025 17:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant