Skip to content

Commit

Permalink
Merge pull request #548 from robertknight/depthwise-conv-trait
Browse files Browse the repository at this point in the history
Refactor depthwise convolution to support arch/data type-specific kernels
  • Loading branch information
robertknight authored Jan 25, 2025
2 parents b19d123 + 644d49f commit 046e743
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 173 deletions.
14 changes: 7 additions & 7 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::tensor_pool::{AutoReturn, TensorPool};
mod depthwise;
mod im2col;

use depthwise::conv_2d_depthwise;
use depthwise::DepthwiseConvExecutor;
use im2col::build_im2col;

/// Specialization of conv_2d for pointwise convolutions over one image. This
Expand Down Expand Up @@ -87,7 +87,7 @@ where
/// A value equal to the input channel count convolves each input channel
/// separately with `output_channels / groups` outputs. This is known as
/// depthwise convolution.
pub fn conv<X, W, Y>(
pub fn conv<X: GemmInT, W: GemmInT, Y: GemmOutT + Default>(
pool: &TensorPool,
input: TensorView<X>,
kernel: TensorView<W>,
Expand All @@ -98,9 +98,7 @@ pub fn conv<X, W, Y>(
dilations: &[usize],
) -> Result<Tensor<Y>, OpError>
where
X: std::ops::Mul<W, Output = Y> + Default + GemmInT,
W: GemmInT,
Y: Default + std::ops::AddAssign<Y> + GemmOutT,
DepthwiseConvExecutor<X, W, Y>: Default,
GemmExecutor<W, X, Y>: Default,
{
// Handle 1D convolution by expanding to 2D and then removing the extra
Expand Down Expand Up @@ -214,7 +212,8 @@ where
}

if in_c == out_c && groups == in_c {
return Ok(conv_2d_depthwise(
let dw_conv = DepthwiseConvExecutor::default();
let output = dw_conv.depthwise_conv_2d(
pool,
&input.nd_view(),
&kernel.nd_view(),
Expand All @@ -223,7 +222,8 @@ where
[stride_y, stride_x],
[dilation_y, dilation_x],
[out_h, out_w],
));
);
return Ok(output.into_dyn());
}

let n_patches = out_h * out_w;
Expand Down
Loading

0 comments on commit 046e743

Please sign in to comment.