Skip to content

Commit

Permalink
Merge pull request #649 from genn-team/xorwow_mem_reduce
Browse files Browse the repository at this point in the history
Don't keep box muller transform state between kernel launches
  • Loading branch information
neworderofjamie authored Jan 28, 2025
2 parents 3704a7d + beb19d4 commit e9e40b9
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 97 deletions.
10 changes: 3 additions & 7 deletions include/genn/backends/cuda/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
AtomicOperation op = AtomicOperation::ADD,
AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const final;

//! Get type of population RNG
virtual Type::ResolvedType getPopulationRNGType() const final;

//--------------------------------------------------------------------------
// CodeGenerator::BackendBase virtuals
//--------------------------------------------------------------------------
Expand All @@ -191,10 +188,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
virtual std::unique_ptr<Runtime::ArrayBase> createArray(const Type::ResolvedType &type, size_t count,
VarLocation location, bool uninitialized) const final;

//! Create array of backend-specific population RNGs (if they are initialised on host this will occur here)
/*! \param count number of RNGs required*/
virtual std::unique_ptr<Runtime::ArrayBase> createPopulationRNG(size_t count) const final;

//! Generate code to allocate variable with a size known at runtime
virtual void genLazyVariableDynamicAllocation(CodeStream &os,
const Type::ResolvedType &type, const std::string &name, VarLocation loc,
Expand Down Expand Up @@ -234,6 +227,9 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
return m_ChosenDevice.totalConstMem - getPreferences<Preferences>().constantCacheOverhead;
}

//! Get internal type population RNG gets loaded into
virtual Type::ResolvedType getPopulationRNGInternalType() const final;

//! Get library of RNG functions to use
virtual const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) const final;

Expand Down
10 changes: 3 additions & 7 deletions include/genn/backends/hip/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
AtomicOperation op = AtomicOperation::ADD,
AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const final;

//! Get type of population RNG
virtual Type::ResolvedType getPopulationRNGType() const final;

//--------------------------------------------------------------------------
// CodeGenerator::BackendBase virtuals
//--------------------------------------------------------------------------
Expand All @@ -182,10 +179,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
virtual std::unique_ptr<Runtime::ArrayBase> createArray(const Type::ResolvedType &type, size_t count,
VarLocation location, bool uninitialized) const final;

//! Create array of backend-specific population RNGs (if they are initialised on host this will occur here)
/*! \param count number of RNGs required*/
virtual std::unique_ptr<Runtime::ArrayBase> createPopulationRNG(size_t count) const final;

//! Generate code to allocate variable with a size known at runtime
virtual void genLazyVariableDynamicAllocation(CodeStream &os,
const Type::ResolvedType &type, const std::string &name, VarLocation loc,
Expand Down Expand Up @@ -225,6 +218,9 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
return m_ChosenDevice.totalConstMem - getPreferences<Preferences>().constantCacheOverhead;
}

//! Get internal type population RNG gets loaded into
virtual Type::ResolvedType getPopulationRNGInternalType() const final;

//! Get library of RNG functions to use
virtual const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) const final;

Expand Down
27 changes: 16 additions & 11 deletions include/genn/genn/code_generator/backendCUDAHIP.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
m_RandPrefix(randPrefix), m_CCLPrefix(cclPrefix)
{}

//--------------------------------------------------------------------------
// Declared virtuals
//--------------------------------------------------------------------------

//--------------------------------------------------------------------------
// CodeGenerator::BackendSIMT virtuals
//--------------------------------------------------------------------------
Expand All @@ -87,16 +83,18 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
//! For SIMT backends which initialize RNGs on device, initialize population RNG with specified seed and sequence
virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const final;

//! Generate a preamble to add substitution name for population RNG
virtual std::string genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const final;

//! If required, generate a postamble for population RNG
/*! For example, in OpenCL, this is used to write local RNG state back to global memory*/
virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const final;

//! Generate code to skip ahead local copy of global RNG
virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const final;

//! Get type of population RNG
virtual Type::ResolvedType getPopulationRNGType() const final;

//! Generate a preamble to add substitution name for population RNG
virtual void buildPopulationRNGEnvironment(EnvironmentGroupMergedField<NeuronUpdateGroupMerged> &env) const final;

//! Add $(_rng) to environment based on $(_rng_internal) field with any initialisers and destructors required
virtual void buildPopulationRNGEnvironment(EnvironmentGroupMergedField<CustomConnectivityUpdateGroupMerged> &env) const final;

//--------------------------------------------------------------------------
// CodeGenerator::BackendBase virtuals
//--------------------------------------------------------------------------
Expand All @@ -118,6 +116,10 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
virtual void genFreeMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final;
virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final;

