Skip to content

Commit

Permalink
Sort (#120)
Browse files Browse the repository at this point in the history
* add sort

* update

* update tests

* update

* remove generics

* update tests

* use alloy's buffer allocation
  • Loading branch information
eugenebokhan authored Apr 20, 2021
1 parent d1735f7 commit 75b0916
Show file tree
Hide file tree
Showing 6 changed files with 815 additions and 0 deletions.
114 changes: 114 additions & 0 deletions Sources/Alloy/Encoders/BitonicSort/BitonicSort.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import Metal

final public class BitonicSort {

// MARK: - Properties

private let firstPass: FirstPass
private let generalPass: GeneralPass
private let finalPass: FinalPass

// MARK: - Init

public convenience init(context: MTLContext,
scalarType: MTLPixelFormat.ScalarType) throws {
try self.init(library: context.library(for: .module),
scalarType: scalarType)
}

public init(library: MTLLibrary,
scalarType: MTLPixelFormat.ScalarType) throws {
self.firstPass = try .init(library: library,
scalarType: scalarType)
self.generalPass = try .init(library: library,
scalarType: scalarType)
self.finalPass = try .init(library: library,
scalarType: scalarType)
}

// MARK: - Encode

public func callAsFunction(data: MTLBuffer,
count: Int,
in commandeBuffer: MTLCommandBuffer) {
self.encode(data: data,
count: count,
in: commandeBuffer)
}

public func encode(data: MTLBuffer,
count: Int,
in commandBuffer: MTLCommandBuffer) {
let elementStride = data.length / count
let gridSize = count >> 1
let unitSize = min(gridSize,
self.generalPass
.pipelineState
.maxTotalThreadsPerThreadgroup)

var params = SIMD2<UInt32>(repeating: 1)

self.firstPass(data: data,
elementStride: elementStride,
gridSize: gridSize,
unitSize: unitSize,
in: commandBuffer)
params.x = .init(unitSize << 1)

while params.x < count {
params.y = params.x
params.x <<= 1
repeat {
if unitSize < params.y {
self.generalPass(data: data,
params: params,
gridSize: gridSize,
unitSize: unitSize,
in: commandBuffer)
params.y >>= 1
} else {
self.finalPass(data: data,
elementStride: elementStride,
params: params,
gridSize: gridSize,
unitSize: unitSize,
in: commandBuffer)
params.y = .zero
}
} while params.y > .zero
}
}

public static func buffer<T: FixedWidthInteger>(from array: [T],
device: MTLDevice,
options: MTLResourceOptions = []) throws -> (buffer: MTLBuffer, paddedCount: Int) {
return try Self.buffer(from: array,
paddingValue: T.max,
device: device,
options: options)
}

public static func buffer<T: FloatingPoint>(from array: [T],
device: MTLDevice,
options: MTLResourceOptions = []) throws -> (buffer: MTLBuffer, paddedCount: Int) {
return try Self.buffer(from: array,
paddingValue: T.greatestFiniteMagnitude,
device: device,
options: options)
}

private static func buffer<T: Numeric>(from array: [T],
paddingValue: T,
device: MTLDevice,
options: MTLResourceOptions = []) throws -> (buffer: MTLBuffer, paddedCount: Int) {
let paddedCount = 1 << UInt(ceil(log2f(.init(array.count))))
var array = array
if paddedCount > array.count {
array += .init(repeating: paddingValue,
count: paddedCount - array.count)
}
return try (buffer: device.buffer(with: array, options: options),
paddedCount: paddedCount)
}

}
106 changes: 106 additions & 0 deletions Sources/Alloy/Encoders/BitonicSort/BitonicSortFinalPass.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import Metal

extension BitonicSort {

final class FinalPass {

// MARK: - Properties

let pipelineState: MTLComputePipelineState
private let deviceSupportsNonuniformThreadgroups: Bool

// MARK: - Init

convenience init(context: MTLContext,
scalarType: MTLPixelFormat.ScalarType) throws {
try self.init(library: context.library(for: .module),
scalarType: scalarType)
}

init(library: MTLLibrary,
scalarType: MTLPixelFormat.ScalarType) throws {
self.deviceSupportsNonuniformThreadgroups = library.device
.supports(feature: .nonUniformThreadgroups)

let constantValues = MTLFunctionConstantValues()
constantValues.set(self.deviceSupportsNonuniformThreadgroups,
at: 0)

let `extension` = "_" + scalarType.rawValue
self.pipelineState = try library.computePipelineState(function: "bitonicSortFinalPass" + `extension`,
constants: constantValues)
}

// MARK: - Encode

func callAsFunction(data: MTLBuffer,
elementStride: Int,
params: SIMD2<UInt32>,
gridSize: Int,
unitSize: Int,
in commandBuffer: MTLCommandBuffer) {
self.encode(data: data,
elementStride: elementStride,
params: params,
gridSize: gridSize,
unitSize: unitSize,
in: commandBuffer)
}

func callAsFunction(data: MTLBuffer,
elementStride: Int,
params: SIMD2<UInt32>,
gridSize: Int,
unitSize: Int,
using encoder: MTLComputeCommandEncoder) {
self.encode(data: data,
elementStride: elementStride,
params: params,
gridSize: gridSize,
unitSize: unitSize,
using: encoder)
}

func encode(data: MTLBuffer,
elementStride: Int,
params: SIMD2<UInt32>,
gridSize: Int,
unitSize: Int,
in commandBuffer: MTLCommandBuffer) {
commandBuffer.compute { encoder in
encoder.label = "Bitonic Sort Final Pass"
self.encode(data: data,
elementStride: elementStride,
params: params,
gridSize: gridSize,
unitSize: unitSize,
using: encoder)
}
}

func encode(data: MTLBuffer,
elementStride: Int,
params: SIMD2<UInt32>,
gridSize: Int,
unitSize: Int,
using encoder: MTLComputeCommandEncoder) {
encoder.setBuffers(data)
encoder.setValue(UInt32(gridSize), at: 1)
encoder.setValue(params, at: 2)

encoder.setThreadgroupMemoryLength((elementStride * unitSize) << 1,
index: 0)

if self.deviceSupportsNonuniformThreadgroups {
encoder.dispatch1d(state: self.pipelineState,
exactly: gridSize,
threadgroupWidth: unitSize)
} else {
encoder.dispatch1d(state: self.pipelineState,
covering: gridSize,
threadgroupWidth: unitSize)
}
}
}

}
98 changes: 98 additions & 0 deletions Sources/Alloy/Encoders/BitonicSort/BitonicSortFirstPass.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import Metal

extension BitonicSort {

final class FirstPass {

// MARK: - Properties

let pipelineState: MTLComputePipelineState
private let deviceSupportsNonuniformThreadgroups: Bool

// MARK: - Init

convenience init(context: MTLContext,
scalarType: MTLPixelFormat.ScalarType) throws {
try self.init(library: context.library(for: .module),
scalarType: scalarType)
}

init(library: MTLLibrary,
scalarType: MTLPixelFormat.ScalarType) throws {
self.deviceSupportsNonuniformThreadgroups = library.device
.supports(feature: .nonUniformThreadgroups)

let constantValues = MTLFunctionConstantValues()
constantValues.set(self.deviceSupportsNonuniformThreadgroups,
at: 0)

let `extension` = "_" + scalarType.rawValue
self.pipelineState = try library.computePipelineState(function: "bitonicSortFirstPass" + `extension`,
constants: constantValues)
}

// MARK: - Encode

func callAsFunction(data: MTLBuffer,
elementStride: Int,
gridSize: Int,
unitSize: Int,
in commandBuffer: MTLCommandBuffer) {
self.encode(data: data,
elementStride: elementStride,
gridSize: gridSize,
unitSize: unitSize,
in: commandBuffer)
}

func callAsFunction(data: MTLBuffer,
elementStride: Int,
gridSize: Int,
unitSize: Int,
using encoder: MTLComputeCommandEncoder) {
self.encode(data: data,
elementStride: elementStride,
gridSize: gridSize,
unitSize: unitSize,
using: encoder)
}

func encode(data: MTLBuffer,
elementStride: Int,
gridSize: Int,
unitSize: Int,
in commandBuffer: MTLCommandBuffer) {
commandBuffer.compute { encoder in
encoder.label = "Bitonic Sort First Pass"
self.encode(data: data,
elementStride: elementStride,
gridSize: gridSize,
unitSize: unitSize,
using: encoder)
}
}

func encode(data: MTLBuffer,
elementStride: Int,
gridSize: Int,
unitSize: Int,
using encoder: MTLComputeCommandEncoder) {
encoder.setBuffers(data)
encoder.setValue(UInt32(gridSize), at: 1)
encoder.setThreadgroupMemoryLength((elementStride * unitSize) << 1,
index: 0)

if self.deviceSupportsNonuniformThreadgroups {
encoder.dispatch1d(state: self.pipelineState,
exactly: gridSize,
threadgroupWidth: unitSize)
} else {
encoder.dispatch1d(state: self.pipelineState,
covering: gridSize,
threadgroupWidth: unitSize)
}
}

}

}
Loading

0 comments on commit 75b0916

Please sign in to comment.