Skip to content

Commit

Permalink
feat: add new TTS examples for Kokoro, Vits, and Matcha models
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Jan 18, 2025
1 parent b39c161 commit 8e65bd9
Show file tree
Hide file tree
Showing 12 changed files with 405 additions and 303 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ sherpa-onnx-kws-*
jniLibs/
build/
kokoro-en-*/
matcha-*
/
14 changes: 12 additions & 2 deletions crates/sherpa-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,19 @@ cuda = ["sherpa-rs-sys/cuda"]
directml = ["sherpa-rs-sys/directml"]

[[example]]
name = "tts"
name = "tts_kokoro"
required-features = ["tts"]
path = "../../examples/tts.rs"
path = "../../examples/tts_kokoro.rs"

[[example]]
name = "tts_vits"
required-features = ["tts"]
path = "../../examples/tts_vits.rs"

[[example]]
name = "tts_matcha"
required-features = ["tts"]
path = "../../examples/tts_matcha.rs"

[[example]]
name = "audio_tag"
Expand Down
62 changes: 50 additions & 12 deletions crates/sherpa-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@ pub mod tts;
#[cfg(feature = "sys")]
pub use sherpa_rs_sys;

use eyre::{bail, Result};
use eyre::{ bail, Result };

pub fn get_default_provider() -> String {
if cfg!(feature = "cuda") {
"cuda"
} else if cfg!(target_os = "macos") {
"coreml"
} else if cfg!(feature = "directml") {
"directml"
} else {
"cpu"
}
.into()
"cpu".into()
// Other providers has many issues with different models!!
// if cfg!(feature = "cuda") {
// "cuda"
// } else if cfg!(target_os = "macos") {
// "coreml"
// } else if cfg!(feature = "directml") {
// "directml"
// } else {
// "cpu"
// }
// .into()
}

pub fn read_audio_file(path: &str) -> Result<(Vec<f32>, u32)> {
Expand All @@ -45,8 +47,44 @@ pub fn read_audio_file(path: &str) -> Result<(Vec<f32>, u32)> {
// Collect samples into a Vec<f32>
let samples: Vec<f32> = reader
.samples::<i16>()
.map(|s| s.unwrap() as f32 / i16::MAX as f32)
.map(|s| (s.unwrap() as f32) / (i16::MAX as f32))
.collect();

Ok((samples, sample_rate))
}

pub fn write_audio_file(path: &str, samples: &[f32], sample_rate: u32) -> Result<()> {
// Create a WAV file writer
let spec = hound::WavSpec {
channels: 1,
sample_rate,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};

let mut writer = hound::WavWriter::create(path, spec)?;

// Convert samples from f32 to i16 and write them to the WAV file
for &sample in samples {
let scaled_sample = (sample * (i16::MAX as f32)).clamp(
i16::MIN as f32,
i16::MAX as f32
) as i16;
writer.write_sample(scaled_sample)?;
}

writer.finalize()?;
Ok(())
}

pub struct OnnxConfig {
pub provider: String,
pub debug: bool,
pub num_threads: i32,
}

impl Default for OnnxConfig {
fn default() -> Self {
Self { provider: get_default_provider(), debug: false, num_threads: 1 }
}
}
189 changes: 0 additions & 189 deletions crates/sherpa-rs/src/tts.rs

This file was deleted.

64 changes: 64 additions & 0 deletions crates/sherpa-rs/src/tts/kokoro.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::{ mem, ptr::null };

use eyre::Result;
use sherpa_rs_sys;
use crate::{ utils::RawCStr, OnnxConfig };

use super::TtsAudio;

pub struct KokoroTts {
tts: *const sherpa_rs_sys::SherpaOnnxOfflineTts,
}

#[derive(Default)]
pub struct KokoroTtsConfig {
pub model: String,
pub voices: String,
pub tokens: String,
pub data_dir: String,
pub length_scale: f32,
pub onnx_config: OnnxConfig,
}

impl KokoroTts {
pub fn new(config: KokoroTtsConfig) -> Self {
let tts = unsafe {
let model = RawCStr::new(&config.model);
let voices = RawCStr::new(&config.voices);
let tokens = RawCStr::new(&config.tokens);
let data_dir = RawCStr::new(&config.data_dir);

let provider = RawCStr::new(&config.onnx_config.provider);

let model_config = sherpa_rs_sys::SherpaOnnxOfflineTtsModelConfig {
vits: mem::zeroed::<_>(),
num_threads: config.onnx_config.num_threads,
debug: config.onnx_config.debug.into(),
provider: provider.as_ptr(),
matcha: mem::zeroed::<_>(),
kokoro: sherpa_rs_sys::SherpaOnnxOfflineTtsKokoroModelConfig {
model: model.as_ptr(),
voices: voices.as_ptr(),
tokens: tokens.as_ptr(),
data_dir: data_dir.as_ptr(),
length_scale: config.length_scale,
},
};
let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig {
max_num_sentences: 0,
model: model_config,
rule_fars: null(),
rule_fsts: null(),
};
sherpa_rs_sys::SherpaOnnxCreateOfflineTts(&config)
};

Self {
tts,
}
}

pub fn create(&mut self, text: &str, sid: i32, speed: f32) -> Result<TtsAudio> {
unsafe { super::create(self.tts, text, sid, speed) }
}
}
Loading

0 comments on commit 8e65bd9

Please sign in to comment.