1#include "mtmd-audio.h"
2
3#define _USE_MATH_DEFINES // for M_PI
4#include <cmath>
5#include <cstdint>
6#include <cstring>
7#include <thread>
8#include <vector>
9#include <fstream>
10#include <algorithm>
11
12// some of the code here is copied from whisper.cpp
13
14constexpr bool DEBUG = false;
15
16void mtmd_audio_cache::fill_sin_cos_table(int n) {
17 sin_vals.resize(n);
18 cos_vals.resize(n);
19 for (int i = 0; i < n; i++) {
20 double theta = (2 * M_PI * i) / n;
21 sin_vals[i] = sinf(theta);
22 cos_vals[i] = cosf(theta);
23 }
24}
25
26void mtmd_audio_cache::fill_hann_window(int length, bool periodic) {
27 hann_window.resize(length);
28 int offset = -1;
29 if (periodic) {
30 offset = 0;
31 }
32 for (int i = 0; i < length; i++) {
33 hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
34 }
35}
36
37void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel,
38 int n_fft,
39 int sample_rate,
40 float fmin,
41 float fmax,
42 bool slaney_area_norm,
43 float scale) {
44 GGML_ASSERT(n_mel > 0 && n_fft > 1);
45 if (fmax <= 0.0f) {
46 fmax = 0.5f * sample_rate;
47 }
48
49 // Slaney scale (matches librosa default)
50 const double min_log_hz = 1000.0;
51 const double lin_slope = 3 / 200.;
52 const double min_log_mel = min_log_hz * lin_slope;
53 const double log_step = log(6.4) / 27.0;
54 auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
55 return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
56 };
57 auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
58 return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
59 };
60
61 // infer N_fft from n_fft_bins
62 const double bin_hz_step = double(sample_rate) / double(n_fft);
63
64 // mel grid: n_mel + 2 edges
65 const double m_lo = hz_to_mel(fmin);
66 const double m_hi = hz_to_mel(fmax);
67 std::vector<double> mel_pts(n_mel + 2);
68 for (int i = 0; i < n_mel + 2; ++i) {
69 mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
70 }
71
72 // convert to Hz
73 std::vector<double> hz_pts(n_mel + 2);
74 for (int i = 0; i < n_mel + 2; ++i) {
75 hz_pts[i] = mel_to_hz(mel_pts[i]);
76 }
77
78 const int n_fft_bins = n_fft / 2 + 1;
79
80 // filterbank
81 std::vector<float> out(n_mel * n_fft_bins, 0);
82 for (int m = 0; m < n_mel; ++m) {
83 const double f_left = hz_pts[m];
84 const double f_center = hz_pts[m + 1];
85 const double f_right = hz_pts[m + 2];
86
87 const double denom_l = std::max(1e-30, f_center - f_left);
88 const double denom_r = std::max(1e-30, f_right - f_center);
89 const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
90
91 for (int k = 0; k < n_fft_bins; ++k) {
92 const double f = k * bin_hz_step;
93 double w = 0.0;
94 if (f >= f_left && f <= f_center) {
95 w = (f - f_left) / denom_l;
96 } else if (f > f_center && f <= f_right) {
97 w = (f_right - f) / denom_r;
98 }
99 out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
100 }
101 }
102
103 filters.n_mel = n_mel;
104 filters.n_fft = n_fft;
105 filters.data = std::move(out);
106
107 if (DEBUG) { // debug
108 for (size_t i = 0; i < filters.data.size(); ++i) {
109 if (filters.data[i] != 0.0f) {
110 printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
111 }
112 }
113 }
114}
115
116// Unified DFT implementation for both forward and inverse transforms
117// Template parameters:
118// Inverse: false = DFT with exp(-2πi·k·n/N), no scaling
119// true = IDFT with exp(+2πi·k·n/N), scales by 1/N
120// RealInput: true = input is real-valued (stride 1), avoids imaginary computations
121// false = input is complex-valued (interleaved real/imag, stride 2)
122template <bool Inverse, bool RealInput>
123static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) {
124 const int n_sin_cos_vals = cache.sin_vals.size();
125 const int sin_cos_step = n_sin_cos_vals / N;
126
127 constexpr float sign = Inverse ? 1.0f : -1.0f;
128 const float scale = Inverse ? (1.0f / N) : 1.0f;
129
130 for (int k = 0; k < N; k++) {
131 float re = 0;
132 float im = 0;
133
134 for (int n = 0; n < N; n++) {
135 int idx = (k * n * sin_cos_step) % n_sin_cos_vals;
136 float cos_val = cache.cos_vals[idx];
137 float sin_val = cache.sin_vals[idx];
138
139 if constexpr (RealInput) {
140 // Real input: in_im = 0, simplifies to:
141 // re += in_re * cos_val
142 // im += sign * in_re * sin_val
143 float in_re = in[n];
144 re += in_re * cos_val;
145 im += sign * in_re * sin_val;
146 } else {
147 float in_re = in[n * 2 + 0];
148 float in_im = in[n * 2 + 1];
149 // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i
150 re += in_re * cos_val - sign * in_im * sin_val;
151 im += sign * in_re * sin_val + in_im * cos_val;
152 }
153 }
154
155 out[k * 2 + 0] = re * scale;
156 out[k * 2 + 1] = im * scale;
157 }
158}
159
160// Cooley-Tukey FFT/IFFT unified implementation
161// Template parameters:
162// Inverse: false = FFT with exp(-2πi·k/N), no scaling
163// true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level
164// RealInput: true = input is real-valued (stride 1)
165// false = input is complex-valued (interleaved real/imag, stride 2)
166template <bool Inverse, bool RealInput>
167static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) {
168 const int n_sin_cos_vals = cache.sin_vals.size();
169
170 if (N == 1) {
171 out[0] = in[0];
172 if constexpr (RealInput) {
173 out[1] = 0.0f;
174 } else {
175 out[1] = in[1];
176 }
177 return;
178 }
179
180 const int half_N = N / 2;
181 if (N - half_N * 2 == 1) {
182 // Odd N: fall back to DFT
183 dft_impl<Inverse, RealInput>(cache, in, N, out);
184 return;
185 }
186
187 // Split into even and odd
188 if constexpr (RealInput) {
189 // Real input: stride is 1, copy only real values
190 float * even = in + N;
191 for (int i = 0; i < half_N; ++i) {
192 even[i] = in[2 * i];
193 }
194 float * even_fft = out + 2 * N;
195 fft_impl<Inverse, true>(cache, even, half_N, even_fft);
196
197 float * odd = even;
198 for (int i = 0; i < half_N; ++i) {
199 odd[i] = in[2 * i + 1];
200 }
201 float * odd_fft = even_fft + N;
202 fft_impl<Inverse, true>(cache, odd, half_N, odd_fft);
203 } else {
204 // Complex input: stride is 2, copy complex pairs
205 float * even = in + N * 2;
206 for (int i = 0; i < half_N; ++i) {
207 even[i * 2 + 0] = in[2 * i * 2 + 0];
208 even[i * 2 + 1] = in[2 * i * 2 + 1];
209 }
210 float * even_fft = out + 2 * N;
211 fft_impl<Inverse, false>(cache, even, half_N, even_fft);
212
213 float * odd = even;
214 for (int i = 0; i < half_N; ++i) {
215 odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0];
216 odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1];
217 }
218 float * odd_fft = even_fft + N;
219 fft_impl<Inverse, false>(cache, odd, half_N, odd_fft);
220 }
221
222 float * even_fft = out + 2 * N;
223 float * odd_fft = even_fft + N;
224
225 const int sin_cos_step = n_sin_cos_vals / N;
226
227 constexpr float sign = Inverse ? 1.0f : -1.0f;
228 constexpr float scale = Inverse ? 0.5f : 1.0f;
229
230 for (int k = 0; k < half_N; k++) {
231 int idx = k * sin_cos_step; // t = 2*M_PI*k/N
232 float re = cache.cos_vals[idx];
233 float im = sign * cache.sin_vals[idx];
234
235 float re_odd = odd_fft[2 * k + 0];
236 float im_odd = odd_fft[2 * k + 1];
237
238 out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd);
239 out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd);
240
241 out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd);
242 out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd);
243 }
244}
245
246// Forward FFT for real input (used by mel spectrogram)
247static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
248 fft_impl<false, true>(cache, in, N, out);
249}
250
251// Inverse FFT for complex input
252static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
253 fft_impl<true, false>(cache, in, N, out);
254}
255
256struct filter_params {
257 int32_t n_mel;
258 int32_t n_fft_bins;
259 int32_t hann_window_size;
260 int32_t hop_length;
261 int32_t sample_rate;
262 bool center_padding = false;
263 float preemph = 0.f;
264 bool use_natural_log = false;
265 bool norm_per_feature = false;
266};
267
268static void log_mel_spectrogram_worker_thread(int ith,
269 const float * hann,
270 const std::vector<float> & samples,
271 int n_samples,
272 int frame_size,
273 int frame_step,
274 int n_threads,
275 const filter_params & params,
276 const mtmd_audio_cache & cache,
277 mtmd_audio_mel & out) {
278 std::vector<float> fft_in(frame_size * 2, 0.0);
279 std::vector<float> fft_out(frame_size * 2 * 2 * 2);
280
281 int n_fft_bins = params.n_fft_bins;
282 int i = ith;
283
284 const auto & filters = cache.filters;
285
286 // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
287 GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2));
288 GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size());
289 // calculate FFT only when fft_in are not all zero
290 for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) {
291 const int offset = i * frame_step;
292
293 // apply Hann window (~10% faster)
294 for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
295 fft_in[j] = hann[j] * samples[offset + j];
296 }
297
298 // fill the rest with zeros
299 if (n_samples - offset < frame_size) {
300 std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
301 }
302
303 // FFT
304 fft(cache, fft_in.data(), frame_size, fft_out.data());
305
306 // Calculate modulus^2 of complex numbers
307 // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
308 for (int j = 0; j < n_fft_bins; j++) {
309 fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
310 }
311
312 // mel spectrogram
313 for (int j = 0; j < out.n_mel; j++) {
314 double sum = 0.0;
315 // unroll loop (suggested by GH user @lunixbochs)
316 int k = 0;
317 for (k = 0; k < n_fft_bins - 3; k += 4) {
318 size_t idx = size_t(j) * size_t(n_fft_bins) + size_t(k);
319 sum +=
320 fft_out[k + 0] * filters.data[idx + 0] +
321 fft_out[k + 1] * filters.data[idx + 1] +
322 fft_out[k + 2] * filters.data[idx + 2] +
323 fft_out[k + 3] * filters.data[idx + 3];
324 }
325 // handle n_fft remainder
326 for (; k < n_fft_bins; k++) {
327 sum += fft_out[k] * filters.data[j * n_fft_bins + k];
328 }
329 sum = params.use_natural_log
330 ? log(sum + 5.960464477539063e-08)
331 : log10(std::max(sum, 1e-10));
332 out.data[j * out.n_len + i] = sum;
333 }
334 }
335
336 // Otherwise fft_out are all zero
337 double sum = params.use_natural_log ? log(1e-10) : log10(1e-10);
338 for (; i < out.n_len; i += n_threads) {
339 for (int j = 0; j < out.n_mel; j++) {
340 out.data[j * out.n_len + i] = sum;
341 }
342 }
343}
344
345// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
346static bool log_mel_spectrogram(
347 const float * samples,
348 const int n_samples_in,
349 const int n_threads,
350 const filter_params & params,
351 const mtmd_audio_cache & cache,
352 mtmd_audio_mel & out) {
353 //const int64_t t_start_us = ggml_time_us();
354
355 out.n_len_org = n_samples_in;
356 int n_samples = n_samples_in;
357
358 // Hann window
359 const float * hann = cache.hann_window.data();
360 const int frame_size = (params.n_fft_bins - 1) * 2;
361 const int frame_step = params.hop_length;
362
363 // Padding
364 std::vector<float> samples_padded;
365 if (params.center_padding) {
366 const auto pad_amount = frame_size / 2;
367 samples_padded = std::vector<float>(n_samples + 2 * pad_amount, 0);
368 std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount);
369 samples = samples_padded.data();
370 n_samples = samples_padded.size();
371 } else {
372 // existing padding logic
373 int64_t stage_1_pad = params.sample_rate * 30;
374 int64_t stage_2_pad = frame_size / 2;
375 samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
376 std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
377 // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
378 std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
379 // reflective pad 200 samples at the beginning of audio
380 if (n_samples < stage_2_pad + 1) {
381 // TODO: Handle short audio differently or return error
382 return false;
383 }
384 std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
385 }
386
387 // preemphasis
388 if (params.preemph) {
389 const int pad_amount = frame_size / 2;
390 const float preemph = 0.97f;
391 float prev = samples_padded[pad_amount];
392 for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) {
393 float cur = samples_padded[i];
394 samples_padded[i] = cur - preemph * prev;
395 prev = cur;
396 }
397 }
398
399 // pad hann window if it's smaller than frame_size
400 // TODO: probably unnecessary here? (or better doing it in g_cache?)
401 std::vector<float> hann_window_padded;
402 if (params.hann_window_size < frame_size) {
403 hann_window_padded.resize(frame_size);
404 const int padding = (frame_size - params.hann_window_size) / 2;
405 std::copy(hann, hann + params.hann_window_size, &hann_window_padded[padding]);
406 hann = hann_window_padded.data();
407 }
408
409
410 out.n_mel = params.n_mel;
411 out.n_len = (n_samples - frame_size) / frame_step + 1;
412 // TODO: handle these checks better
413 if (out.n_mel > 0 && (unsigned long)out.n_len > SIZE_MAX / out.n_mel) {
414 LOG_ERR("%s: size overflow\n", __func__);
415 return false;
416 }
417 if (n_samples < frame_size) {
418 LOG_ERR("%s: not enough samples after padding\n", __func__);
419 return false;
420 }
421 out.data.resize(out.n_mel * out.n_len);
422
423 {
424 std::vector<std::thread> workers(n_threads - 1);
425 for (int iw = 0; iw < n_threads - 1; ++iw) {
426 workers[iw] =
427 std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples,
428 frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out));
429 }
430
431 // main thread
432 log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params,
433 cache, out);
434 for (int iw = 0; iw < n_threads - 1; ++iw) {
435 workers[iw].join();
436 }
437 }
438
439 const int effective_n_len = n_samples_in / frame_step;
440 if (params.norm_per_feature) {
441 for (int i = 0; i < out.n_mel; i++) {
442 double mean = 0;
443 for (int j = 0; j < effective_n_len; ++j) {
444 mean += out.data[i * out.n_len + j];
445 }
446 mean /= effective_n_len;
447
448 double var = 0.0;
449 for (int j = 0; j < effective_n_len; ++j) {
450 const double value = out.data[i * out.n_len + j] - mean;
451 var += value * value;
452 }
453 var /= effective_n_len - 1; // unbiased
454 const double mstd = std::sqrt(var + 1e-5);
455
456 for (int j = 0; j < effective_n_len; ++j) {
457 auto &value = out.data[i * out.n_len + j];
458 value = (value - mean) / mstd;
459 }
460
461 // pad the rest with zeros
462 for (int j = effective_n_len; j < out.n_len; ++j) {
463 out.data[i * out.n_len + j] = 0.0;
464 }
465 }
466 } else {
467 // clamping and normalization
468 double mmax = -1e20;
469 for (int i = 0; i < out.n_mel*out.n_len; i++) {
470 if (out.data[i] > mmax) {
471 mmax = out.data[i];
472 }
473 }
474
475 mmax -= 8.0;
476
477 for (int i = 0; i < out.n_mel*out.n_len; i++) {
478 if (out.data[i] < mmax) {
479 out.data[i] = mmax;
480 }
481 out.data[i] = (out.data[i] + 4.0)/4.0;
482 }
483 }
484
485 // Dump log_mel_spectrogram
486 if (DEBUG) {
487 std::ofstream outFile("log_mel_spectrogram.json");
488 outFile << "[";
489 for (uint64_t i = 0; i < out.data.size() - 1; i++) {
490 outFile << out.data[i] << ", ";
491 }
492 outFile << out.data[out.data.size() - 1] << "]";
493 outFile.close();
494 }
495
496 return true;
497}
498
499//
500// mtmd_audio_preprocessor_whisper
501//
502
503void mtmd_audio_preprocessor_whisper::initialize() {
504 cache.fill_sin_cos_table(hparams.audio_n_fft);
505 cache.fill_hann_window(hparams.audio_window_len, true);
506 cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
507}
508
509bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples,
510 size_t n_samples,
511 std::vector<mtmd_audio_mel> & output) {
512 if (n_samples == 0) {
513 // empty audio
514 return false;
515 }
516
517 std::vector<float> smpl;
518 // if input is too short, pad with zeros
519 // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram
520 // TODO: maybe handle this better
521 size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
522 if (n_samples < min_samples) {
523 smpl.resize(min_samples, 0.0f);
524 std::memcpy(smpl.data(), samples, n_samples * sizeof(float));
525 samples = smpl.data();
526 n_samples = smpl.size();
527 }
528
529 filter_params params;
530 params.n_mel = hparams.n_mel_bins;
531 params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
532 params.hann_window_size = hparams.audio_window_len;
533 params.hop_length = hparams.audio_hop_len;
534 params.sample_rate = hparams.audio_sample_rate;
535 params.center_padding = false;
536 params.preemph = 0.0f; // disabled
537 params.use_natural_log = false;
538 params.norm_per_feature = false;
539
540 // make sure the cache is initialized
541 GGML_ASSERT(!cache.sin_vals.empty());
542 GGML_ASSERT(!cache.cos_vals.empty());
543 GGML_ASSERT(!cache.filters.data.empty());
544
545 mtmd_audio_mel out_full;
546 bool ok = log_mel_spectrogram(samples, n_samples,
547 4, // n_threads
548 params, cache, out_full);
549 if (!ok) {
550 return false;
551 }
552
553 // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
554 // we always expect the mel to have 3000 silent frames at the end
555 if (DEBUG) {
556 printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len);
557 }
558 const size_t frames_per_chunk = 3000;
559 GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk);
560 for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) {
561 int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off);
562 if ((size_t) n_len < frames_per_chunk) {
563 break; // last uncomplete chunk will always be a padded chunk, safe to ignore
564 }
565
566 mtmd_audio_mel out_chunk;
567 out_chunk.n_len = n_len;
568 out_chunk.n_mel = out_full.n_mel;
569 out_chunk.n_len_org = out_full.n_mel; // unused
570 out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
571
572 for (int i = 0; i < out_full.n_mel; i++) {
573 auto src = out_full.data.begin() + i * out_full.n_len + off;
574 out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
575 }
576
577 output.push_back(std::move(out_chunk));
578 }
579
580 return true;
581}
582
583//
584// mtmd_audio_preprocessor_conformer
585//
586
587void mtmd_audio_preprocessor_conformer::initialize() {
588 cache.fill_sin_cos_table(hparams.audio_n_fft);
589 cache.fill_hann_window(hparams.audio_window_len, true);
590 cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
591}
592
593bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples,
594 size_t n_samples,
595 std::vector<mtmd_audio_mel> & output) {
596 // empty audio
597 if (n_samples == 0) {
598 return false;
599 }
600
601 filter_params params;
602 params.n_mel = hparams.n_mel_bins;
603 params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
604 params.hann_window_size = hparams.audio_window_len;
605 params.hop_length = hparams.audio_hop_len;
606 params.sample_rate = hparams.audio_sample_rate;
607 params.center_padding = true;
608 params.preemph = 0.97f;
609 params.use_natural_log = true;
610 params.norm_per_feature = true;
611
612 // make sure the cache is initialized
613 GGML_ASSERT(!cache.sin_vals.empty());
614 GGML_ASSERT(!cache.cos_vals.empty());
615 GGML_ASSERT(!cache.filters.data.empty());
616
617 mtmd_audio_mel out_full;
618 bool ok = log_mel_spectrogram(samples, n_samples,
619 4, // n_threads
620 params, cache, out_full);
621 if (!ok) {
622 return false;
623 }
624
625 output.push_back(std::move(out_full));
626 return true;
627}
628
629//
630// mtmd_audio_streaming_istft implementation
631//
632
633mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) :
634 n_fft(n_fft),
635 hop_length(hop_length),
636 n_fft_bins(n_fft / 2 + 1),
637 overlap_buffer(n_fft, 0.0f),
638 window_sum_buffer(n_fft, 0.0f),
639 padding_to_remove((n_fft - hop_length) / 2),
640 ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT
641 ifft_out(n_fft * 2 * 4, 0.0f) {
642 cache.fill_sin_cos_table(n_fft);
643 cache.fill_hann_window(n_fft, true);
644}
645
646void mtmd_audio_streaming_istft::reset() {
647 std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f);
648 std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f);
649 padding_to_remove = (n_fft - hop_length) / 2;
650}
651
652std::vector<float> mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) {
653 std::vector<float> output(hop_length);
654
655 // copy frequencies
656 for (int j = 0; j < n_fft_bins; j++) {
657 ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0];
658 ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1];
659 }
660
661 // mirror negative frequencies
662 for (int j = 1; j < n_fft_bins - 1; j++) {
663 int mirror_idx = n_fft - j;
664 ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0];
665 ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate
666 }
667
668 ifft(cache, ifft_in.data(), n_fft, ifft_out.data());
669
670 // update window sum and overlap buffer
671 for (int j = 0; j < n_fft; j++) {
672 window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j];
673 overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j];
674 }
675
676 // extract hop_length samples with normalization
677 for (int i = 0; i < hop_length; i++) {
678 if (window_sum_buffer[i] > 1e-8f) {
679 output[i] = overlap_buffer[i] / window_sum_buffer[i];
680 } else {
681 output[i] = overlap_buffer[i];
682 }
683 }
684
685 // shift buffers left by hop_length
686 std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin());
687 std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f);
688
689 std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin());
690 std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f);
691
692 // Remove padding if needed
693 int to_remove = std::min(padding_to_remove, (int) output.size());
694 padding_to_remove -= to_remove;
695 output.erase(output.begin(), output.begin() + to_remove);
696
697 return output;
698}
699
700std::vector<float> mtmd_audio_streaming_istft::flush() {
701 std::vector<float> output;
702
703 // Extract remaining samples from overlap buffer
704 // Continue until we've extracted all meaningful samples
705 int remaining = n_fft - hop_length;
706 while (remaining > 0) {
707 int chunk_size = std::min(remaining, hop_length);
708
709 for (int i = 0; i < chunk_size; i++) {
710 float sample;
711 if (window_sum_buffer[i] > 1e-8f) {
712 sample = overlap_buffer[i] / window_sum_buffer[i];
713 } else {
714 sample = overlap_buffer[i];
715 }
716 output.push_back(sample);
717 }
718
719 // Shift buffers
720 std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin());
721 std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f);
722
723 std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin());
724 std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f);
725
726 remaining -= chunk_size;
727 }
728
729 return output;
730}