blob: 2e9568ba4e15b7174716bc3644899fba56064d62 [file] [log] [blame]
tlegrand@chromium.orge3ea0492013-10-23 09:13:50 +00001/* Copyright (c) 2008-2011 Octasic Inc.
2 Written by Jean-Marc Valin */
3/*
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions
6 are met:
7
8 - Redistributions of source code must retain the above copyright
9 notice, this list of conditions and the following disclaimer.
10
11 - Redistributions in binary form must reproduce the above copyright
12 notice, this list of conditions and the following disclaimer in the
13 documentation and/or other materials provided with the distribution.
14
15 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
19 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*/
27
28
29#include "mlp_train.h"
30#include <stdlib.h>
31#include <stdio.h>
32#include <string.h>
33#include <semaphore.h>
34#include <pthread.h>
35#include <time.h>
36#include <signal.h>
37
38int stopped = 0;
39
40void handler(int sig)
41{
42 stopped = 1;
43 signal(sig, handler);
44}
45
46MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
47{
48 int i, j, k;
49 MLPTrain *net;
50 int inDim, outDim;
51 net = malloc(sizeof(*net));
52 net->topo = malloc(nbLayers*sizeof(net->topo[0]));
53 for (i=0;i<nbLayers;i++)
54 net->topo[i] = topo[i];
55 inDim = topo[0];
56 outDim = topo[nbLayers-1];
57 net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
58 net->weights = malloc((nbLayers-1)*sizeof(net->weights));
59 net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
60 for (i=0;i<nbLayers-1;i++)
61 {
62 net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
63 net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
64 }
65 double inMean[inDim];
66 for (j=0;j<inDim;j++)
67 {
68 double std=0;
69 inMean[j] = 0;
70 for (i=0;i<nbSamples;i++)
71 {
72 inMean[j] += inputs[i*inDim+j];
73 std += inputs[i*inDim+j]*inputs[i*inDim+j];
74 }
75 inMean[j] /= nbSamples;
76 std /= nbSamples;
77 net->in_rate[1+j] = .5/(.0001+std);
78 std = std-inMean[j]*inMean[j];
79 if (std<.001)
80 std = .001;
81 std = 1/sqrt(inDim*std);
82 for (k=0;k<topo[1];k++)
83 net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
84 }
85 net->in_rate[0] = 1;
86 for (j=0;j<topo[1];j++)
87 {
88 double sum = 0;
89 for (k=0;k<inDim;k++)
90 sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
91 net->weights[0][j*(topo[0]+1)] = -sum;
92 }
93 for (j=0;j<outDim;j++)
94 {
95 double mean = 0;
96 double std;
97 for (i=0;i<nbSamples;i++)
98 mean += outputs[i*outDim+j];
99 mean /= nbSamples;
100 std = 1/sqrt(topo[nbLayers-2]);
101 net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
102 for (k=0;k<topo[nbLayers-2];k++)
103 net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
104 }
105 return net;
106}
107
108#define MAX_NEURONS 100
109#define MAX_OUT 10
110
111double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
112{
113 int i,j;
114 int s;
115 int inDim, outDim, hiddenDim;
116 int *topo;
117 double *W0, *W1;
118 double rms=0;
119 int W0_size, W1_size;
120 double hidden[MAX_NEURONS];
121 double netOut[MAX_NEURONS];
122 double error[MAX_NEURONS];
123
124 for (i=0;i<outDim;i++)
125 error_rate[i] = 0;
126 topo = net->topo;
127 inDim = net->topo[0];
128 hiddenDim = net->topo[1];
129 outDim = net->topo[2];
130 W0_size = (topo[0]+1)*topo[1];
131 W1_size = (topo[1]+1)*topo[2];
132 W0 = net->weights[0];
133 W1 = net->weights[1];
134 memset(W0_grad, 0, W0_size*sizeof(double));
135 memset(W1_grad, 0, W1_size*sizeof(double));
136 for (i=0;i<outDim;i++)
137 netOut[i] = outputs[i];
138 for (s=0;s<nbSamples;s++)
139 {
140 float *in, *out;
141 in = inputs+s*inDim;
142 out = outputs + s*outDim;
143 for (i=0;i<hiddenDim;i++)
144 {
145 double sum = W0[i*(inDim+1)];
146 for (j=0;j<inDim;j++)
147 sum += W0[i*(inDim+1)+j+1]*in[j];
148 hidden[i] = tansig_approx(sum);
149 }
150 for (i=0;i<outDim;i++)
151 {
152 double sum = W1[i*(hiddenDim+1)];
153 for (j=0;j<hiddenDim;j++)
154 sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
155 netOut[i] = tansig_approx(sum);
156 error[i] = out[i] - netOut[i];
157 rms += error[i]*error[i];
158 error_rate[i] += fabs(error[i])>1;
159 /*error[i] = error[i]/(1+fabs(error[i]));*/
160 }
161 /* Back-propagate error */
162 for (i=0;i<outDim;i++)
163 {
164 float grad = 1-netOut[i]*netOut[i];
165 W1_grad[i*(hiddenDim+1)] += error[i]*grad;
166 for (j=0;j<hiddenDim;j++)
167 W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
168 }
169 for (i=0;i<hiddenDim;i++)
170 {
171 double grad;
172 grad = 0;
173 for (j=0;j<outDim;j++)
174 grad += error[j]*W1[j*(hiddenDim+1)+i+1];
175 grad *= 1-hidden[i]*hidden[i];
176 W0_grad[i*(inDim+1)] += grad;
177 for (j=0;j<inDim;j++)
178 W0_grad[i*(inDim+1)+j+1] += grad*in[j];
179 }
180 }
181 return rms;
182}
183
184#define NB_THREADS 8
185
186sem_t sem_begin[NB_THREADS];
187sem_t sem_end[NB_THREADS];
188
189struct GradientArg {
190 int id;
191 int done;
192 MLPTrain *net;
193 float *inputs;
194 float *outputs;
195 int nbSamples;
196 double *W0_grad;
197 double *W1_grad;
198 double rms;
199 double error_rate[MAX_OUT];
200};
201
202void *gradient_thread_process(void *_arg)
203{
204 int W0_size, W1_size;
205 struct GradientArg *arg = _arg;
206 int *topo = arg->net->topo;
207 W0_size = (topo[0]+1)*topo[1];
208 W1_size = (topo[1]+1)*topo[2];
209 double W0_grad[W0_size];
210 double W1_grad[W1_size];
211 arg->W0_grad = W0_grad;
212 arg->W1_grad = W1_grad;
213 while (1)
214 {
215 sem_wait(&sem_begin[arg->id]);
216 if (arg->done)
217 break;
218 arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate);
219 sem_post(&sem_end[arg->id]);
220 }
221 fprintf(stderr, "done\n");
222 return NULL;
223}
224
225float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
226{
227 int i, j;
228 int e;
229 float best_rms = 1e10;
230 int inDim, outDim, hiddenDim;
231 int *topo;
232 double *W0, *W1, *best_W0, *best_W1;
233 double *W0_old, *W1_old;
234 double *W0_old2, *W1_old2;
235 double *W0_grad, *W1_grad;
236 double *W0_oldgrad, *W1_oldgrad;
237 double *W0_rate, *W1_rate;
238 double *best_W0_rate, *best_W1_rate;
239 int W0_size, W1_size;
240 topo = net->topo;
241 W0_size = (topo[0]+1)*topo[1];
242 W1_size = (topo[1]+1)*topo[2];
243 struct GradientArg args[NB_THREADS];
244 pthread_t thread[NB_THREADS];
245 int samplePerPart = nbSamples/NB_THREADS;
246 int count_worse=0;
247 int count_retries=0;
248
249 topo = net->topo;
250 inDim = net->topo[0];
251 hiddenDim = net->topo[1];
252 outDim = net->topo[2];
253 W0 = net->weights[0];
254 W1 = net->weights[1];
255 best_W0 = net->best_weights[0];
256 best_W1 = net->best_weights[1];
257 W0_old = malloc(W0_size*sizeof(double));
258 W1_old = malloc(W1_size*sizeof(double));
259 W0_old2 = malloc(W0_size*sizeof(double));
260 W1_old2 = malloc(W1_size*sizeof(double));
261 W0_grad = malloc(W0_size*sizeof(double));
262 W1_grad = malloc(W1_size*sizeof(double));
263 W0_oldgrad = malloc(W0_size*sizeof(double));
264 W1_oldgrad = malloc(W1_size*sizeof(double));
265 W0_rate = malloc(W0_size*sizeof(double));
266 W1_rate = malloc(W1_size*sizeof(double));
267 best_W0_rate = malloc(W0_size*sizeof(double));
268 best_W1_rate = malloc(W1_size*sizeof(double));
269 memcpy(W0_old, W0, W0_size*sizeof(double));
270 memcpy(W0_old2, W0, W0_size*sizeof(double));
271 memset(W0_grad, 0, W0_size*sizeof(double));
272 memset(W0_oldgrad, 0, W0_size*sizeof(double));
273 memcpy(W1_old, W1, W1_size*sizeof(double));
274 memcpy(W1_old2, W1, W1_size*sizeof(double));
275 memset(W1_grad, 0, W1_size*sizeof(double));
276 memset(W1_oldgrad, 0, W1_size*sizeof(double));
277
278 rate /= nbSamples;
279 for (i=0;i<hiddenDim;i++)
280 for (j=0;j<inDim+1;j++)
281 W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
282 for (i=0;i<W1_size;i++)
283 W1_rate[i] = rate;
284
285 for (i=0;i<NB_THREADS;i++)
286 {
287 args[i].net = net;
288 args[i].inputs = inputs+i*samplePerPart*inDim;
289 args[i].outputs = outputs+i*samplePerPart*outDim;
290 args[i].nbSamples = samplePerPart;
291 args[i].id = i;
292 args[i].done = 0;
293 sem_init(&sem_begin[i], 0, 0);
294 sem_init(&sem_end[i], 0, 0);
295 pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
296 }
297 for (e=0;e<nbEpoch;e++)
298 {
299 double rms=0;
300 double error_rate[2] = {0,0};
301 for (i=0;i<NB_THREADS;i++)
302 {
303 sem_post(&sem_begin[i]);
304 }
305 memset(W0_grad, 0, W0_size*sizeof(double));
306 memset(W1_grad, 0, W1_size*sizeof(double));
307 for (i=0;i<NB_THREADS;i++)
308 {
309 sem_wait(&sem_end[i]);
310 rms += args[i].rms;
311 error_rate[0] += args[i].error_rate[0];
312 error_rate[1] += args[i].error_rate[1];
313 for (j=0;j<W0_size;j++)
314 W0_grad[j] += args[i].W0_grad[j];
315 for (j=0;j<W1_size;j++)
316 W1_grad[j] += args[i].W1_grad[j];
317 }
318
319 float mean_rate = 0, min_rate = 1e10;
320 rms = (rms/(outDim*nbSamples));
321 error_rate[0] = (error_rate[0]/(nbSamples));
322 error_rate[1] = (error_rate[1]/(nbSamples));
323 fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms);
324 if (rms < best_rms)
325 {
326 best_rms = rms;
327 for (i=0;i<W0_size;i++)
328 {
329 best_W0[i] = W0[i];
330 best_W0_rate[i] = W0_rate[i];
331 }
332 for (i=0;i<W1_size;i++)
333 {
334 best_W1[i] = W1[i];
335 best_W1_rate[i] = W1_rate[i];
336 }
337 count_worse=0;
338 count_retries=0;
339 } else {
340 count_worse++;
341 if (count_worse>30)
342 {
343 count_retries++;
344 count_worse=0;
345 for (i=0;i<W0_size;i++)
346 {
347 W0[i] = best_W0[i];
348 best_W0_rate[i] *= .7;
349 if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
350 W0_rate[i] = best_W0_rate[i];
351 W0_grad[i] = 0;
352 }
353 for (i=0;i<W1_size;i++)
354 {
355 W1[i] = best_W1[i];
356 best_W1_rate[i] *= .8;
357 if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
358 W1_rate[i] = best_W1_rate[i];
359 W1_grad[i] = 0;
360 }
361 }
362 }
363 if (count_retries>10)
364 break;
365 for (i=0;i<W0_size;i++)
366 {
367 if (W0_oldgrad[i]*W0_grad[i] > 0)
368 W0_rate[i] *= 1.01;
369 else if (W0_oldgrad[i]*W0_grad[i] < 0)
370 W0_rate[i] *= .9;
371 mean_rate += W0_rate[i];
372 if (W0_rate[i] < min_rate)
373 min_rate = W0_rate[i];
374 if (W0_rate[i] < 1e-15)
375 W0_rate[i] = 1e-15;
376 /*if (W0_rate[i] > .01)
377 W0_rate[i] = .01;*/
378 W0_oldgrad[i] = W0_grad[i];
379 W0_old2[i] = W0_old[i];
380 W0_old[i] = W0[i];
381 W0[i] += W0_grad[i]*W0_rate[i];
382 }
383 for (i=0;i<W1_size;i++)
384 {
385 if (W1_oldgrad[i]*W1_grad[i] > 0)
386 W1_rate[i] *= 1.01;
387 else if (W1_oldgrad[i]*W1_grad[i] < 0)
388 W1_rate[i] *= .9;
389 mean_rate += W1_rate[i];
390 if (W1_rate[i] < min_rate)
391 min_rate = W1_rate[i];
392 if (W1_rate[i] < 1e-15)
393 W1_rate[i] = 1e-15;
394 W1_oldgrad[i] = W1_grad[i];
395 W1_old2[i] = W1_old[i];
396 W1_old[i] = W1[i];
397 W1[i] += W1_grad[i]*W1_rate[i];
398 }
399 mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
400 fprintf (stderr, "%g %d", mean_rate, e);
401 if (count_retries)
402 fprintf(stderr, " %d", count_retries);
403 fprintf(stderr, "\n");
404 if (stopped)
405 break;
406 }
407 for (i=0;i<NB_THREADS;i++)
408 {
409 args[i].done = 1;
410 sem_post(&sem_begin[i]);
411 pthread_join(thread[i], NULL);
412 fprintf (stderr, "joined %d\n", i);
413 }
414 free(W0_old);
415 free(W1_old);
416 free(W0_grad);
417 free(W1_grad);
418 free(W0_rate);
419 free(W1_rate);
420 return best_rms;
421}
422
423int main(int argc, char **argv)
424{
425 int i, j;
426 int nbInputs;
427 int nbOutputs;
428 int nbHidden;
429 int nbSamples;
430 int nbEpoch;
431 int nbRealInputs;
432 unsigned int seed;
433 int ret;
434 float rms;
435 float *inputs;
436 float *outputs;
437 if (argc!=6)
438 {
439 fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
440 return 1;
441 }
442 nbInputs = atoi(argv[1]);
443 nbHidden = atoi(argv[2]);
444 nbOutputs = atoi(argv[3]);
445 nbSamples = atoi(argv[4]);
446 nbEpoch = atoi(argv[5]);
447 nbRealInputs = nbInputs;
448 inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
449 outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
450
451 seed = time(NULL);
452 /*seed = 1361480659;*/
453 fprintf (stderr, "Seed is %u\n", seed);
454 srand(seed);
455 build_tansig_table();
456 signal(SIGTERM, handler);
457 signal(SIGINT, handler);
458 signal(SIGHUP, handler);
459 for (i=0;i<nbSamples;i++)
460 {
461 for (j=0;j<nbRealInputs;j++)
462 ret = scanf(" %f", &inputs[i*nbInputs+j]);
463 for (j=0;j<nbOutputs;j++)
464 ret = scanf(" %f", &outputs[i*nbOutputs+j]);
465 if (feof(stdin))
466 {
467 nbSamples = i;
468 break;
469 }
470 }
471 int topo[3] = {nbInputs, nbHidden, nbOutputs};
472 MLPTrain *net;
473
474 fprintf (stderr, "Got %d samples\n", nbSamples);
475 net = mlp_init(topo, 3, inputs, outputs, nbSamples);
476 rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
477 printf ("#include \"mlp.h\"\n\n");
478 printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
479 printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
480 printf ("\n/* hidden layer */\n");
481 for (i=0;i<(topo[0]+1)*topo[1];i++)
482 {
483 printf ("%gf, ", net->weights[0][i]);
484 if (i%5==4)
485 printf("\n");
486 }
487 printf ("\n/* output layer */\n");
488 for (i=0;i<(topo[1]+1)*topo[2];i++)
489 {
490 printf ("%g, ", net->weights[1][i]);
491 if (i%5==4)
492 printf("\n");
493 }
494 printf ("};\n\n");
495 printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
496 printf ("const MLP net = {\n");
497 printf ("\t3,\n");
498 printf ("\ttopo,\n");
499 printf ("\tweights\n};\n");
500 return 0;
501}