-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add sort * update * update tests * update * remove generics * update tests * use alloy's buffer allocation
- Loading branch information
1 parent
d1735f7
commit 75b0916
Showing
6 changed files
with
815 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import Metal | ||
|
||
final public class BitonicSort { | ||
|
||
// MARK: - Properties | ||
|
||
private let firstPass: FirstPass | ||
private let generalPass: GeneralPass | ||
private let finalPass: FinalPass | ||
|
||
// MARK: - Init | ||
|
||
public convenience init(context: MTLContext, | ||
scalarType: MTLPixelFormat.ScalarType) throws { | ||
try self.init(library: context.library(for: .module), | ||
scalarType: scalarType) | ||
} | ||
|
||
public init(library: MTLLibrary, | ||
scalarType: MTLPixelFormat.ScalarType) throws { | ||
self.firstPass = try .init(library: library, | ||
scalarType: scalarType) | ||
self.generalPass = try .init(library: library, | ||
scalarType: scalarType) | ||
self.finalPass = try .init(library: library, | ||
scalarType: scalarType) | ||
} | ||
|
||
// MARK: - Encode | ||
|
||
public func callAsFunction(data: MTLBuffer, | ||
count: Int, | ||
in commandeBuffer: MTLCommandBuffer) { | ||
self.encode(data: data, | ||
count: count, | ||
in: commandeBuffer) | ||
} | ||
|
||
public func encode(data: MTLBuffer, | ||
count: Int, | ||
in commandBuffer: MTLCommandBuffer) { | ||
let elementStride = data.length / count | ||
let gridSize = count >> 1 | ||
let unitSize = min(gridSize, | ||
self.generalPass | ||
.pipelineState | ||
.maxTotalThreadsPerThreadgroup) | ||
|
||
var params = SIMD2<UInt32>(repeating: 1) | ||
|
||
self.firstPass(data: data, | ||
elementStride: elementStride, | ||
gridSize: gridSize, | ||
unitSize: unitSize, | ||
in: commandBuffer) | ||
params.x = .init(unitSize << 1) | ||
|
||
while params.x < count { | ||
params.y = params.x | ||
params.x <<= 1 | ||
repeat { | ||
if unitSize < params.y { | ||
self.generalPass(data: data, | ||
params: params, | ||
gridSize: gridSize, | ||
unitSize: unitSize, | ||
in: commandBuffer) | ||
params.y >>= 1 | ||
} else { | ||
self.finalPass(data: data, | ||
elementStride: elementStride, | ||
params: params, | ||
gridSize: gridSize, | ||
unitSize: unitSize, | ||
in: commandBuffer) | ||
params.y = .zero | ||
} | ||
} while params.y > .zero | ||
} | ||
} | ||
|
||
public static func buffer<T: FixedWidthInteger>(from array: [T], | ||
device: MTLDevice, | ||
options: MTLResourceOptions = []) throws -> (buffer: MTLBuffer, paddedCount: Int) { | ||
return try Self.buffer(from: array, | ||
paddingValue: T.max, | ||
device: device, | ||
options: options) | ||
} | ||
|
||
public static func buffer<T: FloatingPoint>(from array: [T], | ||
device: MTLDevice, | ||
options: MTLResourceOptions = []) throws -> (buffer: MTLBuffer, paddedCount: Int) { | ||
return try Self.buffer(from: array, | ||
paddingValue: T.greatestFiniteMagnitude, | ||
device: device, | ||
options: options) | ||
} | ||
|
||
private static func buffer<T: Numeric>(from array: [T], | ||
paddingValue: T, | ||
device: MTLDevice, | ||
options: MTLResourceOptions = []) throws -> (buffer: MTLBuffer, paddedCount: Int) { | ||
let paddedCount = 1 << UInt(ceil(log2f(.init(array.count)))) | ||
var array = array | ||
if paddedCount > array.count { | ||
array += .init(repeating: paddingValue, | ||
count: paddedCount - array.count) | ||
} | ||
return try (buffer: device.buffer(with: array, options: options), | ||
paddedCount: paddedCount) | ||
} | ||
|
||
} |
106 changes: 106 additions & 0 deletions
106
Sources/Alloy/Encoders/BitonicSort/BitonicSortFinalPass.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import Metal | ||
|
||
extension BitonicSort { | ||
|
||
final class FinalPass { | ||
|
||
// MARK: - Properties | ||
|
||
let pipelineState: MTLComputePipelineState | ||
private let deviceSupportsNonuniformThreadgroups: Bool | ||
|
||
// MARK: - Init | ||
|
||
convenience init(context: MTLContext, | ||
scalarType: MTLPixelFormat.ScalarType) throws { | ||
try self.init(library: context.library(for: .module), | ||
scalarType: scalarType) | ||
} | ||
|
||
init(library: MTLLibrary, | ||
scalarType: MTLPixelFormat.ScalarType) throws { | ||
self.deviceSupportsNonuniformThreadgroups = library.device | ||
.supports(feature: .nonUniformThreadgroups) | ||
|
||
let constantValues = MTLFunctionConstantValues() | ||
constantValues.set(self.deviceSupportsNonuniformThreadgroups, | ||
at: 0) | ||
|
||
let `extension` = "_" + scalarType.rawValue | ||
self.pipelineState = try library.computePipelineState(function: "bitonicSortFinalPass" + `extension`, | ||
constants: constantValues) | ||
} | ||
|
||
// MARK: - Encode | ||
|
||
func callAsFunction(data: MTLBuffer, | ||
elementStride: Int, | ||
params: SIMD2<UInt32>, | ||
gridSize: Int, | ||
unitSize: Int, | ||
in commandBuffer: MTLCommandBuffer) { | ||
self.encode(data: data, | ||
elementStride: elementStride, | ||
params: params, | ||
gridSize: gridSize, | ||
unitSize: unitSize, | ||
in: commandBuffer) | ||
} | ||
|
||
func callAsFunction(data: MTLBuffer, | ||
elementStride: Int, | ||
params: SIMD2<UInt32>, | ||
gridSize: Int, | ||
unitSize: Int, | ||
using encoder: MTLComputeCommandEncoder) { | ||
self.encode(data: data, | ||
elementStride: elementStride, | ||
params: params, | ||
gridSize: gridSize, | ||
unitSize: unitSize, | ||
using: encoder) | ||
} | ||
|
||
func encode(data: MTLBuffer, | ||
elementStride: Int, | ||
params: SIMD2<UInt32>, | ||
gridSize: Int, | ||
unitSize: Int, | ||
in commandBuffer: MTLCommandBuffer) { | ||
commandBuffer.compute { encoder in | ||
encoder.label = "Bitonic Sort Final Pass" | ||
self.encode(data: data, | ||
elementStride: elementStride, | ||
params: params, | ||
gridSize: gridSize, | ||
unitSize: unitSize, | ||
using: encoder) | ||
} | ||
} | ||
|
||
func encode(data: MTLBuffer, | ||
elementStride: Int, | ||
params: SIMD2<UInt32>, | ||
gridSize: Int, | ||
unitSize: Int, | ||
using encoder: MTLComputeCommandEncoder) { | ||
encoder.setBuffers(data) | ||
encoder.setValue(UInt32(gridSize), at: 1) | ||
encoder.setValue(params, at: 2) | ||
|
||
encoder.setThreadgroupMemoryLength((elementStride * unitSize) << 1, | ||
index: 0) | ||
|
||
if self.deviceSupportsNonuniformThreadgroups { | ||
encoder.dispatch1d(state: self.pipelineState, | ||
exactly: gridSize, | ||
threadgroupWidth: unitSize) | ||
} else { | ||
encoder.dispatch1d(state: self.pipelineState, | ||
covering: gridSize, | ||
threadgroupWidth: unitSize) | ||
} | ||
} | ||
} | ||
|
||
} |
98 changes: 98 additions & 0 deletions
98
Sources/Alloy/Encoders/BitonicSort/BitonicSortFirstPass.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import Metal | ||
|
||
extension BitonicSort { | ||
|
||
final class FirstPass { | ||
|
||
// MARK: - Properties | ||
|
||
let pipelineState: MTLComputePipelineState | ||
private let deviceSupportsNonuniformThreadgroups: Bool | ||
|
||
// MARK: - Init | ||
|
||
convenience init(context: MTLContext, | ||
scalarType: MTLPixelFormat.ScalarType) throws { | ||
try self.init(library: context.library(for: .module), | ||
scalarType: scalarType) | ||
} | ||
|
||
init(library: MTLLibrary, | ||
scalarType: MTLPixelFormat.ScalarType) throws { | ||
self.deviceSupportsNonuniformThreadgroups = library.device | ||
.supports(feature: .nonUniformThreadgroups) | ||
|
||
let constantValues = MTLFunctionConstantValues() | ||
constantValues.set(self.deviceSupportsNonuniformThreadgroups, | ||
at: 0) | ||
|
||
let `extension` = "_" + scalarType.rawValue | ||
self.pipelineState = try library.computePipelineState(function: "bitonicSortFirstPass" + `extension`, | ||
constants: constantValues) | ||
} | ||
|
||
// MARK: - Encode | ||
|
||
func callAsFunction(data: MTLBuffer, | ||
elementStride: Int, | ||
gridSize: Int, | ||
unitSize: Int, | ||
in commandBuffer: MTLCommandBuffer) { | ||
self.encode(data: data, | ||
elementStride: elementStride, | ||
gridSize: gridSize, | ||
unitSize: unitSize, | ||
in: commandBuffer) | ||
} | ||
|
||
func callAsFunction(data: MTLBuffer, | ||
elementStride: Int, | ||
gridSize: Int, | ||
unitSize: Int, | ||
using encoder: MTLComputeCommandEncoder) { | ||
self.encode(data: data, | ||
elementStride: elementStride, | ||
gridSize: gridSize, | ||
unitSize: unitSize, | ||
using: encoder) | ||
} | ||
|
||
func encode(data: MTLBuffer, | ||
elementStride: Int, | ||
gridSize: Int, | ||
unitSize: Int, | ||
in commandBuffer: MTLCommandBuffer) { | ||
commandBuffer.compute { encoder in | ||
encoder.label = "Bitonic Sort First Pass" | ||
self.encode(data: data, | ||
elementStride: elementStride, | ||
gridSize: gridSize, | ||
unitSize: unitSize, | ||
using: encoder) | ||
} | ||
} | ||
|
||
func encode(data: MTLBuffer, | ||
elementStride: Int, | ||
gridSize: Int, | ||
unitSize: Int, | ||
using encoder: MTLComputeCommandEncoder) { | ||
encoder.setBuffers(data) | ||
encoder.setValue(UInt32(gridSize), at: 1) | ||
encoder.setThreadgroupMemoryLength((elementStride * unitSize) << 1, | ||
index: 0) | ||
|
||
if self.deviceSupportsNonuniformThreadgroups { | ||
encoder.dispatch1d(state: self.pipelineState, | ||
exactly: gridSize, | ||
threadgroupWidth: unitSize) | ||
} else { | ||
encoder.dispatch1d(state: self.pipelineState, | ||
covering: gridSize, | ||
threadgroupWidth: unitSize) | ||
} | ||
} | ||
|
||
} | ||
|
||
} |
Oops, something went wrong.