//! Create array of backend-specific population RNGs (if they are initialised on host this will occur here)
/*! \param count number of RNGs required*/
virtual std::unique_ptr<GeNN::Runtime::ArrayBase> createPopulationRNG(size_t count) const final;

//! Generate code for pushing a variable with a size known at runtime to the 'device'
virtual void genLazyVariableDynamicPush(CodeStream &os,
const Type::ResolvedType &type, const std::string &name,
Expand Down Expand Up @@ -173,6 +175,9 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
//! Get the safe amount of constant cache we can use
virtual size_t getChosenDeviceSafeConstMemBytes() const = 0;

//! Get internal type population RNG gets loaded into
virtual Type::ResolvedType getPopulationRNGInternalType() const = 0;

//! Get library of RNG functions to use
virtual const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) const = 0;

Expand Down
9 changes: 4 additions & 5 deletions include/genn/genn/code_generator/backendSIMT.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ class GENN_EXPORT BackendSIMT : public BackendBase
//! For SIMT backends which initialize RNGs on device, initialize population RNG with specified seed and sequence
virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const = 0;

//! Generate a preamble to add substitution name for population RNG
virtual std::string genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const = 0;
//! Add $(_rng) to environment based on $(_rng_internal) field with any initialisers and destructors required
virtual void buildPopulationRNGEnvironment(EnvironmentGroupMergedField<NeuronUpdateGroupMerged> &env) const = 0;

//! If required, generate a postamble for population RNG
/*! For example, in OpenCL, this is used to write local RNG state back to global memory*/
virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const = 0;
//! Add $(_rng) to environment based on $(_rng_internal) field with any initialisers and destructors required
virtual void buildPopulationRNGEnvironment(EnvironmentGroupMergedField<CustomConnectivityUpdateGroupMerged> &env) const = 0;

//! Generate code to skip ahead local copy of global RNG
virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const = 0;
Expand Down
81 changes: 62 additions & 19 deletions include/genn/genn/code_generator/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,14 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P

~EnvironmentExternalDynamicBase()
{
// Loop through initialisers
std::vector<std::string> initialiserCode(m_Initialisers.size());
std::vector<std::string> finaliserCode(m_Finalisers.size());

// Because initialisers may refer to other initialisers,
// keep evaluating initialisers until no new ones are founf
// Because initialisers and finalisers may refer to others,
// keep evaluating them until no new ones are found
bool anyReferences;
do {
// Loop through initialiser
// Loop through initialisers
anyReferences = false;
for(size_t i = 0; i < m_Initialisers.size(); i++) {
// If initialiser has been referenced
Expand All @@ -279,6 +279,20 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P
anyReferences = true;
}
}

// Loop through finalisers
for(size_t i = 0; i < m_Finalisers.size(); i++) {
// If finaliser has been referenced
auto &finaliser = m_Finalisers[i];
if (finaliser.first) {
// Evaluate lazy string into vector
finaliserCode[i] = finaliser.second.str();

// Clear referenced flag and set flag to ensure another iteration occurs
finaliser.first = false;
anyReferences = true;
}
}
} while(anyReferences);

// Write out generated initialiser code
Expand All @@ -291,6 +305,14 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P

// Write contents to context stream
getContextStream() << m_ContentsStream.str();

// Write out generated finaliser code
// **NOTE** in order
for(const auto &i : finaliserCode) {
if(!i.empty()) {
getContextStream() << i << std::endl;
}
}
}

//------------------------------------------------------------------------
Expand All @@ -310,10 +332,15 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P
m_Initialisers.at(i).first = true;
}

// If this identifier relies on any finaliser statements, mark these finalisers as required
for(size_t i : std::get<2>(env->second)) {
m_Finalisers.at(i).first = true;
}

// Perform any type-specific logic to mark this identifier as required
this->setRequired(std::get<2>(env->second));
this->setRequired(std::get<3>(env->second));

return this->getNameInternal(std::get<2>(env->second));
return this->getNameInternal(std::get<3>(env->second));
}
}

Expand All @@ -336,8 +363,13 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P
m_Initialisers.at(i).first = true;
}

// If this identifier relies on any finaliser statements, mark these finalisers as required
for(size_t i : std::get<2>(env->second)) {
m_Finalisers.at(i).first = true;
}

// Perform any type-specific logic to mark this identifier as required
this->setRequired(std::get<2>(env->second));
this->setRequired(std::get<3>(env->second));

// Return type of variables
return {std::get<0>(env->second)};
Expand All @@ -350,15 +382,21 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P
return (m_Initialisers.size() - 1);
}

size_t addFinaliser(const std::string &format)
{
m_Finalisers.emplace_back(false, LazyString{format, *this});
return (m_Finalisers.size() - 1);
}

