Surround masking rewrite

The channel energy is now computed using an MDCT in the multi-stream
encoder rather than computing the energy of the stereo downmix.
diff --git a/src/opus_multistream_encoder.c b/src/opus_multistream_encoder.c
index 4cddbff..bc86489 100644
--- a/src/opus_multistream_encoder.c
+++ b/src/opus_multistream_encoder.c
@@ -38,6 +38,10 @@
 #include "os_support.h"
 #include "analysis.h"
 #include "mathops.h"
+#include "mdct.h"
+#include "modes.h"
+#include "bands.h"
+#include "quant_bands.h"
 
 typedef struct {
    int nb_streams;
@@ -57,6 +61,15 @@
       {5, 3, {0, 6, 1, 2, 3, 4, 5, 7}}, /* 8: 7.1 surround */
 };
 
+typedef void (*opus_copy_channel_in_func)(
+  opus_val16 *dst,
+  int dst_stride,
+  const void *src,
+  int src_stride,
+  int src_channel,
+  int frame_size
+);
+
 struct OpusMSEncoder {
    TonalityAnalysisState analysis;
    ChannelLayout layout;
@@ -66,8 +79,47 @@
    opus_int32 bitrate_bps;
    opus_val32 subframe_mem[3];
    /* Encoder states go here */
+   /* then opus_val32 window_mem[channels*120]; */
+   /* then opus_val32 preemph_mem[channels]; */
 };
 
+static opus_val32 *ms_get_preemph_mem(OpusMSEncoder *st)
+{
+   int s;
+   char *ptr;
+   int coupled_size, mono_size;
+
+   coupled_size = opus_encoder_get_size(2);
+   mono_size = opus_encoder_get_size(1);
+   ptr = (char*)st + align(sizeof(OpusMSEncoder));
+   for (s=0;s<st->layout.nb_streams;s++)
+   {
+      if (s < st->layout.nb_coupled_streams)
+         ptr += align(coupled_size);
+      else
+         ptr += align(mono_size);
+   }
+   return (opus_val32*)(ptr+st->layout.nb_channels*120*sizeof(opus_val32));
+}
+
+static opus_val32 *ms_get_window_mem(OpusMSEncoder *st)
+{
+   int s;
+   char *ptr;
+   int coupled_size, mono_size;
+
+   coupled_size = opus_encoder_get_size(2);
+   mono_size = opus_encoder_get_size(1);
+   ptr = (char*)st + align(sizeof(OpusMSEncoder));
+   for (s=0;s<st->layout.nb_streams;s++)
+   {
+      if (s < st->layout.nb_coupled_streams)
+         ptr += align(coupled_size);
+      else
+         ptr += align(mono_size);
+   }
+   return (opus_val32*)ptr;
+}
 
 static int validate_encoder_layout(const ChannelLayout *layout)
 {
@@ -88,6 +140,164 @@
    return 1;
 }
 
