blob: 06d7e1a4c93994a64a4164cd620b1a462c4fe9f7 [file] [log] [blame]
Jean-Marc Valincf473ce2017-08-03 15:26:05 -04001#!/usr/bin/python
2
3from __future__ import print_function
4
5import keras
6from keras.models import Sequential
7from keras.models import Model
8from keras.layers import Input
9from keras.layers import Dense
10from keras.layers import LSTM
11from keras.layers import GRU
12from keras.layers import SimpleRNN
13from keras.layers import Dropout
14from keras.layers import concatenate
15from keras import losses
16from keras import regularizers
Jean-Marc Valin4d1e6302017-08-14 12:48:27 -040017from keras.constraints import min_max_norm
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040018import h5py
19
Jean-Marc Valin4d1e6302017-08-14 12:48:27 -040020from keras.constraints import Constraint
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040021from keras import backend as K
22import numpy as np
23
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040024#import tensorflow as tf
25#from keras.backend.tensorflow_backend import set_session
26#config = tf.ConfigProto()
27#config.gpu_options.per_process_gpu_memory_fraction = 0.42
28#set_session(tf.Session(config=config))
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040029
30
31def my_crossentropy(y_true, y_pred):
32 return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)
33
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040034def mymask(y_true):
35 return K.minimum(y_true+1., 1.)
36
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040037def msse(y_true, y_pred):
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040038 return K.mean(mymask(y_true) * K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040039
40def mycost(y_true, y_pred):
Jean-Marc Valin92739d82017-08-23 11:29:50 -040041 return K.mean(mymask(y_true) * (10*K.square(K.square(K.sqrt(y_pred) - K.sqrt(y_true))) + K.square(K.sqrt(y_pred) - K.sqrt(y_true)) + 0.01*K.binary_crossentropy(y_pred, y_true)), axis=-1)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040042
43def my_accuracy(y_true, y_pred):
44 return K.mean(2*K.abs(y_true-0.5) * K.equal(y_true, K.round(y_pred)), axis=-1)
45
Jean-Marc Valin4d1e6302017-08-14 12:48:27 -040046class WeightClip(Constraint):
47 '''Clips the weights incident to each hidden unit to be inside a range
48 '''
49 def __init__(self, c=2):
50 self.c = c
51
52 def __call__(self, p):
53 return K.clip(p, -self.c, self.c)
54
55 def get_config(self):
56 return {'name': self.__class__.__name__,
57 'c': self.c}
58
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040059reg = 0.000001
Jean-Marc Valin4d1e6302017-08-14 12:48:27 -040060constraint = WeightClip(0.499)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040061
62print('Build model...')
63main_input = Input(shape=(None, 42), name='main_input')
Jean-Marc Valin4d1e6302017-08-14 12:48:27 -040064tmp = Dense(24, activation='tanh', name='input_dense', kernel_constraint=constraint, bias_constraint=constraint)(main_input)
65vad_gru = GRU(24, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='vad_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(tmp)
66vad_output = Dense(1, activation='sigmoid', name='vad_output', kernel_constraint=constraint, bias_constraint=constraint)(vad_gru)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040067noise_input = keras.layers.concatenate([tmp, vad_gru, main_input])
Jean-Marc Valin4d1e6302017-08-14 12:48:27 -040068noise_gru = GRU(48, activation='relu', recurrent_activation='sigmoid', return_sequences=True, name='noise_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(noise_input)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040069denoise_input = keras.layers.concatenate([vad_gru, noise_gru, main_input])
70
Jean-Marc Valin4d1e6302017-08-14 12:48:27 -040071denoise_gru = GRU(96, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='denoise_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(denoise_input)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040072
Jean-Marc Valin4d1e6302017-08-14 12:48:27 -040073denoise_output = Dense(22, activation='sigmoid', name='denoise_output', kernel_constraint=constraint, bias_constraint=constraint)(denoise_gru)
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040074
75model = Model(inputs=main_input, outputs=[denoise_output, vad_output])
76
77model.compile(loss=[mycost, my_crossentropy],
78 metrics=[msse],
79 optimizer='adam', loss_weights=[10, 0.5])
80
81
Jean-Marc Valin54eeea72017-08-08 11:20:29 -040082batch_size = 32
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040083
84print('Loading data...')
Gregor Richards9aff6a72018-08-30 17:21:52 -040085with h5py.File('training.h5', 'r') as hf:
Jean-Marc Valincf473ce2017-08-03 15:26:05 -040086 all_data = hf['data'][:]
87print('done.')
88
89window_size = 2000
90
91nb_sequences = len(all_data)//window_size
92print(nb_sequences, ' sequences')
93x_train = all_data[:nb_sequences*window_size, :42]
94x_train = np.reshape(x_train, (nb_sequences, window_size, 42))
95
96y_train = np.copy(all_data[:nb_sequences*window_size, 42:64])
97y_train = np.reshape(y_train, (nb_sequences, window_size, 22))
98
99noise_train = np.copy(all_data[:nb_sequences*window_size, 64:86])
100noise_train = np.reshape(noise_train, (nb_sequences, window_size, 22))
101
102vad_train = np.copy(all_data[:nb_sequences*window_size, 86:87])
103vad_train = np.reshape(vad_train, (nb_sequences, window_size, 1))
104
105all_data = 0;
106#x_train = x_train.astype('float32')
107#y_train = y_train.astype('float32')
108
109print(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape)
110
111print('Train...')
112model.fit(x_train, [y_train, vad_train],
113 batch_size=batch_size,
Jean-Marc Valin4d1e6302017-08-14 12:48:27 -0400114 epochs=120,
Jean-Marc Valincf473ce2017-08-03 15:26:05 -0400115 validation_split=0.1)
Gregor Richards9aff6a72018-08-30 17:21:52 -0400116model.save("weights.hdf5")