protected:
//------------------------------------------------------------------------
// Protected API
//------------------------------------------------------------------------
//! Map an identifier to a type (for type-checking), lists of initialisers and a payload
void addInternal(const GeNN::Type::ResolvedType &type, const std::string &name, const typename P::Payload &payload,
const std::vector<size_t> &initialisers = {})
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
if(!m_Environment.try_emplace(name, type, initialisers, payload).second) {
if(!m_Environment.try_emplace(name, type, initialisers, finalisers, payload).second) {
throw std::runtime_error("Redeclaration of '" + std::string{name} + "'");
}
}
Expand All @@ -370,8 +408,9 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P
std::ostringstream m_ContentsStream;
CodeStream m_Contents;

std::unordered_map<std::string, std::tuple<Type::ResolvedType, std::vector<size_t>, typename P::Payload>> m_Environment;
std::unordered_map<std::string, std::tuple<Type::ResolvedType, std::vector<size_t>, std::vector<size_t>, typename P::Payload>> m_Environment;
std::vector<std::pair<bool, LazyString>> m_Initialisers;
std::vector<std::pair<bool, LazyString>> m_Finalisers;
};

//----------------------------------------------------------------------------
Expand All @@ -389,9 +428,10 @@ class EnvironmentExternal : public EnvironmentExternalDynamicBase<EnvironmentSub
//------------------------------------------------------------------------
//! Map a type (for type-checking) and a value (for pretty-printing) to an identifier
void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value,
const std::vector<size_t> &initialisers = {})
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
addInternal(type, name, LazyString{value, *this}, initialisers);
addInternal(type, name, LazyString{value, *this},
initialisers, finalisers);
}
};

Expand Down Expand Up @@ -433,9 +473,10 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
//------------------------------------------------------------------------
//! Map a type and a value to an identifier
void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value,
const std::vector<size_t> &initialisers = {})
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
this->addInternal(type, name, std::make_tuple(false, LazyString{value, *this}, std::nullopt), initialisers);
this->addInternal(type, name, std::make_tuple(false, LazyString{value, *this}, std::nullopt),
initialisers, finalisers);
}

//! Map a type (for type-checking) and a group merged field to back it to an identifier
Expand Down Expand Up @@ -472,7 +513,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
const GeNN::Type::ResolvedType &fieldType, const std::string &fieldName,
GetFieldNonNumericValueFunc getFieldValue, const std::string &indexSuffix = "",
GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD,
const std::vector<size_t> &initialisers = {})
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
typename G::Field field{fieldName, fieldType, mergedFieldType,
[getFieldValue](Runtime::Runtime &r, const GroupInternal &g, size_t i)
Expand All @@ -482,7 +523,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
getFieldValue(r, g, i));
}};
this->addInternal(type, name, std::make_tuple(false, LazyString{indexSuffix, *this}, std::make_optional(field)),
initialisers);
initialisers, finalisers);
}

//! Map a type (for type-checking) and a group merged field to back it to an identifier
Expand All @@ -495,9 +536,11 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
//! Map a type (for type-checking) and a group merged field to back it to an identifier
void addField(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &fieldName,
GetFieldNonNumericValueFunc getFieldValue, const std::string &indexSuffix = "",
GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, const std::vector<size_t> &initialisers = {})
GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD,
const std::vector<size_t> &initialisers = {}, const std::vector<size_t> &finalisers = {})
{
addField(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType, initialisers);
addField(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType,
initialisers, finalisers);
}

void addParams(const Snippet::Base::ParamVec &params, const std::string &fieldSuffix,
Expand Down Expand Up @@ -611,7 +654,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
}

template<typename I>
void addInitialiserDerivedParams(const std::string &fieldSuffix, GetInitialiserFn<I> getInitialiser)
void addInitialiserDerivedParams(const std::string &fieldSuffix, GetInitialiserFn<I> getInitialiser)
{
// Loop through params
const auto &initialiser = std::invoke(getInitialiser, this->getGroup().getArchetype());
Expand Down
6 changes: 6 additions & 0 deletions include/genn/genn/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ struct GENN_EXPORT ResolvedType
return ResolvedType{Value{name, sizeof(T), ffiType, device, false, std::nullopt}, isConst};
}

static ResolvedType createValue(const std::string &name, size_t size, bool isConst = false,
ffi_type *ffiType = nullptr, bool device = false)
{
return ResolvedType{Value{name, size, ffiType, device, false, std::nullopt}, isConst};
}

static ResolvedType createFunction(const ResolvedType &returnType, const std::vector<ResolvedType> &argTypes,
FunctionFlags flags=FunctionFlags{0})
{
Expand Down
Loading

0 comments on commit e9e40b9

Please sign in to comment.