blob: 1230760e5a793f77835d7adaf76c5658af9ff7b3 [file] [log] [blame]
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include "kiss_fft.h"
#include "common.h"
#include <math.h>
#include "pitch.h"
#include "rnn.h"
#include "rnn_data.h"
#define FRAME_SIZE_SHIFT 2
#define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
#define WINDOW_SIZE (2*FRAME_SIZE)
#define FREQ_SIZE (FRAME_SIZE + 1)
#define PITCH_MIN_PERIOD 60
#define PITCH_MAX_PERIOD 768
#define PITCH_FRAME_SIZE 960
#define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
#define SQUARE(x) ((x)*(x))
#define SMOOTH_BANDS 1
#if SMOOTH_BANDS
#define NB_BANDS 22
#else
#define NB_BANDS 21
#endif
#define CEPS_MEM 8
#define NB_DELTA_CEPS 6
#define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
#define TRAINING 0
static const opus_int16 eband5ms[] = {
/*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
};
typedef struct {
int init;
kiss_fft_state *kfft;
float half_window[FRAME_SIZE];
float dct_table[NB_BANDS*NB_BANDS];
} CommonState;
typedef struct {
float analysis_mem[FRAME_SIZE];
float cepstral_mem[CEPS_MEM][NB_BANDS];
int memid;
float synthesis_mem[FRAME_SIZE];
float pitch_buf[PITCH_BUF_SIZE];
float pitch_enh_buf[PITCH_BUF_SIZE];
float last_gain;
int last_period;
float mem_hp_x[2];
RNNState rnn;
} DenoiseState;
#if SMOOTH_BANDS
void compute_band_energy(float *bandE, const kiss_fft_cpx *X) {
int i;
float sum[NB_BANDS] = {0};
for (i=0;i<NB_BANDS-1;i++)
{
int j;
int band_size;
band_size = (eband5ms[i+1]-eband5ms[i])<<FRAME_SIZE_SHIFT;
for (j=0;j<band_size;j++) {
float tmp;
float frac = (float)j/band_size;
tmp = SQUARE(X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].r);
tmp += SQUARE(X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].i);
sum[i] += (1-frac)*tmp;
sum[i+1] += frac*tmp;
}
}
sum[0] *= 2;
sum[NB_BANDS-1] *= 2;
for (i=0;i<NB_BANDS;i++)
{
bandE[i] = sum[i];
}
}
void compute_band_corr(float *bandE, const kiss_fft_cpx *X, const kiss_fft_cpx *P) {
int i;
float sum[NB_BANDS] = {0};
for (i=0;i<NB_BANDS-1;i++)
{
int j;
int band_size;
band_size = (eband5ms[i+1]-eband5ms[i])<<FRAME_SIZE_SHIFT;
for (j=0;j<band_size;j++) {
float tmp;
float frac = (float)j/band_size;
tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].r * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].r;
tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].i * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].i;
sum[i] += (1-frac)*tmp;
sum[i+1] += frac*tmp;
}
}
sum[0] *= 2;
sum[NB_BANDS-1] *= 2;
for (i=0;i<NB_BANDS;i++)
{
bandE[i] = sum[i];
}
}
void interp_band_gain(float *g, const float *bandE) {
int i;
memset(g, 0, FREQ_SIZE);
for (i=0;i<NB_BANDS-1;i++)
{
int j;
int band_size;
band_size = (eband5ms[i+1]-eband5ms[i])<<FRAME_SIZE_SHIFT;
for (j=0;j<band_size;j++) {
float frac = (float)j/band_size;
g[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j] = (1-frac)*bandE[i] + frac*bandE[i+1];
}
}
}
#else
void compute_band_energy(float *bandE, const kiss_fft_cpx *X) {
int i;
for (i=0;i<NB_BANDS;i++)
{
int j;
opus_val32 sum = 0;
for (j=0;j<(eband5ms[i+1]-eband5ms[i])<<FRAME_SIZE_SHIFT;j++) {
sum += SQUARE(X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].r);
sum += SQUARE(X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].i);
}
bandE[i] = sum;
}
}
void interp_band_gain(float *g, const float *bandE) {
int i;
memset(g, 0, FREQ_SIZE);
for (i=0;i<NB_BANDS;i++)
{
int j;
for (j=0;j<(eband5ms[i+1]-eband5ms[i])<<FRAME_SIZE_SHIFT;j++)
g[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j] = bandE[i];
}
}
#endif
CommonState common;
static void check_init() {
int i;
if (common.init) return;
common.kfft = opus_fft_alloc_twiddles(2*FRAME_SIZE, NULL, NULL, NULL, 0);
for (i=0;i<FRAME_SIZE;i++)
common.half_window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
for (i=0;i<NB_BANDS;i++) {
int j;
for (j=0;j<NB_BANDS;j++) {
common.dct_table[i*NB_BANDS + j] = cos((i+.5)*j*M_PI/NB_BANDS);
if (j==0) common.dct_table[i*NB_BANDS + j] *= sqrt(.5);
}
}
common.init = 1;
}
static void dct(float *out, const float *in) {
int i;
check_init();
for (i=0;i<NB_BANDS;i++) {
int j;
float sum = 0;
for (j=0;j<NB_BANDS;j++) {
sum += in[j] * common.dct_table[j*NB_BANDS + i];
}
out[i] = sum*sqrt(2./22);
}
}
#if 0
static void idct(float *out, const float *in) {
int i;
check_init();
for (i=0;i<NB_BANDS;i++) {
int j;
float sum = 0;
for (j=0;j<NB_BANDS;j++) {
sum += in[j] * common.dct_table[i*NB_BANDS + j];
}
out[i] = sum*sqrt(2./22);
}
}
#endif
static void forward_transform(kiss_fft_cpx *out, const float *in) {
int i;
kiss_fft_cpx x[WINDOW_SIZE];
kiss_fft_cpx y[WINDOW_SIZE];
check_init();
for (i=0;i<WINDOW_SIZE;i++) {
x[i].r = in[i];
x[i].i = 0;
}
opus_fft(common.kfft, x, y, 0);
for (i=0;i<FREQ_SIZE;i++) {
out[i] = y[i];
}
}
static void inverse_transform(float *out, const kiss_fft_cpx *in) {
int i;
kiss_fft_cpx x[WINDOW_SIZE];
kiss_fft_cpx y[WINDOW_SIZE];
check_init();
for (i=0;i<FREQ_SIZE;i++) {
x[i] = in[i];
}
for (;i<WINDOW_SIZE;i++) {
x[i].r = x[WINDOW_SIZE - i].r;
x[i].i = -x[WINDOW_SIZE - i].i;
}
opus_fft(common.kfft, x, y, 0);
/* output in reverse order for IFFT. */
out[0] = WINDOW_SIZE*y[0].r;
for (i=1;i<WINDOW_SIZE;i++) {
out[i] = WINDOW_SIZE*y[WINDOW_SIZE - i].r;
}
}
static void apply_window(float *x) {
int i;
check_init();
for (i=0;i<FRAME_SIZE;i++) {
x[i] *= common.half_window[i];
x[WINDOW_SIZE - 1 - i] *= common.half_window[i];
}
}
int rnnoise_init(DenoiseState *st) {
memset(st, 0, sizeof(*st));
return 0;
}
DenoiseState *rnnoise_create() {
DenoiseState *st;
st = malloc(sizeof(DenoiseState));
rnnoise_init(st);
return st;
}
static int frame_analysis(DenoiseState *st, kiss_fft_cpx *X, float *Ex, float *features, const float *in) {
float x[WINDOW_SIZE];
int i;
float E = 0;
RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
for (i=0;i<FRAME_SIZE;i++) x[FRAME_SIZE + i] = in[i];
RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
apply_window(x);
forward_transform(X, x);
compute_band_energy(Ex, X);
if (1) {
float p[WINDOW_SIZE];
kiss_fft_cpx P[WINDOW_SIZE];
float Ep[NB_BANDS], Exp[NB_BANDS];
float pitch_buf[PITCH_BUF_SIZE>>1];
int pitch_index;
float gain;
float *(pre[1]);
RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
pre[0] = &st->pitch_buf[0];
pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
pitch_index = PITCH_MAX_PERIOD-pitch_index;
gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
st->last_period = pitch_index;
st->last_gain = gain;
for (i=0;i<WINDOW_SIZE;i++)
p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
apply_window(p);
forward_transform(P, p);
compute_band_energy(Ep, P);
compute_band_corr(Exp, X, P);
for (i=0;i<NB_BANDS;i++) Exp[i] = Exp[i]/sqrt(.001+Ex[i]*Ep[i]);
if (features) {
float tmp[NB_BANDS];
dct(tmp, Exp);
for (i=0;i<NB_DELTA_CEPS;i++) features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
}
}
{
if (features != NULL) {
float *ceps_0, *ceps_1, *ceps_2;
float spec_variability = 0;
float Ly[NB_BANDS];
E = 0;
for (i=0;i<NB_BANDS;i++) {
Ly[i] = log10(1e-2+Ex[i]);
E += Ex[i];
}
if (!TRAINING && E < 0.04) {
/* If there's no audio, avoid messing up the state. */
RNN_CLEAR(features, NB_FEATURES);
return 1;
}
dct(features, Ly);
features[0] -= 12;
features[1] -= 4;
ceps_0 = st->cepstral_mem[st->memid];
ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
for (i=0;i<NB_BANDS;i++) ceps_0[i] = features[i];
st->memid++;
for (i=0;i<NB_DELTA_CEPS;i++) {
features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
}
/* Spectral variability features. */
if (st->memid == CEPS_MEM) st->memid = 0;
for (i=0;i<CEPS_MEM;i++)
{
int j;
float mindist = 1e15f;
for (j=0;j<CEPS_MEM;j++)
{
int k;
float dist=0;
for (k=0;k<NB_BANDS;k++)
{
float tmp;
tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
dist += tmp*tmp;
}
if (j!=i)
mindist = MIN32(mindist, dist);
}
spec_variability += mindist;
}
features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
}
}
return TRAINING && E < 0.1;
}
static void frame_synthesis(DenoiseState *st, float *out, const kiss_fft_cpx *y) {
float x[WINDOW_SIZE];
int i;
inverse_transform(x, y);
apply_window(x);
for (i=0;i<FRAME_SIZE;i++) out[i] = x[i] + st->synthesis_mem[i];
RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
}
static void biquad(float *y, float mem[2], const float *x, const float *b, const float *a, int N) {
int i;
for (i=0;i<N;i++) {
float xi, yi;
xi = x[i];
yi = x[i] + mem[0];
mem[0] = mem[1] + (b[0]*(double)xi - a[0]*(double)yi);
mem[1] = (b[1]*(double)xi - a[1]*(double)yi);
y[i] = yi;
}
}
void rnnoise_process_frame(DenoiseState *st, float *out, const float *in) {
int i;
kiss_fft_cpx Y[FREQ_SIZE];
float x[FRAME_SIZE];
float Ey[NB_BANDS];
float features[NB_FEATURES];
float g[NB_BANDS];
float gf[FREQ_SIZE]={1};
float vad_prob;
int silence;
static const float a_hp[2] = {-1.99599, 0.99600};
static const float b_hp[2] = {-2, 1};
biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
silence = frame_analysis(st, Y, Ey, features, x);
if (!silence) {
compute_rnn(&st->rnn, g, &vad_prob, features);
interp_band_gain(gf, g);
#if 1
for (i=0;i<FREQ_SIZE;i++) {
Y[i].r *= gf[i];
Y[i].i *= gf[i];
}
#endif
}
frame_synthesis(st, out, Y);
}
#if TRAINING
static float uni_rand() {
return rand()/(double)RAND_MAX-.5;
}
static void rand_resp(float *a, float *b) {
a[0] = .75*uni_rand();
a[1] = .75*uni_rand();
b[0] = .75*uni_rand();
b[1] = .75*uni_rand();
}
int main(int argc, char **argv) {
int i;
int count=0;
static const float a_hp[2] = {-1.99599, 0.99600};
static const float b_hp[2] = {-2, 1};
float a_noise[2] = {0};
float b_noise[2] = {0};
float a_sig[2] = {0};
float b_sig[2] = {0};
float mem_hp_x[2]={0};
float mem_hp_n[2]={0};
float mem_resp_x[2]={0};
float mem_resp_n[2]={0};
float x[FRAME_SIZE];
float n[FRAME_SIZE];
float xn[FRAME_SIZE];
int vad_cnt=0;
int gain_change_count=0;
float speech_gain = 1, noise_gain = 1;
FILE *f1, *f2, *fout;
DenoiseState *st;
DenoiseState *noise_state;
DenoiseState *noisy;
st = rnnoise_create();
noise_state = rnnoise_create();
noisy = rnnoise_create();
if (argc!=4) {
fprintf(stderr, "usage: %s <speech> <noise> <output denoised>\n", argv[0]);
return 1;
}
f1 = fopen(argv[1], "r");
f2 = fopen(argv[2], "r");
fout = fopen(argv[3], "w");
for(i=0;i<150;i++) {
short tmp[FRAME_SIZE];
fread(tmp, sizeof(short), FRAME_SIZE, f2);
}
while (1) {
kiss_fft_cpx X[FREQ_SIZE], Y[FREQ_SIZE], N[FREQ_SIZE];
float Ex[NB_BANDS], Ey[NB_BANDS], En[NB_BANDS];
float Ln[NB_BANDS];
float features[NB_FEATURES];
float g[NB_BANDS];
float gf[FREQ_SIZE]={1};
short tmp[FRAME_SIZE];
float vad=0;
float vad_prob;
float E=0;
if (++gain_change_count > 101*300) {
speech_gain = pow(10., (-40+(rand()%60))/20.);
noise_gain = pow(10., (-30+(rand()%40))/20.);
if (rand()%10==0) noise_gain = 0;
noise_gain *= speech_gain;
if (rand()%10==0) speech_gain = 0;
gain_change_count = 0;
rand_resp(a_noise, b_noise);
rand_resp(a_sig, b_sig);
}
fread(tmp, sizeof(short), FRAME_SIZE, f1);
if (feof(f1)) break;
for (i=0;i<FRAME_SIZE;i++) x[i] = speech_gain*tmp[i];
fread(tmp, sizeof(short), FRAME_SIZE, f2);
if (feof(f2)) break;
for (i=0;i<FRAME_SIZE;i++) n[i] = noise_gain*tmp[i];
biquad(x, mem_hp_x, x, b_hp, a_hp, FRAME_SIZE);
biquad(x, mem_resp_x, x, b_sig, a_sig, FRAME_SIZE);
biquad(n, mem_hp_n, n, b_hp, a_hp, FRAME_SIZE);
biquad(n, mem_resp_n, n, b_noise, a_noise, FRAME_SIZE);
for (i=0;i<FRAME_SIZE;i++) xn[i] = x[i] + n[i];
for (i=0;i<FRAME_SIZE;i++) E += x[i]*(float)x[i];
if (E > 1e9f*speech_gain*speech_gain) {
vad_cnt=0;
} else if (E > 1e8f*speech_gain*speech_gain) {
vad_cnt -= 5;
if (vad_cnt < 0) vad_cnt = 0;
} else {
vad_cnt++;
if (vad_cnt > 15) vad_cnt = 15;
}
if (vad_cnt >= 10) vad = 0;
else if (vad_cnt > 0) vad = 0.5f;
else vad = 1.f;
frame_analysis(st, X, Ex, NULL, x);
frame_analysis(noise_state, N, En, NULL, n);
for (i=0;i<NB_BANDS;i++) Ln[i] = log10(1e-2+En[i]);
int silence = frame_analysis(noisy, Y, Ey, features, xn);
//printf("%f %d\n", noisy->last_gain, noisy->last_period);
for (i=0;i<NB_BANDS;i++) {
g[i] = sqrt((Ex[i]+1e-2)/(Ey[i]+1e-2));
if (g[i] > 1) g[i] = 1;
if (silence) g[i] = -1;
}
count++;
#if 0
for (i=0;i<NB_FEATURES;i++) printf("%f ", features[i]);
for (i=0;i<NB_BANDS;i++) printf("%f ", g[i]);
for (i=0;i<NB_BANDS;i++) printf("%f ", Ln[i]);
printf("%f\n", vad);
#endif
#if 1
fwrite(features, sizeof(float), NB_FEATURES, stdout);
fwrite(g, sizeof(float), NB_BANDS, stdout);
fwrite(Ln, sizeof(float), NB_BANDS, stdout);
fwrite(&vad, sizeof(float), 1, stdout);
#endif
#if 0
compute_rnn(&noisy->rnn, g, &vad_prob, features);
//for (i=0;i<NB_BANDS;i++) scanf("%f", &g[i]);
interp_band_gain(gf, g);
#if 1
for (i=0;i<FREQ_SIZE;i++) {
Y[i].r *= gf[i];
Y[i].i *= gf[i];
}
#endif
frame_synthesis(noisy, xn, Y);
for (i=0;i<FRAME_SIZE;i++) tmp[i] = xn[i];
fwrite(tmp, sizeof(short), FRAME_SIZE, fout);
#endif
}
fprintf(stderr, "matrix size: %d x %d\n", count, NB_FEATURES + 2*NB_BANDS + 1);
fclose(f1);
fclose(f2);
fclose(fout);
return 0;
}
#else
int main(int argc, char **argv) {
int i;
float x[FRAME_SIZE];
FILE *f1, *fout;
DenoiseState *st;
st = rnnoise_create();
if (argc!=3) {
fprintf(stderr, "usage: %s <noisy speech> <output denoised>\n", argv[0]);
return 1;
}
f1 = fopen(argv[1], "r");
fout = fopen(argv[2], "w");
while (1) {
short tmp[FRAME_SIZE];
fread(tmp, sizeof(short), FRAME_SIZE, f1);
if (feof(f1)) break;
for (i=0;i<FRAME_SIZE;i++) x[i] = tmp[i];
rnnoise_process_frame(st, x, x);
for (i=0;i<FRAME_SIZE;i++) tmp[i] = x[i];
fwrite(tmp, sizeof(short), FRAME_SIZE, fout);
}
fclose(f1);
fclose(fout);
return 0;
}
#endif