+static void channel_pos(int channels, int pos[8])
+{
+   /* Position in the mix: 0 don't mix, 1: left, 2: center, 3:right */
+   if (channels==4)
+   {
+      pos[0]=1;
+      pos[1]=3;
+      pos[2]=1;
+      pos[3]=3;
+   } else if (channels==3||channels==5||channels==6)
+   {
+      pos[0]=1;
+      pos[1]=2;
+      pos[2]=3;
+      pos[3]=1;
+      pos[4]=3;
+      pos[5]=0;
+   } else if (channels==7)
+   {
+      pos[0]=1;
+      pos[1]=2;
+      pos[2]=3;
+      pos[3]=1;
+      pos[4]=3;
+      pos[5]=2;
+      pos[6]=0;
+   } else if (channels==8)
+   {
+      pos[0]=1;
+      pos[1]=2;
+      pos[2]=3;
+      pos[3]=1;
+      pos[4]=3;
+      pos[5]=1;
+      pos[6]=3;
+      pos[7]=0;
+   }
+}
+
+void surround_analysis(const CELTMode *celt_mode, const void *pcm, opus_val16 *bandLogE, opus_val32 *mem, opus_val32 *preemph_mem,
+      int len, int overlap, int channels, int rate, opus_copy_channel_in_func copy_channel_in
+)
+{
+   int c;
+   int i;
+   /* FIXME: pass LM properly */
+   int LM=3;
+   int pos[8] = {0};
+   int upsample;
+   opus_val32 bandE[21];
+   opus_val32 maskE[3][21];
+   opus_val16 maskLogE[3][21];
+   VARDECL(opus_val32, in);
+   VARDECL(opus_val16, x);
+   VARDECL(opus_val32, out);
+   SAVE_STACK;
+   ALLOC(in, len+overlap, opus_val32);
+   ALLOC(x, len, opus_val16);
+   ALLOC(freq, len, opus_val32);
+
+   channel_pos(channels, pos);
+
+   for (c=0;c<2;c++)
+      for (i=0;i<21;i++)
+         maskE[c][i] = 0;
+
+   upsample = resampling_factor(rate);
+   for (c=0;c<channels;c++)
+   {
+      OPUS_COPY(in, mem+c*overlap, overlap);
+      (*copy_channel_in)(x, 1, pcm, channels, c, len);
+      /* FIXME: Handle upsampling properly wrt len */
+      preemphasis(x, in+overlap, len, 1, upsample, celt_mode->preemph, preemph_mem+c, 0);
+      clt_mdct_forward(&celt_mode->mdct, in, freq, celt_mode->window, overlap, celt_mode->maxLM-LM, 1);
+      if (upsample != 1)
+      {
+         int bound = len/upsample;
+         for (i=0;i<bound;i++)
+            freq[i] *= upsample;
+         for (;i<len;i++)
+            freq[i] = 0;
+      }
+
+      compute_band_energies(celt_mode, freq, bandE, 21, 1, 1<<LM);
+      /* FIXME: Figure out how to square bandE[] in fixed-point */
+      if (pos[c]==1)
+      {
+         for (i=0;i<21;i++)
+            maskE[0][i] += bandE[i]*bandE[i];
+      } else if (pos[c]==3)
+      {
+         for (i=0;i<21;i++)
+            maskE[1][i] += bandE[i]*bandE[i];
+      } else if (pos[c]==2)
+      {
+         for (i=0;i<21;i++)
+         {
+            maskE[0][i] += HALF32(bandE[i]*bandE[i]);
+            maskE[1][i] += HALF32(bandE[i]*bandE[i]);
+         }
+      }
+      amp2Log2(celt_mode, 21, 21, bandE, bandLogE+21*c, 1);
+#if 0
+      for (i=0;i<21;i++)
+         printf("%f ", bandLogE[21*c+i]);
+//#else
+      float sum=0;
+      for (i=0;i<21;i++)
+         sum += bandLogE[21*c+i];
+      printf("%f ", sum/21);
+#endif
+      OPUS_COPY(mem+c*overlap, in+len, overlap);
+   }
+   for (i=0;i<21;i++)
+      maskE[2][i] = MIN32(maskE[0][i],maskE[1][i]);
+   for (c=0;c<3;c++)
+      for (i=0;i<21;i++)
+         maskE[c][i] = sqrt(maskE[c][i]*2/(channels-1));
+   /* Left mask */
+   amp2Log2(celt_mode, 21, 21, &maskE[0][0], &maskLogE[0][0], 1);
+   /* Right mask */
+   amp2Log2(celt_mode, 21, 21, &maskE[1][0], &maskLogE[2][0], 1);
+   /* Centre mask */
+   amp2Log2(celt_mode, 21, 21, &maskE[2][0], &maskLogE[1][0], 1);
+#if 0
+   for (c=0;c<3;c++)
+   {
+      for (i=0;i<21;i++)
+         printf("%f ", maskLogE[c][i]);
+   }
+#endif
+   for (c=0;c<channels;c++)
+   {
+      opus_val16 *mask;
+      if (pos[c]!=0)
+      {
+         mask = &maskLogE[pos[c]-1][0];
+         for (i=0;i<21;i++)
+            bandLogE[21*c+i] = bandLogE[21*c+i] - mask[i];
+      } else {
+         for (i=0;i<21;i++)
+            bandLogE[21*c+i] = 0;
+      }
+#if 0
+      for (i=0;i<21;i++)
+         printf("%f ", bandLogE[21*c+i]);
+      printf("\n");
+#endif
+#if 0
+      float sum=0;
+      for (i=0;i<21;i++)
+         sum += bandLogE[21*c+i];
+      printf("%f ", sum/21);
+      printf("\n");
+#endif
+   }
+   RESTORE_STACK;
+}
 
 opus_int32 opus_multistream_encoder_get_size(int nb_streams, int nb_coupled_streams)
 {
@@ -132,7 +342,9 @@
       return 0;
    size = opus_multistream_encoder_get_size(nb_streams, nb_coupled_streams);
    if (channels>2)
-      size += align(opus_encoder_get_size(2));
+   {
+      size += channels*(120*sizeof(opus_val32) + sizeof(opus_val32));
+   }
    return size;
 }
 
