Skip to content

Commit

Permalink
Merge pull request #16 from siefkenj/buffer
Browse files Browse the repository at this point in the history
Dynamic buffer allocation
  • Loading branch information
alesgenova authored May 4, 2021
2 parents 01bf50c + 3a1e8ec commit beda9dc
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 126 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ categories = ["algorithms", "multimedia::audio", "no-std"]
readme = "README.md"

[dependencies]
object-pool = "0.5.3"
rustfft = { version = "5.0.1", default-features = false }

[dev-dependencies]
Expand Down
7 changes: 3 additions & 4 deletions benches/utils_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@ pub fn pitch_detect_benchmark(c: &mut Criterion) {
.map(|x| (2.0 * std::f64::consts::PI * x as f64 * dt * freq).sin())
.collect();

let mut mcleod_detector = McLeodDetector::new(SIZE, PADDING);
let mut autocorrelation_detector = AutocorrelationDetector::new(SIZE, PADDING);
let mut yin_detector = YINDetector::new(SIZE, PADDING);

c.bench_function("McLeod get_pitch", |b| {
let mut mcleod_detector = McLeodDetector::new(SIZE, PADDING);
b.iter(|| {
mcleod_detector
.get_pitch(
Expand All @@ -53,6 +50,7 @@ pub fn pitch_detect_benchmark(c: &mut Criterion) {
});

c.bench_function("Autocorrelation get_pitch", |b| {
let mut autocorrelation_detector = AutocorrelationDetector::new(SIZE, PADDING);
b.iter(|| {
autocorrelation_detector
.get_pitch(
Expand All @@ -65,6 +63,7 @@ pub fn pitch_detect_benchmark(c: &mut Criterion) {
});
});
c.bench_function("YIN get_pitch", |b| {
let mut yin_detector = YINDetector::new(SIZE, PADDING);
b.iter(|| {
yin_detector
.get_pitch(
Expand Down
24 changes: 5 additions & 19 deletions src/detector/autocorrelation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ where
T: Float,
{
pub fn new(size: usize, padding: usize) -> Self {
let internals = DetectorInternals::new(1, 2, size, padding);
let internals = DetectorInternals::new(size, padding);
AutocorrelationDetector { internals }
}
}
Expand All @@ -35,30 +35,16 @@ where
clarity_threshold: T,
) -> Option<Pitch<T>> {
assert_eq!(signal.len(), self.internals.size);
assert!(
self.internals.has_sufficient_buffers(1, 2),
"McLeodDetector requires at least 1 real and 2 complex buffers"
);

if square_sum(signal) < power_threshold {
return None;
}

let mut iter = self.internals.complex_buffers.iter_mut();
let signal_complex = iter.next().unwrap();
let scratch = iter.next().unwrap();
let result = &mut self.internals.buffers.get_real_buffer()[..];

let mut iter = self.internals.real_buffers.iter_mut();
let autocorr = iter.next().unwrap();
autocorrelation(signal, &self.internals.buffers, result);
let clarity_threshold = clarity_threshold * result[0];

autocorrelation(signal, signal_complex, scratch, autocorr);
let clarity_threshold = clarity_threshold * autocorr[0];

pitch_from_peaks(
autocorr,
sample_rate,
clarity_threshold,
PeakCorrection::None,
)
pitch_from_peaks(result, sample_rate, clarity_threshold, PeakCorrection::None)
}
}
108 changes: 35 additions & 73 deletions src/detector/internals.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use rustfft::num_complex::Complex;
use rustfft::FftPlanner;

use crate::utils::buffer::copy_real_to_complex;
use crate::utils::buffer::new_complex_buffer;
use crate::utils::buffer::new_real_buffer;
use crate::utils::buffer::ComplexComponent;
use crate::utils::buffer::{copy_complex_to_real, square_sum};
use crate::utils::buffer::{copy_real_to_complex, BufferPool};
use crate::utils::peak::choose_peak;
use crate::utils::peak::correct_peak;
use crate::utils::peak::detect_peaks;
Expand All @@ -22,59 +19,42 @@ where

/// Data structure to hold any buffers needed for pitch computation.
/// For WASM it's best to allocate buffers once rather than allocate and
/// free buffers repeatedly.
/// free buffers repeatedly, so we use a `BufferPool` object to manage the buffers.
pub struct DetectorInternals<T>
where
T: Float,
{
pub size: usize,
pub padding: usize,
pub real_buffers: Vec<Vec<T>>,
pub complex_buffers: Vec<Vec<Complex<T>>>,
pub buffers: BufferPool<T>,
}

impl<T> DetectorInternals<T>
where
T: Float,
{
pub fn new(
n_real_buffers: usize,
n_complex_buffers: usize,
size: usize,
padding: usize,
) -> Self {
let real_buffers: Vec<Vec<T>> = (0..n_real_buffers)
.map(|_| new_real_buffer(size + padding))
.collect();

let complex_buffers: Vec<Vec<Complex<T>>> = (0..n_complex_buffers)
.map(|_| new_complex_buffer(size + padding))
.collect();
pub fn new(size: usize, padding: usize) -> Self {
let buffers = BufferPool::new(size + padding);

DetectorInternals {
size,
padding,
real_buffers,
complex_buffers,
buffers,
}
}

// Check whether there are at least the appropriate number of real and complex buffers.
pub fn has_sufficient_buffers(&self, n_real_buffers: usize, n_complex_buffers: usize) -> bool {
self.real_buffers.len() >= n_real_buffers && self.complex_buffers.len() >= n_complex_buffers
}
}

/// Compute the autocorrelation of `signal` to `result`. All buffers but `signal`
/// may be used as scratch.
pub fn autocorrelation<T>(
signal: &[T],
signal_complex: &mut [Complex<T>],
scratch: &mut [Complex<T>],
result: &mut [T],
) where
pub fn autocorrelation<T>(signal: &[T], buffers: &BufferPool<T>, result: &mut [T])
where
T: Float,
{
let (signal_complex, scratch) = (
&mut buffers.get_complex_buffer(),
&mut buffers.get_complex_buffer(),
);

let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(signal_complex.len());
let inv_fft = planner.plan_fft_inverse(signal_complex.len());
Expand Down Expand Up @@ -128,22 +108,19 @@ where
result[signal.len()..].iter_mut().for_each(|r| *r = last);
}

pub fn normalized_square_difference<T>(
signal: &[T],
scratch0: &mut [Complex<T>],
scratch1: &mut [Complex<T>],
scratch2: &mut [T],
result: &mut [T],
) where
pub fn normalized_square_difference<T>(signal: &[T], buffers: &BufferPool<T>, result: &mut [T])
where
T: Float + std::iter::Sum,
{
let two = T::from_usize(2).unwrap();

autocorrelation(signal, scratch0, scratch1, result);
m_of_tau(signal, Some(result[0]), scratch2);
let scratch = &mut buffers.get_real_buffer()[..];

autocorrelation(signal, buffers, result);
m_of_tau(signal, Some(result[0]), scratch);
result
.iter_mut()
.zip(scratch2)
.zip(scratch)
.for_each(|(r, s)| *r = two * *r / *s)
}

Expand All @@ -157,23 +134,25 @@ pub fn normalized_square_difference<T>(
pub fn windowed_autocorrelation<T>(
signal: &[T],
window_size: usize,
(scratch1, scratch2, scratch3): (&mut [Complex<T>], &mut [Complex<T>], &mut [Complex<T>]),
buffers: &BufferPool<T>,
result: &mut [T],
) where
T: Float + std::iter::Sum,
{
assert!(
scratch1.len() >= signal.len() && scratch2.len() >= signal.len(),
"`scratch1`/`scratch2` must have a length at least equal to `signal`."
buffers.buffer_size >= signal.len(),
"Buffers must have a length at least equal to `signal`."
);

let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(signal.len());
let inv_fft = planner.plan_fft_inverse(signal.len());

let signal_complex = &mut scratch1[..signal.len()];
let truncated_signal_complex = &mut scratch2[..signal.len()];
let scratch = &mut scratch3[..signal.len()];
let (signal_complex, truncated_signal_complex, scratch) = (
&mut buffers.get_complex_buffer()[..signal.len()],
&mut buffers.get_complex_buffer()[..signal.len()],
&mut buffers.get_complex_buffer()[..signal.len()],
);

// To achieve the windowed autocorrelation, we compute the cross correlation between
// the original signal and the signal truncated to lie in `0..window_size`
Expand Down Expand Up @@ -212,7 +191,7 @@ pub fn windowed_autocorrelation<T>(
pub fn windowed_square_error<T>(
signal: &[T],
window_size: usize,
(scratch1, scratch2, scratch3): (&mut [Complex<T>], &mut [Complex<T>], &mut [Complex<T>]),
buffers: &BufferPool<T>,
result: &mut [T],
) where
T: Float + std::iter::Sum,
Expand All @@ -223,11 +202,12 @@ pub fn windowed_square_error<T>(
);

let two = T::from_f64(2.).unwrap();

// The windowed square error function, d(t), can be computed
// as d(t) = pow_0^w + pow_t^{t+w} - 2*windowed_autocorrelation(t)
// where pow_a^b is the sum of the square of `signal` on the window `a..b`
// We proceed accordingly.
windowed_autocorrelation(signal, window_size, (scratch1, scratch2, scratch3), result);
windowed_autocorrelation(signal, window_size, buffers, result);
let mut windowed_power = square_sum(&signal[..window_size]);
let power = windowed_power;

Expand Down Expand Up @@ -270,11 +250,7 @@ mod tests {
let signal: Vec<f64> = vec![0., 1., 2., 0., -1., -2.];
let window_size: usize = 3;

let (scratch1, scratch2, scratch3) = (
&mut vec![Complex { re: 0., im: 0. }; signal.len()],
&mut vec![Complex { re: 0., im: 0. }; signal.len()],
&mut vec![Complex { re: 0., im: 0. }; signal.len()],
);
let buffers = &mut BufferPool::new(signal.len());

let result: Vec<f64> = (0..window_size)
.map(|i| {
Expand All @@ -287,12 +263,7 @@ mod tests {
.collect();

let mut computed_result = vec![0.; window_size];
windowed_autocorrelation(
&signal,
window_size,
(scratch1, scratch2, scratch3),
&mut computed_result,
);
windowed_autocorrelation(&signal, window_size, buffers, &mut computed_result);
// Using an FFT loses precision; we don't care that much, so round generously.
computed_result
.iter_mut()
Expand All @@ -306,11 +277,7 @@ mod tests {
let signal: Vec<f64> = vec![0., 1., 2., 0., -1., -2.];
let window_size: usize = 3;

let (scratch1, scratch2, scratch3) = (
&mut vec![Complex { re: 0., im: 0. }; signal.len()],
&mut vec![Complex { re: 0., im: 0. }; signal.len()],
&mut vec![Complex { re: 0., im: 0. }; signal.len()],
);
let buffers = &mut BufferPool::new(signal.len());

let result: Vec<f64> = (0..window_size)
.map(|i| {
Expand All @@ -323,12 +290,7 @@ mod tests {
.collect();

let mut computed_result = vec![0.; window_size];
windowed_square_error(
&signal,
window_size,
(scratch1, scratch2, scratch3),
&mut computed_result,
);
windowed_square_error(&signal, window_size, buffers, &mut computed_result);
// Using an FFT loses precision; we don't care that much, so round generously.
computed_result
.iter_mut()
Expand Down
19 changes: 4 additions & 15 deletions src/detector/mcleod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
T: Float + std::iter::Sum,
{
pub fn new(size: usize, padding: usize) -> Self {
let internals = DetectorInternals::new(2, 2, size, padding);
let internals = DetectorInternals::new(size, padding);
McLeodDetector { internals }
}
}
Expand All @@ -36,26 +36,15 @@ where
clarity_threshold: T,
) -> Option<Pitch<T>> {
assert_eq!(signal.len(), self.internals.size);
assert!(
self.internals.has_sufficient_buffers(2, 2),
"McLeodDetector requires at least 2 real and 2 complex buffers"
);

if square_sum(signal) < power_threshold {
return None;
}
let result = &mut self.internals.buffers.get_real_buffer();

let mut iter = self.internals.complex_buffers.iter_mut();
let signal_complex = iter.next().unwrap();
let scratch0 = iter.next().unwrap();

let mut iter = self.internals.real_buffers.iter_mut();
let scratch1 = iter.next().unwrap();
let peaks = iter.next().unwrap();

normalized_square_difference(signal, signal_complex, scratch0, scratch1, peaks);
normalized_square_difference(signal, &self.internals.buffers, result);
pitch_from_peaks(
peaks,
result,
sample_rate,
clarity_threshold,
PeakCorrection::Quadratic,
Expand Down
Loading

0 comments on commit beda9dc

Please sign in to comment.