This CL includes the following changes:
- BUILD file reorganized, unit tests now have dedicated targets. - "fake_polqa" is a binary producing fake output in the same format of PolqaOem64; the binary is injected for unit tests instead of the actual POLQA tool. - Minor refactoring to inject the path to the POLQA binary instead of its parent folder. - Unit tests for the evaluation score workers. - Unit tests for the ApmModuleSimulator class. - Unit tests for the test data generators: ReverberationTestDataGenerator added. BUG=webrtc:7218 Review-Url: https://codereview.webrtc.org/2811953002 Cr-Commit-Position: refs/heads/master@{#17674}
This commit is contained in:
@ -8,7 +8,15 @@
|
||||
|
||||
import("../../../../webrtc.gni")
|
||||
|
||||
copy("py_quality_assessment") {
|
||||
group("py_quality_assessment") {
|
||||
testonly = true
|
||||
deps = [
|
||||
":scripts",
|
||||
":unit_tests",
|
||||
]
|
||||
}
|
||||
|
||||
copy("scripts") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"README.md",
|
||||
@ -16,7 +24,6 @@ copy("py_quality_assessment") {
|
||||
"apm_quality_assessment.sh",
|
||||
"apm_quality_assessment_export.py",
|
||||
"apm_quality_assessment_gencfgs.py",
|
||||
"apm_quality_assessment_unittest.py",
|
||||
]
|
||||
outputs = [
|
||||
"$root_build_dir/py_quality_assessment/{{source_file_part}}",
|
||||
@ -28,37 +35,7 @@ copy("py_quality_assessment") {
|
||||
"../..:audioproc_f",
|
||||
"//resources/audio_processing/test/py_quality_assessment:probing_signals",
|
||||
]
|
||||
} # py_quality_assessment
|
||||
|
||||
copy("lib") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"quality_assessment/__init__.py",
|
||||
"quality_assessment/audioproc_wrapper.py",
|
||||
"quality_assessment/data_access.py",
|
||||
"quality_assessment/eval_scores.py",
|
||||
"quality_assessment/eval_scores_factory.py",
|
||||
"quality_assessment/eval_scores_unittest.py",
|
||||
"quality_assessment/evaluation.py",
|
||||
"quality_assessment/exceptions.py",
|
||||
"quality_assessment/export.py",
|
||||
"quality_assessment/results.css",
|
||||
"quality_assessment/results.js",
|
||||
"quality_assessment/signal_processing.py",
|
||||
"quality_assessment/signal_processing_unittest.py",
|
||||
"quality_assessment/simulation.py",
|
||||
"quality_assessment/test_data_generation.py",
|
||||
"quality_assessment/test_data_generation_factory.py",
|
||||
"quality_assessment/test_data_generation_unittest.py",
|
||||
]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
outputs = [
|
||||
"$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}",
|
||||
]
|
||||
deps = [
|
||||
"//resources/audio_processing/test/py_quality_assessment:noise_tracks",
|
||||
]
|
||||
} # lib
|
||||
}
|
||||
|
||||
copy("apm_configs") {
|
||||
testonly = true
|
||||
@ -71,6 +48,33 @@ copy("apm_configs") {
|
||||
]
|
||||
} # apm_configs
|
||||
|
||||
copy("lib") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"quality_assessment/__init__.py",
|
||||
"quality_assessment/audioproc_wrapper.py",
|
||||
"quality_assessment/data_access.py",
|
||||
"quality_assessment/eval_scores.py",
|
||||
"quality_assessment/eval_scores_factory.py",
|
||||
"quality_assessment/evaluation.py",
|
||||
"quality_assessment/exceptions.py",
|
||||
"quality_assessment/export.py",
|
||||
"quality_assessment/results.css",
|
||||
"quality_assessment/results.js",
|
||||
"quality_assessment/signal_processing.py",
|
||||
"quality_assessment/simulation.py",
|
||||
"quality_assessment/test_data_generation.py",
|
||||
"quality_assessment/test_data_generation_factory.py",
|
||||
]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
outputs = [
|
||||
"$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}",
|
||||
]
|
||||
deps = [
|
||||
"//resources/audio_processing/test/py_quality_assessment:noise_tracks",
|
||||
]
|
||||
}
|
||||
|
||||
copy("output") {
|
||||
testonly = true
|
||||
sources = [
|
||||
@ -80,4 +84,52 @@ copy("output") {
|
||||
outputs = [
|
||||
"$root_build_dir/py_quality_assessment/output/{{source_file_part}}",
|
||||
]
|
||||
} # output
|
||||
}
|
||||
|
||||
group("unit_tests") {
|
||||
testonly = true
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
deps = [
|
||||
":fake_polqa",
|
||||
":lib_unit_tests",
|
||||
":scripts_unit_tests",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_executable("fake_polqa") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"quality_assessment/fake_polqa.cc",
|
||||
]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
output_name = "py_quality_assessment/quality_assessment/fake_polqa"
|
||||
deps = [
|
||||
"//webrtc:webrtc_common",
|
||||
"//webrtc/base:rtc_base_approved",
|
||||
]
|
||||
}
|
||||
|
||||
copy("lib_unit_tests") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"quality_assessment/eval_scores_unittest.py",
|
||||
"quality_assessment/signal_processing_unittest.py",
|
||||
"quality_assessment/simulation_unittest.py",
|
||||
"quality_assessment/test_data_generation_unittest.py",
|
||||
]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
outputs = [
|
||||
"$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}",
|
||||
]
|
||||
}
|
||||
|
||||
copy("scripts_unit_tests") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"apm_quality_assessment_unittest.py",
|
||||
]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
outputs = [
|
||||
"$root_build_dir/py_quality_assessment/{{source_file_part}}",
|
||||
]
|
||||
}
|
||||
|
||||
@ -19,9 +19,12 @@ Usage: apm_quality_assessment.py -i audio1.wav [audio2.wav ...]
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import quality_assessment.audioproc_wrapper as audioproc_wrapper
|
||||
import quality_assessment.eval_scores as eval_scores
|
||||
import quality_assessment.evaluation as evaluation
|
||||
import quality_assessment.test_data_generation as test_data_generation
|
||||
import quality_assessment.simulation as simulation
|
||||
|
||||
@ -33,6 +36,8 @@ _EVAL_SCORE_WORKER_NAMES = _EVAL_SCORE_WORKER_CLASSES.keys()
|
||||
|
||||
_DEFAULT_CONFIG_FILE = 'apm_configs/default.json'
|
||||
|
||||
_POLQA_BIN_NAME = 'PolqaOem64'
|
||||
|
||||
|
||||
def _InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
@ -85,7 +90,9 @@ def main():
|
||||
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
aechen_ir_database_path=args.air_db_path,
|
||||
polqa_tool_path=args.polqa_path)
|
||||
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
simulator.Run(
|
||||
config_filepaths=args.config_files,
|
||||
input_filepaths=args.input_files,
|
||||
|
||||
@ -145,21 +145,19 @@ class PolqaScore(EvaluationScore):
|
||||
"""
|
||||
|
||||
NAME = 'polqa'
|
||||
_BIN_FILENAME = 'PolqaOem64'
|
||||
|
||||
def __init__(self, polqa_tool_path):
|
||||
def __init__(self, polqa_bin_filepath):
|
||||
EvaluationScore.__init__(self)
|
||||
|
||||
# Path to the POLQA directory with binary and license files.
|
||||
self._polqa_tool_path = polqa_tool_path
|
||||
|
||||
# POLQA binary file path.
|
||||
self._polqa_bin_filepath = os.path.join(
|
||||
self._polqa_tool_path, self._BIN_FILENAME)
|
||||
self._polqa_bin_filepath = polqa_bin_filepath
|
||||
if not os.path.exists(self._polqa_bin_filepath):
|
||||
logging.error('cannot find POLQA tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
# Path to the POLQA directory with binary and license files.
|
||||
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
|
||||
|
||||
def _Run(self, output_path):
|
||||
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
|
||||
if os.path.exists(polqa_out_filepath):
|
||||
|
||||
@ -21,8 +21,8 @@ class EvaluationScoreWorkerFactory(object):
|
||||
workers.
|
||||
"""
|
||||
|
||||
def __init__(self, polqa_tool_path):
|
||||
self._polqa_tool_path = polqa_tool_path
|
||||
def __init__(self, polqa_tool_bin_path):
|
||||
self._polqa_tool_bin_path = polqa_tool_bin_path
|
||||
|
||||
def GetInstance(self, evaluation_score_class):
|
||||
"""Creates an EvaluationScore instance given a class object.
|
||||
@ -33,7 +33,7 @@ class EvaluationScoreWorkerFactory(object):
|
||||
logging.debug(
|
||||
'factory producing a %s evaluation score', evaluation_score_class)
|
||||
if evaluation_score_class == eval_scores.PolqaScore:
|
||||
return eval_scores.PolqaScore(self._polqa_tool_path)
|
||||
return eval_scores.PolqaScore(self._polqa_tool_bin_path)
|
||||
else:
|
||||
# By default, no arguments in the constructor.
|
||||
return evaluation_score_class()
|
||||
|
||||
@ -9,17 +9,76 @@
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pydub
|
||||
|
||||
from . import data_access
|
||||
from . import eval_scores
|
||||
from . import eval_scores_factory
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class TestEvalScores(unittest.TestCase):
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary output folder and two audio track files."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
|
||||
# Create fake reference and tested (i.e., APM output) audio track files.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_reference_signal = (
|
||||
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
|
||||
fake_tested_signal = (
|
||||
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
|
||||
|
||||
# Save fake audio tracks.
|
||||
self._fake_reference_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_ref.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_reference_signal_filepath, fake_reference_signal)
|
||||
self._fake_tested_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_test.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_tested_signal_filepath, fake_tested_signal)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
shutil.rmtree(self._output_path)
|
||||
|
||||
def test_registered_classes(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._output_path))
|
||||
|
||||
# Check that there is at least one registered evaluation score worker.
|
||||
classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
self.assertIsInstance(classes, dict)
|
||||
self.assertGreater(len(classes), 0)
|
||||
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance evaluation score workers factory with fake dependencies.
|
||||
eval_score_workers_factory = (
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa')))
|
||||
|
||||
# Try each registered evaluation score worker.
|
||||
for eval_score_name in registered_classes:
|
||||
# Instance evaluation score worker.
|
||||
eval_score_worker = eval_score_workers_factory.GetInstance(
|
||||
registered_classes[eval_score_name])
|
||||
|
||||
# Set reference and test, then run.
|
||||
eval_score_worker.SetReferenceSignalFilepath(
|
||||
self._fake_reference_signal_filepath)
|
||||
eval_score_worker.SetTestedSignalFilepath(
|
||||
self._fake_tested_signal_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
|
||||
# Check output.
|
||||
score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
|
||||
self.assertTrue(isinstance(score, float))
|
||||
|
||||
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "webrtc/base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
const char* const kErrorMessage = "-Out /path/to/output/file is mandatory";
|
||||
|
||||
// Writes fake output intended to be parsed by
|
||||
// quality_assessment.eval_scores.PolqaScore.
|
||||
void WriteOutputFile(const std::string& output_file_path) {
|
||||
RTC_CHECK_NE(output_file_path, "");
|
||||
std::ofstream out(output_file_path);
|
||||
RTC_CHECK(!out.bad());
|
||||
out << "* Fake Polqa output" << std::endl;
|
||||
out << "FakeField1\tPolqaScore\tFakeField2" << std::endl;
|
||||
out << "FakeValue1\t3.25\tFakeValue2" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
// Find "-Out" and use its next argument as output file path.
|
||||
RTC_CHECK_GE(argc, 3) << kErrorMessage;
|
||||
const std::string kSoughtFlagName = "-Out";
|
||||
for (int i = 1; i < argc - 1; ++i) {
|
||||
if (kSoughtFlagName.compare(argv[i]) == 0) {
|
||||
WriteOutputFile(argv[i + 1]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
FATAL() << kErrorMessage;
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return webrtc::test::main(argc, argv);
|
||||
}
|
||||
@ -12,11 +12,9 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from . import audioproc_wrapper
|
||||
from . import data_access
|
||||
from . import eval_scores
|
||||
from . import eval_scores_factory
|
||||
from . import evaluation
|
||||
from . import test_data_generation
|
||||
from . import test_data_generation_factory
|
||||
|
||||
@ -29,10 +27,11 @@ class ApmModuleSimulator(object):
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
|
||||
def __init__(self, aechen_ir_database_path, polqa_tool_path):
|
||||
def __init__(self, aechen_ir_database_path, polqa_tool_bin_path,
|
||||
ap_wrapper, evaluator):
|
||||
# Init.
|
||||
self._audioproc_wrapper = audioproc_wrapper.AudioProcWrapper()
|
||||
self._evaluator = evaluation.ApmModuleEvaluator()
|
||||
self._audioproc_wrapper = ap_wrapper
|
||||
self._evaluator = evaluator
|
||||
|
||||
# Instance factory objects.
|
||||
self._test_data_generator_factory = (
|
||||
@ -40,7 +39,7 @@ class ApmModuleSimulator(object):
|
||||
aechen_ir_database_path=aechen_ir_database_path))
|
||||
self._evaluation_score_factory = (
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_path=polqa_tool_path))
|
||||
polqa_tool_bin_path=polqa_tool_bin_path))
|
||||
|
||||
# Properties for each run.
|
||||
self._base_output_path = None
|
||||
|
||||
@ -0,0 +1,89 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the simulation module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
SRC = os.path.abspath(os.path.join(
|
||||
os.path.dirname((__file__)), os.pardir, os.pardir, os.pardir, os.pardir))
|
||||
sys.path.append(os.path.join(SRC, 'third_party', 'pymock'))
|
||||
|
||||
import mock
|
||||
import pydub
|
||||
|
||||
from . import audioproc_wrapper
|
||||
from . import evaluation
|
||||
from . import signal_processing
|
||||
from . import simulation
|
||||
|
||||
|
||||
class TestApmModuleSimulator(unittest.TestCase):
|
||||
"""Unit tests for the ApmModuleSimulator class.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folder and fake audio track."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_audio_track_path, fake_signal)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._output_path)
|
||||
|
||||
def testSimulation(self):
|
||||
# Instance dependencies to inject and mock.
|
||||
ap_wrapper = audioproc_wrapper.AudioProcWrapper()
|
||||
evaluator = evaluation.ApmModuleEvaluator()
|
||||
ap_wrapper.Run = mock.MagicMock(name='Run')
|
||||
evaluator.Run = mock.MagicMock(name='Run')
|
||||
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
aechen_ir_database_path='',
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(__file__), 'fake_polqa'),
|
||||
ap_wrapper=ap_wrapper,
|
||||
evaluator=evaluator)
|
||||
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [self._fake_audio_track_path]
|
||||
test_data_generators = ['identity', 'white_noise']
|
||||
eval_scores = ['audio_level', 'polqa']
|
||||
|
||||
# Run all simulations.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
input_filepaths=input_files,
|
||||
test_data_generator_names=test_data_generators,
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
|
||||
# Check.
|
||||
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
|
||||
# the client code (e.g., number of SNR pairs for the white noise teste data
|
||||
# gnerator), the exact number of calls to ap_wrapper.Run and evaluator.Run
|
||||
# is known; use that with assertEqual.
|
||||
min_number_of_simulations = len(config_files) * len(input_files) * len(
|
||||
test_data_generators)
|
||||
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
@ -297,10 +297,11 @@ class EnvironmentalNoiseTestDataGenerator(TestDataGenerator):
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
|
||||
# TODO(alessiob): allow the user to store the noise tracks in a custom path.
|
||||
_NOISE_TRACKS_PATH = os.path.join(os.getcwd(), 'noise_tracks')
|
||||
_NOISE_TRACKS_PATH = os.path.join(
|
||||
os.path.dirname(__file__), os.pardir, 'noise_tracks')
|
||||
|
||||
# TODO(alessiob): allow the user to have custom noise tracks.
|
||||
# TODO(alessiob): exploit TestDataGeneratorFactory.GetInstance().
|
||||
# TODO(alessiob): Allow the user to have custom noise tracks.
|
||||
# TODO(alessiob): Exploit TestDataGeneratorFactory.GetInstance().
|
||||
_NOISE_TRACKS = [
|
||||
'city.wav'
|
||||
]
|
||||
@ -436,7 +437,7 @@ class ReverberationTestDataGenerator(TestDataGenerator):
|
||||
# Load noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
except IOError: # File not found.
|
||||
except exceptions.FileNotFoundError:
|
||||
# Generate noise track by applying the impulse response.
|
||||
impulse_response_filepath = os.path.join(
|
||||
self._aechen_ir_database_path,
|
||||
|
||||
@ -14,6 +14,9 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
|
||||
from . import test_data_generation
|
||||
from . import test_data_generation_factory
|
||||
from . import signal_processing
|
||||
@ -27,11 +30,27 @@ class TestTestDataGenerators(unittest.TestCase):
|
||||
"""Create temporary folders."""
|
||||
self._base_output_path = tempfile.mkdtemp()
|
||||
self._input_noise_cache_path = tempfile.mkdtemp()
|
||||
self._fake_air_db_path = tempfile.mkdtemp()
|
||||
|
||||
# Fake AIR DB impulse responses.
|
||||
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
|
||||
# impulse responses. When changed, the coupling below between
|
||||
# impulse_response_mat_file_names and
|
||||
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
|
||||
impulse_response_mat_file_names = [
|
||||
'air_binaural_lecture_0_0_1.mat',
|
||||
'air_binaural_booth_0_0_1.mat',
|
||||
]
|
||||
for impulse_response_mat_file_name in impulse_response_mat_file_names:
|
||||
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
|
||||
scipy.io.savemat(os.path.join(
|
||||
self._fake_air_db_path, impulse_response_mat_file_name), data)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._base_output_path)
|
||||
shutil.rmtree(self._input_noise_cache_path)
|
||||
shutil.rmtree(self._fake_air_db_path)
|
||||
|
||||
def testTestDataGenerators(self):
|
||||
# Preliminary check.
|
||||
@ -47,11 +66,7 @@ class TestTestDataGenerators(unittest.TestCase):
|
||||
# Instance generators factory.
|
||||
generators_factory = (
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=''))
|
||||
# TODO(alessiob): Replace with a mock of TestDataGeneratorFactory that
|
||||
# takes no arguments in the ctor. For those generators that need parameters,
|
||||
# it will return a mock generator (see the first comment in the next for
|
||||
# loop).
|
||||
aechen_ir_database_path=self._fake_air_db_path))
|
||||
|
||||
# Use a sample input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(
|
||||
@ -64,16 +79,6 @@ class TestTestDataGenerators(unittest.TestCase):
|
||||
|
||||
# Try each registered test data generator.
|
||||
for generator_name in registered_classes:
|
||||
# Exclude ReverberationTestDataGenerator.
|
||||
# TODO(alessiob): Mock ReverberationTestDataGenerator, the mock
|
||||
# should rely on hard-coded impulse responses. This requires a mock for
|
||||
# TestDataGeneratorFactory. The latter knows whether returning the
|
||||
# actual generator or a mock object (as in the case of
|
||||
# ReverberationTestDataGenerator).
|
||||
if generator_name == (
|
||||
test_data_generation.ReverberationTestDataGenerator.NAME):
|
||||
continue
|
||||
|
||||
# Instance test data generator.
|
||||
generator = generators_factory.GetInstance(
|
||||
registered_classes[generator_name])
|
||||
|
||||
Reference in New Issue
Block a user