@@ -192,10 +404,8 @@
    }
    if (surround)
    {
-      OpusEncoder *downmix_enc;
-      downmix_enc = (OpusEncoder*)ptr;
-      ret = opus_encoder_init(downmix_enc, Fs, 2, OPUS_APPLICATION_AUDIO);
-      if(ret!=OPUS_OK)return ret;
+      OPUS_CLEAR(ms_get_preemph_mem(st), channels);
+      OPUS_CLEAR(ms_get_window_mem(st), channels*120);
    }
    st->surround = surround;
    return OPUS_OK;
@@ -339,22 +549,6 @@
    return st;
 }
 
-typedef void (*opus_copy_channel_in_func)(
-  opus_val16 *dst,
-  int dst_stride,
-  const void *src,
-  int src_stride,
-  int src_channel,
-  int frame_size
-);
-
-typedef void (*opus_surround_downmix_funct)(
-  opus_val16 *dst,
-  const void *src,
-  int channels,
-  int frame_size
-);
-
 static void surround_rate_allocation(
       OpusMSEncoder *st,
       opus_int32 *rate,
@@ -436,8 +630,7 @@
     int frame_size,
     unsigned char *data,
     opus_int32 max_data_bytes,
-    int lsb_depth,
-    opus_surround_downmix_funct surround_downmix
+    int lsb_depth
 #ifndef FIXED_POINT
     , downmix_func downmix
     , const void *pcm_analysis
@@ -451,6 +644,7 @@
    char *ptr;
    int tot_size;
    VARDECL(opus_val16, buf);
+   VARDECL(opus_val16, bandSMR);
    unsigned char tmp_data[MS_FRAME_TMP];
    OpusRepacketizer rp;
    opus_int32 complexity;
@@ -460,9 +654,16 @@
    const CELTMode *celt_mode;
    opus_int32 bitrates[256];
    opus_val16 bandLogE[42];
-   opus_val16 bandLogE_mono[21];
+   opus_val32 *mem = NULL;
+   opus_val32 *preemph_mem=NULL;
    ALLOC_STACK;
 
+   if (st->surround)
+   {
+      preemph_mem = ms_get_preemph_mem(st);
+      mem = ms_get_window_mem(st);
+   }
+
    ptr = (char*)st + align(sizeof(OpusMSEncoder));
    opus_encoder_ctl((OpusEncoder*)ptr, OPUS_GET_SAMPLE_RATE(&Fs));
    opus_encoder_ctl((OpusEncoder*)ptr, OPUS_GET_COMPLEXITY(&complexity));
@@ -504,42 +705,10 @@
    coupled_size = opus_encoder_get_size(2);
    mono_size = opus_encoder_get_size(1);
 
+   ALLOC(bandSMR, 21*st->layout.nb_channels, opus_val16);
    if (st->surround)
    {
-      int i;
-      unsigned char dummy[512];
-      /* Temporary kludge -- remove */
-      OpusEncoder *downmix_enc;
-
-      ptr = (char*)st + align(sizeof(OpusMSEncoder));
-      for (s=0;s<st->layout.nb_streams;s++)
-      {
-         if (s < st->layout.nb_coupled_streams)
-            ptr += align(coupled_size);
-         else
-            ptr += align(mono_size);
-      }
-      downmix_enc = (OpusEncoder*)ptr;
-      surround_downmix(buf, pcm, st->layout.nb_channels, frame_size);
-      opus_encoder_ctl(downmix_enc, OPUS_SET_ENERGY_SAVE(bandLogE));
-      opus_encoder_ctl(downmix_enc, OPUS_SET_BANDWIDTH(OPUS_BANDWIDTH_FULLBAND));
-      opus_encoder_ctl(downmix_enc, OPUS_SET_FORCE_MODE(MODE_CELT_ONLY));
-      opus_encoder_ctl(downmix_enc, OPUS_SET_FORCE_CHANNELS(2));
-      opus_encode_native(downmix_enc, buf, frame_size, dummy, 512, lsb_depth
-#ifndef FIXED_POINT
-            , &analysis_info
-#endif
-            );
-      /* Combines the left and right mask into a centre mask. We
-         use an approximation for the log of the sum of the energies. */
-      for(i=0;i<21;i++)
-      {
-         opus_val16 diff;
-         diff = ABS16(SUB16(bandLogE[i], bandLogE[21+i]));
-         diff = diff + HALF16(diff);
-         diff = SHR32(HALF32(celt_exp2(-diff)), 16-DB_SHIFT);
-         bandLogE_mono[i] = MAX16(bandLogE[i], bandLogE[21+i]) + diff;
-      }
+      surround_analysis(celt_mode, pcm, bandSMR, mem, preemph_mem, frame_size, 120, st->layout.nb_channels, Fs, copy_channel_in);
    }
 
    if (max_data_bytes < 4*st->layout.nb_streams-1)
@@ -583,6 +752,7 @@
       enc = (OpusEncoder*)ptr;
       if (s < st->layout.nb_coupled_streams)
       {
+         int i;
          int left, right;
          left = get_left_channel(&st->layout, s, -1);
          right = get_right_channel(&st->layout, s, -1);
@@ -591,18 +761,28 @@
          (*copy_channel_in)(buf+1, 2,
             pcm, st->layout.nb_channels, right, frame_size);
          ptr += align(coupled_size);
-         /* FIXME: This isn't correct for the coupled center channels in
-            6.1 surround configuration */
          if (st->surround)
-            opus_encoder_ctl(enc, OPUS_SET_ENERGY_MASK(bandLogE));
+         {
+            for (i=0;i<21;i++)
+            {
+               bandLogE[i] = bandSMR[21*left+i];
+               bandLogE[21+i] = bandSMR[21*right+i];
+            }
+         }
       } else {
+         int i;
          int chan = get_mono_channel(&st->layout, s, -1);
          (*copy_channel_in)(buf, 1,
             pcm, st->layout.nb_channels, chan, frame_size);
          ptr += align(mono_size);
          if (st->surround)
-            opus_encoder_ctl(enc, OPUS_SET_ENERGY_MASK(bandLogE_mono));
+         {
+            for (i=0;i<21;i++)
+               bandLogE[i] = bandSMR[21*chan+i];
+         }
       }
+      if (st->surround)
+         opus_encoder_ctl(enc, OPUS_SET_ENERGY_MASK(bandLogE));
       /* number of bytes left (+Toc) */
       curr_max = max_data_bytes - tot_size;
       /* Reserve three bytes for the last stream and four for the others */
@@ -626,50 +806,12 @@
       data += len;
       tot_size += len;
    }
