Classes and tests for audio an classifier. The class can be used to classify whether a frame of audio contains speech or music. The classifier uses the music/speech classifier in Opus.

R=andrew@webrtc.org, henrik.lundin@webrtc.org, turaj@webrtc.org

Review URL: https://webrtc-codereview.appspot.com/5549004

git-svn-id: http://webrtc.googlecode.com/svn/trunk@5677 4adac7df-926f-26a2-2b94-8c16560cd09d
diff --git a/webrtc/modules/audio_coding/neteq4/audio_classifier.cc b/webrtc/modules/audio_coding/neteq4/audio_classifier.cc
new file mode 100644
index 0000000..a272fbc
--- /dev/null
+++ b/webrtc/modules/audio_coding/neteq4/audio_classifier.cc
@@ -0,0 +1,71 @@
+/*
+ *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "webrtc/modules/audio_coding/neteq4/audio_classifier.h"
+
+#include <assert.h>
+#include <string.h>
+
+namespace webrtc {
+
+static const int kDefaultSampleRateHz = 48000;
+static const int kDefaultFrameRateHz = 50;
+static const int kDefaultFrameSizeSamples =
+    kDefaultSampleRateHz / kDefaultFrameRateHz;
+static const float kDefaultThreshold = 0.5f;
+
+AudioClassifier::AudioClassifier()
+    : analysis_info_(),
+      is_music_(false),
+      music_probability_(0),
+      // This actually assigns the pointer to a static constant struct
+      // rather than creates a struct and |celt_mode_| does not need
+      // to be deleted.
+      celt_mode_(opus_custom_mode_create(kDefaultSampleRateHz,
+                                         kDefaultFrameSizeSamples,
+                                         NULL)),
+      analysis_state_() {
+  assert(celt_mode_);
+}
+
+AudioClassifier::~AudioClassifier() {}
+
+bool AudioClassifier::Analysis(const int16_t* input,
+                               int input_length,
+                               int channels) {
+  // Must be 20 ms frames at 48 kHz sampling.
+  assert((input_length / channels) == kDefaultFrameSizeSamples);
+
+  // Only mono or stereo are allowed.
+  assert(channels == 1 || channels == 2);
+
+  // Call Opus' classifier, defined in
+  // "third_party/opus/src/src/analysis.h", with lsb_depth = 16.
+  // Also uses a down-mixing function downmix_int, defined in
+  // "third_party/opus/src/src/opus_private.h", with
+  // constants c1 = 0, and c2 = -2.
+  run_analysis(&analysis_state_,
+               celt_mode_,
+               input,
+               kDefaultFrameSizeSamples,
+               kDefaultFrameSizeSamples,
+               0,
+               -2,
+               channels,
+               kDefaultSampleRateHz,
+               16,
+               downmix_int,
+               &analysis_info_);
+  music_probability_ = analysis_info_.music_prob;
+  is_music_ = music_probability_ > kDefaultThreshold;
+  return is_music_;
+}
+
+}  // namespace webrtc
diff --git a/webrtc/modules/audio_coding/neteq4/audio_classifier.h b/webrtc/modules/audio_coding/neteq4/audio_classifier.h
new file mode 100644
index 0000000..7451d3e
--- /dev/null
+++ b/webrtc/modules/audio_coding/neteq4/audio_classifier.h
@@ -0,0 +1,59 @@
+/*
+ *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef WEBRTC_MODULES_AUDIO_CODING_NETEQ4_AUDIO_CLASSIFIER_H_
+#define WEBRTC_MODULES_AUDIO_CODING_NETEQ4_AUDIO_CLASSIFIER_H_
+
+#if defined(__cplusplus)
+extern "C" {
+#endif
+#include "third_party/opus/src/celt/celt.h"
+#include "third_party/opus/src/src/analysis.h"
+#include "third_party/opus/src/src/opus_private.h"
+#if defined(__cplusplus)
+}
+#endif
+
+#include "webrtc/system_wrappers/interface/scoped_ptr.h"
+#include "webrtc/typedefs.h"
+
+namespace webrtc {
+
+// This class provides a speech/music classification and is a wrapper over the
+// Opus classifier. It currently only supports 48 kHz mono or stereo with a
+// frame size of 20 ms.
+
+class AudioClassifier {
+ public:
+  AudioClassifier();
+  virtual ~AudioClassifier();
+
+  // Classifies one frame of audio data in input,
+  // input_length   : must be channels * 960;
+  // channels       : must be 1 (mono) or 2 (stereo).
+  bool Analysis(const int16_t* input, int input_length, int channels);
+
+  // Gets the current classification : true = music, false = speech.
+  bool is_music() const { return is_music_; }
+
+  // Gets the current music probability.
+  float music_probability() const { return music_probability_; }
+
+ private:
+  AnalysisInfo analysis_info_;
+  bool is_music_;
+  float music_probability_;
+  const CELTMode* celt_mode_;
+  TonalityAnalysisState analysis_state_;
+};
+
+}  // namespace webrtc
+
+#endif  // WEBRTC_MODULES_AUDIO_CODING_NETEQ4_AUDIO_CLASSIFIER_H_
diff --git a/webrtc/modules/audio_coding/neteq4/audio_classifier_unittest.cc b/webrtc/modules/audio_coding/neteq4/audio_classifier_unittest.cc
new file mode 100644
index 0000000..0a66718
--- /dev/null
+++ b/webrtc/modules/audio_coding/neteq4/audio_classifier_unittest.cc
@@ -0,0 +1,75 @@
+/*
+ *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "webrtc/modules/audio_coding/neteq4/audio_classifier.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "webrtc/test/testsupport/fileutils.h"
+
+namespace webrtc {
+
+static const size_t kFrameSize = 960;
+
+TEST(AudioClassifierTest, AllZeroInput) {
+  int16_t in_mono[kFrameSize] = {0};
+
+  // Test all-zero vectors and let the classifier converge from its default
+  // to the expected value.
+  AudioClassifier zero_classifier;
+  for (int i = 0; i < 100; ++i) {
+    zero_classifier.Analysis(in_mono, kFrameSize, 1);
+  }
+  EXPECT_TRUE(zero_classifier.is_music());
+}
+
+void RunAnalysisTest(const std::string& audio_filename,
+                     const std::string& data_filename,
+                     size_t channels) {
+  AudioClassifier classifier;
+  scoped_ptr<int16_t[]> in(new int16_t[channels * kFrameSize]);
+  bool is_music_ref;
+
+  FILE* audio_file = fopen(audio_filename.c_str(), "rb");
+  ASSERT_TRUE(audio_file != NULL) << "Failed to open file " << audio_filename
+                                  << std::endl;
+  FILE* data_file = fopen(data_filename.c_str(), "rb");
+  ASSERT_TRUE(audio_file != NULL) << "Failed to open file " << audio_filename
+                                  << std::endl;
+  while (fread(in.get(), sizeof(int16_t), channels * kFrameSize, audio_file) ==
+         channels * kFrameSize) {
+    bool is_music =
+        classifier.Analysis(in.get(), channels * kFrameSize, channels);
+    EXPECT_EQ(is_music, classifier.is_music());
+    ASSERT_EQ(1u, fread(&is_music_ref, sizeof(is_music_ref), 1, data_file));
+    EXPECT_EQ(is_music_ref, is_music);
+  }
+  fclose(audio_file);
+  fclose(data_file);
+}
+
+TEST(AudioClassifierTest, DoAnalysisMono) {
+  RunAnalysisTest(test::ResourcePath("short_mixed_mono_48", "pcm"),
+                  test::ResourcePath("short_mixed_mono_48", "dat"),
+                  1);
+}
+
+TEST(AudioClassifierTest, DoAnalysisStereo) {
+  RunAnalysisTest(test::ResourcePath("short_mixed_stereo_48", "pcm"),
+                  test::ResourcePath("short_mixed_stereo_48", "dat"),
+                  2);
+}
+
+}  // namespace webrtc
diff --git a/webrtc/modules/audio_coding/neteq4/neteq.gypi b/webrtc/modules/audio_coding/neteq4/neteq.gypi
index 4660109..afcefbe 100644
--- a/webrtc/modules/audio_coding/neteq4/neteq.gypi
+++ b/webrtc/modules/audio_coding/neteq4/neteq.gypi
@@ -16,6 +16,7 @@
       'iSAC',
       'iSACFix',
       'CNG',
+      '<(DEPTH)/third_party/opus/opus.gyp:opus',
       '<(webrtc_root)/common_audio/common_audio.gyp:common_audio',
       '<(webrtc_root)/system_wrappers/source/system_wrappers.gyp:system_wrappers',
     ],
@@ -38,20 +39,25 @@
         '<@(neteq_defines)',
       ],
       'include_dirs': [
-        'interface',
-        '<(webrtc_root)',
+        # Need Opus header files for the audio classifier.
+        '<(DEPTH)/third_party/opus/src/celt',
       ],
       'direct_dependent_settings': {
         'include_dirs': [
-          'interface',
-          '<(webrtc_root)',
+          # Need Opus header files for the audio classifier.
+          '<(DEPTH)/third_party/opus/src/celt',
         ],
       },
+      'export_dependent_settings': [
+        '<(DEPTH)/third_party/opus/opus.gyp:opus',
+      ],
       'sources': [
         'interface/audio_decoder.h',
         'interface/neteq.h',
         'accelerate.cc',
         'accelerate.h',
+        'audio_classifier.cc',
+        'audio_classifier.h',
         'audio_decoder_impl.cc',
         'audio_decoder_impl.h',
         'audio_decoder.cc',
diff --git a/webrtc/modules/audio_coding/neteq4/neteq_tests.gypi b/webrtc/modules/audio_coding/neteq4/neteq_tests.gypi
index 419aefa..e1fcae7 100644
--- a/webrtc/modules/audio_coding/neteq4/neteq_tests.gypi
+++ b/webrtc/modules/audio_coding/neteq4/neteq_tests.gypi
@@ -141,6 +141,17 @@
     },
 
     {
+      'target_name': 'audio_classifier_test',
+      'type': 'executable',
+      'dependencies': [
+        'NetEq4',
+      ],
+      'sources': [
+        'test/audio_classifier_test.cc',
+      ],
+    },
+
+    {
       'target_name': 'neteq4_speed_test',
       'type': 'executable',
       'dependencies': [
diff --git a/webrtc/modules/audio_coding/neteq4/test/audio_classifier_test.cc b/webrtc/modules/audio_coding/neteq4/test/audio_classifier_test.cc
new file mode 100644
index 0000000..730406b
--- /dev/null
+++ b/webrtc/modules/audio_coding/neteq4/test/audio_classifier_test.cc
@@ -0,0 +1,105 @@
+/*
+ *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "webrtc/modules/audio_coding/neteq4/audio_classifier.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <string>
+#include <iostream>
+
+#include "webrtc/system_wrappers/interface/scoped_ptr.h"
+
+int main(int argc, char* argv[]) {
+  if (argc != 5) {
+    std::cout << "Usage: " << argv[0] <<
+        " channels output_type <input file name> <output file name> "
+        << std::endl << std::endl;
+    std::cout << "Where channels can be 1 (mono) or 2 (interleaved stereo),";
+    std::cout << " outputs can be 1 (classification (boolean)) or 2";
+    std::cout << " (classification and music probability (float)),"
+        << std::endl;
+    std::cout << "and the sampling frequency is assumed to be 48 kHz."
+        << std::endl;
+    return -1;
+  }
+
+  const int kFrameSizeSamples = 960;
+  int channels = atoi(argv[1]);
+  if (channels < 1 || channels > 2) {
+    std::cout << "Disallowed number of channels  " << channels << std::endl;
+    return -1;
+  }
+
+  int outputs = atoi(argv[2]);
+  if (outputs < 1 || outputs > 2) {
+    std::cout << "Disallowed number of outputs  " << outputs << std::endl;
+    return -1;
+  }
+
+  const int data_size = channels * kFrameSizeSamples;
+  webrtc::scoped_ptr<int16_t[]> in(new int16_t[data_size]);
+
+  std::string input_filename = argv[3];
+  std::string output_filename = argv[4];
+
+  std::cout << "Input file: " << input_filename << std::endl;
+  std::cout << "Output file: " << output_filename << std::endl;
+
+  FILE* in_file = fopen(input_filename.c_str(), "rb");
+  if (!in_file) {
+    std::cout << "Cannot open input file " << input_filename << std::endl;
+    return -1;
+  }
+
+  FILE* out_file = fopen(output_filename.c_str(), "wb");
+  if (!out_file) {
+    std::cout << "Cannot open output file " << output_filename << std::endl;
+    return -1;
+  }
+
+  webrtc::AudioClassifier classifier;
+  int frame_counter = 0;
+  int music_counter = 0;
+  while (fread(in.get(), sizeof(*in.get()),
+               data_size, in_file) == (size_t) data_size) {
+    bool is_music = classifier.Analysis(in.get(), data_size, channels);
+    if (!fwrite(&is_music, sizeof(is_music), 1, out_file)) {
+       std::cout << "Error writing." << std::endl;
+       return -1;
+    }
+    if (is_music) {
+      music_counter++;
+    }
+    std::cout << "frame " << frame_counter << " decision " << is_music;
+    if (outputs == 2) {
+      float music_prob = classifier.music_probability();
+      if (!fwrite(&music_prob, sizeof(music_prob), 1, out_file)) {
+        std::cout << "Error writing." << std::endl;
+        return -1;
+      }
+      std::cout << " music prob " << music_prob;
+    }
+    std::cout << std::endl;
+    frame_counter++;
+  }
+  std::cout << frame_counter << " frames processed." << std::endl;
+  if (frame_counter > 0) {
+    float music_percentage = music_counter / static_cast<float>(frame_counter);
+    std::cout <<  music_percentage <<  " percent music." << std::endl;
+  }
+
+  fclose(in_file);
+  fclose(out_file);
+  return 0;
+}
diff --git a/webrtc/modules/modules.gyp b/webrtc/modules/modules.gyp
index 10da58a..5d0827c 100644
--- a/webrtc/modules/modules.gyp
+++ b/webrtc/modules/modules.gyp
@@ -115,6 +115,7 @@
             'audio_coding/codecs/isac/fix/source/transform_unittest.cc',
             'audio_coding/codecs/isac/main/source/isac_unittest.cc',
             'audio_coding/codecs/opus/opus_unittest.cc',
+            'audio_coding/neteq4/audio_classifier_unittest.cc',
             'audio_coding/neteq4/audio_multi_vector_unittest.cc',
             'audio_coding/neteq4/audio_vector_unittest.cc',
             'audio_coding/neteq4/background_noise_unittest.cc',
diff --git a/webrtc/modules/modules_unittests.isolate b/webrtc/modules/modules_unittests.isolate
index 06f8a2d..e4139ba 100644
--- a/webrtc/modules/modules_unittests.isolate
+++ b/webrtc/modules/modules_unittests.isolate
@@ -15,6 +15,12 @@
           '../../../data/',
           '../../../resources/',
         ],
+        'isolate_dependency_tracked': [
+          '../../../resources/short_mixed_mono_48.dat',
+          '../../../resources/short_mixed_mono_48.pcm',
+          '../../../resources/short_mixed_stereo_48.dat',
+          '../../../resources/short_mixed_stereo_48.pcm',
+        ],
       },
     }],
     ['OS=="linux" or OS=="mac" or OS=="win"', {
@@ -72,6 +78,10 @@
           '../../resources/remote_bitrate_estimator/VideoSendersTest_BweTest_SteadyLoss_0_TOF.bin',
           '../../resources/remote_bitrate_estimator/VideoSendersTest_BweTest_UnlimitedSpeed_0_AST.bin',
           '../../resources/remote_bitrate_estimator/VideoSendersTest_BweTest_UnlimitedSpeed_0_TOF.bin',
+          '../../resources/short_mixed_mono_48.dat',
+          '../../resources/short_mixed_mono_48.pcm',
+          '../../resources/short_mixed_stereo_48.dat',
+          '../../resources/short_mixed_stereo_48.pcm',
           '../../resources/sprint-downlink.rx',
           '../../resources/sprint-uplink.rx',
           '../../resources/synthetic-trace.rx',