blob: e1b6935f59f9281ea4459228661013f53b882fcf [file] [log] [blame]
Jean-Marc Valincf473ce2017-08-03 15:26:05 -04001#!/usr/bin/python
2
3from __future__ import print_function
4
5import keras
6from keras.models import Sequential
7from keras.models import Model
8from keras.layers import Input
9from keras.layers import Dense
10from keras.layers import LSTM
11from keras.layers import GRU
12from keras.layers import SimpleRNN
13from keras.layers import Dropout
14from keras.layers import concatenate
15from keras import losses
16from keras import regularizers
17import h5py
18
19from keras import backend as K
20import numpy as np
21
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040022#import tensorflow as tf
23#from keras.backend.tensorflow_backend import set_session
24#config = tf.ConfigProto()
25#config.gpu_options.per_process_gpu_memory_fraction = 0.42
26#set_session(tf.Session(config=config))
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040027
28
29def my_crossentropy(y_true, y_pred):
30 return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)
31
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040032def mymask(y_true):
33 return K.minimum(y_true+1., 1.)
34
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040035def msse(y_true, y_pred):
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040036 return K.mean(mymask(y_true) * K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040037
38def mycost(y_true, y_pred):
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040039 return K.mean(mymask(y_true) * (K.square(K.sqrt(y_pred) - K.sqrt(y_true)) + 0.01*K.binary_crossentropy(y_pred, y_true)), axis=-1)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040040
41def my_accuracy(y_true, y_pred):
42 return K.mean(2*K.abs(y_true-0.5) * K.equal(y_true, K.round(y_pred)), axis=-1)
43
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040044reg = 0.000001
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040045
46print('Build model...')
47main_input = Input(shape=(None, 42), name='main_input')
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040048tmp = Dense(24, activation='tanh', name='input_dense')(main_input)
49vad_gru = GRU(24, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='vad_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg))(tmp)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040050vad_output = Dense(1, activation='sigmoid', name='vad_output')(vad_gru)
51noise_input = keras.layers.concatenate([tmp, vad_gru, main_input])
52noise_gru = GRU(48, activation='relu', recurrent_activation='sigmoid', return_sequences=True, name='noise_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg))(noise_input)
53denoise_input = keras.layers.concatenate([vad_gru, noise_gru, main_input])
54
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040055denoise_gru = GRU(96, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='denoise_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg))(denoise_input)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040056
57denoise_output = Dense(22, activation='sigmoid', name='denoise_output')(denoise_gru)
58
59model = Model(inputs=main_input, outputs=[denoise_output, vad_output])
60
61model.compile(loss=[mycost, my_crossentropy],
62 metrics=[msse],
63 optimizer='adam', loss_weights=[10, 0.5])
64
65
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040066batch_size = 32
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040067
68print('Loading data...')
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040069with h5py.File('denoise_data6.h5', 'r') as hf:
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040070 all_data = hf['data'][:]
71print('done.')
72
73window_size = 2000
74
75nb_sequences = len(all_data)//window_size
76print(nb_sequences, ' sequences')
77x_train = all_data[:nb_sequences*window_size, :42]
78x_train = np.reshape(x_train, (nb_sequences, window_size, 42))
79
80y_train = np.copy(all_data[:nb_sequences*window_size, 42:64])
81y_train = np.reshape(y_train, (nb_sequences, window_size, 22))
82
83noise_train = np.copy(all_data[:nb_sequences*window_size, 64:86])
84noise_train = np.reshape(noise_train, (nb_sequences, window_size, 22))
85
86vad_train = np.copy(all_data[:nb_sequences*window_size, 86:87])
87vad_train = np.reshape(vad_train, (nb_sequences, window_size, 1))
88
89all_data = 0;
90#x_train = x_train.astype('float32')
91#y_train = y_train.astype('float32')
92
93print(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape)
94
95print('Train...')
96model.fit(x_train, [y_train, vad_train],
97 batch_size=batch_size,
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040098 epochs=60,
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040099 validation_split=0.1)
Jean-Marc Valin54eeea72017-08-08 11:20:29 -0400100model.save("newweights6a2a.hdf5")