Skip to content

Commit

Permalink
[Update] Initial fix for GPU version.
Browse files Browse the repository at this point in the history
  • Loading branch information
latrocinia committed Aug 9, 2016
1 parent 9d2685d commit e57ec4e
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 134 deletions.
236 changes: 136 additions & 100 deletions powerfit/kernels.cl
Original file line number Diff line number Diff line change
@@ -1,114 +1,150 @@
__kernel
void rotate_grids_and_multiply(
read_only image3d_t template, read_only image3d_t mask,
global float* rotmat, sampler_t s_linear, sampler_t s_nearest,
float4 center, int4 shape, int radius,
global float* rot_template, global float* rot_mask,
global float* rot_mask2, int nrot
#define SQUARE(a) ((a) * (a))

// To be defined on compile time
#define SHAPE_X $shape_x
#define SHAPE_Y $shape_y
#define SHAPE_Z $shape_z
#define LLENGTH $llength

#define LLENGTH2 (LLENGTH * LLENGTH)
#define SLICE ((SHAPE_X * SHAPE_Y))
#define SIZE ((SHAPE_Z * SLICE))


kernel
void rotate_grid3d(
global float *grid, float16 rotmat, global float *out, int nearest
)
{
/*Rotate the template and mask grid, and also calculate the mask2 grid
*
* Parameters
* ----------
* template
*
* mask
*
* rotmat
* Array that holds all the rotations.
*
* s_linear : sampler_t
* Sampler with LINEAR property.
*
* s_nearest : sampler_t
* Sampler with NEAREST property.
*
* center : float4
* Center around which the images are rotated.
*
*
* shape : int4
* Contains the shape of output arrays, with the fourth element the size.
*
* radius : int
* Largest radius of image from center. All voxels within this radius
* will be rotated
*
* nrot : uint
* Index of the initial rotation that is sampled.
*
* Notes
* -----
*/

/* there is an offset of a half when sampling images properly */
const float OFFSET = 0.5f;
int radius2 = radius * radius;

int slice, rotmat_offset;
float4 weight, dist2;
float4 coor, coor_z, coor_zy, coor_zyx;
int4 index;
int z, y, x;

size_t zid = get_global_id(0);
size_t yid = get_global_id(1);
size_t xid = get_global_id(2);
size_t zstride = get_global_size(0);
size_t ystride = get_global_size(1);
size_t xstride = get_global_size(2);

/* Some precalculations */
slice = shape.s2 * shape.s1;
coor_zyx.s3 = 0;

rotmat_offset = nrot * 9;
coor.s0 = center.s0 + OFFSET;
coor.s1 = center.s1 + OFFSET;
coor.s2 = center.s2 + OFFSET;

/* Loop over the grids */
for (z = zid - radius; z <= radius; z += zstride) {
dist2.s2 = z * z;
coor_z.s0 = rotmat[rotmat_offset + 2] * z + coor.s0;
coor_z.s1 = rotmat[rotmat_offset + 5] * z + coor.s1;
coor_z.s2 = rotmat[rotmat_offset + 8] * z + coor.s2;

index.s0 = z * slice;
/* Wraparound the z-coordinate */
// Rotate grid around the origin. Only grid points within LLENGTH of the
// origin are rotated. Nearest neighbour interpolation.

int zid = get_global_id(0);
int yid = get_global_id(1);
int xid = get_global_id(2);
int zstride = get_global_size(0);
int ystride = get_global_size(1);
int xstride = get_global_size(2);

int z, y, x, x0, y0, z0, x1, y1, z1, offset0, offset1, grid_ind;
float dx, dy, dz, dx1, dy1, dz1, c00, c10, c01, c11, c0, c1, c;
float3 dist2, coor_z, coor_zy, coor_zyx;
int3 out_ind;


for (z = zid - LLENGTH; z <= LLENGTH; z += zstride) {
dist2.s2 = SQUARE(z);
if (dist2.s2 > LLENGTH2)
continue;

coor_z.s0 = rotmat.s2 * z;
coor_z.s1 = rotmat.s5 * z;
coor_z.s2 = rotmat.s8 * z;

out_ind.s0 = z * SLICE;
if (z < 0)
index.s0 += shape.s3;
out_ind.s0 += SIZE;

for (y = yid - LLENGTH; y <= LLENGTH; y += ystride) {
dist2.s1 = SQUARE(y) + dist2.s2;
if (dist2.s1 > LLENGTH2)
continue;

for (y = yid - radius; y <= radius; y += ystride) {
dist2.s1 = y * y + dist2.s2;
coor_zy.s0 = rotmat[rotmat_offset + 1] * y + coor_z.s0;
coor_zy.s1 = rotmat[rotmat_offset + 4] * y + coor_z.s1;
coor_zy.s2 = rotmat[rotmat_offset + 7] * y + coor_z.s2;
coor_zy.s0 = rotmat.s1 * y + coor_z.s0;
coor_zy.s1 = rotmat.s4 * y + coor_z.s1;
coor_zy.s2 = rotmat.s7 * y + coor_z.s2;

index.s1 = index.s0 + y * shape.s2;
/* Wraparound the y-coordinate */
out_ind.s1 = out_ind.s0 + y * SHAPE_X;
if (y < 0)
index.s1 += slice;
out_ind.s1 += SLICE;

for (x = xid - radius; x <= radius; x += xstride) {
dist2.s0 = x * x + dist2.s1;
if (dist2.s0 > radius2)
for (x = xid - LLENGTH; x <= LLENGTH; x += xstride) {
dist2.s0 = SQUARE(x) + dist2.s1;
if (dist2.s0 > LLENGTH2)
continue;
coor_zyx.s0 = rotmat.s0 * x + coor_zy.s0;
coor_zyx.s1 = rotmat.s3 * x + coor_zy.s1;
coor_zyx.s2 = rotmat.s6 * x + coor_zy.s2;

coor_zyx.s0 = rotmat[rotmat_offset + 0] * x + coor_zy.s0;
coor_zyx.s1 = rotmat[rotmat_offset + 3] * x + coor_zy.s1;
coor_zyx.s2 = rotmat[rotmat_offset + 6] * x + coor_zy.s2;

index.s2 = index.s1 + x;
out_ind.s2 = out_ind.s1 + x;
if (x < 0)
index.s2 += shape.s2;
out_ind.s2 += SHAPE_X;

if (nearest > 0) {

x0 = (int) round(coor_zyx.s0);
y0 = (int) round(coor_zyx.s1);
z0 = (int) round(coor_zyx.s2);

grid_ind = z0 * SLICE + y0 * SHAPE_X + x0;
if (x0 < 0)
grid_ind += SHAPE_X;
if (y0 < 0)
grid_ind += SLICE;
if (z0 < 0)
grid_ind += SIZE;

out[out_ind.s2] = grid[grid_ind];

} else {
x0 = (int) floor(coor_zyx.s0);
y0 = (int) floor(coor_zyx.s1);
z0 = (int) floor(coor_zyx.s2);
x1 = x0 + 1;
y1 = y0 + 1;
z1 = z0 + 1;

// Grid index
grid_ind = z0 * SLICE + y0 * SHAPE_X + x0;
if (x0 < 0)
grid_ind += SHAPE_X;
if (y0 < 0)
grid_ind += SLICE;
if (z0 < 0)
grid_ind += SIZE;

offset1 = 1;
if (x1 == 0)
offset1 -= SHAPE_X;
c00 = grid[grid_ind] * dx1 +
grid[grid_ind + offset1] * dx;

offset0 = SHAPE_X;
if (y1 == 0)
offset0 -= SLICE;
offset1 = offset0 + 1;
if (x1 == 0)
offset1 -= SHAPE_X;
c10 = grid[grid_ind + offset0] * dx1 +
grid[grid_ind + offset1] * dx;

offset0 = SLICE;
if (z1 == 0)
offset0 -= SIZE;
offset1 = offset0 + 1;
if (x1 == 0)
offset1 -= SHAPE_X;
c01 = grid[grid_ind + offset0] * dx1 +
grid[grid_ind + offset1] * dx;

offset0 = SLICE + SHAPE_X;
if (z1 == 0)
offset0 -= SIZE;
if (y1 == 0)
offset0 -= SLICE;
offset1 = offset0 + 1;
if (x1 == 0)
offset1 -= SHAPE_X;
c01 = grid[grid_ind + offset0] * dx1 +
grid[grid_ind + offset1] * dx;

c0 = c00 * dy1 + c10 * dy;
c1 = c01 * dy1 + c11 * dy;

c = c0 * dz1 + c1 * dz;

weight = read_imagef(template, s_linear, coor_zyx);
rot_template[index.s2] = weight.s0;
weight = read_imagef(mask, s_nearest, coor_zyx);
rot_mask[index.s2] = weight.s0;
rot_mask2[index.s2] = weight.s0 * weight.s0;
out[out_ind.s2] = c;
}
}
}
}
Expand Down
64 changes: 30 additions & 34 deletions powerfit/powerfitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os.path
from time import time, sleep
from multiprocessing import RawValue, Lock, Process, cpu_count
from string import Template

import numpy as np
from numpy.fft import irfftn as np_irfftn, rfftn as np_rfftn
Expand Down Expand Up @@ -170,15 +171,8 @@ def __init__(self, target, laplace=False):
self._template = None
self._mask = None
self._laplace = laplace
# get center of grid
self._center = self._get_center(self._target.shape)
self._lcc_mask = self._get_lcc_mask(self._target)

@staticmethod
def _get_center(shape):
"""Get the center of the grid to rotate around"""
#self._center = (np.asarray(template.shape, dtype=np.float64)[::-1] - 1)/ 2
return (np.asarray(shape, dtype=np.float64) / 2)[::-1]
self._rmax = min(target.shape) // 2

@staticmethod
def _get_lcc_mask(target):
Expand Down Expand Up @@ -214,7 +208,6 @@ def mask(self, mask):
# multiply again for core-weighted correlation score
self._template *= self._mask
# calculate the maximum radius
self._rmax = min(self._mask.shape) // 2

@staticmethod
def _laplace_filter(array):
Expand Down Expand Up @@ -424,7 +417,6 @@ def __init__(self, target, queue, laplace=False):
self._k.multiply(self._gtarget, self._gtarget, self._target2)
self._rfftn(self._target2, self._ft_target2)

self._gcenter = np.asarray(list(self._center) + [0], dtype=np.float32)
self._gshape = np.asarray(
list(self._target.shape) + [np.product(self._target.shape)],
dtype=np.int32)
Expand Down Expand Up @@ -474,18 +466,12 @@ def mask(self, mask):
BaseCorrelator.mask.fset(self, mask)
self._norm_factor = np.float32(self._norm_factor)
self._rmax = np.int32(self._rmax)
self._gtemplate = cl.image_from_array(
self._ctx, self._template.astype(np.float32)
self._gtemplate = cl_array.to_device(
self._queue, self._template.astype(np.float32)
)
self._gmask = cl.image_from_array(
self._ctx, self._mask.astype(np.float32)
self._gmask = cl_array.to_device(
self._queue, self._mask.astype(np.float32)
)
max_items = self._queue.device.max_compute_units * 32 * 16
gws = [0] * 3
gws[0] = min(2 * self._rmax, max_items)
gws[1] = min(max_items // gws[0], 2 * self._rmax)
gws[2] = min(max(max_items // (gws[0] * gws[0]), 1), 2 * self._rmax)
self._gws = tuple(gws)

@property
def rotations(self):
Expand All @@ -494,8 +480,8 @@ def rotations(self):
@rotations.setter
def rotations(self, rotations):
BaseCorrelator.rotations.fset(self, rotations)
self._grotations = cl_array.to_device(self._queue,
rotations.ravel().astype(np.float32))
self._cl_rotations = np.zeros((self._rotations.shape[0], 16), dtype=np.float32)
self._cl_rotations[:, :9] = self._rotations.reshape(-1, 9)

def scan(self):
super(GPUCorrelator, self).scan()
Expand All @@ -505,12 +491,13 @@ def scan(self):
time0 = time()
for n in xrange(0, self._rotations.shape[0]):

args = (self._gtemplate, self._gmask, self._grotations.data,
self._k._sampler_linear, self._k._sampler_nearest,
self._gcenter, self._gshape, self._rmax,
self._rot_template.data, self._rot_mask.data,
self._rot_mask2.data, np.int32(n))
self._k.rotate_grids_and_multiply(self._queue, self._gws, None, *args)
rotmat = self._cl_rotations[n]

self._k.rotate_grid3d(self._queue, self._gtemplate, rotmat, self._rot_template)
self._k.rotate_grid3d(self._queue, self._gmask, rotmat,
self._rot_mask, nearest=True)
self._k.multiply(self._rot_mask, self._rot_mask, self._rot_mask2)

self._rfftn(self._rot_template, self._ft_template)
self._rfftn(self._rot_mask, self._ft_mask)
self._rfftn(self._rot_mask2, self._ft_mask2)
Expand Down Expand Up @@ -541,11 +528,16 @@ def _print_progress(n, nrot, time0):
stdout.flush()

def _generate_kernels(self):
self._k = CLKernels(self._ctx)
kernel_values = {'shape_x': self._shape[2],
'shape_y': self._shape[1],
'shape_z': self._shape[0],
'llength': self._rmax,
}
self._k = CLKernels(self._ctx, kernel_values)


class CLKernels(object):
def __init__(self, ctx):
def __init__(self, ctx, values):
self.multiply = ElementwiseKernel(ctx,
"float *x, float *y, float *z",
"z[i] = x[i] * y[i];"
Expand Down Expand Up @@ -573,10 +565,14 @@ def __init__(self, ctx):
t = Template(f.read()).substitute(**values)

self._program = cl.Program(ctx, t).build()
self._gws_rotate_grid3d(96, 64, 1)

self.rotate_grids_and_multiply = self._program.rotate_grids_and_multiply

self._gws_rotate_grid3d = (96, 64, 1)

def rotate_grid3d(self, queue, grid, rotmat, out, nearest=False):
_nearest = np.int32(0)
if nearest:
_nearest = np.int32(1)
args = (grid.data, rotmat, out.data, _nearest)
self._program.rotate_grid3d(queue, self._gws_rotate_grid3d, None, *args)


class grfftn_builder(object):
Expand Down

0 comments on commit e57ec4e

Please sign in to comment.