Skip to content

Commit

Permalink
fix: improve compatibility with the reference implementation in FFT c…
Browse files Browse the repository at this point in the history
…alculation and port missing FFT tests (#10)
  • Loading branch information
darksv authored Oct 12, 2024
1 parent ec4f062 commit b33fe97
Showing 1 changed file with 153 additions and 10 deletions.
163 changes: 153 additions & 10 deletions chromaprint/src/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,14 @@ impl<C: FeatureVectorConsumer> AudioConsumer<f64> for Fft<C> {
assert_eq!(self.fft_buffer_complex.len(), self.frame_size);
assert_eq!(self.window.len(), self.frame_size);

for (i, (output, input))
in self.fft_buffer_complex.iter_mut().zip(window).enumerate() {
for (i, (output, input)) in self.fft_buffer_complex.iter_mut().zip(window).enumerate() {
output.re = input * self.window[i];
output.im = 0.0;
}

self.fft_plan.process_with_scratch(&mut self.fft_buffer_complex, &mut self.fft_scratch);

self.fft_frame[0] = self.fft_buffer_complex[0].re.powi(2);
self.fft_frame[self.frame_size / 2] = self.fft_buffer_complex[0].im.powi(2);
for i in 1..self.frame_size / 2 {
for i in 0..self.frame_size / 2 {
self.fft_frame[i] = self.fft_buffer_complex[i].norm_sqr();
}

Expand All @@ -84,10 +81,12 @@ impl<C: FeatureVectorConsumer> AudioConsumer<f64> for Fft<C> {
return;
}

if self.ring_buf.len() < self.frame_size {
self.ring_buf.resize(self.frame_size, 0.0);
self.consume(&[]);
}
// It makes sense to pad the remaining samples with zeros and process the last frame,
// but the reference implementation doesn't do it.
// if self.ring_buf.len() < self.frame_size {
// self.ring_buf.resize(self.frame_size, 0.0);
// self.consume(&[]);
// }
}
}

Expand All @@ -97,4 +96,148 @@ fn make_hamming_window(size: usize, scale: f64) -> Box<[f64]> {
window.push(scale * (0.54 - 0.46 * f64::cos(2.0 * std::f64::consts::PI * (i as f64) / (size as f64 - 1.0))));
}
window.into_boxed_slice()
}
}

#[cfg(test)]
mod tests {
use crate::fft::Fft;
use crate::stages::{AudioConsumer, FeatureVectorConsumer, Stage};

struct Collector {
frames: Vec<Vec<f64>>,
}

impl Collector {
fn new() -> Self {
Self { frames: vec![] }
}
}

impl Stage for Collector {
type Output = [Vec<f64>];

fn output(&self) -> &Self::Output {
&self.frames
}
}

impl FeatureVectorConsumer for Collector {
fn consume(&mut self, features: &[f64]) {
self.frames.push(features.to_vec());
}

fn reset(&mut self) {
self.frames.clear();
}
}

#[test]
fn sine() {
let nframes = 3;
let frame_size = 32;
let overlap = 8;

let sample_rate = 1000;
let freq = 7 * (sample_rate / 2) / (frame_size / 2);

let mut input = vec![0.0; frame_size + (nframes - 1) * (frame_size - overlap)];
for i in 0..input.len() {
input[i] = f64::sin(i as f64 * freq as f64 * 2.0 * std::f64::consts::PI / sample_rate as f64);
}

let collector = Collector::new();
let mut fft = Fft::new(frame_size, overlap, collector);

assert_eq!(frame_size, fft.frame_size);
assert_eq!(overlap, fft.frame_overlap);

let chunk_size = 100;
for chunk in input.chunks(chunk_size) {
fft.consume(chunk);
}

assert_eq!(nframes, fft.output().len());

let expected_spectrum = [
2.87005e-05,
0.00011901,
0.00029869,
0.000667172,
0.00166813,
0.00605612,
0.228737,
0.494486,
0.210444,
0.00385322,
0.00194379,
0.00124616,
0.000903851,
0.000715237,
0.000605707,
0.000551375,
0.000534304,
];

for (frame_idx, frame) in fft.output().iter().enumerate() {
for i in 0..frame.len() {
let magnitude = f64::sqrt(frame[i]) / frame.len() as f64;
let expected_mag = expected_spectrum[i];
if (expected_mag - magnitude).abs() > 0.001 {
panic!("different magnitude for frame {frame_idx} at offset {i}: s[{i}]={magnitude} (!= {expected_mag})");
}
}
}
}

#[test]
fn dc() {
let nframes = 3;
let frame_size = 32;
let overlap = 8;

let input = vec![0.5; frame_size + (nframes - 1) * (frame_size - overlap)];

let collector = Collector::new();
let mut fft = Fft::new(frame_size, overlap, collector);

assert_eq!(frame_size, fft.frame_size);
assert_eq!(overlap, fft.frame_overlap);

let chunk_size = 100;
for chunk in input.chunks(chunk_size) {
fft.consume(chunk);
}

assert_eq!(nframes, fft.output().len());

let expected_spectrum = [
0.494691,
0.219547,
0.00488079,
0.00178991,
0.000939219,
0.000576082,
0.000385808,
0.000272904,
0.000199905,
0.000149572,
0.000112947,
8.5041e-05,
6.28312e-05,
4.4391e-05,
2.83757e-05,
1.38507e-05,
0.0,
];

for (frame_idx, frame) in fft.output().iter().enumerate() {
for i in 0..frame.len() {
let magnitude = f64::sqrt(frame[i]) / frame.len() as f64;
let expected_mag = expected_spectrum[i];
if (expected_mag - magnitude).abs() > 0.001 {
panic!("different magnitude for frame {frame_idx} at offset {i}: s[{i}]={magnitude} (!= {expected_mag})");
}
}
}
}
}

0 comments on commit b33fe97

Please sign in to comment.