Skip to content

Commit

Permalink
use wrap namespace functions to cast to vec3 and select the correct o…
Browse files Browse the repository at this point in the history
…verload
  • Loading branch information
DomFijan committed Jan 14, 2025
1 parent 0c1622c commit 6049a52
Showing 1 changed file with 28 additions and 46 deletions.
74 changes: 28 additions & 46 deletions freud/environment/export-MatchEnv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,59 +50,41 @@ void compute_env_rmsd_min(const std::shared_ptr<EnvironmentRMSDMinimizer>& env_r
env_rmsd_min->compute(nq, nlist, qargs, motif_data, motif_size, registration);
}

std::map<unsigned int, unsigned int> compute_minimize_RMSD(
const box::Box& box,
const nb_array<float, nanobind::shape<-1, 3>>& refPoints1,
nb_array<float, nanobind::shape<-1, 3>>& refPoints2,
unsigned int numRef,
float& min_rmsd,
bool registration)
{
auto* refPoints1_data = reinterpret_cast<vec3<float>*>(refPoints1.data());
auto* refPoints2_data = reinterpret_cast<vec3<float>*>(refPoints2.data());
return minimizeRMSD(box, refPoints1_data, refPoints2_data, numRef, min_rmsd, registration);
}

std::map<unsigned int, unsigned int> compute_is_similar(
const box::Box& box,
const nb_array<float, nanobind::shape<-1, 3>>& refPoints1,
nb_array<float, nanobind::shape<-1, 3>>& refPoints2,
unsigned int numRef,
float threshold_sq,
bool registration)
{
auto* refPoints1_data = reinterpret_cast<vec3<float>*>(refPoints1.data());
auto* refPoints2_data = reinterpret_cast<vec3<float>*>(refPoints2.data());
return isSimilar(box, refPoints1_data, refPoints2_data, numRef, threshold_sq, registration);
}

};

namespace detail {

// 1. Define the function pointer type for the overload that returns a std::map<unsigned int, unsigned int>.
using MinimizeRMSD_Coords = std::map<unsigned int, unsigned int> (*)(
const box::Box&,
const vec3<float>*,
vec3<float>*,
unsigned int,
float&,
bool);

using IsSimilar_Coords = std::map<unsigned int, unsigned int> (*)(
const box::Box&,
const vec3<float>*,
vec3<float>*,
unsigned int,
float,
bool);

void export_MatchEnv(nb::module_& module)
{
module.def(
"minimizeRMSD",
(MinimizeRMSD_Coords) &minimizeRMSD,
"Compute a map of matching indices between two sets of points. This overload also potentially\n"
"modifies the second set of points in-place if registration=True.\n\n"
"Args:\n"
" box (Box): Simulation box.\n"
" refPoints1 (array of vec3<float>): Points in the first environment.\n"
" refPoints2 (array of vec3<float>): Points in the second environment (modified if registration=True).\n"
" numRef (int): Number of points.\n"
" min_rmsd (float): Updated by reference to the final RMSD.\n"
" registration (bool): If True, perform a brute-force alignment.\n\n"
"Returns:\n"
" dict(int->int): Index mapping from the first set of points to the second set.");
module.def("minimizeRMSD", &wrap::compute_minimize_RMSD);

module.def(
"isSimilar",
(IsSimilar_Coords) &isSimilar,
"Check if two sets of points can be matched (i.e., are 'similar') within a given distance threshold.\n"
"Potentially modifies the second set in-place if registration=True.\n\n"
"Args:\n"
" box (Box): Simulation box.\n"
" refPoints1 (array of vec3<float>): Points in the first environment.\n"
" refPoints2 (array of vec3<float>): Points in the second environment.\n"
" numRef (int): Number of points.\n"
" threshold_sq (float): Square of the max distance allowed to consider points matching.\n"
" registration (bool): If True, attempt brute force alignment.\n\n"
"Returns:\n"
" dict(int->int): Index mapping if the environments match, else an empty map."
);
module.def("isSimilar", &wrap::compute_is_similar);

nb::class_<MatchEnv>(module, "MatchEnv")
.def(nb::init<>())
Expand Down

0 comments on commit 6049a52

Please sign in to comment.