+   /*printf("\n");*/
    RESTORE_STACK;
    return tot_size;
 
 }
 
-static void channel_pos(int channels, int pos[8])
-{
-   /* Position in the mix: 0 don't mix, 1: left, 2: center, 3:right */
-   if (channels==4)
-   {
-      pos[0]=1;
-      pos[1]=3;
-      pos[2]=1;
-      pos[3]=3;
-   } else if (channels==3||channels==5||channels==6)
-   {
-      pos[0]=1;
-      pos[1]=2;
-      pos[2]=3;
-      pos[3]=1;
-      pos[4]=3;
-      pos[5]=0;
-   } else if (channels==7)
-   {
-      pos[0]=1;
-      pos[1]=2;
-      pos[2]=3;
-      pos[3]=1;
-      pos[4]=3;
-      pos[5]=2;
-      pos[6]=0;
-   } else if (channels==8)
-   {
-      pos[0]=1;
-      pos[1]=2;
-      pos[2]=3;
-      pos[3]=1;
-      pos[4]=3;
-      pos[5]=1;
-      pos[6]=3;
-      pos[7]=0;
-   }
-}
-
 #if !defined(DISABLE_FLOAT_API)
 static void opus_copy_channel_in_float(
   opus_val16 *dst,
@@ -690,57 +832,6 @@
       dst[i*dst_stride] = float_src[i*src_stride+src_channel];
 #endif
 }
