Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 1 | #!/usr/bin/python |
| 2 | |
| 3 | from __future__ import print_function |
| 4 | |
| 5 | from keras.models import Sequential |
| 6 | from keras.layers import Dense |
| 7 | from keras.layers import LSTM |
| 8 | from keras.layers import GRU |
| 9 | from keras.models import load_model |
| 10 | from keras import backend as K |
| 11 | import sys |
| 12 | import re |
| 13 | import numpy as np |
| 14 | |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 15 | def printVector(f, ft, vector, name): |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 16 | v = np.reshape(vector, (-1)); |
| 17 | #print('static const float ', name, '[', len(v), '] = \n', file=f) |
Jean-Marc Valin | 4d1e630 | 2017-08-14 12:48:27 -0400 | [diff] [blame] | 18 | f.write('static const rnn_weight {}[{}] = {{\n '.format(name, len(v))) |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 19 | for i in range(0, len(v)): |
Jean-Marc Valin | 4d1e630 | 2017-08-14 12:48:27 -0400 | [diff] [blame] | 20 | f.write('{}'.format(min(127, int(round(256*v[i]))))) |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 21 | ft.write('{}'.format(min(127, int(round(256*v[i]))))) |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 22 | if (i!=len(v)-1): |
| 23 | f.write(',') |
| 24 | else: |
| 25 | break; |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 26 | ft.write(" ") |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 27 | if (i%8==7): |
| 28 | f.write("\n ") |
| 29 | else: |
| 30 | f.write(" ") |
| 31 | #print(v, file=f) |
| 32 | f.write('\n};\n\n') |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 33 | ft.write("\n") |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 34 | return; |
| 35 | |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 36 | def printLayer(f, ft, layer): |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 37 | weights = layer.get_weights() |
Jean-Marc Valin | b3abc61 | 2017-08-04 01:56:11 -0400 | [diff] [blame] | 38 | activation = re.search('function (.*) at', str(layer.activation)).group(1).upper() |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 39 | if len(weights) > 2: |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 40 | ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1]/3)) |
| 41 | else: |
| 42 | ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1])) |
| 43 | if activation == 'SIGMOID': |
| 44 | ft.write('1\n') |
| 45 | elif activation == 'RELU': |
| 46 | ft.write('2\n') |
| 47 | else: |
| 48 | ft.write('0\n') |
| 49 | printVector(f, ft, weights[0], layer.name + '_weights') |
| 50 | if len(weights) > 2: |
| 51 | printVector(f, ft, weights[1], layer.name + '_recurrent_weights') |
| 52 | printVector(f, ft, weights[-1], layer.name + '_bias') |
| 53 | name = layer.name |
| 54 | if len(weights) > 2: |
Gregor Richards | f30741b | 2018-08-28 10:40:28 -0400 | [diff] [blame] | 55 | f.write('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 56 | .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation)) |
| 57 | else: |
Gregor Richards | f30741b | 2018-08-28 10:40:28 -0400 | [diff] [blame] | 58 | f.write('static const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 59 | .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation)) |
Gregor Richards | f30741b | 2018-08-28 10:40:28 -0400 | [diff] [blame] | 60 | |
| 61 | def structLayer(f, layer): |
| 62 | weights = layer.get_weights() |
| 63 | name = layer.name |
| 64 | if len(weights) > 2: |
| 65 | f.write(' {},\n'.format(weights[0].shape[1]/3)) |
| 66 | else: |
| 67 | f.write(' {},\n'.format(weights[0].shape[1])) |
| 68 | f.write(' &{},\n'.format(name)) |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 69 | |
| 70 | |
Jean-Marc Valin | 4d1e630 | 2017-08-14 12:48:27 -0400 | [diff] [blame] | 71 | def foo(c, name): |
Gregor Richards | f30741b | 2018-08-28 10:40:28 -0400 | [diff] [blame] | 72 | return None |
Jean-Marc Valin | 4d1e630 | 2017-08-14 12:48:27 -0400 | [diff] [blame] | 73 | |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 74 | def mean_squared_sqrt_error(y_true, y_pred): |
| 75 | return K.mean(K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1) |
| 76 | |
| 77 | |
Jean-Marc Valin | 4d1e630 | 2017-08-14 12:48:27 -0400 | [diff] [blame] | 78 | model = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error, 'mean_squared_sqrt_error': mean_squared_sqrt_error, 'my_crossentropy': mean_squared_sqrt_error, 'mycost': mean_squared_sqrt_error, 'WeightClip': foo}) |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 79 | |
| 80 | weights = model.get_weights() |
| 81 | |
| 82 | f = open(sys.argv[2], 'w') |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 83 | ft = open(sys.argv[3], 'w') |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 84 | |
| 85 | f.write('/*This file is automatically generated from a Keras model*/\n\n') |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 86 | f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n#include "rnn_data.h"\n\n') |
| 87 | ft.write('rnnoise-nu model file version 1\n') |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 88 | |
Jean-Marc Valin | 1399bd8 | 2017-08-04 02:08:47 -0400 | [diff] [blame] | 89 | layer_list = [] |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 90 | for i, layer in enumerate(model.layers): |
| 91 | if len(layer.get_weights()) > 0: |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 92 | printLayer(f, ft, layer) |
Jean-Marc Valin | 1399bd8 | 2017-08-04 02:08:47 -0400 | [diff] [blame] | 93 | if len(layer.get_weights()) > 2: |
| 94 | layer_list.append(layer.name) |
| 95 | |
Gregor Richards | 5e7af83 | 2018-09-20 21:51:38 -0400 | [diff] [blame] | 96 | f.write('const struct RNNModel rnnoise_model_{} = {{\n'.format(sys.argv[4])) |
Gregor Richards | f30741b | 2018-08-28 10:40:28 -0400 | [diff] [blame] | 97 | for i, layer in enumerate(model.layers): |
| 98 | if len(layer.get_weights()) > 0: |
| 99 | structLayer(f, layer) |
| 100 | f.write('};\n') |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 101 | |
Gregor Richards | f30741b | 2018-08-28 10:40:28 -0400 | [diff] [blame] | 102 | #hf.write('struct RNNState {\n') |
| 103 | #for i, name in enumerate(layer_list): |
| 104 | # hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper())) |
| 105 | #hf.write('};\n') |
Jean-Marc Valin | 0bcf788 | 2017-08-03 20:12:57 -0400 | [diff] [blame] | 106 | |
| 107 | f.close() |