Conflux.jl is a toolkit designed to enable data parallelism for Flux.jl models by simplifying the process of replicating them across multiple GPUs on a single node, and by leveraging NCCL.jl for efficient inter-GPU communication. This package aims to provide a straightforward and intuitive interface for multi-GPU training, requiring minimal changes to existing code and training loops.
- Easy replication of objects across multiple GPUs with the replicate function
- Efficient synchronization of models and averaging of gradients with the allreduce! function, which takes an operator (e.g.
) and a set of replicas, and reduces all their parameters with the given operator, leaving the replicas identical. - A withdevices function that allows you to run code on each device asynchronously.
See the documentation for more details, examples, and important caveats.
The package can be installed with the Julia package manager. From the Julia REPL, type ]
to enter the Pkg REPL mode and run:
pkg> add
# Specify the default devices to use
using Conflux
using Flux, Optimisers
model = Chain(Dense(1 => 256, tanh), Dense(256 => 512, tanh), Dense(512 => 256, tanh), Dense(256 => 1))
# This will use the available devices. If you want to use a specific device, you can pass them in a second argument.
models = replicate(model)
opt = Optimisers.Adam(0.001f0)
# Instantiate the optimiser states on each device
states = Conflux.withdevices() do (i, device)
Optimisers.setup(opt, model) |> device
# A single batch, stored on CPU. Could use a more sophisticated mechanism to distribute multiple batches.
X = rand(1, 16)
Y = X .^ 2
loss(y, Y) = sum(abs2, y .- Y)
losses = []
for epoch in 1:10
# Get the gradients for each batch on each device
∇models = Conflux.withdevices() do (i, device)
x, y = device(X), device(Y)
# The second return value is a tuple because `Flux.withgradient` takes `args...`, and the model is the first argument.
l, (∇model,) = Flux.withgradient(m -> loss(m(x), y), models[i])
push!(losses, l)
# Average the gradients across devices
allreduce!(avg, ∇models...)
# Update the models on each device
Conflux.withdevices() do (i, device)
Optimisers.update!(states[i], models[i], ∇models[i])
# Optionally synchronize the models and optimiser states, in case the parameters diverge
#allreduce!(avg, models...)
#allreduce!(avg, states...)