-
-static void opus_surround_downmix_float(
-  opus_val16 *dst,
-  const void *src,
-  int channels,
-  int frame_size
-)
-{
-   const float *float_src;
-   opus_int32 i;
-   int pos[8] = {0};
-   int c;
-   float_src = (const float *)src;
-
-   channel_pos(channels, pos);
-   for (i=0;i<2*frame_size;i++)
-      dst[i]=0;
-
-   for (c=0;c<channels;c++)
-   {
-      if (pos[c]==1)
-      {
-         for (i=0;i<frame_size;i++)
-#if defined(FIXED_POINT)
-            dst[2*i] += SHR16(FLOAT2INT16(float_src[i*channels+c]),3);
-#else
-            dst[2*i] += float_src[i*channels+c];
-#endif
-      } else if (pos[c]==3)
-      {
-         for (i=0;i<frame_size;i++)
-#if defined(FIXED_POINT)
-            dst[2*i+1] += SHR16(FLOAT2INT16(float_src[i*channels+c]),3);
-#else
-            dst[2*i+1] += float_src[i*channels+c];
-#endif
-      } else if (pos[c]==2)
-      {
-         for (i=0;i<frame_size;i++)
-         {
-#if defined(FIXED_POINT)
-            dst[2*i] += SHR32(MULT16_16(QCONST16(.70711f,15), FLOAT2INT16(float_src[i*channels+c])),3+15);
-            dst[2*i+1] += SHR32(MULT16_16(QCONST16(.70711f,15), FLOAT2INT16(float_src[i*channels+c])),3+15);
-#else
-            dst[2*i] += .707f*float_src[i*channels+c];
-            dst[2*i+1] += .707f*float_src[i*channels+c];
-#endif
-         }
-      }
-   }
-}
 #endif
 
 static void opus_copy_channel_in_short(
@@ -763,57 +854,6 @@
 #endif
 }
 
-static void opus_surround_downmix_short(
-  opus_val16 *dst,
-  const void *src,
-  int channels,
-  int frame_size
-)
-{
-   const opus_int16 *short_src;
-   opus_int32 i;
-   int pos[8] = {0};
-   int c;
-   short_src = (const opus_int16 *)src;
-
-   channel_pos(channels, pos);
-   for (i=0;i<2*frame_size;i++)
-      dst[i]=0;
-
-   for (c=0;c<channels;c++)
-   {
-      if (pos[c]==1)
-      {
-         for (i=0;i<frame_size;i++)
-#if defined(FIXED_POINT)
-            dst[2*i] += SHR16(short_src[i*channels+c],3);
-#else
-            dst[2*i] += (1/32768.f)*short_src[i*channels+c];
-#endif
-      } else if (pos[c]==3)
-      {
-         for (i=0;i<frame_size;i++)
-#if defined(FIXED_POINT)
-            dst[2*i+1] += SHR16(short_src[i*channels+c],3);
-#else
-            dst[2*i+1] += (1/32768.f)*short_src[i*channels+c];
-#endif
-      } else if (pos[c]==2)
-      {
-         for (i=0;i<frame_size;i++)
-         {
-#if defined(FIXED_POINT)
-            dst[2*i] += SHR32(MULT16_16(QCONST16(.70711f,15), short_src[i*channels+c]),3+15);
-            dst[2*i+1] += SHR32(MULT16_16(QCONST16(.70711f,15), short_src[i*channels+c]),3+15);
-#else
-            dst[2*i] += (.707f/32768.f)*short_src[i*channels+c];
-            dst[2*i+1] += (.707f/32768.f)*short_src[i*channels+c];
-#endif
-         }
-      }
-   }
-}
-
 
 #ifdef FIXED_POINT
 int opus_multistream_encode(
@@ -825,7 +865,7 @@
 )
 {
    return opus_multistream_encode_native(st, opus_copy_channel_in_short,
-      pcm, frame_size, data, max_data_bytes, 16, opus_surround_downmix_short);
+      pcm, frame_size, data, max_data_bytes, 16);
 }
 
 #ifndef DISABLE_FLOAT_API
@@ -838,7 +878,7 @@
 )
 {
    return opus_multistream_encode_native(st, opus_copy_channel_in_float,
-      pcm, frame_size, data, max_data_bytes, 16, opus_surround_downmix_float);
+      pcm, frame_size, data, max_data_bytes, 16);
 }
 #endif
 
@@ -855,7 +895,7 @@
 {
    int channels = st->layout.nb_streams + st->layout.nb_coupled_streams;
    return opus_multistream_encode_native(st, opus_copy_channel_in_float,
-      pcm, frame_size, data, max_data_bytes, 24, opus_surround_downmix_float, downmix_float, pcm+channels*st->analysis.analysis_offset);
+      pcm, frame_size, data, max_data_bytes, 24, downmix_float, pcm+channels*st->analysis.analysis_offset);
 }
 
 int opus_multistream_encode(
@@ -868,7 +908,7 @@
 {
    int channels = st->layout.nb_streams + st->layout.nb_coupled_streams;
    return opus_multistream_encode_native(st, opus_copy_channel_in_short,
-      pcm, frame_size, data, max_data_bytes, 16, opus_surround_downmix_short, downmix_int, pcm+channels*st->analysis.analysis_offset);
+      pcm, frame_size, data, max_data_bytes, 16, downmix_int, pcm+channels*st->analysis.analysis_offset);
 }
 #endif