diff --git a/api/audio/echo_canceller3_config.h b/api/audio/echo_canceller3_config.h index 3b7cf25325..7c8ca1b4f7 100644 --- a/api/audio/echo_canceller3_config.h +++ b/api/audio/echo_canceller3_config.h @@ -86,6 +86,8 @@ struct RTC_EXPORT EchoCanceller3Config { float max_h = 1.5f; bool onset_detection = true; size_t num_sections = 1; + bool clamp_quality_estimate_to_zero = true; + bool clamp_quality_estimate_to_one = true; } erle; struct EpStrength { diff --git a/api/audio/echo_canceller3_config_json.cc b/api/audio/echo_canceller3_config_json.cc index d07491d148..c17497a335 100644 --- a/api/audio/echo_canceller3_config_json.cc +++ b/api/audio/echo_canceller3_config_json.cc @@ -197,6 +197,10 @@ void Aec3ConfigFromJsonString(absl::string_view json_string, ReadParam(section, "max_h", &cfg.erle.max_h); ReadParam(section, "onset_detection", &cfg.erle.onset_detection); ReadParam(section, "num_sections", &cfg.erle.num_sections); + ReadParam(section, "clamp_quality_estimate_to_zero", + &cfg.erle.clamp_quality_estimate_to_zero); + ReadParam(section, "clamp_quality_estimate_to_one", + &cfg.erle.clamp_quality_estimate_to_one); } if (rtc::GetValueFromJsonObject(aec3_root, "ep_strength", §ion)) { @@ -408,7 +412,11 @@ std::string Aec3ConfigToJsonString(const EchoCanceller3Config& config) { ost << "\"max_h\": " << config.erle.max_h << ","; ost << "\"onset_detection\": " << (config.erle.onset_detection ? "true" : "false") << ","; - ost << "\"num_sections\": " << config.erle.num_sections; + ost << "\"num_sections\": " << config.erle.num_sections << ","; + ost << "\"clamp_quality_estimate_to_zero\": " + << (config.erle.clamp_quality_estimate_to_zero ? "true" : "false") << ","; + ost << "\"clamp_quality_estimate_to_one\": " + << (config.erle.clamp_quality_estimate_to_one ? "true" : "false"); ost << "},"; ost << "\"ep_strength\": {"; diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc index 686592398c..7518e3a3ea 100644 --- a/modules/audio_processing/aec3/aec_state.cc +++ b/modules/audio_processing/aec3/aec_state.cc @@ -111,29 +111,23 @@ AecState::AecState(const EchoCanceller3Config& config, new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))), config_(config), initial_state_(config_), - delay_state_(config_), + delay_state_(config_, num_capture_channels), transparent_state_(config_), - filter_quality_state_(config_), + filter_quality_state_(config_, num_capture_channels), erl_estimator_(2 * kNumBlocksPerSecond), erle_estimator_(2 * kNumBlocksPerSecond, config_, num_capture_channels), - filter_analyzers_(num_capture_channels), + filter_analyzer_(config_, num_capture_channels), echo_audibility_( config_.echo_audibility.use_stationarity_properties_at_init), - reverb_model_estimator_(config_), - subtractor_output_analyzers_(num_capture_channels) { - for (size_t ch = 0; ch < num_capture_channels; ++ch) { - filter_analyzers_[ch] = std::make_unique(config_); - } -} + reverb_model_estimator_(config_, num_capture_channels), + subtractor_output_analyzers_(num_capture_channels) {} AecState::~AecState() = default; void AecState::HandleEchoPathChange( const EchoPathVariability& echo_path_variability) { const auto full_reset = [&]() { - for (auto& filter_analyzer : filter_analyzers_) { - filter_analyzer->Reset(); - } + filter_analyzer_.Reset(); capture_signal_saturation_ = false; strong_not_saturated_render_blocks_ = 0; blocks_with_active_render_ = 0; @@ -161,49 +155,44 @@ void AecState::HandleEchoPathChange( void AecState::Update( const absl::optional& external_delay, rtc::ArrayView>> - adaptive_filter_frequency_response, - rtc::ArrayView> adaptive_filter_impulse_response, + adaptive_filter_frequency_responses, + rtc::ArrayView> adaptive_filter_impulse_responses, const RenderBuffer& render_buffer, rtc::ArrayView> E2_main, rtc::ArrayView> Y2, rtc::ArrayView subtractor_output) { - const size_t num_capture_channels = filter_analyzers_.size(); + const size_t num_capture_channels = subtractor_output_analyzers_.size(); RTC_DCHECK_EQ(num_capture_channels, E2_main.size()); RTC_DCHECK_EQ(num_capture_channels, Y2.size()); RTC_DCHECK_EQ(num_capture_channels, subtractor_output.size()); RTC_DCHECK_EQ(num_capture_channels, subtractor_output_analyzers_.size()); RTC_DCHECK_EQ(num_capture_channels, - adaptive_filter_frequency_response.size()); - RTC_DCHECK_EQ(num_capture_channels, adaptive_filter_impulse_response.size()); + adaptive_filter_frequency_responses.size()); + RTC_DCHECK_EQ(num_capture_channels, adaptive_filter_impulse_responses.size()); // Analyze the filter outputs and filters. bool any_filter_converged = false; bool all_filters_diverged = true; - bool any_filter_consistent = false; - float max_echo_path_gain = 0.f; for (size_t ch = 0; ch < subtractor_output.size(); ++ch) { subtractor_output_analyzers_[ch].Update(subtractor_output[ch]); any_filter_converged = any_filter_converged || subtractor_output_analyzers_[ch].ConvergedFilter(); all_filters_diverged = all_filters_diverged && subtractor_output_analyzers_[ch].DivergedFilter(); - - filter_analyzers_[ch]->Update(adaptive_filter_impulse_response[ch], - render_buffer); - any_filter_consistent = - any_filter_consistent || filter_analyzers_[ch]->Consistent(); - max_echo_path_gain = - std::max(max_echo_path_gain, filter_analyzers_[ch]->Gain()); } + bool any_filter_consistent; + float max_echo_path_gain; + filter_analyzer_.Update(adaptive_filter_impulse_responses, render_buffer, + &any_filter_consistent, &max_echo_path_gain); // Estimate the direct path delay of the filter. if (config_.filter.use_linear_filter) { - delay_state_.Update(filter_analyzers_, external_delay, + delay_state_.Update(filter_analyzer_.FilterDelaysBlocks(), external_delay, strong_not_saturated_render_blocks_); } const std::vector>& aligned_render_block = - render_buffer.Block(-delay_state_.DirectPathFilterDelay())[0]; + render_buffer.Block(-delay_state_.DirectPathFilterDelays()[0])[0]; // Update render counters. bool active_render = false; @@ -225,13 +214,13 @@ void AecState::Update( std::array X2_reverb; UpdateAndComputeReverb(render_buffer.GetSpectrumBuffer(), - delay_state_.DirectPathFilterDelay(), ReverbDecay(), - &reverb_model_, X2_reverb); + delay_state_.DirectPathFilterDelays()[0], + ReverbDecay(), &reverb_model_, X2_reverb); if (config_.echo_audibility.use_stationarity_properties) { // Update the echo audibility evaluator. echo_audibility_.Update(render_buffer, reverb_model_.reverb(), - delay_state_.DirectPathFilterDelay(), + delay_state_.DirectPathFilterDelays()[0], delay_state_.ExternalDelayReported()); } @@ -241,11 +230,12 @@ void AecState::Update( } // TODO(bugs.webrtc.org/10913): Take all channels into account. - const auto& X2 = render_buffer.Spectrum(delay_state_.DirectPathFilterDelay(), - /*channel=*/0); + const auto& X2 = + render_buffer.Spectrum(delay_state_.DirectPathFilterDelays()[0], + /*channel=*/0); const auto& X2_input_erle = X2_reverb; - erle_estimator_.Update(render_buffer, adaptive_filter_frequency_response[0], + erle_estimator_.Update(render_buffer, adaptive_filter_frequency_responses[0], X2_input_erle, Y2[0], E2_main[0], subtractor_output_analyzers_[0].ConvergedFilter(), config_.erle.onset_detection); @@ -262,7 +252,7 @@ void AecState::Update( initial_state_.Update(active_render, SaturatedCapture()); // Detect whether the transparent mode should be activated. - transparent_state_.Update(delay_state_.DirectPathFilterDelay(), + transparent_state_.Update(delay_state_.DirectPathFilterDelays()[0], any_filter_consistent, any_filter_converged, all_filters_diverged, active_render, SaturatedCapture()); @@ -277,11 +267,12 @@ void AecState::Update( config_.echo_audibility.use_stationarity_properties && echo_audibility_.IsBlockStationary(); - reverb_model_estimator_.Update(filter_analyzers_[0]->GetAdjustedFilter(), - adaptive_filter_frequency_response[0], - erle_estimator_.GetInstLinearQualityEstimate(), - delay_state_.DirectPathFilterDelay(), - UsableLinearEstimate(), stationary_block); + reverb_model_estimator_.Update( + filter_analyzer_.GetAdjustedFilters(), + adaptive_filter_frequency_responses, + erle_estimator_.GetInstLinearQualityEstimates(), + delay_state_.DirectPathFilterDelays(), + filter_quality_state_.UsableLinearFilterOutputs(), stationary_block); erle_estimator_.Dump(data_dumper_); reverb_model_estimator_.Dump(data_dumper_.get()); @@ -291,7 +282,7 @@ void AecState::Update( data_dumper_->DumpRaw("aec3_usable_linear_estimate", UsableLinearEstimate()); data_dumper_->DumpRaw("aec3_transparent_mode", TransparentMode()); data_dumper_->DumpRaw("aec3_filter_delay", - filter_analyzers_[0]->DelayBlocks()); + filter_analyzer_.MinFilterDelayBlocks()); data_dumper_->DumpRaw("aec3_any_filter_consistent", any_filter_consistent); data_dumper_->DumpRaw("aec3_initial_state", @@ -335,11 +326,13 @@ void AecState::InitialState::InitialState::Update(bool active_render, transition_triggered_ = !initial_state_ && prev_initial_state; } -AecState::FilterDelay::FilterDelay(const EchoCanceller3Config& config) - : delay_headroom_samples_(config.delay.delay_headroom_samples) {} +AecState::FilterDelay::FilterDelay(const EchoCanceller3Config& config, + size_t num_capture_channels) + : delay_headroom_samples_(config.delay.delay_headroom_samples), + filter_delays_blocks_(num_capture_channels, 0) {} void AecState::FilterDelay::Update( - const std::vector>& filter_analyzers, + rtc::ArrayView analyzer_filter_delay_estimates_blocks, const absl::optional& external_delay, size_t blocks_with_proper_filter_adaptation) { // Update the delay based on the external delay. @@ -354,14 +347,15 @@ void AecState::FilterDelay::Update( const bool delay_estimator_may_not_have_converged = blocks_with_proper_filter_adaptation < 2 * kNumBlocksPerSecond; if (delay_estimator_may_not_have_converged && external_delay_) { - filter_delay_blocks_ = delay_headroom_samples_ / kBlockSize; + int delay_guess = delay_headroom_samples_ / kBlockSize; + std::fill(filter_delays_blocks_.begin(), filter_delays_blocks_.end(), + delay_guess); } else { - // Conservatively use the min delay among the filters. - filter_delay_blocks_ = filter_analyzers[0]->DelayBlocks(); - for (size_t ch = 1; ch < filter_analyzers.size(); ++ch) { - filter_delay_blocks_ = - std::min(filter_delay_blocks_, filter_analyzers[ch]->DelayBlocks()); - } + RTC_DCHECK_EQ(filter_delays_blocks_.size(), + analyzer_filter_delay_estimates_blocks.size()); + std::copy(analyzer_filter_delay_estimates_blocks.begin(), + analyzer_filter_delay_estimates_blocks.end(), + filter_delays_blocks_.begin()); } } @@ -452,10 +446,15 @@ void AecState::TransparentMode::Update(int filter_delay_blocks, } AecState::FilteringQualityAnalyzer::FilteringQualityAnalyzer( - const EchoCanceller3Config& config) {} + const EchoCanceller3Config& config, + size_t num_capture_channels) + : use_linear_filter_(config.filter.use_linear_filter), + usable_linear_filter_estimates_(num_capture_channels, false) {} void AecState::FilteringQualityAnalyzer::Reset() { - usable_linear_estimate_ = false; + std::fill(usable_linear_filter_estimates_.begin(), + usable_linear_filter_estimates_.end(), false); + overall_usable_linear_estimates_ = false; filter_update_blocks_since_reset_ = 0; } @@ -482,17 +481,24 @@ void AecState::FilteringQualityAnalyzer::Update( sufficient_data_to_converge_at_startup && filter_update_blocks_since_reset_ > kNumBlocksPerSecond * 0.2f; - // The linear filter can only be used it has had time to converge. - usable_linear_estimate_ = sufficient_data_to_converge_at_startup && - sufficient_data_to_converge_at_reset; + // The linear filter can only be used if it has had time to converge. + overall_usable_linear_estimates_ = sufficient_data_to_converge_at_startup && + sufficient_data_to_converge_at_reset; // The linear filter can only be used if an external delay or convergence have // been identified - usable_linear_estimate_ = - usable_linear_estimate_ && (external_delay || convergence_seen_); + overall_usable_linear_estimates_ = + overall_usable_linear_estimates_ && (external_delay || convergence_seen_); // If transparent mode is on, deactivate usign the linear filter. - usable_linear_estimate_ = usable_linear_estimate_ && !transparent_mode; + overall_usable_linear_estimates_ = + overall_usable_linear_estimates_ && !transparent_mode; + + if (use_linear_filter_) { + std::fill(usable_linear_filter_estimates_.begin(), + usable_linear_filter_estimates_.end(), + overall_usable_linear_estimates_); + } } void AecState::SaturationDetector::Update( diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h index 7a7a71e8a0..79fe13e431 100644 --- a/modules/audio_processing/aec3/aec_state.h +++ b/modules/audio_processing/aec3/aec_state.h @@ -91,7 +91,9 @@ class AecState { float ErlTimeDomain() const { return erl_estimator_.ErlTimeDomain(); } // Returns the delay estimate based on the linear filter. - int FilterDelayBlocks() const { return delay_state_.DirectPathFilterDelay(); } + int FilterDelayBlocks() const { + return delay_state_.DirectPathFilterDelays()[0]; + } // Returns whether the capture signal is saturated. bool SaturatedCapture() const { return capture_signal_saturation_; } @@ -130,8 +132,9 @@ class AecState { void Update( const absl::optional& external_delay, rtc::ArrayView>> - adaptive_filter_frequency_response, - rtc::ArrayView> adaptive_filter_impulse_response, + adaptive_filter_frequency_responses, + rtc::ArrayView> + adaptive_filter_impulse_responses, const RenderBuffer& render_buffer, rtc::ArrayView> E2_main, rtc::ArrayView> Y2, @@ -140,7 +143,7 @@ class AecState { // Returns filter length in blocks. int FilterLengthBlocks() const { // All filters have the same length, so arbitrarily return channel 0 length. - return filter_analyzers_[/*channel=*/0]->FilterLengthBlocks(); + return filter_analyzer_.FilterLengthBlocks(); } private: @@ -178,7 +181,8 @@ class AecState { // AecState. class FilterDelay { public: - explicit FilterDelay(const EchoCanceller3Config& config); + FilterDelay(const EchoCanceller3Config& config, + size_t num_capture_channels); // Returns whether an external delay has been reported to the AecState (from // the delay estimator). @@ -186,18 +190,20 @@ class AecState { // Returns the delay in blocks relative to the beginning of the filter that // corresponds to the direct path of the echo. - int DirectPathFilterDelay() const { return filter_delay_blocks_; } + rtc::ArrayView DirectPathFilterDelays() const { + return filter_delays_blocks_; + } // Updates the delay estimates based on new data. void Update( - const std::vector>& filter_analyzer, + rtc::ArrayView analyzer_filter_delay_estimates_blocks, const absl::optional& external_delay, size_t blocks_with_proper_filter_adaptation); private: const int delay_headroom_samples_; bool external_delay_reported_ = false; - int filter_delay_blocks_ = 0; + std::vector filter_delays_blocks_; absl::optional external_delay_; } delay_state_; @@ -243,11 +249,18 @@ class AecState { // suppressor. class FilteringQualityAnalyzer { public: - FilteringQualityAnalyzer(const EchoCanceller3Config& config); + FilteringQualityAnalyzer(const EchoCanceller3Config& config, + size_t num_capture_channels); - // Returns whether the the linear filter can be used for the echo + // Returns whether the linear filter can be used for the echo // canceller output. - bool LinearFilterUsable() const { return usable_linear_estimate_; } + bool LinearFilterUsable() const { return overall_usable_linear_estimates_; } + + // Returns whether an individual filter output can be used for the echo + // canceller output. + const std::vector& UsableLinearFilterOutputs() const { + return usable_linear_filter_estimates_; + } // Resets the state of the analyzer. void Reset(); @@ -260,10 +273,12 @@ class AecState { bool any_filter_converged); private: - bool usable_linear_estimate_ = false; + const bool use_linear_filter_; + bool overall_usable_linear_estimates_ = false; size_t filter_update_blocks_since_reset_ = 0; size_t filter_update_blocks_since_start_ = 0; bool convergence_seen_ = false; + std::vector usable_linear_filter_estimates_; } filter_quality_state_; // Class for detecting whether the echo is to be considered to be @@ -289,7 +304,7 @@ class AecState { size_t strong_not_saturated_render_blocks_ = 0; size_t blocks_with_active_render_ = 0; bool capture_signal_saturation_ = false; - std::vector> filter_analyzers_; + FilterAnalyzer filter_analyzer_; absl::optional external_delay_; EchoAudibility echo_audibility_; ReverbModelEstimator reverb_model_estimator_; diff --git a/modules/audio_processing/aec3/echo_canceller3.cc b/modules/audio_processing/aec3/echo_canceller3.cc index ffff1b6ccf..a7a76d35d9 100644 --- a/modules/audio_processing/aec3/echo_canceller3.cc +++ b/modules/audio_processing/aec3/echo_canceller3.cc @@ -42,6 +42,14 @@ EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) { adjusted_cfg.delay.delay_headroom_samples = kBlockSize * 2; } + if (field_trial::IsEnabled("WebRTC-Aec3ClampInstQualityToZeroKillSwitch")) { + adjusted_cfg.erle.clamp_quality_estimate_to_zero = false; + } + + if (field_trial::IsEnabled("WebRTC-Aec3ClampInstQualityToOneKillSwitch")) { + adjusted_cfg.erle.clamp_quality_estimate_to_one = false; + } + return adjusted_cfg; } diff --git a/modules/audio_processing/aec3/erle_estimator.cc b/modules/audio_processing/aec3/erle_estimator.cc index 17bb79d690..a3f68d175b 100644 --- a/modules/audio_processing/aec3/erle_estimator.cc +++ b/modules/audio_processing/aec3/erle_estimator.cc @@ -20,7 +20,7 @@ ErleEstimator::ErleEstimator(size_t startup_phase_length_blocks_, size_t num_capture_channels) : startup_phase_length_blocks__(startup_phase_length_blocks_), use_signal_dependent_erle_(config.erle.num_sections > 1), - fullband_erle_estimator_(config.erle.min, config.erle.max_l), + fullband_erle_estimator_(config.erle, num_capture_channels), subband_erle_estimator_(config, num_capture_channels), signal_dependent_erle_estimator_(config, num_capture_channels) { Reset(true); diff --git a/modules/audio_processing/aec3/erle_estimator.h b/modules/audio_processing/aec3/erle_estimator.h index 7f882caa99..cac6741226 100644 --- a/modules/audio_processing/aec3/erle_estimator.h +++ b/modules/audio_processing/aec3/erle_estimator.h @@ -69,10 +69,12 @@ class ErleEstimator { // Returns an estimation of the current linear filter quality based on the // current and past fullband ERLE estimates. The returned value is a float - // between 0 and 1 where 1 indicates that, at this current time instant, the - // linear filter is reaching its maximum subtraction performance. - absl::optional GetInstLinearQualityEstimate() const { - return fullband_erle_estimator_.GetInstLinearQualityEstimate(); + // vector with content between 0 and 1 where 1 indicates that, at this current + // time instant, the linear filter is reaching its maximum subtraction + // performance. + rtc::ArrayView> GetInstLinearQualityEstimates() + const { + return fullband_erle_estimator_.GetInstLinearQualityEstimates(); } void Dump(const std::unique_ptr& data_dumper) const; diff --git a/modules/audio_processing/aec3/filter_analyzer.cc b/modules/audio_processing/aec3/filter_analyzer.cc index 313460fbd4..f5920f0b27 100644 --- a/modules/audio_processing/aec3/filter_analyzer.cc +++ b/modules/audio_processing/aec3/filter_analyzer.cc @@ -47,91 +47,136 @@ size_t FindPeakIndex(rtc::ArrayView filter_time_domain, int FilterAnalyzer::instance_count_ = 0; -FilterAnalyzer::FilterAnalyzer(const EchoCanceller3Config& config) +FilterAnalyzer::FilterAnalyzer(const EchoCanceller3Config& config, + size_t num_capture_channels) : data_dumper_( new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))), bounded_erl_(config.ep_strength.bounded_erl), default_gain_(config.ep_strength.default_gain), - h_highpass_(GetTimeDomainLength(config.filter.main.length_blocks), 0.f), - filter_length_blocks_(config.filter.main_initial.length_blocks), - consistent_filter_detector_(config) { + h_highpass_(num_capture_channels, + std::vector( + GetTimeDomainLength(config.filter.main.length_blocks), + 0.f)), + filter_analysis_states_(num_capture_channels, + FilterAnalysisState(config)), + filter_delays_blocks_(num_capture_channels, 0) { Reset(); } FilterAnalyzer::~FilterAnalyzer() = default; void FilterAnalyzer::Reset() { - delay_blocks_ = 0; blocks_since_reset_ = 0; - gain_ = default_gain_; - peak_index_ = 0; ResetRegion(); - consistent_filter_detector_.Reset(); + for (auto& state : filter_analysis_states_) { + state.peak_index = 0; + state.gain = default_gain_; + state.consistent_filter_detector.Reset(); + } + std::fill(filter_delays_blocks_.begin(), filter_delays_blocks_.end(), 0); } -void FilterAnalyzer::Update(rtc::ArrayView filter_time_domain, - const RenderBuffer& render_buffer) { - SetRegionToAnalyze(filter_time_domain); - AnalyzeRegion(filter_time_domain, render_buffer); +void FilterAnalyzer::Update( + rtc::ArrayView> filters_time_domain, + const RenderBuffer& render_buffer, + bool* any_filter_consistent, + float* max_echo_path_gain) { + RTC_DCHECK(any_filter_consistent); + RTC_DCHECK(max_echo_path_gain); + RTC_DCHECK_EQ(filters_time_domain.size(), filter_analysis_states_.size()); + RTC_DCHECK_EQ(filters_time_domain.size(), h_highpass_.size()); + + ++blocks_since_reset_; + SetRegionToAnalyze(filters_time_domain[0].size()); + AnalyzeRegion(filters_time_domain, render_buffer); + + // Aggregate the results for all capture channels. + auto& st_ch0 = filter_analysis_states_[0]; + *any_filter_consistent = st_ch0.consistent_estimate; + *max_echo_path_gain = st_ch0.gain; + min_filter_delay_blocks_ = filter_delays_blocks_[0]; + for (size_t ch = 1; ch < filters_time_domain.size(); ++ch) { + auto& st_ch = filter_analysis_states_[ch]; + *any_filter_consistent = + *any_filter_consistent || st_ch.consistent_estimate; + *max_echo_path_gain = std::max(*max_echo_path_gain, st_ch.gain); + min_filter_delay_blocks_ = + std::min(min_filter_delay_blocks_, filter_delays_blocks_[ch]); + } } void FilterAnalyzer::AnalyzeRegion( - rtc::ArrayView filter_time_domain, + rtc::ArrayView> filters_time_domain, const RenderBuffer& render_buffer) { - RTC_DCHECK_LT(region_.start_sample_, filter_time_domain.size()); - RTC_DCHECK_LT(peak_index_, filter_time_domain.size()); - RTC_DCHECK_LT(region_.end_sample_, filter_time_domain.size()); - // Preprocess the filter to avoid issues with low-frequency components in the // filter. - PreProcessFilter(filter_time_domain); - data_dumper_->DumpRaw("aec3_linear_filter_processed_td", h_highpass_); + PreProcessFilters(filters_time_domain); + data_dumper_->DumpRaw("aec3_linear_filter_processed_td", h_highpass_[0]); - RTC_DCHECK_EQ(h_highpass_.size(), filter_time_domain.size()); + constexpr float kOneByBlockSize = 1.f / kBlockSize; + for (size_t ch = 0; ch < filters_time_domain.size(); ++ch) { + RTC_DCHECK_LT(region_.start_sample_, filters_time_domain[ch].size()); + RTC_DCHECK_LT(filter_analysis_states_[ch].peak_index, + filters_time_domain[0].size()); + RTC_DCHECK_LT(region_.end_sample_, filters_time_domain[ch].size()); - peak_index_ = FindPeakIndex(h_highpass_, peak_index_, region_.start_sample_, - region_.end_sample_); - delay_blocks_ = peak_index_ >> kBlockSizeLog2; - UpdateFilterGain(h_highpass_, peak_index_); - filter_length_blocks_ = filter_time_domain.size() * (1.f / kBlockSize); + auto& st_ch = filter_analysis_states_[ch]; + RTC_DCHECK_EQ(h_highpass_[ch].size(), filters_time_domain[ch].size()); - consistent_estimate_ = consistent_filter_detector_.Detect( - h_highpass_, region_, render_buffer.Block(-delay_blocks_)[0], peak_index_, - delay_blocks_); + st_ch.peak_index = + FindPeakIndex(h_highpass_[ch], st_ch.peak_index, region_.start_sample_, + region_.end_sample_); + filter_delays_blocks_[ch] = st_ch.peak_index >> kBlockSizeLog2; + UpdateFilterGain(h_highpass_[ch], &st_ch); + st_ch.filter_length_blocks = + filters_time_domain[ch].size() * kOneByBlockSize; + + st_ch.consistent_estimate = st_ch.consistent_filter_detector.Detect( + h_highpass_[ch], region_, + render_buffer.Block(-filter_delays_blocks_[ch])[0], st_ch.peak_index, + filter_delays_blocks_[ch]); + } } void FilterAnalyzer::UpdateFilterGain( rtc::ArrayView filter_time_domain, - size_t peak_index) { + FilterAnalysisState* st) { bool sufficient_time_to_converge = - ++blocks_since_reset_ > 5 * kNumBlocksPerSecond; + blocks_since_reset_ > 5 * kNumBlocksPerSecond; - if (sufficient_time_to_converge && consistent_estimate_) { - gain_ = fabsf(filter_time_domain[peak_index]); + if (sufficient_time_to_converge && st->consistent_estimate) { + st->gain = fabsf(filter_time_domain[st->peak_index]); } else { - if (gain_) { - gain_ = std::max(gain_, fabsf(filter_time_domain[peak_index])); + // TODO(peah): Verify whether this check against a float is ok. + if (st->gain) { + st->gain = std::max(st->gain, fabsf(filter_time_domain[st->peak_index])); } } - if (bounded_erl_ && gain_) { - gain_ = std::max(gain_, 0.01f); + if (bounded_erl_ && st->gain) { + st->gain = std::max(st->gain, 0.01f); } } -void FilterAnalyzer::PreProcessFilter( - rtc::ArrayView filter_time_domain) { - RTC_DCHECK_GE(h_highpass_.capacity(), filter_time_domain.size()); - h_highpass_.resize(filter_time_domain.size()); - // Minimum phase high-pass filter with cutoff frequency at about 600 Hz. - constexpr std::array h = {{0.7929742f, -0.36072128f, -0.47047766f}}; +void FilterAnalyzer::PreProcessFilters( + rtc::ArrayView> filters_time_domain) { + for (size_t ch = 0; ch < filters_time_domain.size(); ++ch) { + RTC_DCHECK_LT(region_.start_sample_, filters_time_domain[ch].size()); + RTC_DCHECK_LT(region_.end_sample_, filters_time_domain[ch].size()); - std::fill(h_highpass_.begin() + region_.start_sample_, - h_highpass_.begin() + region_.end_sample_ + 1, 0.f); - for (size_t k = std::max(h.size() - 1, region_.start_sample_); - k <= region_.end_sample_; ++k) { - for (size_t j = 0; j < h.size(); ++j) { - h_highpass_[k] += filter_time_domain[k - j] * h[j]; + RTC_DCHECK_GE(h_highpass_[ch].capacity(), filters_time_domain[ch].size()); + h_highpass_[ch].resize(filters_time_domain[ch].size()); + // Minimum phase high-pass filter with cutoff frequency at about 600 Hz. + constexpr std::array h = { + {0.7929742f, -0.36072128f, -0.47047766f}}; + + std::fill(h_highpass_[ch].begin() + region_.start_sample_, + h_highpass_[ch].begin() + region_.end_sample_ + 1, 0.f); + for (size_t k = std::max(h.size() - 1, region_.start_sample_); + k <= region_.end_sample_; ++k) { + for (size_t j = 0; j < h.size(); ++j) { + h_highpass_[ch][k] += filters_time_domain[ch][k - j] * h[j]; + } } } } @@ -141,19 +186,17 @@ void FilterAnalyzer::ResetRegion() { region_.end_sample_ = 0; } -void FilterAnalyzer::SetRegionToAnalyze( - rtc::ArrayView filter_time_domain) { +void FilterAnalyzer::SetRegionToAnalyze(size_t filter_size) { constexpr size_t kNumberBlocksToUpdate = 1; auto& r = region_; - r.start_sample_ = - r.end_sample_ >= filter_time_domain.size() - 1 ? 0 : r.end_sample_ + 1; + r.start_sample_ = r.end_sample_ >= filter_size - 1 ? 0 : r.end_sample_ + 1; r.end_sample_ = std::min(r.start_sample_ + kNumberBlocksToUpdate * kBlockSize - 1, - filter_time_domain.size() - 1); + filter_size - 1); // Check range. - RTC_DCHECK_LT(r.start_sample_, filter_time_domain.size()); - RTC_DCHECK_LT(r.end_sample_, filter_time_domain.size()); + RTC_DCHECK_LT(r.start_sample_, filter_size); + RTC_DCHECK_LT(r.end_sample_, filter_size); RTC_DCHECK_LE(r.start_sample_, r.end_sample_); } diff --git a/modules/audio_processing/aec3/filter_analyzer.h b/modules/audio_processing/aec3/filter_analyzer.h index de6c8a7dd2..a7375778c6 100644 --- a/modules/audio_processing/aec3/filter_analyzer.h +++ b/modules/audio_processing/aec3/filter_analyzer.h @@ -30,7 +30,8 @@ class RenderBuffer; // Class for analyzing the properties of an adaptive filter. class FilterAnalyzer { public: - explicit FilterAnalyzer(const EchoCanceller3Config& config); + FilterAnalyzer(const EchoCanceller3Config& config, + size_t num_capture_channels); ~FilterAnalyzer(); FilterAnalyzer(const FilterAnalyzer&) = delete; @@ -40,35 +41,43 @@ class FilterAnalyzer { void Reset(); // Updates the estimates with new input data. - void Update(rtc::ArrayView filter_time_domain, - const RenderBuffer& render_buffer); + void Update(rtc::ArrayView> filters_time_domain, + const RenderBuffer& render_buffer, + bool* any_filter_consistent, + float* max_echo_path_gain); - // Returns the delay of the filter in terms of blocks. - int DelayBlocks() const { return delay_blocks_; } + // Returns the delay in blocks for each filter. + rtc::ArrayView FilterDelaysBlocks() const { + return filter_delays_blocks_; + } - // Returns whether the filter is consistent in the sense that it does not - // change much over time. - bool Consistent() const { return consistent_estimate_; } - - // Returns the estimated filter gain. - float Gain() const { return gain_; } + // Returns the minimum delay of all filters in terms of blocks. + int MinFilterDelayBlocks() const { return min_filter_delay_blocks_; } // Returns the number of blocks for the current used filter. - int FilterLengthBlocks() const { return filter_length_blocks_; } + int FilterLengthBlocks() const { + return filter_analysis_states_[0].filter_length_blocks; + } // Returns the preprocessed filter. - rtc::ArrayView GetAdjustedFilter() const { return h_highpass_; } + rtc::ArrayView> GetAdjustedFilters() const { + return h_highpass_; + } // Public for testing purposes only. - void SetRegionToAnalyze(rtc::ArrayView filter_time_domain); + void SetRegionToAnalyze(size_t filter_size); private: - void AnalyzeRegion(rtc::ArrayView filter_time_domain, - const RenderBuffer& render_buffer); + struct FilterAnalysisState; - void UpdateFilterGain(rtc::ArrayView filter_time_domain, - size_t max_index); - void PreProcessFilter(rtc::ArrayView filter_time_domain); + void AnalyzeRegion( + rtc::ArrayView> filters_time_domain, + const RenderBuffer& render_buffer); + + void UpdateFilterGain(rtc::ArrayView filters_time_domain, + FilterAnalysisState* st); + void PreProcessFilters( + rtc::ArrayView> filters_time_domain); void ResetRegion(); @@ -100,19 +109,30 @@ class FilterAnalyzer { int consistent_delay_reference_ = -10; }; + struct FilterAnalysisState { + explicit FilterAnalysisState(const EchoCanceller3Config& config) + : filter_length_blocks(config.filter.main_initial.length_blocks), + consistent_filter_detector(config) {} + float gain; + size_t peak_index; + int filter_length_blocks; + bool consistent_estimate = false; + ConsistentFilterDetector consistent_filter_detector; + }; + static int instance_count_; std::unique_ptr data_dumper_; const bool bounded_erl_; const float default_gain_; - std::vector h_highpass_; - int delay_blocks_ = 0; + std::vector> h_highpass_; + size_t blocks_since_reset_ = 0; - bool consistent_estimate_ = false; - float gain_; - size_t peak_index_; - int filter_length_blocks_; FilterRegion region_; - ConsistentFilterDetector consistent_filter_detector_; + + std::vector filter_analysis_states_; + std::vector filter_delays_blocks_; + + int min_filter_delay_blocks_ = 0; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/filter_analyzer_unittest.cc b/modules/audio_processing/aec3/filter_analyzer_unittest.cc index 474d67d348..34104c39b2 100644 --- a/modules/audio_processing/aec3/filter_analyzer_unittest.cc +++ b/modules/audio_processing/aec3/filter_analyzer_unittest.cc @@ -21,11 +21,11 @@ namespace webrtc { TEST(FilterAnalyzer, FilterResize) { EchoCanceller3Config c; std::vector filter(65, 0.f); - FilterAnalyzer fa(c); - fa.SetRegionToAnalyze(filter); - fa.SetRegionToAnalyze(filter); + FilterAnalyzer fa(c, 1); + fa.SetRegionToAnalyze(filter.size()); + fa.SetRegionToAnalyze(filter.size()); filter.resize(32); - fa.SetRegionToAnalyze(filter); + fa.SetRegionToAnalyze(filter.size()); } } // namespace webrtc diff --git a/modules/audio_processing/aec3/fullband_erle_estimator.cc b/modules/audio_processing/aec3/fullband_erle_estimator.cc index 7893b97b3a..086638d6b5 100644 --- a/modules/audio_processing/aec3/fullband_erle_estimator.cc +++ b/modules/audio_processing/aec3/fullband_erle_estimator.cc @@ -30,9 +30,13 @@ constexpr int kBlocksToHoldErle = 100; constexpr int kPointsToAccumulate = 6; } // namespace -FullBandErleEstimator::FullBandErleEstimator(float min_erle, float max_erle_lf) - : min_erle_log2_(FastApproxLog2f(min_erle + kEpsilon)), - max_erle_lf_log2(FastApproxLog2f(max_erle_lf + kEpsilon)) { +FullBandErleEstimator::FullBandErleEstimator( + const EchoCanceller3Config::Erle& config, + size_t num_capture_channels) + : min_erle_log2_(FastApproxLog2f(config.min + kEpsilon)), + max_erle_lf_log2(FastApproxLog2f(config.max_l + kEpsilon)), + instantaneous_erle_(config), + linear_filters_qualities_(num_capture_channels) { Reset(); } @@ -40,6 +44,7 @@ FullBandErleEstimator::~FullBandErleEstimator() = default; void FullBandErleEstimator::Reset() { instantaneous_erle_.Reset(); + UpdateQualityEstimates(); erle_time_domain_log2_ = min_erle_log2_; hold_counter_time_domain_ = 0; } @@ -72,6 +77,8 @@ void FullBandErleEstimator::Update(rtc::ArrayView X2, if (hold_counter_time_domain_ == 0) { instantaneous_erle_.ResetAccumulators(); } + + UpdateQualityEstimates(); } void FullBandErleEstimator::Dump( @@ -80,7 +87,15 @@ void FullBandErleEstimator::Dump( instantaneous_erle_.Dump(data_dumper); } -FullBandErleEstimator::ErleInstantaneous::ErleInstantaneous() { +void FullBandErleEstimator::UpdateQualityEstimates() { + std::fill(linear_filters_qualities_.begin(), linear_filters_qualities_.end(), + instantaneous_erle_.GetQualityEstimate()); +} + +FullBandErleEstimator::ErleInstantaneous::ErleInstantaneous( + const EchoCanceller3Config::Erle& config) + : clamp_inst_quality_to_zero_(config.clamp_quality_estimate_to_zero), + clamp_inst_quality_to_one_(config.clamp_quality_estimate_to_one) { Reset(); } @@ -154,6 +169,8 @@ void FullBandErleEstimator::ErleInstantaneous::UpdateQualityEstimate() { const float alpha = 0.07f; float quality_estimate = 0.f; RTC_DCHECK(erle_log2_); + // TODO(peah): Currently, the estimate can become be less than 0; this should + // be corrected. if (max_erle_log2_ > min_erle_log2_) { quality_estimate = (erle_log2_.value() - min_erle_log2_) / (max_erle_log2_ - min_erle_log2_); diff --git a/modules/audio_processing/aec3/fullband_erle_estimator.h b/modules/audio_processing/aec3/fullband_erle_estimator.h index 175db55e11..64372a2009 100644 --- a/modules/audio_processing/aec3/fullband_erle_estimator.h +++ b/modules/audio_processing/aec3/fullband_erle_estimator.h @@ -12,9 +12,11 @@ #define MODULES_AUDIO_PROCESSING_AEC3_FULLBAND_ERLE_ESTIMATOR_H_ #include +#include #include "absl/types/optional.h" #include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" #include "modules/audio_processing/logging/apm_data_dumper.h" namespace webrtc { @@ -23,7 +25,8 @@ namespace webrtc { // freuquency bands. class FullBandErleEstimator { public: - FullBandErleEstimator(float min_erle, float max_erle_lf); + FullBandErleEstimator(const EchoCanceller3Config::Erle& config, + size_t num_capture_channels); ~FullBandErleEstimator(); // Resets the ERLE estimator. void Reset(); @@ -39,16 +42,19 @@ class FullBandErleEstimator { // Returns an estimation of the current linear filter quality. It returns a // float number between 0 and 1 mapping 1 to the highest possible quality. - absl::optional GetInstLinearQualityEstimate() const { - return instantaneous_erle_.GetQualityEstimate(); + rtc::ArrayView> GetInstLinearQualityEstimates() + const { + return linear_filters_qualities_; } void Dump(const std::unique_ptr& data_dumper) const; private: + void UpdateQualityEstimates(); + class ErleInstantaneous { public: - ErleInstantaneous(); + explicit ErleInstantaneous(const EchoCanceller3Config::Erle& config); ~ErleInstantaneous(); // Updates the estimator with a new point, returns true @@ -64,14 +70,25 @@ class FullBandErleEstimator { // Gets an indication between 0 and 1 of the performance of the linear // filter for the current time instant. absl::optional GetQualityEstimate() const { - return erle_log2_ ? absl::optional(inst_quality_estimate_) - : absl::nullopt; + if (erle_log2_) { + float value = inst_quality_estimate_; + if (clamp_inst_quality_to_zero_) { + value = std::max(0.f, value); + } + if (clamp_inst_quality_to_one_) { + value = std::min(1.f, value); + } + return absl::optional(value); + } + return absl::nullopt; } void Dump(const std::unique_ptr& data_dumper) const; private: void UpdateMaxMin(); void UpdateQualityEstimate(); + const bool clamp_inst_quality_to_zero_; + const bool clamp_inst_quality_to_one_; absl::optional erle_log2_; float inst_quality_estimate_; float max_erle_log2_; @@ -86,6 +103,7 @@ class FullBandErleEstimator { const float min_erle_log2_; const float max_erle_lf_log2; ErleInstantaneous instantaneous_erle_; + std::vector> linear_filters_qualities_; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/reverb_model_estimator.cc b/modules/audio_processing/aec3/reverb_model_estimator.cc index ce3e2be335..717431103f 100644 --- a/modules/audio_processing/aec3/reverb_model_estimator.cc +++ b/modules/audio_processing/aec3/reverb_model_estimator.cc @@ -12,26 +12,43 @@ namespace webrtc { -ReverbModelEstimator::ReverbModelEstimator(const EchoCanceller3Config& config) - : reverb_decay_estimator_(config) {} +ReverbModelEstimator::ReverbModelEstimator(const EchoCanceller3Config& config, + size_t num_capture_channels) + : reverb_decay_estimators_(num_capture_channels), + reverb_frequency_responses_(num_capture_channels) { + for (size_t ch = 0; ch < reverb_decay_estimators_.size(); ++ch) { + reverb_decay_estimators_[ch] = + std::make_unique(config); + } +} ReverbModelEstimator::~ReverbModelEstimator() = default; void ReverbModelEstimator::Update( - rtc::ArrayView impulse_response, - const std::vector>& - frequency_response, - const absl::optional& linear_filter_quality, - int filter_delay_blocks, - bool usable_linear_estimate, + rtc::ArrayView> impulse_responses, + rtc::ArrayView>> + frequency_responses, + rtc::ArrayView> linear_filter_qualities, + rtc::ArrayView filter_delays_blocks, + const std::vector& usable_linear_estimates, bool stationary_block) { - // Estimate the frequency response for the reverb. - reverb_frequency_response_.Update(frequency_response, filter_delay_blocks, - linear_filter_quality, stationary_block); + const size_t num_capture_channels = reverb_decay_estimators_.size(); + RTC_DCHECK_EQ(num_capture_channels, impulse_responses.size()); + RTC_DCHECK_EQ(num_capture_channels, frequency_responses.size()); + RTC_DCHECK_EQ(num_capture_channels, usable_linear_estimates.size()); - // Estimate the reverb decay, - reverb_decay_estimator_.Update(impulse_response, linear_filter_quality, - filter_delay_blocks, usable_linear_estimate, - stationary_block); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + // Estimate the frequency response for the reverb. + reverb_frequency_responses_[ch].Update( + frequency_responses[ch], filter_delays_blocks[ch], + linear_filter_qualities[ch], stationary_block); + + // Estimate the reverb decay, + reverb_decay_estimators_[ch]->Update( + impulse_responses[ch], linear_filter_qualities[ch], + filter_delays_blocks[ch], usable_linear_estimates[ch], + stationary_block); + } } + } // namespace webrtc diff --git a/modules/audio_processing/aec3/reverb_model_estimator.h b/modules/audio_processing/aec3/reverb_model_estimator.h index 1112f93a71..3b9971abae 100644 --- a/modules/audio_processing/aec3/reverb_model_estimator.h +++ b/modules/audio_processing/aec3/reverb_model_estimator.h @@ -28,34 +28,38 @@ class ApmDataDumper; // Class for estimating the model parameters for the reverberant echo. class ReverbModelEstimator { public: - explicit ReverbModelEstimator(const EchoCanceller3Config& config); + ReverbModelEstimator(const EchoCanceller3Config& config, + size_t num_capture_channels); ~ReverbModelEstimator(); // Updates the estimates based on new data. - void Update(rtc::ArrayView impulse_response, - const std::vector>& - frequency_response, - const absl::optional& linear_filter_quality, - int filter_delay_blocks, - bool usable_linear_estimate, - bool stationary_block); + void Update( + rtc::ArrayView> impulse_responses, + rtc::ArrayView>> + frequency_responses, + rtc::ArrayView> linear_filter_qualities, + rtc::ArrayView filter_delays_blocks, + const std::vector& usable_linear_estimates, + bool stationary_block); // Returns the exponential decay of the reverberant echo. - float ReverbDecay() const { return reverb_decay_estimator_.Decay(); } + // TODO(peah): Correct to properly support multiple channels. + float ReverbDecay() const { return reverb_decay_estimators_[0]->Decay(); } // Return the frequency response of the reverberant echo. + // TODO(peah): Correct to properly support multiple channels. rtc::ArrayView GetReverbFrequencyResponse() const { - return reverb_frequency_response_.FrequencyResponse(); + return reverb_frequency_responses_[0].FrequencyResponse(); } // Dumps debug data. void Dump(ApmDataDumper* data_dumper) const { - reverb_decay_estimator_.Dump(data_dumper); + reverb_decay_estimators_[0]->Dump(data_dumper); } private: - ReverbDecayEstimator reverb_decay_estimator_; - ReverbFrequencyResponse reverb_frequency_response_; + std::vector> reverb_decay_estimators_; + std::vector reverb_frequency_responses_; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc b/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc index 8fce9d2867..50a4dc0256 100644 --- a/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc +++ b/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc @@ -11,8 +11,10 @@ #include "modules/audio_processing/aec3/reverb_model_estimator.h" #include +#include #include #include +#include #include "absl/types/optional.h" #include "api/array_view.h" @@ -25,14 +27,32 @@ namespace webrtc { +namespace { + +EchoCanceller3Config CreateConfigForTest(float default_decay) { + EchoCanceller3Config cfg; + cfg.ep_strength.default_len = default_decay; + cfg.filter.main.length_blocks = 40; + return cfg; +} + +constexpr int kFilterDelayBlocks = 2; + +} // namespace + class ReverbModelEstimatorTest { public: - explicit ReverbModelEstimatorTest(float default_decay) - : default_decay_(default_decay), estimated_decay_(default_decay) { - aec3_config_.ep_strength.default_len = default_decay_; - aec3_config_.filter.main.length_blocks = 40; - h_.resize(aec3_config_.filter.main.length_blocks * kBlockSize); - H2_.resize(aec3_config_.filter.main.length_blocks); + ReverbModelEstimatorTest(float default_decay, size_t num_capture_channels) + : aec3_config_(CreateConfigForTest(default_decay)), + estimated_decay_(default_decay), + h_(num_capture_channels, + std::vector( + aec3_config_.filter.main.length_blocks * kBlockSize, + 0.f)), + H2_(num_capture_channels, + std::vector>( + aec3_config_.filter.main.length_blocks)), + quality_linear_(num_capture_channels, 1.0f) { CreateImpulseResponseWithDecay(); } void RunEstimator(); @@ -43,51 +63,63 @@ class ReverbModelEstimatorTest { private: void CreateImpulseResponseWithDecay(); - - absl::optional quality_linear_ = 1.0f; - static constexpr int kFilterDelayBlocks = 2; - static constexpr bool kUsableLinearEstimate = true; static constexpr bool kStationaryBlock = false; static constexpr float kTruePowerDecay = 0.5f; - EchoCanceller3Config aec3_config_; - float default_decay_; + const EchoCanceller3Config aec3_config_; float estimated_decay_; float estimated_power_tail_ = 0.f; float true_power_tail_ = 0.f; - std::vector h_; - std::vector> H2_; + std::vector> h_; + std::vector>> H2_; + std::vector> quality_linear_; }; void ReverbModelEstimatorTest::CreateImpulseResponseWithDecay() { const Aec3Fft fft; - RTC_DCHECK_EQ(h_.size(), aec3_config_.filter.main.length_blocks * kBlockSize); - RTC_DCHECK_EQ(H2_.size(), aec3_config_.filter.main.length_blocks); + for (const auto& h_k : h_) { + RTC_DCHECK_EQ(h_k.size(), + aec3_config_.filter.main.length_blocks * kBlockSize); + } + for (const auto& H2_k : H2_) { + RTC_DCHECK_EQ(H2_k.size(), aec3_config_.filter.main.length_blocks); + } RTC_DCHECK_EQ(kFilterDelayBlocks, 2); float decay_sample = std::sqrt(powf(kTruePowerDecay, 1.f / kBlockSize)); const size_t filter_delay_coefficients = kFilterDelayBlocks * kBlockSize; - std::fill(h_.begin(), h_.end(), 0.f); - h_[filter_delay_coefficients] = 1.f; - for (size_t k = filter_delay_coefficients + 1; k < h_.size(); ++k) { - h_[k] = h_[k - 1] * decay_sample; + for (auto& h_i : h_) { + std::fill(h_i.begin(), h_i.end(), 0.f); + h_i[filter_delay_coefficients] = 1.f; + for (size_t k = filter_delay_coefficients + 1; k < h_i.size(); ++k) { + h_i[k] = h_i[k - 1] * decay_sample; + } } - std::array fft_data; - FftData H_j; - for (size_t j = 0, k = 0; j < H2_.size(); ++j, k += kBlockSize) { - fft_data.fill(0.f); - std::copy(h_.begin() + k, h_.begin() + k + kBlockSize, fft_data.begin()); - fft.Fft(&fft_data, &H_j); - H_j.Spectrum(Aec3Optimization::kNone, H2_[j]); + for (size_t ch = 0; ch < H2_.size(); ++ch) { + for (size_t j = 0, k = 0; j < H2_[ch].size(); ++j, k += kBlockSize) { + std::array fft_data; + fft_data.fill(0.f); + std::copy(h_[ch].begin() + k, h_[ch].begin() + k + kBlockSize, + fft_data.begin()); + FftData H_j; + fft.Fft(&fft_data, &H_j); + H_j.Spectrum(Aec3Optimization::kNone, H2_[ch][j]); + } } - rtc::ArrayView H2_tail(H2_[H2_.size() - 1]); + rtc::ArrayView H2_tail(H2_[0][H2_[0].size() - 1]); true_power_tail_ = std::accumulate(H2_tail.begin(), H2_tail.end(), 0.f); } void ReverbModelEstimatorTest::RunEstimator() { - ReverbModelEstimator estimator(aec3_config_); + const size_t num_capture_channels = H2_.size(); + constexpr bool kUsableLinearEstimate = true; + ReverbModelEstimator estimator(aec3_config_, num_capture_channels); + std::vector usable_linear_estimates(num_capture_channels, + kUsableLinearEstimate); + std::vector filter_delay_blocks(num_capture_channels, + kFilterDelayBlocks); for (size_t k = 0; k < 3000; ++k) { - estimator.Update(h_, H2_, quality_linear_, kFilterDelayBlocks, - kUsableLinearEstimate, kStationaryBlock); + estimator.Update(h_, H2_, quality_linear_, filter_delay_blocks, + usable_linear_estimates, kStationaryBlock); } estimated_decay_ = estimator.ReverbDecay(); auto freq_resp_tail = estimator.GetReverbFrequencyResponse(); @@ -96,19 +128,23 @@ void ReverbModelEstimatorTest::RunEstimator() { } TEST(ReverbModelEstimatorTests, NotChangingDecay) { - constexpr float default_decay = 0.9f; - ReverbModelEstimatorTest test(default_decay); - test.RunEstimator(); - EXPECT_EQ(test.GetDecay(), default_decay); - EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); + constexpr float kDefaultDecay = 0.9f; + for (size_t num_capture_channels : {1, 2, 4, 8}) { + ReverbModelEstimatorTest test(kDefaultDecay, num_capture_channels); + test.RunEstimator(); + EXPECT_EQ(test.GetDecay(), kDefaultDecay); + EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); + } } TEST(ReverbModelEstimatorTests, ChangingDecay) { - constexpr float default_decay = -0.9f; - ReverbModelEstimatorTest test(default_decay); - test.RunEstimator(); - EXPECT_NEAR(test.GetDecay(), test.GetTrueDecay(), 0.1); - EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); + constexpr float kDefaultDecay = -0.9f; + for (size_t num_capture_channels : {1, 2, 4, 8}) { + ReverbModelEstimatorTest test(kDefaultDecay, num_capture_channels); + test.RunEstimator(); + EXPECT_NEAR(test.GetDecay(), test.GetTrueDecay(), 0.1); + EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); + } } } // namespace webrtc