This CL includes the following changes:

- BUILD file reorganized, unit tests now have dedicated targets.
- "fake_polqa" is a binary producing fake output in the same format of PolqaOem64; the binary is injected for unit tests instead of the actual POLQA tool.
- Minor refactoring to inject the path to the POLQA binary instead of its parent folder.
- Unit tests for the evaluation score workers.
- Unit tests for the ApmModuleSimulator class.
- Unit tests for the test data generators: ReverberationTestDataGenerator added.

BUG=webrtc:7218

Review-Url: https://codereview.webrtc.org/2811953002
Cr-Commit-Position: refs/heads/master@{#17674}
This commit is contained in:
alessiob
2017-04-12 06:56:25 -07:00
committed by Commit bot
parent 103ac7e7d9
commit a79143f3e9
10 changed files with 337 additions and 73 deletions

View File

@ -8,7 +8,15 @@
import("../../../../webrtc.gni")
copy("py_quality_assessment") {
group("py_quality_assessment") {
testonly = true
deps = [
":scripts",
":unit_tests",
]
}
copy("scripts") {
testonly = true
sources = [
"README.md",
@ -16,7 +24,6 @@ copy("py_quality_assessment") {
"apm_quality_assessment.sh",
"apm_quality_assessment_export.py",
"apm_quality_assessment_gencfgs.py",
"apm_quality_assessment_unittest.py",
]
outputs = [
"$root_build_dir/py_quality_assessment/{{source_file_part}}",
@ -28,37 +35,7 @@ copy("py_quality_assessment") {
"../..:audioproc_f",
"//resources/audio_processing/test/py_quality_assessment:probing_signals",
]
} # py_quality_assessment
copy("lib") {
testonly = true
sources = [
"quality_assessment/__init__.py",
"quality_assessment/audioproc_wrapper.py",
"quality_assessment/data_access.py",
"quality_assessment/eval_scores.py",
"quality_assessment/eval_scores_factory.py",
"quality_assessment/eval_scores_unittest.py",
"quality_assessment/evaluation.py",
"quality_assessment/exceptions.py",
"quality_assessment/export.py",
"quality_assessment/results.css",
"quality_assessment/results.js",
"quality_assessment/signal_processing.py",
"quality_assessment/signal_processing_unittest.py",
"quality_assessment/simulation.py",
"quality_assessment/test_data_generation.py",
"quality_assessment/test_data_generation_factory.py",
"quality_assessment/test_data_generation_unittest.py",
]
visibility = [ ":*" ] # Only targets in this file can depend on this.
outputs = [
"$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}",
]
deps = [
"//resources/audio_processing/test/py_quality_assessment:noise_tracks",
]
} # lib
}
copy("apm_configs") {
testonly = true
@ -71,6 +48,33 @@ copy("apm_configs") {
]
} # apm_configs
copy("lib") {
testonly = true
sources = [
"quality_assessment/__init__.py",
"quality_assessment/audioproc_wrapper.py",
"quality_assessment/data_access.py",
"quality_assessment/eval_scores.py",
"quality_assessment/eval_scores_factory.py",
"quality_assessment/evaluation.py",
"quality_assessment/exceptions.py",
"quality_assessment/export.py",
"quality_assessment/results.css",
"quality_assessment/results.js",
"quality_assessment/signal_processing.py",
"quality_assessment/simulation.py",
"quality_assessment/test_data_generation.py",
"quality_assessment/test_data_generation_factory.py",
]
visibility = [ ":*" ] # Only targets in this file can depend on this.
outputs = [
"$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}",
]
deps = [
"//resources/audio_processing/test/py_quality_assessment:noise_tracks",
]
}
copy("output") {
testonly = true
sources = [
@ -80,4 +84,52 @@ copy("output") {
outputs = [
"$root_build_dir/py_quality_assessment/output/{{source_file_part}}",
]
} # output
}
group("unit_tests") {
testonly = true
visibility = [ ":*" ] # Only targets in this file can depend on this.
deps = [
":fake_polqa",
":lib_unit_tests",
":scripts_unit_tests",
]
}
rtc_executable("fake_polqa") {
testonly = true
sources = [
"quality_assessment/fake_polqa.cc",
]
visibility = [ ":*" ] # Only targets in this file can depend on this.
output_name = "py_quality_assessment/quality_assessment/fake_polqa"
deps = [
"//webrtc:webrtc_common",
"//webrtc/base:rtc_base_approved",
]
}
copy("lib_unit_tests") {
testonly = true
sources = [
"quality_assessment/eval_scores_unittest.py",
"quality_assessment/signal_processing_unittest.py",
"quality_assessment/simulation_unittest.py",
"quality_assessment/test_data_generation_unittest.py",
]
visibility = [ ":*" ] # Only targets in this file can depend on this.
outputs = [
"$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}",
]
}
copy("scripts_unit_tests") {
testonly = true
sources = [
"apm_quality_assessment_unittest.py",
]
visibility = [ ":*" ] # Only targets in this file can depend on this.
outputs = [
"$root_build_dir/py_quality_assessment/{{source_file_part}}",
]
}

View File

@ -19,9 +19,12 @@ Usage: apm_quality_assessment.py -i audio1.wav [audio2.wav ...]
import argparse
import logging
import os
import sys
import quality_assessment.audioproc_wrapper as audioproc_wrapper
import quality_assessment.eval_scores as eval_scores
import quality_assessment.evaluation as evaluation
import quality_assessment.test_data_generation as test_data_generation
import quality_assessment.simulation as simulation
@ -33,6 +36,8 @@ _EVAL_SCORE_WORKER_NAMES = _EVAL_SCORE_WORKER_CLASSES.keys()
_DEFAULT_CONFIG_FILE = 'apm_configs/default.json'
_POLQA_BIN_NAME = 'PolqaOem64'
def _InstanceArgumentsParser():
"""Arguments parser factory.
@ -85,7 +90,9 @@ def main():
simulator = simulation.ApmModuleSimulator(
aechen_ir_database_path=args.air_db_path,
polqa_tool_path=args.polqa_path)
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(),
evaluator=evaluation.ApmModuleEvaluator())
simulator.Run(
config_filepaths=args.config_files,
input_filepaths=args.input_files,

View File

@ -145,21 +145,19 @@ class PolqaScore(EvaluationScore):
"""
NAME = 'polqa'
_BIN_FILENAME = 'PolqaOem64'
def __init__(self, polqa_tool_path):
def __init__(self, polqa_bin_filepath):
EvaluationScore.__init__(self)
# Path to the POLQA directory with binary and license files.
self._polqa_tool_path = polqa_tool_path
# POLQA binary file path.
self._polqa_bin_filepath = os.path.join(
self._polqa_tool_path, self._BIN_FILENAME)
self._polqa_bin_filepath = polqa_bin_filepath
if not os.path.exists(self._polqa_bin_filepath):
logging.error('cannot find POLQA tool binary file')
raise exceptions.FileNotFoundError()
# Path to the POLQA directory with binary and license files.
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
def _Run(self, output_path):
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
if os.path.exists(polqa_out_filepath):

View File

@ -21,8 +21,8 @@ class EvaluationScoreWorkerFactory(object):
workers.
"""
def __init__(self, polqa_tool_path):
self._polqa_tool_path = polqa_tool_path
def __init__(self, polqa_tool_bin_path):
self._polqa_tool_bin_path = polqa_tool_bin_path
def GetInstance(self, evaluation_score_class):
"""Creates an EvaluationScore instance given a class object.
@ -33,7 +33,7 @@ class EvaluationScoreWorkerFactory(object):
logging.debug(
'factory producing a %s evaluation score', evaluation_score_class)
if evaluation_score_class == eval_scores.PolqaScore:
return eval_scores.PolqaScore(self._polqa_tool_path)
return eval_scores.PolqaScore(self._polqa_tool_bin_path)
else:
# By default, no arguments in the constructor.
return evaluation_score_class()

View File

@ -9,17 +9,76 @@
"""Unit tests for the eval_scores module.
"""
import os
import shutil
import tempfile
import unittest
import pydub
from . import data_access
from . import eval_scores
from . import eval_scores_factory
from . import signal_processing
class TestEvalScores(unittest.TestCase):
"""Unit tests for the eval_scores module.
"""
def setUp(self):
"""Create temporary output folder and two audio track files."""
self._output_path = tempfile.mkdtemp()
# Create fake reference and tested (i.e., APM output) audio track files.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_reference_signal = (
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
fake_tested_signal = (
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
# Save fake audio tracks.
self._fake_reference_signal_filepath = os.path.join(
self._output_path, 'fake_ref.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_reference_signal_filepath, fake_reference_signal)
self._fake_tested_signal_filepath = os.path.join(
self._output_path, 'fake_test.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_tested_signal_filepath, fake_tested_signal)
def tearDown(self):
"""Recursively delete temporary folder."""
shutil.rmtree(self._output_path)
def test_registered_classes(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._output_path))
# Check that there is at least one registered evaluation score worker.
classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
self.assertIsInstance(classes, dict)
self.assertGreater(len(classes), 0)
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance evaluation score workers factory with fake dependencies.
eval_score_workers_factory = (
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa')))
# Try each registered evaluation score worker.
for eval_score_name in registered_classes:
# Instance evaluation score worker.
eval_score_worker = eval_score_workers_factory.GetInstance(
registered_classes[eval_score_name])
# Set reference and test, then run.
eval_score_worker.SetReferenceSignalFilepath(
self._fake_reference_signal_filepath)
eval_score_worker.SetTestedSignalFilepath(
self._fake_tested_signal_filepath)
eval_score_worker.Run(self._output_path)
# Check output.
score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
self.assertTrue(isinstance(score, float))

View File

@ -0,0 +1,54 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <fstream>
#include <iostream>
#include "webrtc/base/checks.h"
namespace webrtc {
namespace test {
namespace {
const char* const kErrorMessage = "-Out /path/to/output/file is mandatory";
// Writes fake output intended to be parsed by
// quality_assessment.eval_scores.PolqaScore.
void WriteOutputFile(const std::string& output_file_path) {
RTC_CHECK_NE(output_file_path, "");
std::ofstream out(output_file_path);
RTC_CHECK(!out.bad());
out << "* Fake Polqa output" << std::endl;
out << "FakeField1\tPolqaScore\tFakeField2" << std::endl;
out << "FakeValue1\t3.25\tFakeValue2" << std::endl;
out.close();
}
} // namespace
int main(int argc, char* argv[]) {
// Find "-Out" and use its next argument as output file path.
RTC_CHECK_GE(argc, 3) << kErrorMessage;
const std::string kSoughtFlagName = "-Out";
for (int i = 1; i < argc - 1; ++i) {
if (kSoughtFlagName.compare(argv[i]) == 0) {
WriteOutputFile(argv[i + 1]);
return 0;
}
}
FATAL() << kErrorMessage;
}
} // namespace test
} // namespace webrtc
int main(int argc, char* argv[]) {
return webrtc::test::main(argc, argv);
}

View File

@ -12,11 +12,9 @@
import logging
import os
from . import audioproc_wrapper
from . import data_access
from . import eval_scores
from . import eval_scores_factory
from . import evaluation
from . import test_data_generation
from . import test_data_generation_factory
@ -29,10 +27,11 @@ class ApmModuleSimulator(object):
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
def __init__(self, aechen_ir_database_path, polqa_tool_path):
def __init__(self, aechen_ir_database_path, polqa_tool_bin_path,
ap_wrapper, evaluator):
# Init.
self._audioproc_wrapper = audioproc_wrapper.AudioProcWrapper()
self._evaluator = evaluation.ApmModuleEvaluator()
self._audioproc_wrapper = ap_wrapper
self._evaluator = evaluator
# Instance factory objects.
self._test_data_generator_factory = (
@ -40,7 +39,7 @@ class ApmModuleSimulator(object):
aechen_ir_database_path=aechen_ir_database_path))
self._evaluation_score_factory = (
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_path=polqa_tool_path))
polqa_tool_bin_path=polqa_tool_bin_path))
# Properties for each run.
self._base_output_path = None

View File

@ -0,0 +1,89 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the simulation module.
"""
import os
import shutil
import sys
import tempfile
import unittest
SRC = os.path.abspath(os.path.join(
os.path.dirname((__file__)), os.pardir, os.pardir, os.pardir, os.pardir))
sys.path.append(os.path.join(SRC, 'third_party', 'pymock'))
import mock
import pydub
from . import audioproc_wrapper
from . import evaluation
from . import signal_processing
from . import simulation
class TestApmModuleSimulator(unittest.TestCase):
"""Unit tests for the ApmModuleSimulator class.
"""
def setUp(self):
"""Create temporary folder and fake audio track."""
self._output_path = tempfile.mkdtemp()
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_audio_track_path, fake_signal)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._output_path)
def testSimulation(self):
# Instance dependencies to inject and mock.
ap_wrapper = audioproc_wrapper.AudioProcWrapper()
evaluator = evaluation.ApmModuleEvaluator()
ap_wrapper.Run = mock.MagicMock(name='Run')
evaluator.Run = mock.MagicMock(name='Run')
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
aechen_ir_database_path='',
polqa_tool_bin_path=os.path.join(
os.path.dirname(__file__), 'fake_polqa'),
ap_wrapper=ap_wrapper,
evaluator=evaluator)
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [self._fake_audio_track_path]
test_data_generators = ['identity', 'white_noise']
eval_scores = ['audio_level', 'polqa']
# Run all simulations.
simulator.Run(
config_filepaths=config_files,
input_filepaths=input_files,
test_data_generator_names=test_data_generators,
eval_score_names=eval_scores,
output_dir=self._output_path)
# Check.
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
# the client code (e.g., number of SNR pairs for the white noise teste data
# gnerator), the exact number of calls to ap_wrapper.Run and evaluator.Run
# is known; use that with assertEqual.
min_number_of_simulations = len(config_files) * len(input_files) * len(
test_data_generators)
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
min_number_of_simulations)
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
min_number_of_simulations)

View File

@ -297,10 +297,11 @@ class EnvironmentalNoiseTestDataGenerator(TestDataGenerator):
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
# TODO(alessiob): allow the user to store the noise tracks in a custom path.
_NOISE_TRACKS_PATH = os.path.join(os.getcwd(), 'noise_tracks')
_NOISE_TRACKS_PATH = os.path.join(
os.path.dirname(__file__), os.pardir, 'noise_tracks')
# TODO(alessiob): allow the user to have custom noise tracks.
# TODO(alessiob): exploit TestDataGeneratorFactory.GetInstance().
# TODO(alessiob): Allow the user to have custom noise tracks.
# TODO(alessiob): Exploit TestDataGeneratorFactory.GetInstance().
_NOISE_TRACKS = [
'city.wav'
]
@ -436,7 +437,7 @@ class ReverberationTestDataGenerator(TestDataGenerator):
# Load noise track.
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
except IOError: # File not found.
except exceptions.FileNotFoundError:
# Generate noise track by applying the impulse response.
impulse_response_filepath = os.path.join(
self._aechen_ir_database_path,

View File

@ -14,6 +14,9 @@ import shutil
import tempfile
import unittest
import numpy as np
import scipy.io
from . import test_data_generation
from . import test_data_generation_factory
from . import signal_processing
@ -27,11 +30,27 @@ class TestTestDataGenerators(unittest.TestCase):
"""Create temporary folders."""
self._base_output_path = tempfile.mkdtemp()
self._input_noise_cache_path = tempfile.mkdtemp()
self._fake_air_db_path = tempfile.mkdtemp()
# Fake AIR DB impulse responses.
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
# impulse responses. When changed, the coupling below between
# impulse_response_mat_file_names and
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
impulse_response_mat_file_names = [
'air_binaural_lecture_0_0_1.mat',
'air_binaural_booth_0_0_1.mat',
]
for impulse_response_mat_file_name in impulse_response_mat_file_names:
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
scipy.io.savemat(os.path.join(
self._fake_air_db_path, impulse_response_mat_file_name), data)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._base_output_path)
shutil.rmtree(self._input_noise_cache_path)
shutil.rmtree(self._fake_air_db_path)
def testTestDataGenerators(self):
# Preliminary check.
@ -47,11 +66,7 @@ class TestTestDataGenerators(unittest.TestCase):
# Instance generators factory.
generators_factory = (
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=''))
# TODO(alessiob): Replace with a mock of TestDataGeneratorFactory that
# takes no arguments in the ctor. For those generators that need parameters,
# it will return a mock generator (see the first comment in the next for
# loop).
aechen_ir_database_path=self._fake_air_db_path))
# Use a sample input file as clean input signal.
input_signal_filepath = os.path.join(
@ -64,16 +79,6 @@ class TestTestDataGenerators(unittest.TestCase):
# Try each registered test data generator.
for generator_name in registered_classes:
# Exclude ReverberationTestDataGenerator.
# TODO(alessiob): Mock ReverberationTestDataGenerator, the mock
# should rely on hard-coded impulse responses. This requires a mock for
# TestDataGeneratorFactory. The latter knows whether returning the
# actual generator or a mock object (as in the case of
# ReverberationTestDataGenerator).
if generator_name == (
test_data_generation.ReverberationTestDataGenerator.NAME):
continue
# Instance test data generator.
generator = generators_factory.GetInstance(
registered_classes[generator_name])