diff --git a/webrtc/modules/audio_processing/audio_processing_tests.gypi b/webrtc/modules/audio_processing/audio_processing_tests.gypi index 19b9ddf596..0314c69b04 100644 --- a/webrtc/modules/audio_processing/audio_processing_tests.gypi +++ b/webrtc/modules/audio_processing/audio_processing_tests.gypi @@ -70,7 +70,7 @@ '<(webrtc_root)/test/test.gyp:test_support', ], 'sources': [ - 'intelligibility/intelligibility_proc.cc', + 'intelligibility/test/intelligibility_proc.cc', ], }, # intelligibility_proc ], diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc index 3029e21619..1e766875ca 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc @@ -17,8 +17,8 @@ #include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h" -#include -#include +#include +#include #include #include @@ -27,26 +27,24 @@ #include "webrtc/common_audio/vad/include/webrtc_vad.h" #include "webrtc/common_audio/window_generator.h" +namespace webrtc { + +namespace { + +const int kErbResolution = 2; +const int kWindowSizeMs = 2; +const int kChunkSizeMs = 10; // Size provided by APM. +const float kClipFreq = 200.0f; +const float kConfigRho = 0.02f; // Default production and interpretation SNR. +const float kKbdAlpha = 1.5f; +const float kLambdaBot = -1.0f; // Extreme values in bisection +const float kLambdaTop = -10e-18f; // search for lamda. + +} // namespace + using std::complex; using std::max; using std::min; - -namespace webrtc { - -const int IntelligibilityEnhancer::kErbResolution = 2; -const int IntelligibilityEnhancer::kWindowSizeMs = 2; -const int IntelligibilityEnhancer::kChunkSizeMs = 10; // Size provided by APM. -const int IntelligibilityEnhancer::kAnalyzeRate = 800; -const int IntelligibilityEnhancer::kVarianceRate = 2; -const float IntelligibilityEnhancer::kClipFreq = 200.0f; -const float IntelligibilityEnhancer::kConfigRho = 0.02f; -const float IntelligibilityEnhancer::kKbdAlpha = 1.5f; - -// To disable gain update smoothing, set gain limit to be VERY high. -// TODO(ekmeyerson): Add option to disable gain smoothing altogether -// to avoid the extra computation. -const float IntelligibilityEnhancer::kGainChangeLimit = 0.0125f; - using VarianceType = intelligibility::VarianceArray::StepType; IntelligibilityEnhancer::TransformCallback::TransformCallback( @@ -93,7 +91,7 @@ IntelligibilityEnhancer::IntelligibilityEnhancer(int erb_resolution, noise_variance_(freqs_, VarianceType::kStepInfinite, 475, 0.01f), filtered_clear_var_(new float[bank_size_]), filtered_noise_var_(new float[bank_size_]), - filter_bank_(nullptr), + filter_bank_(bank_size_), center_freqs_(new float[bank_size_]), rho_(new float[bank_size_]), gains_eq_(new float[bank_size_]), @@ -149,7 +147,7 @@ IntelligibilityEnhancer::IntelligibilityEnhancer(int erb_resolution, IntelligibilityEnhancer::~IntelligibilityEnhancer() { WebRtcVad_Free(vad_low_); WebRtcVad_Free(vad_high_); - free(filter_bank_); + free(temp_out_buffer_); } void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio) { @@ -203,8 +201,6 @@ void IntelligibilityEnhancer::DispatchAudio( void IntelligibilityEnhancer::ProcessClearBlock(const complex* in_block, complex* out_block) { - float power_target; - if (block_count_ < 2) { memset(out_block, 0, freqs_ * sizeof(*out_block)); ++block_count_; @@ -216,8 +212,8 @@ void IntelligibilityEnhancer::ProcessClearBlock(const complex* in_block, // based on experiments with different cutoffs. if (has_voice_low_ || true) { clear_variance_.Step(in_block, false); - power_target = std::accumulate(clear_variance_.variance(), - clear_variance_.variance() + freqs_, 0.0f); + const float power_target = std::accumulate( + clear_variance_.variance(), clear_variance_.variance() + freqs_, 0.0f); if (block_count_ % analysis_rate_ == analysis_rate_ - 1) { AnalyzeClearBlock(power_target); @@ -239,35 +235,46 @@ void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) { FilterVariance(clear_variance_.variance(), filtered_clear_var_.get()); FilterVariance(noise_variance_.variance(), filtered_noise_var_.get()); - // Bisection search for optimal |lambda| - - float lambda_bot = -1.0f, lambda_top = -10e-18f, lambda; - float power_bot, power_top, power; - SolveForGainsGivenLambda(lambda_top, start_freq_, gains_eq_.get()); - power_top = + SolveForGainsGivenLambda(kLambdaTop, start_freq_, gains_eq_.get()); + const float power_top = DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); - SolveForGainsGivenLambda(lambda_bot, start_freq_, gains_eq_.get()); - power_bot = + SolveForGainsGivenLambda(kLambdaBot, start_freq_, gains_eq_.get()); + const float power_bot = DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); - DCHECK(power_target >= power_bot && power_target <= power_top); + if (power_target >= power_bot && power_target <= power_top) { + SolveForLambda(power_target, power_bot, power_top); + UpdateErbGains(); + } // Else experiencing variance underflow, so do nothing. +} - float power_ratio = 2.0f; // Ratio of achieved power to target power. +void IntelligibilityEnhancer::SolveForLambda(float power_target, + float power_bot, + float power_top) { const float kConvergeThresh = 0.001f; // TODO(ekmeyerson): Find best values const int kMaxIters = 100; // for these, based on experiments. + + const float reciprocal_power_target = 1.f / power_target; + float lambda_bot = kLambdaBot; + float lambda_top = kLambdaTop; + float power_ratio = 2.0f; // Ratio of achieved power to target power. int iters = 0; - while (fabs(power_ratio - 1.0f) > kConvergeThresh && iters <= kMaxIters) { - lambda = lambda_bot + (lambda_top - lambda_bot) / 2.0f; + while (std::fabs(power_ratio - 1.0f) > kConvergeThresh && + iters <= kMaxIters) { + const float lambda = lambda_bot + (lambda_top - lambda_bot) / 2.0f; SolveForGainsGivenLambda(lambda, start_freq_, gains_eq_.get()); - power = DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); + const float power = + DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); if (power < power_target) { lambda_bot = lambda; } else { lambda_top = lambda; } - power_ratio = fabs(power / power_target); + power_ratio = std::fabs(power * reciprocal_power_target); ++iters; } +} +void IntelligibilityEnhancer::UpdateErbGains() { // (ERB gain) = filterbank' * (freq gain) float* gains = gain_applier_.target(); for (int i = 0; i < freqs_; ++i) { @@ -303,12 +310,8 @@ void IntelligibilityEnhancer::CreateErbBank() { center_freqs_[i] *= 0.5f * sample_rate_hz_ / last_center_freq; } - filter_bank_ = static_cast( - malloc(sizeof(*filter_bank_) * bank_size_ + - sizeof(**filter_bank_) * freqs_ * bank_size_)); for (int i = 0; i < bank_size_; ++i) { - filter_bank_[i] = - reinterpret_cast(filter_bank_ + bank_size_) + freqs_ * i; + filter_bank_[i].resize(freqs_); } for (int i = 1; i <= bank_size_; ++i) { @@ -388,7 +391,7 @@ void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda, void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) { for (int i = 0; i < bank_size_; ++i) { - result[i] = DotProduct(filter_bank_[i], var, freqs_); + result[i] = DotProduct(filter_bank_[i].data(), var, freqs_); } } diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h index 8125707f12..df47de5978 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h @@ -16,6 +16,7 @@ #define WEBRTC_MODULES_AUDIO_PROCESSING_INTELLIGIBILITY_INTELLIGIBILITY_ENHANCER_H_ #include +#include #include "webrtc/base/scoped_ptr.h" #include "webrtc/common_audio/lapped_transform.h" @@ -83,6 +84,8 @@ class IntelligibilityEnhancer { AudioSource source_; }; friend class TransformCallback; + FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestErbCreation); + FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestSolveForGains); // Sends streams to ProcessClearBlock or ProcessNoiseBlock based on source. void DispatchAudio(AudioSource source, @@ -97,6 +100,12 @@ class IntelligibilityEnhancer { // Computes and sets modified gains. void AnalyzeClearBlock(float power_target); + // Bisection search for optimal |lambda|. + void SolveForLambda(float power_target, float power_bot, float power_top); + + // Transforms freq gains to ERB gains. + void UpdateErbGains(); + // Updates variance calculation for noise input with |in_block|. void ProcessNoiseBlock(const std::complex* in_block, std::complex* out_block); @@ -118,16 +127,6 @@ class IntelligibilityEnhancer { // Returns dot product of vectors specified by size |length| arrays |a|,|b|. static float DotProduct(const float* a, const float* b, int length); - static const int kErbResolution; - static const int kWindowSizeMs; - static const int kChunkSizeMs; - static const int kAnalyzeRate; // Default for |analysis_rate_|. - static const int kVarianceRate; // Default for |variance_rate_|. - static const float kClipFreq; - static const float kConfigRho; // Default production and interpretation SNR. - static const float kKbdAlpha; - static const float kGainChangeLimit; - const int freqs_; // Num frequencies in frequency domain. const int window_size_; // Window size in samples; also the block size. const int chunk_length_; // Chunk size in samples. @@ -142,7 +141,7 @@ class IntelligibilityEnhancer { intelligibility::VarianceArray noise_variance_; rtc::scoped_ptr filtered_clear_var_; rtc::scoped_ptr filtered_noise_var_; - float** filter_bank_; // TODO(ekmeyerson): Switch to using ChannelBuffer. + std::vector> filter_bank_; rtc::scoped_ptr center_freqs_; int start_freq_; rtc::scoped_ptr rho_; // Production and interpretation SNR. diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc new file mode 100644 index 0000000000..490db2c646 --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// +// Unit tests for intelligibility enhancer. +// + +#include +#include +#include +#include + +#include "testing/gtest/include/gtest/gtest.h" +#include "webrtc/base/arraysize.h" +#include "webrtc/common_audio/signal_processing/include/signal_processing_library.h" +#include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h" + +namespace webrtc { + +namespace { + +// Target output for ERB create test. Generated with matlab. +const float kTestCenterFreqs[] = { + 13.169f, 26.965f, 41.423f, 56.577f, 72.461f, 89.113f, 106.57f, 124.88f, + 144.08f, 164.21f, 185.34f, 207.5f, 230.75f, 255.16f, 280.77f, 307.66f, + 335.9f, 365.56f, 396.71f, 429.44f, 463.84f, 500.f}; +const float kTestFilterBank[][2] = {{0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.f}, + {0.055556f, 0.2f}, + {0, 0.2f}, + {0, 0.2f}, + {0, 0.2f}, + {0, 0.2f}}; +static_assert(arraysize(kTestCenterFreqs) == arraysize(kTestFilterBank), + "Test filterbank badly initialized."); + +// Target output for gain solving test. Generated with matlab. +const int kTestStartFreq = 12; // Lowest integral frequency for ERBs. +const float kTestZeroVar[] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; +static_assert(arraysize(kTestCenterFreqs) == arraysize(kTestZeroVar), + "Variance test data badly initialized."); +const float kTestNonZeroVarLambdaTop[] = { + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 0.f, 0.f, 0.0351f, 0.0636f, 0.0863f, + 0.1037f, 0.1162f, 0.1236f, 0.1251f, 0.1189f, 0.0993f}; +static_assert(arraysize(kTestCenterFreqs) == + arraysize(kTestNonZeroVarLambdaTop), + "Variance test data badly initialized."); +const float kMaxTestError = 0.005f; + +// Enhancer initialization parameters. +const int kSamples = 2000; +const int kErbResolution = 2; +const int kSampleRate = 1000; +const int kFragmentSize = kSampleRate / 100; +const int kNumChannels = 1; +const float kDecayRate = 0.9f; +const int kWindowSize = 800; +const int kAnalyzeRate = 800; +const int kVarianceRate = 2; +const float kGainLimit = 0.1f; + +} // namespace + +using std::vector; +using intelligibility::VarianceArray; + +class IntelligibilityEnhancerTest : public ::testing::Test { + protected: + IntelligibilityEnhancerTest() + : enh_(kErbResolution, + kSampleRate, + kNumChannels, + VarianceArray::kStepInfinite, + kDecayRate, + kWindowSize, + kAnalyzeRate, + kVarianceRate, + kGainLimit), + clear_data_(kSamples), + noise_data_(kSamples), + orig_data_(kSamples) {} + + bool CheckUpdate(VarianceArray::StepType step_type) { + IntelligibilityEnhancer enh(kErbResolution, kSampleRate, kNumChannels, + step_type, kDecayRate, kWindowSize, + kAnalyzeRate, kVarianceRate, kGainLimit); + float* clear_cursor = &clear_data_[0]; + float* noise_cursor = &noise_data_[0]; + for (int i = 0; i < kSamples; i += kFragmentSize) { + enh.ProcessCaptureAudio(&noise_cursor); + enh.ProcessRenderAudio(&clear_cursor); + clear_cursor += kFragmentSize; + noise_cursor += kFragmentSize; + } + for (int i = 0; i < kSamples; i++) { + if (std::fabs(clear_data_[i] - orig_data_[i]) > kMaxTestError) { + return true; + } + } + return false; + } + + IntelligibilityEnhancer enh_; + vector clear_data_; + vector noise_data_; + vector orig_data_; +}; + +// For each class of generated data, tests that render stream is +// updated when it should be for each variance update method. +TEST_F(IntelligibilityEnhancerTest, TestRenderUpdate) { + vector step_types; + step_types.push_back(VarianceArray::kStepInfinite); + step_types.push_back(VarianceArray::kStepDecaying); + step_types.push_back(VarianceArray::kStepWindowed); + step_types.push_back(VarianceArray::kStepBlocked); + step_types.push_back(VarianceArray::kStepBlockBasedMovingAverage); + std::fill(noise_data_.begin(), noise_data_.end(), 0.0f); + std::fill(orig_data_.begin(), orig_data_.end(), 0.0f); + for (auto step_type : step_types) { + std::fill(clear_data_.begin(), clear_data_.end(), 0.0f); + EXPECT_FALSE(CheckUpdate(step_type)); + } + std::srand(1); + auto float_rand = []() { return std::rand() * 2.f / RAND_MAX - 1; }; + std::generate(noise_data_.begin(), noise_data_.end(), float_rand); + for (auto step_type : step_types) { + EXPECT_FALSE(CheckUpdate(step_type)); + } + for (auto step_type : step_types) { + std::generate(clear_data_.begin(), clear_data_.end(), float_rand); + orig_data_ = clear_data_; + EXPECT_TRUE(CheckUpdate(step_type)); + } +} + +// Tests ERB bank creation, comparing against matlab output. +TEST_F(IntelligibilityEnhancerTest, TestErbCreation) { + ASSERT_EQ(static_cast(arraysize(kTestCenterFreqs)), enh_.bank_size_); + for (int i = 0; i < enh_.bank_size_; ++i) { + EXPECT_NEAR(kTestCenterFreqs[i], enh_.center_freqs_[i], kMaxTestError); + ASSERT_EQ(static_cast(arraysize(kTestFilterBank[0])), enh_.freqs_); + for (int j = 0; j < enh_.freqs_; ++j) { + EXPECT_NEAR(kTestFilterBank[i][j], enh_.filter_bank_[i][j], + kMaxTestError); + } + } +} + +// Tests analytic solution for optimal gains, comparing +// against matlab output. +TEST_F(IntelligibilityEnhancerTest, TestSolveForGains) { + ASSERT_EQ(kTestStartFreq, enh_.start_freq_); + vector sols(enh_.bank_size_); + float lambda = -0.001f; + for (int i = 0; i < enh_.bank_size_; i++) { + enh_.filtered_clear_var_[i] = 0.0f; + enh_.filtered_noise_var_[i] = 0.0f; + enh_.rho_[i] = 0.02f; + } + enh_.SolveForGainsGivenLambda(lambda, enh_.start_freq_, &sols[0]); + for (int i = 0; i < enh_.bank_size_; i++) { + EXPECT_NEAR(kTestZeroVar[i], sols[i], kMaxTestError); + } + for (int i = 0; i < enh_.bank_size_; i++) { + enh_.filtered_clear_var_[i] = static_cast(i + 1); + enh_.filtered_noise_var_[i] = static_cast(enh_.bank_size_ - i); + } + enh_.SolveForGainsGivenLambda(lambda, enh_.start_freq_, &sols[0]); + for (int i = 0; i < enh_.bank_size_; i++) { + EXPECT_NEAR(kTestNonZeroVarLambdaTop[i], sols[i], kMaxTestError); + } + lambda = -1.0; + enh_.SolveForGainsGivenLambda(lambda, enh_.start_freq_, &sols[0]); + for (int i = 0; i < enh_.bank_size_; i++) { + EXPECT_NEAR(kTestZeroVar[i], sols[i], kMaxTestError); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc index 145cc08728..d67d200689 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc @@ -14,36 +14,32 @@ #include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h" +#include +#include #include -#include -#include using std::complex; +using std::min; -namespace { +namespace webrtc { -// Return |current| changed towards |target|, with the change being at most -// |limit|. -inline float UpdateFactor(float target, float current, float limit) { +namespace intelligibility { + +float UpdateFactor(float target, float current, float limit) { float delta = fabsf(target - current); float sign = copysign(1.0f, target - current); return current + sign * fminf(delta, limit); } -// std::isfinite for complex numbers. -inline bool cplxfinite(complex c) { +bool cplxfinite(complex c) { return std::isfinite(c.real()) && std::isfinite(c.imag()); } -// std::isnormal for complex numbers. -inline bool cplxnormal(complex c) { +bool cplxnormal(complex c) { return std::isnormal(c.real()) && std::isnormal(c.imag()); } -// Apply a small fudge to degenerate complex values. The numbers in the array -// were chosen randomly, so that even a series of all zeroes has some small -// variability. -inline complex zerofudge(complex c) { +complex zerofudge(complex c) { const static complex fudge[7] = {{0.001f, 0.002f}, {0.008f, 0.001f}, {0.003f, 0.008f}, @@ -59,25 +55,14 @@ inline complex zerofudge(complex c) { return c; } -// Incremental mean computation. Return the mean of the series with the -// mean |mean| with added |data|. -inline complex NewMean(complex mean, - complex data, - int count) { +complex NewMean(complex mean, complex data, int count) { return mean + (data - mean) / static_cast(count); } -inline void AddToMean(complex data, int count, complex* mean) { +void AddToMean(complex data, int count, complex* mean) { (*mean) = NewMean(*mean, data, count); } -} // namespace - -using std::min; - -namespace webrtc { - -namespace intelligibility { static const int kWindowBlockSize = 10; @@ -96,7 +81,8 @@ VarianceArray::VarianceArray(int freqs, decay_(decay), history_cursor_(0), count_(0), - array_mean_(0.0f) { + array_mean_(0.0f), + buffer_full_(false) { history_.reset(new rtc::scoped_ptr[]>[freqs_]()); for (int i = 0; i < freqs_; ++i) { history_[i].reset(new complex[window_size_]()); @@ -122,6 +108,9 @@ VarianceArray::VarianceArray(int freqs, case kStepBlocked: step_func_ = &VarianceArray::BlockedStep; break; + case kStepBlockBasedMovingAverage: + step_func_ = &VarianceArray::BlockBasedMovingAverage; + break; } } @@ -223,7 +212,7 @@ void VarianceArray::WindowedStep(const complex* data, bool /*dummy*/) { // history window and a new block is started. The variances for the window // are recomputed from scratch at each of these transitions. void VarianceArray::BlockedStep(const complex* data, bool /*dummy*/) { - int blocks = min(window_size_, history_cursor_); + int blocks = min(window_size_, history_cursor_ + 1); for (int i = 0; i < freqs_; ++i) { AddToMean(data[i], count_ + 1, &sub_running_mean_[i]); AddToMean(data[i] * std::conj(data[i]), count_ + 1, @@ -242,8 +231,8 @@ void VarianceArray::BlockedStep(const complex* data, bool /*dummy*/) { running_mean_[i] = complex(0.0f, 0.0f); running_mean_sq_[i] = complex(0.0f, 0.0f); for (int j = 0; j < min(window_size_, history_cursor_); ++j) { - AddToMean(subhistory_[i][j], j, &running_mean_[i]); - AddToMean(subhistory_sq_[i][j], j, &running_mean_sq_[i]); + AddToMean(subhistory_[i][j], j + 1, &running_mean_[i]); + AddToMean(subhistory_sq_[i][j], j + 1, &running_mean_sq_[i]); } ++history_cursor_; } @@ -254,6 +243,51 @@ void VarianceArray::BlockedStep(const complex* data, bool /*dummy*/) { } } +// Recomputes variances for each window from scratch based on previous window. +void VarianceArray::BlockBasedMovingAverage(const std::complex* data, + bool /*dummy*/) { + // TODO(ekmeyerson) To mitigate potential divergence, add counter so that + // after every so often sums are computed scratch by summing over all + // elements instead of subtracting oldest and adding newest. + for (int i = 0; i < freqs_; ++i) { + sub_running_mean_[i] += data[i]; + sub_running_mean_sq_[i] += data[i] * std::conj(data[i]); + } + ++count_; + + // TODO(ekmeyerson) Make kWindowBlockSize nonconstant to allow + // experimentation with different block size,window size pairs. + if (count_ >= kWindowBlockSize) { + count_ = 0; + + for (int i = 0; i < freqs_; ++i) { + running_mean_[i] -= subhistory_[i][history_cursor_]; + running_mean_sq_[i] -= subhistory_sq_[i][history_cursor_]; + + float scale = 1.f / kWindowBlockSize; + subhistory_[i][history_cursor_] = sub_running_mean_[i] * scale; + subhistory_sq_[i][history_cursor_] = sub_running_mean_sq_[i] * scale; + + sub_running_mean_[i] = std::complex(0.0f, 0.0f); + sub_running_mean_sq_[i] = std::complex(0.0f, 0.0f); + + running_mean_[i] += subhistory_[i][history_cursor_]; + running_mean_sq_[i] += subhistory_sq_[i][history_cursor_]; + + scale = 1.f / (buffer_full_ ? window_size_ : history_cursor_ + 1); + variance_[i] = std::real(running_mean_sq_[i] * scale - + running_mean_[i] * scale * + std::conj(running_mean_[i]) * scale); + } + + ++history_cursor_; + if (history_cursor_ >= window_size_) { + buffer_full_ = true; + history_cursor_ = 0; + } + } +} + void VarianceArray::Clear() { memset(running_mean_.get(), 0, sizeof(*running_mean_.get()) * freqs_); memset(running_mean_sq_.get(), 0, sizeof(*running_mean_sq_.get()) * freqs_); diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h index 075b8ad46b..9908ac0456 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h @@ -23,6 +23,30 @@ namespace webrtc { namespace intelligibility { +// Return |current| changed towards |target|, with the change being at most +// |limit|. +float UpdateFactor(float target, float current, float limit); + +// std::isfinite for complex numbers. +bool cplxfinite(std::complex c); + +// std::isnormal for complex numbers. +bool cplxnormal(std::complex c); + +// Apply a small fudge to degenerate complex values. The numbers in the array +// were chosen randomly, so that even a series of all zeroes has some small +// variability. +std::complex zerofudge(std::complex c); + +// Incremental mean computation. Return the mean of the series with the +// mean |mean| with added |data|. +std::complex NewMean(std::complex mean, + std::complex data, + int count); + +// Updates |mean| with added |data|; +void AddToMean(std::complex data, int count, std::complex* mean); + // Internal helper for computing the variances of a stream of arrays. // The result is an array of variances per position: the i-th variance // is the variance of the stream of data on the i-th positions in the @@ -43,7 +67,8 @@ class VarianceArray { kStepInfinite = 0, kStepDecaying, kStepWindowed, - kStepBlocked + kStepBlocked, + kStepBlockBasedMovingAverage }; // Construct an instance for the given input array length (|freqs|) and @@ -77,6 +102,7 @@ class VarianceArray { void DecayStep(const std::complex* data, bool dummy); void WindowedStep(const std::complex* data, bool dummy); void BlockedStep(const std::complex* data, bool dummy); + void BlockBasedMovingAverage(const std::complex* data, bool dummy); // TODO(ekmeyerson): Switch the following running means // and histories from rtc::scoped_ptr to std::vector. @@ -105,6 +131,7 @@ class VarianceArray { int history_cursor_; int count_; float array_mean_; + bool buffer_full_; void (VarianceArray::*step_func_)(const std::complex*, bool); }; diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils_unittest.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils_unittest.cc new file mode 100644 index 0000000000..ca5567cded --- /dev/null +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils_unittest.cc @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// +// Unit tests for intelligibility utils. +// + +#include +#include +#include +#include + +#include "testing/gtest/include/gtest/gtest.h" +#include "webrtc/base/arraysize.h" +#include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h" + +using std::complex; +using std::vector; + +namespace webrtc { + +namespace intelligibility { + +vector>> GenerateTestData(int freqs, int samples) { + vector>> data(samples); + for (int i = 0; i < samples; i++) { + for (int j = 0; j < freqs; j++) { + const float val = 0.99f / ((i + 1) * (j + 1)); + data[i].push_back(complex(val, val)); + } + } + return data; +} + +// Tests UpdateFactor. +TEST(IntelligibilityUtilsTest, TestUpdateFactor) { + EXPECT_EQ(0, intelligibility::UpdateFactor(0, 0, 0)); + EXPECT_EQ(4, intelligibility::UpdateFactor(4, 2, 3)); + EXPECT_EQ(3, intelligibility::UpdateFactor(4, 2, 1)); + EXPECT_EQ(2, intelligibility::UpdateFactor(2, 4, 3)); + EXPECT_EQ(3, intelligibility::UpdateFactor(2, 4, 1)); +} + +// Tests cplxfinite, cplxnormal, and zerofudge. +TEST(IntelligibilityUtilsTest, TestCplx) { + complex t0(1.f, 0.f); + EXPECT_TRUE(intelligibility::cplxfinite(t0)); + EXPECT_FALSE(intelligibility::cplxnormal(t0)); + t0 = intelligibility::zerofudge(t0); + EXPECT_NE(t0.imag(), 0.f); + EXPECT_NE(t0.real(), 0.f); + const complex t1(1.f, std::sqrt(-1.f)); + EXPECT_FALSE(intelligibility::cplxfinite(t1)); + EXPECT_FALSE(intelligibility::cplxnormal(t1)); + const complex t2(1.f, 1.f); + EXPECT_TRUE(intelligibility::cplxfinite(t2)); + EXPECT_TRUE(intelligibility::cplxnormal(t2)); +} + +// Tests NewMean and AddToMean. +TEST(IntelligibilityUtilsTest, TestMeanUpdate) { + const complex data[] = {{3, 8}, {7, 6}, {2, 1}, {8, 9}, {0, 6}}; + const complex means[] = {{3, 8}, {5, 7}, {4, 5}, {5, 6}, {4, 6}}; + complex mean(3, 8); + for (size_t i = 0; i < arraysize(data); i++) { + EXPECT_EQ(means[i], NewMean(mean, data[i], i + 1)); + AddToMean(data[i], i + 1, &mean); + EXPECT_EQ(means[i], mean); + } +} + +// Tests VarianceArray, for all variance step types. +TEST(IntelligibilityUtilsTest, TestVarianceArray) { + const int kFreqs = 10; + const int kSamples = 100; + const int kWindowSize = 10; // Should pass for all kWindowSize > 1. + const float kDecay = 0.5f; + vector step_types; + step_types.push_back(VarianceArray::kStepInfinite); + step_types.push_back(VarianceArray::kStepDecaying); + step_types.push_back(VarianceArray::kStepWindowed); + step_types.push_back(VarianceArray::kStepBlocked); + step_types.push_back(VarianceArray::kStepBlockBasedMovingAverage); + const vector>> test_data( + GenerateTestData(kFreqs, kSamples)); + for (auto step_type : step_types) { + VarianceArray variance_array(kFreqs, step_type, kWindowSize, kDecay); + EXPECT_EQ(0, variance_array.variance()[0]); + EXPECT_EQ(0, variance_array.array_mean()); + variance_array.ApplyScale(2.0f); + EXPECT_EQ(0, variance_array.variance()[0]); + EXPECT_EQ(0, variance_array.array_mean()); + + // Makes sure Step is doing something. + variance_array.Step(&test_data[0][0]); + for (int i = 1; i < kSamples; i++) { + variance_array.Step(&test_data[i][0]); + EXPECT_GE(variance_array.array_mean(), 0.0f); + EXPECT_LE(variance_array.array_mean(), 1.0f); + for (int j = 0; j < kFreqs; j++) { + EXPECT_GE(variance_array.variance()[j], 0.0f); + EXPECT_LE(variance_array.variance()[j], 1.0f); + } + } + variance_array.Clear(); + EXPECT_EQ(0, variance_array.variance()[0]); + EXPECT_EQ(0, variance_array.array_mean()); + } +} + +// Tests exact computation on synthetic data. +TEST(IntelligibilityUtilsTest, TestMovingBlockAverage) { + // Exact, not unbiased estimates. + const float kTestVarianceBufferNotFull = 16.5f; + const float kTestVarianceBufferFull1 = 66.5f; + const float kTestVarianceBufferFull2 = 333.375f; + const int kFreqs = 2; + const int kSamples = 50; + const int kWindowSize = 2; + const float kDecay = 0.5f; + const float kMaxError = 0.0001f; + + VarianceArray variance_array( + kFreqs, VarianceArray::kStepBlockBasedMovingAverage, kWindowSize, kDecay); + + vector>> test_data(kSamples); + for (int i = 0; i < kSamples; i++) { + for (int j = 0; j < kFreqs; j++) { + if (i < 30) { + test_data[i].push_back(complex(static_cast(kSamples - i), + static_cast(i + 1))); + } else { + test_data[i].push_back(complex(0.f, 0.f)); + } + } + } + + for (int i = 0; i < kSamples; i++) { + variance_array.Step(&test_data[i][0]); + for (int j = 0; j < kFreqs; j++) { + if (i < 9) { // In utils, kWindowBlockSize = 10. + EXPECT_EQ(0, variance_array.variance()[j]); + } else if (i < 19) { + EXPECT_NEAR(kTestVarianceBufferNotFull, variance_array.variance()[j], + kMaxError); + } else if (i < 39) { + EXPECT_NEAR(kTestVarianceBufferFull1, variance_array.variance()[j], + kMaxError); + } else if (i < 49) { + EXPECT_NEAR(kTestVarianceBufferFull2, variance_array.variance()[j], + kMaxError); + } else { + EXPECT_EQ(0, variance_array.variance()[j]); + } + } + } +} + +// Tests gain applier. +TEST(IntelligibilityUtilsTest, TestGainApplier) { + const int kFreqs = 10; + const int kSamples = 100; + const float kChangeLimit = 0.1f; + GainApplier gain_applier(kFreqs, kChangeLimit); + const vector>> in_data( + GenerateTestData(kFreqs, kSamples)); + vector>> out_data(GenerateTestData(kFreqs, kSamples)); + for (int i = 0; i < kSamples; i++) { + gain_applier.Apply(&in_data[i][0], &out_data[i][0]); + for (int j = 0; j < kFreqs; j++) { + EXPECT_GT(out_data[i][j].real(), 0.0f); + EXPECT_LT(out_data[i][j].real(), 1.0f); + EXPECT_GT(out_data[i][j].imag(), 0.0f); + EXPECT_LT(out_data[i][j].imag(), 1.0f); + } + } +} + +} // namespace intelligibility + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_proc.cc b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc similarity index 100% rename from webrtc/modules/audio_processing/intelligibility/intelligibility_proc.cc rename to webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc index 9f7d84e701..2f1888d28c 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_proc.cc +++ b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc @@ -16,9 +16,9 @@ #include #include -#include #include #include +#include #include "gflags/gflags.h" #include "testing/gtest/include/gtest/gtest.h" diff --git a/webrtc/modules/modules.gyp b/webrtc/modules/modules.gyp index b06ecc5685..af4c97bb70 100644 --- a/webrtc/modules/modules.gyp +++ b/webrtc/modules/modules.gyp @@ -171,6 +171,8 @@ 'audio_processing/beamformer/mock_nonlinear_beamformer.cc', 'audio_processing/beamformer/mock_nonlinear_beamformer.h', 'audio_processing/echo_cancellation_impl_unittest.cc', + 'audio_processing/intelligibility/intelligibility_enhancer_unittest.cc', + 'audio_processing/intelligibility/intelligibility_utils_unittest.cc', 'audio_processing/splitting_filter_unittest.cc', 'audio_processing/transient/dyadic_decimator_unittest.cc', 'audio_processing/transient/file_utils.cc',