Reformat python files checked by pylint (part 1/2).

After recently changing .pylintrc (see [1]) we discovered that
the presubmit check always checks all the python files when just
one python file gets updated.

This CL moves all these files one step closer to what the linter
wants.

Autogenerated with:

# Added all the files under pylint control to ~/Desktop/to-reformat
cat ~/Desktop/to-reformat | xargs sed -i '1i\\'
git cl format --python --full

This is part 1 out of 2. The second part will fix function names and
will not be automated.

[1] - https://webrtc-review.googlesource.com/c/src/+/186664

No-Presubmit: True
Bug: webrtc:12114
Change-Id: Idfec4d759f209a2090440d0af2413a1ddc01b841
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/190980
Commit-Queue: Mirko Bonadei <mbonadei@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32530}
This commit is contained in:
Mirko Bonadei
2020-10-30 10:13:45 +01:00
committed by Commit Bot
parent d3a3e9ef36
commit 8cc6695652
93 changed files with 9936 additions and 9285 deletions

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Perform APM module quality assessment on one or more input files using one or
more APM simulator configuration files and one or more test data generators.
@ -47,139 +46,172 @@ _POLQA_BIN_NAME = 'PolqaOem64'
def _InstanceArgumentsParser():
"""Arguments parser factory.
"""Arguments parser factory.
"""
parser = argparse.ArgumentParser(description=(
'Perform APM module quality assessment on one or more input files using '
'one or more APM simulator configuration files and one or more '
'test data generators.'))
parser = argparse.ArgumentParser(description=(
'Perform APM module quality assessment on one or more input files using '
'one or more APM simulator configuration files and one or more '
'test data generators.'))
parser.add_argument('-c', '--config_files', nargs='+', required=False,
help=('path to the configuration files defining the '
'arguments with which the APM simulator tool is '
'called'),
default=[_DEFAULT_CONFIG_FILE])
parser.add_argument('-c',
'--config_files',
nargs='+',
required=False,
help=('path to the configuration files defining the '
'arguments with which the APM simulator tool is '
'called'),
default=[_DEFAULT_CONFIG_FILE])
parser.add_argument('-i', '--capture_input_files', nargs='+', required=True,
help='path to the capture input wav files (one or more)')
parser.add_argument(
'-i',
'--capture_input_files',
nargs='+',
required=True,
help='path to the capture input wav files (one or more)')
parser.add_argument('-r', '--render_input_files', nargs='+', required=False,
help=('path to the render input wav files; either '
'omitted or one file for each file in '
'--capture_input_files (files will be paired by '
'index)'), default=None)
parser.add_argument('-r',
'--render_input_files',
nargs='+',
required=False,
help=('path to the render input wav files; either '
'omitted or one file for each file in '
'--capture_input_files (files will be paired by '
'index)'),
default=None)
parser.add_argument('-p', '--echo_path_simulator', required=False,
help=('custom echo path simulator name; required if '
'--render_input_files is specified'),
choices=_ECHO_PATH_SIMULATOR_NAMES,
default=echo_path_simulation.NoEchoPathSimulator.NAME)
parser.add_argument('-p',
'--echo_path_simulator',
required=False,
help=('custom echo path simulator name; required if '
'--render_input_files is specified'),
choices=_ECHO_PATH_SIMULATOR_NAMES,
default=echo_path_simulation.NoEchoPathSimulator.NAME)
parser.add_argument('-t', '--test_data_generators', nargs='+', required=False,
help='custom list of test data generators to use',
choices=_TEST_DATA_GENERATORS_NAMES,
default=_TEST_DATA_GENERATORS_NAMES)
parser.add_argument('-t',
'--test_data_generators',
nargs='+',
required=False,
help='custom list of test data generators to use',
choices=_TEST_DATA_GENERATORS_NAMES,
default=_TEST_DATA_GENERATORS_NAMES)
parser.add_argument('--additive_noise_tracks_path', required=False,
help='path to the wav files for the additive',
default=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH)
parser.add_argument('--additive_noise_tracks_path', required=False,
help='path to the wav files for the additive',
default=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH)
parser.add_argument('-e', '--eval_scores', nargs='+', required=False,
help='custom list of evaluation scores to use',
choices=_EVAL_SCORE_WORKER_NAMES,
default=_EVAL_SCORE_WORKER_NAMES)
parser.add_argument('-e',
'--eval_scores',
nargs='+',
required=False,
help='custom list of evaluation scores to use',
choices=_EVAL_SCORE_WORKER_NAMES,
default=_EVAL_SCORE_WORKER_NAMES)
parser.add_argument('-o', '--output_dir', required=False,
help=('base path to the output directory in which the '
'output wav files and the evaluation outcomes '
'are saved'),
default='output')
parser.add_argument('-o',
'--output_dir',
required=False,
help=('base path to the output directory in which the '
'output wav files and the evaluation outcomes '
'are saved'),
default='output')
parser.add_argument('--polqa_path', required=True,
help='path to the POLQA tool')
parser.add_argument('--polqa_path',
required=True,
help='path to the POLQA tool')
parser.add_argument('--air_db_path', required=True,
help='path to the Aechen IR database')
parser.add_argument('--air_db_path',
required=True,
help='path to the Aechen IR database')
parser.add_argument('--apm_sim_path', required=False,
help='path to the APM simulator tool',
default=audioproc_wrapper. \
AudioProcWrapper. \
DEFAULT_APM_SIMULATOR_BIN_PATH)
parser.add_argument('--apm_sim_path', required=False,
help='path to the APM simulator tool',
default=audioproc_wrapper. \
AudioProcWrapper. \
DEFAULT_APM_SIMULATOR_BIN_PATH)
parser.add_argument('--echo_metric_tool_bin_path', required=False,
help=('path to the echo metric binary '
'(required for the echo eval score)'),
default=None)
parser.add_argument('--echo_metric_tool_bin_path',
required=False,
help=('path to the echo metric binary '
'(required for the echo eval score)'),
default=None)
parser.add_argument('--copy_with_identity_generator', required=False,
help=('If true, the identity test data generator makes a '
'copy of the clean speech input file.'),
default=False)
parser.add_argument(
'--copy_with_identity_generator',
required=False,
help=('If true, the identity test data generator makes a '
'copy of the clean speech input file.'),
default=False)
parser.add_argument('--external_vad_paths', nargs='+', required=False,
help=('Paths to external VAD programs. Each must take'
'\'-i <wav file> -o <output>\' inputs'), default=[])
parser.add_argument('--external_vad_paths',
nargs='+',
required=False,
help=('Paths to external VAD programs. Each must take'
'\'-i <wav file> -o <output>\' inputs'),
default=[])
parser.add_argument('--external_vad_names', nargs='+', required=False,
help=('Keys to the vad paths. Must be different and '
'as many as the paths.'), default=[])
parser.add_argument('--external_vad_names',
nargs='+',
required=False,
help=('Keys to the vad paths. Must be different and '
'as many as the paths.'),
default=[])
return parser
return parser
def _ValidateArguments(args, parser):
if args.capture_input_files and args.render_input_files and (
len(args.capture_input_files) != len(args.render_input_files)):
parser.error('--render_input_files and --capture_input_files must be lists '
'having the same length')
sys.exit(1)
if args.capture_input_files and args.render_input_files and (len(
args.capture_input_files) != len(args.render_input_files)):
parser.error(
'--render_input_files and --capture_input_files must be lists '
'having the same length')
sys.exit(1)
if args.render_input_files and not args.echo_path_simulator:
parser.error('when --render_input_files is set, --echo_path_simulator is '
'also required')
sys.exit(1)
if args.render_input_files and not args.echo_path_simulator:
parser.error(
'when --render_input_files is set, --echo_path_simulator is '
'also required')
sys.exit(1)
if len(args.external_vad_names) != len(args.external_vad_paths):
parser.error('If provided, --external_vad_paths and '
'--external_vad_names must '
'have the same number of arguments.')
sys.exit(1)
if len(args.external_vad_names) != len(args.external_vad_paths):
parser.error('If provided, --external_vad_paths and '
'--external_vad_names must '
'have the same number of arguments.')
sys.exit(1)
def main():
# TODO(alessiob): level = logging.INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
_ValidateArguments(args, parser)
# TODO(alessiob): level = logging.INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
_ValidateArguments(args, parser)
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=args.air_db_path,
noise_tracks_path=args.additive_noise_tracks_path,
copy_with_identity=args.copy_with_identity_generator)),
evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path
),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
evaluator=evaluation.ApmModuleEvaluator(),
external_vads=external_vad.ExternalVad.ConstructVadDict(
args.external_vad_paths, args.external_vad_names))
simulator.Run(
config_filepaths=args.config_files,
capture_input_filepaths=args.capture_input_files,
render_input_filepaths=args.render_input_files,
echo_path_simulator_name=args.echo_path_simulator,
test_data_generator_names=args.test_data_generators,
eval_score_names=args.eval_scores,
output_dir=args.output_dir)
sys.exit(0)
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=args.air_db_path,
noise_tracks_path=args.additive_noise_tracks_path,
copy_with_identity=args.copy_with_identity_generator)),
evaluation_score_factory=eval_scores_factory.
EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
evaluator=evaluation.ApmModuleEvaluator(),
external_vads=external_vad.ExternalVad.ConstructVadDict(
args.external_vad_paths, args.external_vad_names))
simulator.Run(config_filepaths=args.config_files,
capture_input_filepaths=args.capture_input_files,
render_input_filepaths=args.render_input_files,
echo_path_simulator_name=args.echo_path_simulator,
test_data_generator_names=args.test_data_generators,
eval_score_names=args.eval_scores,
output_dir=args.output_dir)
sys.exit(0)
if __name__ == '__main__':
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Shows boxplots of given score for different values of selected
parameters. Can be used to compare scores by audioproc_f flag.
@ -30,29 +29,37 @@ import quality_assessment.collect_data as collect_data
def InstanceArgumentsParser():
"""Arguments parser factory.
"""Arguments parser factory.
"""
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Shows boxplot of given score for different values of selected'
'parameters. Can be used to compare scores by audioproc_f flag')
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Shows boxplot of given score for different values of selected'
'parameters. Can be used to compare scores by audioproc_f flag')
parser.add_argument('-v', '--eval_score', required=True,
help=('Score name for constructing boxplots'))
parser.add_argument('-v',
'--eval_score',
required=True,
help=('Score name for constructing boxplots'))
parser.add_argument('-n', '--config_dir', required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument(
'-n',
'--config_dir',
required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument('-z', '--params_to_plot', required=True,
nargs='+', help=('audioproc_f parameter values'
'by which to group scores (no leading dash)'))
parser.add_argument('-z',
'--params_to_plot',
required=True,
nargs='+',
help=('audioproc_f parameter values'
'by which to group scores (no leading dash)'))
return parser
return parser
def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
"""Filters data on the values of one or more parameters.
"""Filters data on the values of one or more parameters.
Args:
data_frame: pandas.DataFrame of all used input data.
@ -71,34 +78,36 @@ def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
Returns: dictionary, key is a param value, result is all scores for
that param value (see `filter_params` for explanation).
"""
results = collections.defaultdict(dict)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
results = collections.defaultdict(dict)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + '.json'))
data_with_config = data_frame[data_frame.apm_config == config_name]
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
score_name]
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + '.json'))
data_with_config = data_frame[data_frame.apm_config == config_name]
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
score_name]
# Exactly one of |params_to_plot| must match:
(matching_param, ) = [x for x in filter_params if '-' + x in config_json]
# Exactly one of |params_to_plot| must match:
(matching_param, ) = [
x for x in filter_params if '-' + x in config_json
]
# Add scores for every track to the result.
for capture_name in data_cell_scores.capture:
result_score = float(data_cell_scores[data_cell_scores.capture ==
capture_name].score)
config_dict = results[config_json['-' + matching_param]]
if capture_name not in config_dict:
config_dict[capture_name] = {}
# Add scores for every track to the result.
for capture_name in data_cell_scores.capture:
result_score = float(data_cell_scores[data_cell_scores.capture ==
capture_name].score)
config_dict = results[config_json['-' + matching_param]]
if capture_name not in config_dict:
config_dict[capture_name] = {}
config_dict[capture_name][matching_param] = result_score
config_dict[capture_name][matching_param] = result_score
return results
return results
def _FlattenToScoresList(config_param_score_dict):
"""Extracts a list of scores from input data structure.
"""Extracts a list of scores from input data structure.
Args:
config_param_score_dict: of the form {'capture_name':
@ -107,40 +116,39 @@ def _FlattenToScoresList(config_param_score_dict):
Returns: Plain list of all score value present in input data
structure
"""
result = []
for capture_name in config_param_score_dict:
result += list(config_param_score_dict[capture_name].values())
return result
result = []
for capture_name in config_param_score_dict:
result += list(config_param_score_dict[capture_name].values())
return result
def main():
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = InstanceArgumentsParser()
args = parser.parse_args()
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = InstanceArgumentsParser()
args = parser.parse_args()
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Filter the data by `args.params_to_plot`
scores_filtered = FilterScoresByParams(scores_data_frame,
args.params_to_plot,
args.eval_score,
args.config_dir)
# Filter the data by `args.params_to_plot`
scores_filtered = FilterScoresByParams(scores_data_frame,
args.params_to_plot,
args.eval_score, args.config_dir)
data_list = sorted(scores_filtered.items())
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
data_labels = [x for (x, _) in data_list]
data_list = sorted(scores_filtered.items())
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
data_labels = [x for (x, _) in data_list]
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
axes.boxplot(data_values, labels=data_labels)
axes.set_ylabel(args.eval_score)
axes.set_xlabel('/'.join(args.params_to_plot))
plt.show()
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
axes.boxplot(data_values, labels=data_labels)
axes.set_ylabel(args.eval_score)
axes.set_xlabel('/'.join(args.params_to_plot))
plt.show()
if __name__ == "__main__":
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Export the scores computed by the apm_quality_assessment.py script into an
HTML file.
"""
@ -20,7 +19,7 @@ import quality_assessment.export as export
def _BuildOutputFilename(filename_suffix):
"""Builds the filename for the exported file.
"""Builds the filename for the exported file.
Args:
filename_suffix: suffix for the output file name.
@ -28,34 +27,37 @@ def _BuildOutputFilename(filename_suffix):
Returns:
A string.
"""
if filename_suffix is None:
return 'results.html'
return 'results-{}.html'.format(filename_suffix)
if filename_suffix is None:
return 'results.html'
return 'results-{}.html'.format(filename_suffix)
def main():
# Init.
logging.basicConfig(level=logging.DEBUG) # TODO(alessio): INFO once debugged.
parser = collect_data.InstanceArgumentsParser()
parser.add_argument('-f', '--filename_suffix',
help=('suffix of the exported file'))
parser.description = ('Exports pre-computed APM module quality assessment '
'results into HTML tables')
args = parser.parse_args()
# Init.
logging.basicConfig(
level=logging.DEBUG) # TODO(alessio): INFO once debugged.
parser = collect_data.InstanceArgumentsParser()
parser.add_argument('-f',
'--filename_suffix',
help=('suffix of the exported file'))
parser.description = ('Exports pre-computed APM module quality assessment '
'results into HTML tables')
args = parser.parse_args()
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Export.
output_filepath = os.path.join(args.output_dir, _BuildOutputFilename(
args.filename_suffix))
exporter = export.HtmlExport(output_filepath)
exporter.Export(scores_data_frame)
# Export.
output_filepath = os.path.join(args.output_dir,
_BuildOutputFilename(args.filename_suffix))
exporter = export.HtmlExport(output_filepath)
exporter.Export(scores_data_frame)
logging.info('output file successfully written in %s', output_filepath)
sys.exit(0)
logging.info('output file successfully written in %s', output_filepath)
sys.exit(0)
if __name__ == '__main__':
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Generate .json files with which the APM module can be tested using the
apm_quality_assessment.py script and audioproc_f as APM simulator.
"""
@ -20,7 +19,7 @@ OUTPUT_PATH = os.path.abspath('apm_configs')
def _GenerateDefaultOverridden(config_override):
"""Generates one or more APM overriden configurations.
"""Generates one or more APM overriden configurations.
For each item in config_override, it overrides the default configuration and
writes a new APM configuration file.
@ -45,54 +44,85 @@ def _GenerateDefaultOverridden(config_override):
config_override: dict of APM configuration file names as keys; the values
are dict instances encoding the audioproc_f flags.
"""
for config_filename in config_override:
config = config_override[config_filename]
config['-all_default'] = None
for config_filename in config_override:
config = config_override[config_filename]
config['-all_default'] = None
config_filepath = os.path.join(OUTPUT_PATH, 'default-{}.json'.format(
config_filename))
logging.debug('config file <%s> | %s', config_filepath, config)
config_filepath = os.path.join(
OUTPUT_PATH, 'default-{}.json'.format(config_filename))
logging.debug('config file <%s> | %s', config_filepath, config)
data_access.AudioProcConfigFile.Save(config_filepath, config)
logging.info('config file created: <%s>', config_filepath)
data_access.AudioProcConfigFile.Save(config_filepath, config)
logging.info('config file created: <%s>', config_filepath)
def _GenerateAllDefaultButOne():
"""Disables the flags enabled by default one-by-one.
"""Disables the flags enabled by default one-by-one.
"""
config_sets = {
'no_AEC': {'-aec': 0,},
'no_AGC': {'-agc': 0,},
'no_HP_filter': {'-hpf': 0,},
'no_level_estimator': {'-le': 0,},
'no_noise_suppressor': {'-ns': 0,},
'no_transient_suppressor': {'-ts': 0,},
'no_vad': {'-vad': 0,},
}
_GenerateDefaultOverridden(config_sets)
config_sets = {
'no_AEC': {
'-aec': 0,
},
'no_AGC': {
'-agc': 0,
},
'no_HP_filter': {
'-hpf': 0,
},
'no_level_estimator': {
'-le': 0,
},
'no_noise_suppressor': {
'-ns': 0,
},
'no_transient_suppressor': {
'-ts': 0,
},
'no_vad': {
'-vad': 0,
},
}
_GenerateDefaultOverridden(config_sets)
def _GenerateAllDefaultPlusOne():
"""Enables the flags disabled by default one-by-one.
"""Enables the flags disabled by default one-by-one.
"""
config_sets = {
'with_AECM': {'-aec': 0, '-aecm': 1,}, # AEC and AECM are exclusive.
'with_AGC_limiter': {'-agc_limiter': 1,},
'with_AEC_delay_agnostic': {'-delay_agnostic': 1,},
'with_drift_compensation': {'-drift_compensation': 1,},
'with_residual_echo_detector': {'-ed': 1,},
'with_AEC_extended_filter': {'-extended_filter': 1,},
'with_LC': {'-lc': 1,},
'with_refined_adaptive_filter': {'-refined_adaptive_filter': 1,},
}
_GenerateDefaultOverridden(config_sets)
config_sets = {
'with_AECM': {
'-aec': 0,
'-aecm': 1,
}, # AEC and AECM are exclusive.
'with_AGC_limiter': {
'-agc_limiter': 1,
},
'with_AEC_delay_agnostic': {
'-delay_agnostic': 1,
},
'with_drift_compensation': {
'-drift_compensation': 1,
},
'with_residual_echo_detector': {
'-ed': 1,
},
'with_AEC_extended_filter': {
'-extended_filter': 1,
},
'with_LC': {
'-lc': 1,
},
'with_refined_adaptive_filter': {
'-refined_adaptive_filter': 1,
},
}
_GenerateDefaultOverridden(config_sets)
def main():
logging.basicConfig(level=logging.INFO)
_GenerateAllDefaultPlusOne()
_GenerateAllDefaultButOne()
logging.basicConfig(level=logging.INFO)
_GenerateAllDefaultPlusOne()
_GenerateAllDefaultButOne()
if __name__ == '__main__':
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Finds the APM configuration that maximizes a provided metric by
parsing the output generated apm_quality_assessment.py.
"""
@ -20,33 +19,44 @@ import os
import quality_assessment.data_access as data_access
import quality_assessment.collect_data as collect_data
def _InstanceArgumentsParser():
"""Arguments parser factory. Extends the arguments from 'collect_data'
"""Arguments parser factory. Extends the arguments from 'collect_data'
with a few extra for selecting what parameters to optimize for.
"""
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Rudimentary optimization of a function over different parameter'
'combinations.')
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Rudimentary optimization of a function over different parameter'
'combinations.')
parser.add_argument('-n', '--config_dir', required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument(
'-n',
'--config_dir',
required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument('-p', '--params', required=True, nargs='+',
help=('parameters to parse from the config files in'
'config_dir'))
parser.add_argument('-p',
'--params',
required=True,
nargs='+',
help=('parameters to parse from the config files in'
'config_dir'))
parser.add_argument('-z', '--params_not_to_optimize', required=False,
nargs='+', default=[],
help=('parameters from `params` not to be optimized for'))
parser.add_argument(
'-z',
'--params_not_to_optimize',
required=False,
nargs='+',
default=[],
help=('parameters from `params` not to be optimized for'))
return parser
return parser
def _ConfigurationAndScores(data_frame, params,
params_not_to_optimize, config_dir):
"""Returns a list of all configurations and scores.
def _ConfigurationAndScores(data_frame, params, params_not_to_optimize,
config_dir):
"""Returns a list of all configurations and scores.
Args:
data_frame: A pandas data frame with the scores and config name
@ -72,47 +82,47 @@ def _ConfigurationAndScores(data_frame, params,
param combinations for params in `params_not_to_optimize` and
their scores.
"""
results = collections.defaultdict(list)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
score_names = data_frame['eval_score_name'].drop_duplicates().values.tolist()
results = collections.defaultdict(list)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
score_names = data_frame['eval_score_name'].drop_duplicates(
).values.tolist()
# Normalize the scores
normalization_constants = {}
for score_name in score_names:
scores = data_frame[data_frame.eval_score_name == score_name].score
normalization_constants[score_name] = max(scores)
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
param_combination = collections.namedtuple("ParamCombination",
params_to_optimize)
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + ".json"))
scores = {}
data_cell = data_frame[data_frame.apm_config == config_name]
# Normalize the scores
normalization_constants = {}
for score_name in score_names:
data_cell_scores = data_cell[data_cell.eval_score_name ==
score_name].score
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
scores[score_name] /= normalization_constants[score_name]
scores = data_frame[data_frame.eval_score_name == score_name].score
normalization_constants[score_name] = max(scores)
result = {'scores': scores, 'params': {}}
config_optimize_params = {}
for param in params:
if param in params_to_optimize:
config_optimize_params[param] = config_json['-' + param]
else:
result['params'][param] = config_json['-' + param]
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
param_combination = collections.namedtuple("ParamCombination",
params_to_optimize)
current_param_combination = param_combination(
**config_optimize_params)
results[current_param_combination].append(result)
return results
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + ".json"))
scores = {}
data_cell = data_frame[data_frame.apm_config == config_name]
for score_name in score_names:
data_cell_scores = data_cell[data_cell.eval_score_name ==
score_name].score
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
scores[score_name] /= normalization_constants[score_name]
result = {'scores': scores, 'params': {}}
config_optimize_params = {}
for param in params:
if param in params_to_optimize:
config_optimize_params[param] = config_json['-' + param]
else:
result['params'][param] = config_json['-' + param]
current_param_combination = param_combination(**config_optimize_params)
results[current_param_combination].append(result)
return results
def _FindOptimalParameter(configs_and_scores, score_weighting):
"""Finds the config producing the maximal score.
"""Finds the config producing the maximal score.
Args:
configs_and_scores: structure of the form returned by
@ -127,53 +137,53 @@ def _FindOptimalParameter(configs_and_scores, score_weighting):
to its scores.
"""
min_score = float('+inf')
best_params = None
for config in configs_and_scores:
scores_and_params = configs_and_scores[config]
current_score = score_weighting(scores_and_params)
if current_score < min_score:
min_score = current_score
best_params = config
logging.debug("Score: %f", current_score)
logging.debug("Config: %s", str(config))
return best_params
min_score = float('+inf')
best_params = None
for config in configs_and_scores:
scores_and_params = configs_and_scores[config]
current_score = score_weighting(scores_and_params)
if current_score < min_score:
min_score = current_score
best_params = config
logging.debug("Score: %f", current_score)
logging.debug("Config: %s", str(config))
return best_params
def _ExampleWeighting(scores_and_configs):
"""Example argument to `_FindOptimalParameter`
"""Example argument to `_FindOptimalParameter`
Args:
scores_and_configs: a list of configs and scores, in the form
described in _FindOptimalParameter
Returns:
numeric value, the sum of all scores
"""
res = 0
for score_config in scores_and_configs:
res += sum(score_config['scores'].values())
return res
res = 0
for score_config in scores_and_configs:
res += sum(score_config['scores'].values())
return res
def main():
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug('Src path <%s>', src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
all_scores = _ConfigurationAndScores(scores_data_frame,
args.params,
args.params_not_to_optimize,
args.config_dir)
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug('Src path <%s>', src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
all_scores = _ConfigurationAndScores(scores_data_frame, args.params,
args.params_not_to_optimize,
args.config_dir)
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
logging.info('Optimal parameter combination: <%s>', opt_param)
logging.info('It\'s score values: <%s>', all_scores[opt_param])
logging.info('Optimal parameter combination: <%s>', opt_param)
logging.info('It\'s score values: <%s>', all_scores[opt_param])
if __name__ == "__main__":
main()
main()

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the apm_quality_assessment module.
"""
@ -16,13 +15,14 @@ import mock
import apm_quality_assessment
class TestSimulationScript(unittest.TestCase):
"""Unit tests for the apm_quality_assessment module.
"""Unit tests for the apm_quality_assessment module.
"""
def testMain(self):
# Exit with error code if no arguments are passed.
with self.assertRaises(SystemExit) as cm, mock.patch.object(
sys, 'argv', ['apm_quality_assessment.py']):
apm_quality_assessment.main()
self.assertGreater(cm.exception.code, 0)
def testMain(self):
# Exit with error code if no arguments are passed.
with self.assertRaises(SystemExit) as cm, mock.patch.object(
sys, 'argv', ['apm_quality_assessment.py']):
apm_quality_assessment.main()
self.assertGreater(cm.exception.code, 0)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Extraction of annotations from audio files.
"""
@ -19,10 +18,10 @@ import sys
import tempfile
try:
import numpy as np
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import external_vad
from . import exceptions
@ -30,262 +29,268 @@ from . import signal_processing
class AudioAnnotationsExtractor(object):
"""Extracts annotations from audio files.
"""Extracts annotations from audio files.
"""
class VadType(object):
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
class VadType(object):
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
def __init__(self, value):
if (not isinstance(value, int)) or not 0 <= value <= 7:
raise exceptions.InitializationException(
'Invalid vad type: ' + value)
self._value = value
def __init__(self, value):
if (not isinstance(value, int)) or not 0 <= value <= 7:
raise exceptions.InitializationException('Invalid vad type: ' +
value)
self._value = value
def Contains(self, vad_type):
return self._value | vad_type == self._value
def Contains(self, vad_type):
return self._value | vad_type == self._value
def __str__(self):
vads = []
if self.Contains(self.ENERGY_THRESHOLD):
vads.append("energy")
if self.Contains(self.WEBRTC_COMMON_AUDIO):
vads.append("common_audio")
if self.Contains(self.WEBRTC_APM):
vads.append("apm")
return "VadType({})".format(", ".join(vads))
def __str__(self):
vads = []
if self.Contains(self.ENERGY_THRESHOLD):
vads.append("energy")
if self.Contains(self.WEBRTC_COMMON_AUDIO):
vads.append("common_audio")
if self.Contains(self.WEBRTC_APM):
vads.append("apm")
return "VadType({})".format(", ".join(vads))
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
# Level estimation params.
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
_LEVEL_FRAME_SIZE_MS = 1.0
# The time constants in ms indicate the time it takes for the level estimate
# to go down/up by 1 db if the signal is zero.
_LEVEL_ATTACK_MS = 5.0
_LEVEL_DECAY_MS = 20.0
# VAD params.
_VAD_THRESHOLD = 1
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(
os.path.abspath(__file__)), os.pardir, os.pardir)
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
_VAD_WEBRTC_APM_PATH = os.path.join(
_VAD_WEBRTC_PATH, 'apm_vad')
def __init__(self, vad_type, external_vads=None):
self._signal = None
self._level = None
self._level_frame_size = None
self._common_audio_vad = None
self._energy_vad = None
self._apm_vad_probs = None
self._apm_vad_rms = None
self._vad_frame_size = None
self._vad_frame_size_ms = None
self._c_attack = None
self._c_decay = None
self._vad_type = self.VadType(vad_type)
logging.info('VADs used for annotations: ' + str(self._vad_type))
if external_vads is None:
external_vads = {}
self._external_vads = external_vads
assert len(self._external_vads) == len(external_vads), (
'The external VAD names must be unique.')
for vad in external_vads.values():
if not isinstance(vad, external_vad.ExternalVad):
raise exceptions.InitializationException(
'Invalid vad type: ' + str(type(vad)))
logging.info('External VAD used for annotation: ' +
str(vad.name))
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
self._VAD_WEBRTC_COMMON_AUDIO_PATH
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
self._VAD_WEBRTC_APM_PATH
@classmethod
def GetOutputFileNameTemplate(cls):
return cls._OUTPUT_FILENAME_TEMPLATE
def GetLevel(self):
return self._level
def GetLevelFrameSize(self):
return self._level_frame_size
@classmethod
def GetLevelFrameSizeMs(cls):
return cls._LEVEL_FRAME_SIZE_MS
def GetVadOutput(self, vad_type):
if vad_type == self.VadType.ENERGY_THRESHOLD:
return self._energy_vad
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
return self._common_audio_vad
elif vad_type == self.VadType.WEBRTC_APM:
return (self._apm_vad_probs, self._apm_vad_rms)
else:
raise exceptions.InitializationException(
'Invalid vad type: ' + vad_type)
def GetVadFrameSize(self):
return self._vad_frame_size
def GetVadFrameSizeMs(self):
return self._vad_frame_size_ms
def Extract(self, filepath):
# Load signal.
self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath)
if self._signal.channels != 1:
raise NotImplementedError('Multiple-channel annotations not implemented')
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
# Level estimation params.
self._level_frame_size = int(self._signal.frame_rate / 1000 * (
self._LEVEL_FRAME_SIZE_MS))
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
self._ONE_DB_REDUCTION ** (
self._LEVEL_FRAME_SIZE_MS / self._LEVEL_ATTACK_MS))
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
self._ONE_DB_REDUCTION ** (
self._LEVEL_FRAME_SIZE_MS / self._LEVEL_DECAY_MS))
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
_LEVEL_FRAME_SIZE_MS = 1.0
# The time constants in ms indicate the time it takes for the level estimate
# to go down/up by 1 db if the signal is zero.
_LEVEL_ATTACK_MS = 5.0
_LEVEL_DECAY_MS = 20.0
# Compute level.
self._LevelEstimation()
# VAD params.
_VAD_THRESHOLD = 1
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)),
os.pardir, os.pardir)
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
# Ideal VAD output, it requires clean speech with high SNR as input.
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
# Naive VAD based on level thresholding.
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
self._energy_vad = np.uint8(self._level > vad_threshold)
self._vad_frame_size = self._level_frame_size
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
# WebRTC common_audio/ VAD.
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
# WebRTC modules/audio_processing/ VAD.
self._RunWebRtcApmVad(filepath)
for extvad_name in self._external_vads:
self._external_vads[extvad_name].Run(filepath)
_VAD_WEBRTC_APM_PATH = os.path.join(_VAD_WEBRTC_PATH, 'apm_vad')
def Save(self, output_path, annotation_name=""):
ext_kwargs = {'extvad_conf-' + ext_vad:
self._external_vads[ext_vad].GetVadOutput()
for ext_vad in self._external_vads}
np.savez_compressed(
file=os.path.join(
def __init__(self, vad_type, external_vads=None):
self._signal = None
self._level = None
self._level_frame_size = None
self._common_audio_vad = None
self._energy_vad = None
self._apm_vad_probs = None
self._apm_vad_rms = None
self._vad_frame_size = None
self._vad_frame_size_ms = None
self._c_attack = None
self._c_decay = None
self._vad_type = self.VadType(vad_type)
logging.info('VADs used for annotations: ' + str(self._vad_type))
if external_vads is None:
external_vads = {}
self._external_vads = external_vads
assert len(self._external_vads) == len(external_vads), (
'The external VAD names must be unique.')
for vad in external_vads.values():
if not isinstance(vad, external_vad.ExternalVad):
raise exceptions.InitializationException('Invalid vad type: ' +
str(type(vad)))
logging.info('External VAD used for annotation: ' + str(vad.name))
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
self._VAD_WEBRTC_COMMON_AUDIO_PATH
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
self._VAD_WEBRTC_APM_PATH
@classmethod
def GetOutputFileNameTemplate(cls):
return cls._OUTPUT_FILENAME_TEMPLATE
def GetLevel(self):
return self._level
def GetLevelFrameSize(self):
return self._level_frame_size
@classmethod
def GetLevelFrameSizeMs(cls):
return cls._LEVEL_FRAME_SIZE_MS
def GetVadOutput(self, vad_type):
if vad_type == self.VadType.ENERGY_THRESHOLD:
return self._energy_vad
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
return self._common_audio_vad
elif vad_type == self.VadType.WEBRTC_APM:
return (self._apm_vad_probs, self._apm_vad_rms)
else:
raise exceptions.InitializationException('Invalid vad type: ' +
vad_type)
def GetVadFrameSize(self):
return self._vad_frame_size
def GetVadFrameSizeMs(self):
return self._vad_frame_size_ms
def Extract(self, filepath):
# Load signal.
self._signal = signal_processing.SignalProcessingUtils.LoadWav(
filepath)
if self._signal.channels != 1:
raise NotImplementedError(
'Multiple-channel annotations not implemented')
# Level estimation params.
self._level_frame_size = int(self._signal.frame_rate / 1000 *
(self._LEVEL_FRAME_SIZE_MS))
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
self._LEVEL_ATTACK_MS))
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
self._LEVEL_DECAY_MS))
# Compute level.
self._LevelEstimation()
# Ideal VAD output, it requires clean speech with high SNR as input.
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
# Naive VAD based on level thresholding.
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
self._energy_vad = np.uint8(self._level > vad_threshold)
self._vad_frame_size = self._level_frame_size
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
# WebRTC common_audio/ VAD.
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
# WebRTC modules/audio_processing/ VAD.
self._RunWebRtcApmVad(filepath)
for extvad_name in self._external_vads:
self._external_vads[extvad_name].Run(filepath)
def Save(self, output_path, annotation_name=""):
ext_kwargs = {
'extvad_conf-' + ext_vad:
self._external_vads[ext_vad].GetVadOutput()
for ext_vad in self._external_vads
}
np.savez_compressed(file=os.path.join(
output_path,
self.GetOutputFileNameTemplate().format(annotation_name)),
level=self._level,
level_frame_size=self._level_frame_size,
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
vad_output=self._common_audio_vad,
vad_energy_output=self._energy_vad,
vad_frame_size=self._vad_frame_size,
vad_frame_size_ms=self._vad_frame_size_ms,
vad_probs=self._apm_vad_probs,
vad_rms=self._apm_vad_rms,
**ext_kwargs
)
level=self._level,
level_frame_size=self._level_frame_size,
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
vad_output=self._common_audio_vad,
vad_energy_output=self._energy_vad,
vad_frame_size=self._vad_frame_size,
vad_frame_size_ms=self._vad_frame_size_ms,
vad_probs=self._apm_vad_probs,
vad_rms=self._apm_vad_rms,
**ext_kwargs)
def _LevelEstimation(self):
# Read samples.
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._signal).astype(np.float32) / 32768.0
num_frames = len(samples) // self._level_frame_size
num_samples = num_frames * self._level_frame_size
def _LevelEstimation(self):
# Read samples.
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._signal).astype(np.float32) / 32768.0
num_frames = len(samples) // self._level_frame_size
num_samples = num_frames * self._level_frame_size
# Envelope.
self._level = np.max(np.reshape(np.abs(samples[:num_samples]), (
num_frames, self._level_frame_size)), axis=1)
assert len(self._level) == num_frames
# Envelope.
self._level = np.max(np.reshape(np.abs(samples[:num_samples]),
(num_frames, self._level_frame_size)),
axis=1)
assert len(self._level) == num_frames
# Envelope smoothing.
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
for i in range(1, num_frames):
self._level[i] = smooth(
self._level[i], self._level[i - 1], self._c_attack if (
self._level[i] > self._level[i - 1]) else self._c_decay)
# Envelope smoothing.
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
for i in range(1, num_frames):
self._level[i] = smooth(
self._level[i], self._level[i - 1], self._c_attack if
(self._level[i] > self._level[i - 1]) else self._c_decay)
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
self._common_audio_vad = None
self._vad_frame_size = None
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
self._common_audio_vad = None
self._vad_frame_size = None
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path = os.path.join(
tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp')
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path = os.path.join(
tmp_path,
os.path.split(wav_file_path)[1] + '_vad.tmp')
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_COMMON_AUDIO_PATH,
'-i', wav_file_path,
'-o', output_file_path
], cwd=self._VAD_WEBRTC_PATH)
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_COMMON_AUDIO_PATH, '-i', wav_file_path, '-o',
output_file_path
],
cwd=self._VAD_WEBRTC_PATH)
# Read bytes.
with open(output_file_path, 'rb') as f:
raw_data = f.read()
# Read bytes.
with open(output_file_path, 'rb') as f:
raw_data = f.read()
# Parse side information.
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
assert self._vad_frame_size_ms in [10, 20, 30]
extra_bits = struct.unpack('B', raw_data[-1])[0]
assert 0 <= extra_bits <= 8
# Parse side information.
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
assert self._vad_frame_size_ms in [10, 20, 30]
extra_bits = struct.unpack('B', raw_data[-1])[0]
assert 0 <= extra_bits <= 8
# Init VAD vector.
num_bytes = len(raw_data)
num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte.
self._common_audio_vad = np.zeros(num_frames, np.uint8)
# Init VAD vector.
num_bytes = len(raw_data)
num_frames = 8 * (num_bytes -
2) - extra_bits # 8 frames for each byte.
self._common_audio_vad = np.zeros(num_frames, np.uint8)
# Read VAD decisions.
for i, byte in enumerate(raw_data[1:-1]):
byte = struct.unpack('B', byte)[0]
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
self._common_audio_vad[i * 8 + j] = int(byte & 1)
byte = byte >> 1
except Exception as e:
logging.error('Error while running the WebRTC VAD (' + e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
# Read VAD decisions.
for i, byte in enumerate(raw_data[1:-1]):
byte = struct.unpack('B', byte)[0]
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
self._common_audio_vad[i * 8 + j] = int(byte & 1)
byte = byte >> 1
except Exception as e:
logging.error('Error while running the WebRTC VAD (' + e.message +
')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
def _RunWebRtcApmVad(self, wav_file_path):
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path_probs = os.path.join(
tmp_path, os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
output_file_path_rms = os.path.join(
tmp_path, os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
def _RunWebRtcApmVad(self, wav_file_path):
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path_probs = os.path.join(
tmp_path,
os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
output_file_path_rms = os.path.join(
tmp_path,
os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_APM_PATH,
'-i', wav_file_path,
'-o_probs', output_file_path_probs,
'-o_rms', output_file_path_rms
], cwd=self._VAD_WEBRTC_PATH)
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_APM_PATH, '-i', wav_file_path, '-o_probs',
output_file_path_probs, '-o_rms', output_file_path_rms
],
cwd=self._VAD_WEBRTC_PATH)
# Parse annotations.
self._apm_vad_probs = np.fromfile(output_file_path_probs, np.double)
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
# Parse annotations.
self._apm_vad_probs = np.fromfile(output_file_path_probs,
np.double)
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
except Exception as e:
logging.error('Error while running the WebRTC APM VAD (' +
e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
except Exception as e:
logging.error('Error while running the WebRTC APM VAD (' +
e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the annotations module.
"""
@ -25,133 +24,137 @@ from . import signal_processing
class TestAnnotationsExtraction(unittest.TestCase):
"""Unit tests for the annotations module.
"""Unit tests for the annotations module.
"""
_CLEAN_TMP_OUTPUT = True
_DEBUG_PLOT_VAD = False
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD |
_VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO |
_VAD_TYPE_CLASS.WEBRTC_APM)
_CLEAN_TMP_OUTPUT = True
_DEBUG_PLOT_VAD = False
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD
| _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO
| _VAD_TYPE_CLASS.WEBRTC_APM)
def setUp(self):
"""Create temporary folder."""
self._tmp_path = tempfile.mkdtemp()
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
'pure_tone', [440, 1000])
signal_processing.SignalProcessingUtils.SaveWav(
self._wav_file_path, pure_tone)
self._sample_rate = pure_tone.frame_rate
def setUp(self):
"""Create temporary folder."""
self._tmp_path = tempfile.mkdtemp()
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
'pure_tone', [440, 1000])
signal_processing.SignalProcessingUtils.SaveWav(
self._wav_file_path, pure_tone)
self._sample_rate = pure_tone.frame_rate
def tearDown(self):
"""Recursively delete temporary folder."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' +
(self._tmp_path))
def tearDown(self):
"""Recursively delete temporary folder."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' + (
self._tmp_path))
def testFrameSizes(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
samples_to_ms = lambda n, sr: 1000 * n // sr
self.assertEqual(
samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
e.GetLevelFrameSizeMs())
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
e.GetVadFrameSizeMs())
def testFrameSizes(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
samples_to_ms = lambda n, sr: 1000 * n // sr
self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
e.GetLevelFrameSizeMs())
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
e.GetVadFrameSizeMs())
def testVoiceActivityDetectors(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
e.Extract(self._wav_file_path)
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(
self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(
float(np.sum(vad_output)) / len(vad_output), 0.95)
def testVoiceActivityDetectors(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
e.Extract(self._wav_file_path)
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
0.95)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(
self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(
float(np.sum(vad_output)) / len(vad_output), 0.95)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
0.95)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
# pylint: disable=unpacking-non-sequence
(vad_probs,
vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
self.assertGreater(len(vad_probs), 0)
self.assertGreater(len(vad_rms), 0)
self.assertGreaterEqual(
float(np.sum(vad_probs)) / len(vad_probs), 0.5)
self.assertGreaterEqual(
float(np.sum(vad_rms)) / len(vad_rms), 20000)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
# pylint: disable=unpacking-non-sequence
(vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
self.assertGreater(len(vad_probs), 0)
self.assertGreater(len(vad_rms), 0)
self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs),
0.5)
self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000)
if self._DEBUG_PLOT_VAD:
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
num_frames).astype(np.float32) * frame_size_ms / 1000.0
level = e.GetLevel()
t_level = frame_times_s(num_frames=len(level),
frame_size_ms=e.GetLevelFrameSizeMs())
t_vad = frame_times_s(num_frames=len(vad_output),
frame_size_ms=e.GetVadFrameSizeMs())
import matplotlib.pyplot as plt
plt.figure()
plt.hold(True)
plt.plot(t_level, level)
plt.plot(t_vad, vad_output * np.max(level), '.')
plt.show()
if self._DEBUG_PLOT_VAD:
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
num_frames).astype(np.float32) * frame_size_ms / 1000.0
level = e.GetLevel()
t_level = frame_times_s(
num_frames=len(level),
frame_size_ms=e.GetLevelFrameSizeMs())
t_vad = frame_times_s(
num_frames=len(vad_output),
frame_size_ms=e.GetVadFrameSizeMs())
import matplotlib.pyplot as plt
plt.figure()
plt.hold(True)
plt.plot(t_level, level)
plt.plot(t_vad, vad_output * np.max(level), '.')
plt.show()
def testSaveLoad(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, "fake-annotation")
def testSaveLoad(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, "fake-annotation")
data = np.load(
os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
np.testing.assert_array_equal(e.GetLevel(), data['level'])
self.assertEqual(np.float32, data['level'].dtype)
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
data['vad_energy_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
data['vad_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0],
data['vad_probs'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1],
data['vad_rms'])
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
self.assertEqual(np.float64, data['vad_probs'].dtype)
self.assertEqual(np.float64, data['vad_rms'].dtype)
data = np.load(os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
np.testing.assert_array_equal(e.GetLevel(), data['level'])
self.assertEqual(np.float32, data['level'].dtype)
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
data['vad_energy_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
data['vad_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms'])
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
self.assertEqual(np.float64, data['vad_probs'].dtype)
self.assertEqual(np.float64, data['vad_rms'].dtype)
def testEmptyExternalShouldNotCrash(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
annotations.AudioAnnotationsExtractor(vad_type_value, {})
def testEmptyExternalShouldNotCrash(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
annotations.AudioAnnotationsExtractor(vad_type_value, {})
def testFakeExternalSaveLoad(self):
def FakeExternalFactory():
return external_vad.ExternalVad(
os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fake_external_vad.py'), 'fake')
def testFakeExternalSaveLoad(self):
def FakeExternalFactory():
return external_vad.ExternalVad(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'),
'fake'
)
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
e = annotations.AudioAnnotationsExtractor(
vad_type_value,
{'fake': FakeExternalFactory()})
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, annotation_name="fake-annotation")
data = np.load(os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
data['extvad_conf-fake'])
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
e = annotations.AudioAnnotationsExtractor(
vad_type_value, {'fake': FakeExternalFactory()})
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, annotation_name="fake-annotation")
data = np.load(
os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
data['extvad_conf-fake'])

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Class implementing a wrapper for APM simulators.
"""
@ -19,33 +18,36 @@ from . import exceptions
class AudioProcWrapper(object):
"""Wrapper for APM simulators.
"""Wrapper for APM simulators.
"""
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(os.path.join(
os.pardir, 'audioproc_f'))
OUTPUT_FILENAME = 'output.wav'
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(
os.path.join(os.pardir, 'audioproc_f'))
OUTPUT_FILENAME = 'output.wav'
def __init__(self, simulator_bin_path):
"""Ctor.
def __init__(self, simulator_bin_path):
"""Ctor.
Args:
simulator_bin_path: path to the APM simulator binary.
"""
self._simulator_bin_path = simulator_bin_path
self._config = None
self._output_signal_filepath = None
self._simulator_bin_path = simulator_bin_path
self._config = None
self._output_signal_filepath = None
# Profiler instance to measure running time.
self._profiler = cProfile.Profile()
# Profiler instance to measure running time.
self._profiler = cProfile.Profile()
@property
def output_filepath(self):
return self._output_signal_filepath
@property
def output_filepath(self):
return self._output_signal_filepath
def Run(self, config_filepath, capture_input_filepath, output_path,
render_input_filepath=None):
"""Runs APM simulator.
def Run(self,
config_filepath,
capture_input_filepath,
output_path,
render_input_filepath=None):
"""Runs APM simulator.
Args:
config_filepath: path to the configuration file specifying the arguments
@ -56,41 +58,43 @@ class AudioProcWrapper(object):
render_input_filepath: path to the render audio track input file (aka
reverse or far-end).
"""
# Init.
self._output_signal_filepath = os.path.join(
output_path, self.OUTPUT_FILENAME)
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
# Init.
self._output_signal_filepath = os.path.join(output_path,
self.OUTPUT_FILENAME)
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
# Skip if the output has already been generated.
if os.path.exists(self._output_signal_filepath) and os.path.exists(
profiling_stats_filepath):
return
# Skip if the output has already been generated.
if os.path.exists(self._output_signal_filepath) and os.path.exists(
profiling_stats_filepath):
return
# Load configuration.
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
# Load configuration.
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
# Set remaining parameters.
if not os.path.exists(capture_input_filepath):
raise exceptions.FileNotFoundError('cannot find capture input file')
self._config['-i'] = capture_input_filepath
self._config['-o'] = self._output_signal_filepath
if render_input_filepath is not None:
if not os.path.exists(render_input_filepath):
raise exceptions.FileNotFoundError('cannot find render input file')
self._config['-ri'] = render_input_filepath
# Set remaining parameters.
if not os.path.exists(capture_input_filepath):
raise exceptions.FileNotFoundError(
'cannot find capture input file')
self._config['-i'] = capture_input_filepath
self._config['-o'] = self._output_signal_filepath
if render_input_filepath is not None:
if not os.path.exists(render_input_filepath):
raise exceptions.FileNotFoundError(
'cannot find render input file')
self._config['-ri'] = render_input_filepath
# Build arguments list.
args = [self._simulator_bin_path]
for param_name in self._config:
args.append(param_name)
if self._config[param_name] is not None:
args.append(str(self._config[param_name]))
logging.debug(' '.join(args))
# Build arguments list.
args = [self._simulator_bin_path]
for param_name in self._config:
args.append(param_name)
if self._config[param_name] is not None:
args.append(str(self._config[param_name]))
logging.debug(' '.join(args))
# Run.
self._profiler.enable()
subprocess.call(args)
self._profiler.disable()
# Run.
self._profiler.enable()
subprocess.call(args)
self._profiler.disable()
# Save profiling stats.
self._profiler.dump_stats(profiling_stats_filepath)
# Save profiling stats.
self._profiler.dump_stats(profiling_stats_filepath)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Imports a filtered subset of the scores and configurations computed
by apm_quality_assessment.py into a pandas data frame.
"""
@ -18,71 +17,88 @@ import re
import sys
try:
import pandas as pd
import pandas as pd
except ImportError:
logging.critical('Cannot import the third-party Python package pandas')
sys.exit(1)
logging.critical('Cannot import the third-party Python package pandas')
sys.exit(1)
from . import data_access as data_access
from . import simulation as sim
# Compiled regular expressions used to extract score descriptors.
RE_CONFIG_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixApmConfig() + r'(.+)')
RE_CAPTURE_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixCapture() + r'(.+)')
RE_RENDER_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
RE_ECHO_SIM_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + r'(.+)')
RE_CONFIG_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixApmConfig() +
r'(.+)')
RE_CAPTURE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixCapture() +
r'(.+)')
RE_RENDER_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
RE_ECHO_SIM_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixEchoSimulator() +
r'(.+)')
RE_TEST_DATA_GEN_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + r'(.+)')
RE_TEST_DATA_GEN_PARAMS = re.compile(
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + r'(.+)')
RE_SCORE_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixScore() + r'(.+)(\..+)')
RE_SCORE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixScore() +
r'(.+)(\..+)')
def InstanceArgumentsParser():
"""Arguments parser factory.
"""Arguments parser factory.
"""
parser = argparse.ArgumentParser(description=(
'Override this description in a user script by changing'
' `parser.description` of the returned parser.'))
parser = argparse.ArgumentParser(
description=('Override this description in a user script by changing'
' `parser.description` of the returned parser.'))
parser.add_argument('-o', '--output_dir', required=True,
help=('the same base path used with the '
'apm_quality_assessment tool'))
parser.add_argument('-o',
'--output_dir',
required=True,
help=('the same base path used with the '
'apm_quality_assessment tool'))
parser.add_argument('-c', '--config_names', type=re.compile,
help=('regular expression to filter the APM configuration'
' names'))
parser.add_argument(
'-c',
'--config_names',
type=re.compile,
help=('regular expression to filter the APM configuration'
' names'))
parser.add_argument('-i', '--capture_names', type=re.compile,
help=('regular expression to filter the capture signal '
'names'))
parser.add_argument(
'-i',
'--capture_names',
type=re.compile,
help=('regular expression to filter the capture signal '
'names'))
parser.add_argument('-r', '--render_names', type=re.compile,
help=('regular expression to filter the render signal '
'names'))
parser.add_argument('-r',
'--render_names',
type=re.compile,
help=('regular expression to filter the render signal '
'names'))
parser.add_argument('-e', '--echo_simulator_names', type=re.compile,
help=('regular expression to filter the echo simulator '
'names'))
parser.add_argument(
'-e',
'--echo_simulator_names',
type=re.compile,
help=('regular expression to filter the echo simulator '
'names'))
parser.add_argument('-t', '--test_data_generators', type=re.compile,
help=('regular expression to filter the test data '
'generator names'))
parser.add_argument('-t',
'--test_data_generators',
type=re.compile,
help=('regular expression to filter the test data '
'generator names'))
parser.add_argument('-s', '--eval_scores', type=re.compile,
help=('regular expression to filter the evaluation score '
'names'))
parser.add_argument(
'-s',
'--eval_scores',
type=re.compile,
help=('regular expression to filter the evaluation score '
'names'))
return parser
return parser
def _GetScoreDescriptors(score_filepath):
"""Extracts a score descriptor from the given score file path.
"""Extracts a score descriptor from the given score file path.
Args:
score_filepath: path to the score file.
@ -92,23 +108,23 @@ def _GetScoreDescriptors(score_filepath):
render audio track name, echo simulator name, test data generator name,
test data generator parameters as string, evaluation score name).
"""
fields = score_filepath.split(os.sep)[-7:]
extract_name = lambda index, reg_expr: (
reg_expr.match(fields[index]).groups(0)[0])
return (
extract_name(0, RE_CONFIG_NAME),
extract_name(1, RE_CAPTURE_NAME),
extract_name(2, RE_RENDER_NAME),
extract_name(3, RE_ECHO_SIM_NAME),
extract_name(4, RE_TEST_DATA_GEN_NAME),
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
extract_name(6, RE_SCORE_NAME),
)
fields = score_filepath.split(os.sep)[-7:]
extract_name = lambda index, reg_expr: (reg_expr.match(fields[index]).
groups(0)[0])
return (
extract_name(0, RE_CONFIG_NAME),
extract_name(1, RE_CAPTURE_NAME),
extract_name(2, RE_RENDER_NAME),
extract_name(3, RE_ECHO_SIM_NAME),
extract_name(4, RE_TEST_DATA_GEN_NAME),
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
extract_name(6, RE_SCORE_NAME),
)
def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
test_data_gen_name, score_name, args):
"""Decides whether excluding a score.
"""Decides whether excluding a score.
A set of optional regular expressions in args is used to determine if the
score should be excluded (depending on its |*_name| descriptors).
@ -125,27 +141,27 @@ def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
Returns:
A boolean.
"""
value_regexpr_pairs = [
(config_name, args.config_names),
(capture_name, args.capture_names),
(render_name, args.render_names),
(echo_simulator_name, args.echo_simulator_names),
(test_data_gen_name, args.test_data_generators),
(score_name, args.eval_scores),
]
value_regexpr_pairs = [
(config_name, args.config_names),
(capture_name, args.capture_names),
(render_name, args.render_names),
(echo_simulator_name, args.echo_simulator_names),
(test_data_gen_name, args.test_data_generators),
(score_name, args.eval_scores),
]
# Score accepted if each value matches the corresponding regular expression.
for value, regexpr in value_regexpr_pairs:
if regexpr is None:
continue
if not regexpr.match(value):
return True
# Score accepted if each value matches the corresponding regular expression.
for value, regexpr in value_regexpr_pairs:
if regexpr is None:
continue
if not regexpr.match(value):
return True
return False
return False
def FindScores(src_path, args):
"""Given a search path, find scores and return a DataFrame object.
"""Given a search path, find scores and return a DataFrame object.
Args:
src_path: Search path pattern.
@ -154,89 +170,74 @@ def FindScores(src_path, args):
Returns:
A DataFrame object.
"""
# Get scores.
scores = []
for score_filepath in glob.iglob(src_path):
# Extract score descriptor fields from the path.
(config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
test_data_gen_params,
score_name) = _GetScoreDescriptors(score_filepath)
# Get scores.
scores = []
for score_filepath in glob.iglob(src_path):
# Extract score descriptor fields from the path.
(config_name, capture_name, render_name, echo_simulator_name,
test_data_gen_name, test_data_gen_params,
score_name) = _GetScoreDescriptors(score_filepath)
# Ignore the score if required.
if _ExcludeScore(
config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
score_name,
args):
logging.info(
'ignored score: %s %s %s %s %s %s',
config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
score_name)
continue
# Ignore the score if required.
if _ExcludeScore(config_name, capture_name, render_name,
echo_simulator_name, test_data_gen_name, score_name,
args):
logging.info('ignored score: %s %s %s %s %s %s', config_name,
capture_name, render_name, echo_simulator_name,
test_data_gen_name, score_name)
continue
# Read metadata and score.
metadata = data_access.Metadata.LoadAudioTestDataPaths(
os.path.split(score_filepath)[0])
score = data_access.ScoreFile.Load(score_filepath)
# Read metadata and score.
metadata = data_access.Metadata.LoadAudioTestDataPaths(
os.path.split(score_filepath)[0])
score = data_access.ScoreFile.Load(score_filepath)
# Add a score with its descriptor fields.
scores.append((
metadata['clean_capture_input_filepath'],
metadata['echo_free_capture_filepath'],
metadata['echo_filepath'],
metadata['render_filepath'],
metadata['capture_filepath'],
metadata['apm_output_filepath'],
metadata['apm_reference_filepath'],
config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
test_data_gen_params,
score_name,
score,
))
# Add a score with its descriptor fields.
scores.append((
metadata['clean_capture_input_filepath'],
metadata['echo_free_capture_filepath'],
metadata['echo_filepath'],
metadata['render_filepath'],
metadata['capture_filepath'],
metadata['apm_output_filepath'],
metadata['apm_reference_filepath'],
config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
test_data_gen_params,
score_name,
score,
))
return pd.DataFrame(
data=scores,
columns=(
'clean_capture_input_filepath',
'echo_free_capture_filepath',
'echo_filepath',
'render_filepath',
'capture_filepath',
'apm_output_filepath',
'apm_reference_filepath',
'apm_config',
'capture',
'render',
'echo_simulator',
'test_data_gen',
'test_data_gen_params',
'eval_score_name',
'score',
))
return pd.DataFrame(data=scores,
columns=(
'clean_capture_input_filepath',
'echo_free_capture_filepath',
'echo_filepath',
'render_filepath',
'capture_filepath',
'apm_output_filepath',
'apm_reference_filepath',
'apm_config',
'capture',
'render',
'echo_simulator',
'test_data_gen',
'test_data_gen_params',
'eval_score_name',
'score',
))
def ConstructSrcPath(args):
return os.path.join(
args.output_dir,
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
sim.ApmModuleSimulator.GetPrefixRender() + '*',
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
sim.ApmModuleSimulator.GetPrefixScore() + '*')
return os.path.join(
args.output_dir,
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
sim.ApmModuleSimulator.GetPrefixRender() + '*',
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
sim.ApmModuleSimulator.GetPrefixScore() + '*')

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Data access utility functions and classes.
"""
@ -14,29 +13,29 @@ import os
def MakeDirectory(path):
"""Makes a directory recursively without rising exceptions if existing.
"""Makes a directory recursively without rising exceptions if existing.
Args:
path: path to the directory to be created.
"""
if os.path.exists(path):
return
os.makedirs(path)
if os.path.exists(path):
return
os.makedirs(path)
class Metadata(object):
"""Data access class to save and load metadata.
"""Data access class to save and load metadata.
"""
def __init__(self):
pass
def __init__(self):
pass
_GENERIC_METADATA_SUFFIX = '.mdata'
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
_GENERIC_METADATA_SUFFIX = '.mdata'
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
@classmethod
def LoadFileMetadata(cls, filepath):
"""Loads generic metadata linked to a file.
@classmethod
def LoadFileMetadata(cls, filepath):
"""Loads generic metadata linked to a file.
Args:
filepath: path to the metadata file to read.
@ -44,23 +43,23 @@ class Metadata(object):
Returns:
A dict.
"""
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
return json.load(f)
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
return json.load(f)
@classmethod
def SaveFileMetadata(cls, filepath, metadata):
"""Saves generic metadata linked to a file.
@classmethod
def SaveFileMetadata(cls, filepath, metadata):
"""Saves generic metadata linked to a file.
Args:
filepath: path to the metadata file to write.
metadata: a dict.
"""
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
json.dump(metadata, f)
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
json.dump(metadata, f)
@classmethod
def LoadAudioTestDataPaths(cls, metadata_path):
"""Loads the input and the reference audio track paths.
@classmethod
def LoadAudioTestDataPaths(cls, metadata_path):
"""Loads the input and the reference audio track paths.
Args:
metadata_path: path to the directory containing the metadata file.
@ -68,14 +67,14 @@ class Metadata(object):
Returns:
Tuple with the paths to the input and output audio tracks.
"""
metadata_filepath = os.path.join(
metadata_path, cls._AUDIO_TEST_DATA_FILENAME)
with open(metadata_filepath) as f:
return json.load(f)
metadata_filepath = os.path.join(metadata_path,
cls._AUDIO_TEST_DATA_FILENAME)
with open(metadata_filepath) as f:
return json.load(f)
@classmethod
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
"""Saves the input and the reference audio track paths.
@classmethod
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
"""Saves the input and the reference audio track paths.
Args:
output_path: path to the directory containing the metadata file.
@ -83,23 +82,24 @@ class Metadata(object):
Keyword Args:
filepaths: collection of audio track file paths to save.
"""
output_filepath = os.path.join(output_path, cls._AUDIO_TEST_DATA_FILENAME)
with open(output_filepath, 'w') as f:
json.dump(filepaths, f)
output_filepath = os.path.join(output_path,
cls._AUDIO_TEST_DATA_FILENAME)
with open(output_filepath, 'w') as f:
json.dump(filepaths, f)
class AudioProcConfigFile(object):
"""Data access to load/save APM simulator argument lists.
"""Data access to load/save APM simulator argument lists.
The arguments stored in the config files are used to control the APM flags.
"""
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def Load(cls, filepath):
"""Loads a configuration file for an APM simulator.
@classmethod
def Load(cls, filepath):
"""Loads a configuration file for an APM simulator.
Args:
filepath: path to the configuration file.
@ -107,31 +107,31 @@ class AudioProcConfigFile(object):
Returns:
A dict containing the configuration.
"""
with open(filepath) as f:
return json.load(f)
with open(filepath) as f:
return json.load(f)
@classmethod
def Save(cls, filepath, config):
"""Saves a configuration file for an APM simulator.
@classmethod
def Save(cls, filepath, config):
"""Saves a configuration file for an APM simulator.
Args:
filepath: path to the configuration file.
config: a dict containing the configuration.
"""
with open(filepath, 'w') as f:
json.dump(config, f)
with open(filepath, 'w') as f:
json.dump(config, f)
class ScoreFile(object):
"""Data access class to save and load float scalar scores.
"""Data access class to save and load float scalar scores.
"""
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def Load(cls, filepath):
"""Loads a score from file.
@classmethod
def Load(cls, filepath):
"""Loads a score from file.
Args:
filepath: path to the score file.
@ -139,16 +139,16 @@ class ScoreFile(object):
Returns:
A float encoding the score.
"""
with open(filepath) as f:
return float(f.readline().strip())
with open(filepath) as f:
return float(f.readline().strip())
@classmethod
def Save(cls, filepath, score):
"""Saves a score into a file.
@classmethod
def Save(cls, filepath, score):
"""Saves a score into a file.
Args:
filepath: path to the score file.
score: float encoding the score.
"""
with open(filepath, 'w') as f:
f.write('{0:f}\n'.format(score))
with open(filepath, 'w') as f:
f.write('{0:f}\n'.format(score))

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Echo path simulation module.
"""
@ -16,21 +15,21 @@ from . import signal_processing
class EchoPathSimulator(object):
"""Abstract class for the echo path simulators.
"""Abstract class for the echo path simulators.
In general, an echo path simulator is a function of the render signal and
simulates the propagation of the latter into the microphone (e.g., due to
mechanical or electrical paths).
"""
NAME = None
REGISTERED_CLASSES = {}
NAME = None
REGISTERED_CLASSES = {}
def __init__(self):
pass
def __init__(self):
pass
def Simulate(self, output_path):
"""Creates the echo signal and stores it in an audio file (abstract method).
def Simulate(self, output_path):
"""Creates the echo signal and stores it in an audio file (abstract method).
Args:
output_path: Path in which any output can be saved.
@ -38,11 +37,11 @@ class EchoPathSimulator(object):
Returns:
Path to the generated audio track file or None if no echo is present.
"""
raise NotImplementedError()
raise NotImplementedError()
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EchoPathSimulator implementation.
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EchoPathSimulator implementation.
Decorator to automatically register the classes that extend
EchoPathSimulator.
@ -52,85 +51,86 @@ class EchoPathSimulator(object):
class NoEchoPathSimulator(EchoPathSimulator):
pass
"""
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
@EchoPathSimulator.RegisterClass
class NoEchoPathSimulator(EchoPathSimulator):
"""Simulates absence of echo."""
"""Simulates absence of echo."""
NAME = 'noecho'
NAME = 'noecho'
def __init__(self):
EchoPathSimulator.__init__(self)
def __init__(self):
EchoPathSimulator.__init__(self)
def Simulate(self, output_path):
return None
def Simulate(self, output_path):
return None
@EchoPathSimulator.RegisterClass
class LinearEchoPathSimulator(EchoPathSimulator):
"""Simulates linear echo path.
"""Simulates linear echo path.
This class applies a given impulse response to the render input and then it
sums the signal to the capture input signal.
"""
NAME = 'linear'
NAME = 'linear'
def __init__(self, render_input_filepath, impulse_response):
"""
def __init__(self, render_input_filepath, impulse_response):
"""
Args:
render_input_filepath: Render audio track file.
impulse_response: list or numpy vector of float values.
"""
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
self._impulse_response = impulse_response
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
self._impulse_response = impulse_response
def Simulate(self, output_path):
"""Simulates linear echo path."""
# Form the file name with a hash of the impulse response.
impulse_response_hash = hashlib.sha256(
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
echo_filepath = os.path.join(output_path, 'linear_echo_{}.wav'.format(
impulse_response_hash))
def Simulate(self, output_path):
"""Simulates linear echo path."""
# Form the file name with a hash of the impulse response.
impulse_response_hash = hashlib.sha256(
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
echo_filepath = os.path.join(
output_path, 'linear_echo_{}.wav'.format(impulse_response_hash))
# If the simulated echo audio track file does not exists, create it.
if not os.path.exists(echo_filepath):
render = signal_processing.SignalProcessingUtils.LoadWav(
self._render_input_filepath)
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
render, self._impulse_response)
signal_processing.SignalProcessingUtils.SaveWav(echo_filepath, echo)
# If the simulated echo audio track file does not exists, create it.
if not os.path.exists(echo_filepath):
render = signal_processing.SignalProcessingUtils.LoadWav(
self._render_input_filepath)
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
render, self._impulse_response)
signal_processing.SignalProcessingUtils.SaveWav(
echo_filepath, echo)
return echo_filepath
return echo_filepath
@EchoPathSimulator.RegisterClass
class RecordedEchoPathSimulator(EchoPathSimulator):
"""Uses recorded echo.
"""Uses recorded echo.
This class uses the clean capture input file name to build the file name of
the corresponding recording containing echo (a predefined suffix is used).
Such a file is expected to be already existing.
"""
NAME = 'recorded'
NAME = 'recorded'
_FILE_NAME_SUFFIX = '_echo'
_FILE_NAME_SUFFIX = '_echo'
def __init__(self, render_input_filepath):
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
def __init__(self, render_input_filepath):
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
def Simulate(self, output_path):
"""Uses recorded echo path."""
path, file_name_ext = os.path.split(self._render_input_filepath)
file_name, file_ext = os.path.splitext(file_name_ext)
echo_filepath = os.path.join(path, '{}{}{}'.format(
file_name, self._FILE_NAME_SUFFIX, file_ext))
assert os.path.exists(echo_filepath), (
'cannot find the echo audio track file {}'.format(echo_filepath))
return echo_filepath
def Simulate(self, output_path):
"""Uses recorded echo path."""
path, file_name_ext = os.path.split(self._render_input_filepath)
file_name, file_ext = os.path.splitext(file_name_ext)
echo_filepath = os.path.join(
path, '{}{}{}'.format(file_name, self._FILE_NAME_SUFFIX, file_ext))
assert os.path.exists(echo_filepath), (
'cannot find the echo audio track file {}'.format(echo_filepath))
return echo_filepath

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Echo path simulation factory module.
"""
@ -16,16 +15,16 @@ from . import echo_path_simulation
class EchoPathSimulatorFactory(object):
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
# realistic impulse response.
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0]*(20 * 48) + [0.15])
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
# realistic impulse response.
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0] * (20 * 48) + [0.15])
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
"""Creates an EchoPathSimulator instance given a class object.
@classmethod
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
"""Creates an EchoPathSimulator instance given a class object.
Args:
echo_path_simulator_class: EchoPathSimulator class object (not an
@ -35,14 +34,15 @@ class EchoPathSimulatorFactory(object):
Returns:
An EchoPathSimulator instance.
"""
assert render_input_filepath is not None or (
echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator)
assert render_input_filepath is not None or (
echo_path_simulator_class ==
echo_path_simulation.NoEchoPathSimulator)
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
return echo_path_simulation.NoEchoPathSimulator()
elif echo_path_simulator_class == (
echo_path_simulation.LinearEchoPathSimulator):
return echo_path_simulation.LinearEchoPathSimulator(
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
else:
return echo_path_simulator_class(render_input_filepath)
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
return echo_path_simulation.NoEchoPathSimulator()
elif echo_path_simulator_class == (
echo_path_simulation.LinearEchoPathSimulator):
return echo_path_simulation.LinearEchoPathSimulator(
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
else:
return echo_path_simulator_class(render_input_filepath)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the echo path simulation module.
"""
@ -22,60 +21,62 @@ from . import signal_processing
class TestEchoPathSimulators(unittest.TestCase):
"""Unit tests for the eval_scores module.
"""Unit tests for the eval_scores module.
"""
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
# Create and save white noise.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._audio_track_num_samples = (
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
self._audio_track_filepath = os.path.join(self._tmp_path, 'white_noise.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._audio_track_filepath, white_noise)
# Create and save white noise.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._audio_track_num_samples = (
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
self._audio_track_filepath = os.path.join(self._tmp_path,
'white_noise.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._audio_track_filepath, white_noise)
# Make a copy the white noise audio track file; it will be used by
# echo_path_simulation.RecordedEchoPathSimulator.
shutil.copy(self._audio_track_filepath, os.path.join(
self._tmp_path, 'white_noise_echo.wav'))
# Make a copy the white noise audio track file; it will be used by
# echo_path_simulation.RecordedEchoPathSimulator.
shutil.copy(self._audio_track_filepath,
os.path.join(self._tmp_path, 'white_noise_echo.wav'))
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def testRegisteredClasses(self):
# Check that there is at least one registered echo path simulator.
registered_classes = (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
def testRegisteredClasses(self):
# Check that there is at least one registered echo path simulator.
registered_classes = (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance factory.
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
# Instance factory.
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
# Try each registered echo path simulator.
for echo_path_simulator_name in registered_classes:
simulator = factory.GetInstance(
echo_path_simulator_class=registered_classes[
echo_path_simulator_name],
render_input_filepath=self._audio_track_filepath)
# Try each registered echo path simulator.
for echo_path_simulator_name in registered_classes:
simulator = factory.GetInstance(
echo_path_simulator_class=registered_classes[
echo_path_simulator_name],
render_input_filepath=self._audio_track_filepath)
echo_filepath = simulator.Simulate(self._tmp_path)
if echo_filepath is None:
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
echo_path_simulator_name)
# No other tests in this case.
continue
echo_filepath = simulator.Simulate(self._tmp_path)
if echo_filepath is None:
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
echo_path_simulator_name)
# No other tests in this case.
continue
# Check that the echo audio track file exists and its length is greater or
# equal to that of the render audio track.
self.assertTrue(os.path.exists(echo_filepath))
echo = signal_processing.SignalProcessingUtils.LoadWav(echo_filepath)
self.assertGreaterEqual(
signal_processing.SignalProcessingUtils.CountSamples(echo),
self._audio_track_num_samples)
# Check that the echo audio track file exists and its length is greater or
# equal to that of the render audio track.
self.assertTrue(os.path.exists(echo_filepath))
echo = signal_processing.SignalProcessingUtils.LoadWav(
echo_filepath)
self.assertGreaterEqual(
signal_processing.SignalProcessingUtils.CountSamples(echo),
self._audio_track_num_samples)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Evaluation score abstract class and implementations.
"""
@ -17,10 +16,10 @@ import subprocess
import sys
try:
import numpy as np
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import data_access
from . import exceptions
@ -29,23 +28,23 @@ from . import signal_processing
class EvaluationScore(object):
NAME = None
REGISTERED_CLASSES = {}
NAME = None
REGISTERED_CLASSES = {}
def __init__(self, score_filename_prefix):
self._score_filename_prefix = score_filename_prefix
self._input_signal_metadata = None
self._reference_signal = None
self._reference_signal_filepath = None
self._tested_signal = None
self._tested_signal_filepath = None
self._output_filepath = None
self._score = None
self._render_signal_filepath = None
def __init__(self, score_filename_prefix):
self._score_filename_prefix = score_filename_prefix
self._input_signal_metadata = None
self._reference_signal = None
self._reference_signal_filepath = None
self._tested_signal = None
self._tested_signal_filepath = None
self._output_filepath = None
self._score = None
self._render_signal_filepath = None
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EvaluationScore implementation.
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EvaluationScore implementation.
Decorator to automatically register the classes that extend EvaluationScore.
Example usage:
@ -54,91 +53,90 @@ class EvaluationScore(object):
class AudioLevelScore(EvaluationScore):
pass
"""
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
@property
def output_filepath(self):
return self._output_filepath
@property
def output_filepath(self):
return self._output_filepath
@property
def score(self):
return self._score
@property
def score(self):
return self._score
def SetInputSignalMetadata(self, metadata):
"""Sets input signal metadata.
def SetInputSignalMetadata(self, metadata):
"""Sets input signal metadata.
Args:
metadata: dict instance.
"""
self._input_signal_metadata = metadata
self._input_signal_metadata = metadata
def SetReferenceSignalFilepath(self, filepath):
"""Sets the path to the audio track used as reference signal.
def SetReferenceSignalFilepath(self, filepath):
"""Sets the path to the audio track used as reference signal.
Args:
filepath: path to the reference audio track.
"""
self._reference_signal_filepath = filepath
self._reference_signal_filepath = filepath
def SetTestedSignalFilepath(self, filepath):
"""Sets the path to the audio track used as test signal.
def SetTestedSignalFilepath(self, filepath):
"""Sets the path to the audio track used as test signal.
Args:
filepath: path to the test audio track.
"""
self._tested_signal_filepath = filepath
self._tested_signal_filepath = filepath
def SetRenderSignalFilepath(self, filepath):
"""Sets the path to the audio track used as render signal.
def SetRenderSignalFilepath(self, filepath):
"""Sets the path to the audio track used as render signal.
Args:
filepath: path to the test audio track.
"""
self._render_signal_filepath = filepath
self._render_signal_filepath = filepath
def Run(self, output_path):
"""Extracts the score for the set test data pair.
def Run(self, output_path):
"""Extracts the score for the set test data pair.
Args:
output_path: path to the directory where the output is written.
"""
self._output_filepath = os.path.join(
output_path, self._score_filename_prefix + self.NAME + '.txt')
try:
# If the score has already been computed, load.
self._LoadScore()
logging.debug('score found and loaded')
except IOError:
# Compute the score.
logging.debug('score not found, compute')
self._Run(output_path)
self._output_filepath = os.path.join(
output_path, self._score_filename_prefix + self.NAME + '.txt')
try:
# If the score has already been computed, load.
self._LoadScore()
logging.debug('score found and loaded')
except IOError:
# Compute the score.
logging.debug('score not found, compute')
self._Run(output_path)
def _Run(self, output_path):
# Abstract method.
raise NotImplementedError()
def _Run(self, output_path):
# Abstract method.
raise NotImplementedError()
def _LoadReferenceSignal(self):
assert self._reference_signal_filepath is not None
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._reference_signal_filepath)
def _LoadReferenceSignal(self):
assert self._reference_signal_filepath is not None
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._reference_signal_filepath)
def _LoadTestedSignal(self):
assert self._tested_signal_filepath is not None
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._tested_signal_filepath)
def _LoadTestedSignal(self):
assert self._tested_signal_filepath is not None
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._tested_signal_filepath)
def _LoadScore(self):
return data_access.ScoreFile.Load(self._output_filepath)
def _LoadScore(self):
return data_access.ScoreFile.Load(self._output_filepath)
def _SaveScore(self):
return data_access.ScoreFile.Save(self._output_filepath, self._score)
def _SaveScore(self):
return data_access.ScoreFile.Save(self._output_filepath, self._score)
@EvaluationScore.RegisterClass
class AudioLevelPeakScore(EvaluationScore):
"""Peak audio level score.
"""Peak audio level score.
Defined as the difference between the peak audio level of the tested and
the reference signals.
@ -148,21 +146,21 @@ class AudioLevelPeakScore(EvaluationScore):
Worst case: +/-inf dB
"""
NAME = 'audio_level_peak'
NAME = 'audio_level_peak'
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
self._SaveScore()
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
self._SaveScore()
@EvaluationScore.RegisterClass
class MeanAudioLevelScore(EvaluationScore):
"""Mean audio level score.
"""Mean audio level score.
Defined as the difference between the mean audio level of the tested and
the reference signals.
@ -172,29 +170,30 @@ class MeanAudioLevelScore(EvaluationScore):
Worst case: +/-inf dB
"""
NAME = 'audio_level_mean'
NAME = 'audio_level_mean'
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
dbfs_diffs_sum = 0.0
seconds = min(len(self._tested_signal), len(self._reference_signal)) // 1000
for t in range(seconds):
t0 = t * seconds
t1 = t0 + seconds
dbfs_diffs_sum += (
self._tested_signal[t0:t1].dBFS - self._reference_signal[t0:t1].dBFS)
self._score = dbfs_diffs_sum / float(seconds)
self._SaveScore()
dbfs_diffs_sum = 0.0
seconds = min(len(self._tested_signal), len(
self._reference_signal)) // 1000
for t in range(seconds):
t0 = t * seconds
t1 = t0 + seconds
dbfs_diffs_sum += (self._tested_signal[t0:t1].dBFS -
self._reference_signal[t0:t1].dBFS)
self._score = dbfs_diffs_sum / float(seconds)
self._SaveScore()
@EvaluationScore.RegisterClass
class EchoMetric(EvaluationScore):
"""Echo score.
"""Echo score.
Proportion of detected echo.
@ -203,46 +202,47 @@ class EchoMetric(EvaluationScore):
Worst case: 1
"""
NAME = 'echo_metric'
NAME = 'echo_metric'
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
# POLQA binary file path.
self._echo_detector_bin_filepath = echo_detector_bin_filepath
if not os.path.exists(self._echo_detector_bin_filepath):
logging.error('cannot find EchoMetric tool binary file')
raise exceptions.FileNotFoundError()
# POLQA binary file path.
self._echo_detector_bin_filepath = echo_detector_bin_filepath
if not os.path.exists(self._echo_detector_bin_filepath):
logging.error('cannot find EchoMetric tool binary file')
raise exceptions.FileNotFoundError()
self._echo_detector_bin_path, _ = os.path.split(
self._echo_detector_bin_filepath)
self._echo_detector_bin_path, _ = os.path.split(
self._echo_detector_bin_filepath)
def _Run(self, output_path):
echo_detector_out_filepath = os.path.join(output_path, 'echo_detector.out')
if os.path.exists(echo_detector_out_filepath):
os.unlink(echo_detector_out_filepath)
def _Run(self, output_path):
echo_detector_out_filepath = os.path.join(output_path,
'echo_detector.out')
if os.path.exists(echo_detector_out_filepath):
os.unlink(echo_detector_out_filepath)
logging.debug("Render signal filepath: %s", self._render_signal_filepath)
if not os.path.exists(self._render_signal_filepath):
logging.error("Render input required for evaluating the echo metric.")
logging.debug("Render signal filepath: %s",
self._render_signal_filepath)
if not os.path.exists(self._render_signal_filepath):
logging.error(
"Render input required for evaluating the echo metric.")
args = [
self._echo_detector_bin_filepath,
'--output_file', echo_detector_out_filepath,
'--',
'-i', self._tested_signal_filepath,
'-ri', self._render_signal_filepath
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._echo_detector_bin_path)
args = [
self._echo_detector_bin_filepath, '--output_file',
echo_detector_out_filepath, '--', '-i',
self._tested_signal_filepath, '-ri', self._render_signal_filepath
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._echo_detector_bin_path)
# Parse Echo detector tool output and extract the score.
self._score = self._ParseOutputFile(echo_detector_out_filepath)
self._SaveScore()
# Parse Echo detector tool output and extract the score.
self._score = self._ParseOutputFile(echo_detector_out_filepath)
self._SaveScore()
@classmethod
def _ParseOutputFile(cls, echo_metric_file_path):
"""
@classmethod
def _ParseOutputFile(cls, echo_metric_file_path):
"""
Parses the POLQA tool output formatted as a table ('-t' option).
Args:
@ -251,12 +251,13 @@ class EchoMetric(EvaluationScore):
Returns:
The score as a number in [0, 1].
"""
with open(echo_metric_file_path) as f:
return float(f.read())
with open(echo_metric_file_path) as f:
return float(f.read())
@EvaluationScore.RegisterClass
class PolqaScore(EvaluationScore):
"""POLQA score.
"""POLQA score.
See http://www.polqa.info/.
@ -265,44 +266,51 @@ class PolqaScore(EvaluationScore):
Worst case: 1.0
"""
NAME = 'polqa'
NAME = 'polqa'
def __init__(self, score_filename_prefix, polqa_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
def __init__(self, score_filename_prefix, polqa_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
# POLQA binary file path.
self._polqa_bin_filepath = polqa_bin_filepath
if not os.path.exists(self._polqa_bin_filepath):
logging.error('cannot find POLQA tool binary file')
raise exceptions.FileNotFoundError()
# POLQA binary file path.
self._polqa_bin_filepath = polqa_bin_filepath
if not os.path.exists(self._polqa_bin_filepath):
logging.error('cannot find POLQA tool binary file')
raise exceptions.FileNotFoundError()
# Path to the POLQA directory with binary and license files.
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
# Path to the POLQA directory with binary and license files.
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
def _Run(self, output_path):
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
if os.path.exists(polqa_out_filepath):
os.unlink(polqa_out_filepath)
def _Run(self, output_path):
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
if os.path.exists(polqa_out_filepath):
os.unlink(polqa_out_filepath)
args = [
self._polqa_bin_filepath, '-t', '-q', '-Overwrite',
'-Ref', self._reference_signal_filepath,
'-Test', self._tested_signal_filepath,
'-LC', 'NB',
'-Out', polqa_out_filepath,
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._polqa_tool_path)
args = [
self._polqa_bin_filepath,
'-t',
'-q',
'-Overwrite',
'-Ref',
self._reference_signal_filepath,
'-Test',
self._tested_signal_filepath,
'-LC',
'NB',
'-Out',
polqa_out_filepath,
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._polqa_tool_path)
# Parse POLQA tool output and extract the score.
polqa_output = self._ParseOutputFile(polqa_out_filepath)
self._score = float(polqa_output['PolqaScore'])
# Parse POLQA tool output and extract the score.
polqa_output = self._ParseOutputFile(polqa_out_filepath)
self._score = float(polqa_output['PolqaScore'])
self._SaveScore()
self._SaveScore()
@classmethod
def _ParseOutputFile(cls, polqa_out_filepath):
"""
@classmethod
def _ParseOutputFile(cls, polqa_out_filepath):
"""
Parses the POLQA tool output formatted as a table ('-t' option).
Args:
@ -311,29 +319,32 @@ class PolqaScore(EvaluationScore):
Returns:
A dict.
"""
data = []
with open(polqa_out_filepath) as f:
for line in f:
line = line.strip()
if len(line) == 0 or line.startswith('*'):
# Ignore comments.
continue
# Read fields.
data.append(re.split(r'\t+', line))
data = []
with open(polqa_out_filepath) as f:
for line in f:
line = line.strip()
if len(line) == 0 or line.startswith('*'):
# Ignore comments.
continue
# Read fields.
data.append(re.split(r'\t+', line))
# Two rows expected (header and values).
assert len(data) == 2, 'Cannot parse POLQA output'
number_of_fields = len(data[0])
assert number_of_fields == len(data[1])
# Two rows expected (header and values).
assert len(data) == 2, 'Cannot parse POLQA output'
number_of_fields = len(data[0])
assert number_of_fields == len(data[1])
# Build and return a dictionary with field names (header) as keys and the
# corresponding field values as values.
return {data[0][index]: data[1][index] for index in range(number_of_fields)}
# Build and return a dictionary with field names (header) as keys and the
# corresponding field values as values.
return {
data[0][index]: data[1][index]
for index in range(number_of_fields)
}
@EvaluationScore.RegisterClass
class TotalHarmonicDistorsionScore(EvaluationScore):
"""Total harmonic distorsion plus noise score.
"""Total harmonic distorsion plus noise score.
Total harmonic distorsion plus noise score.
See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
@ -343,69 +354,74 @@ class TotalHarmonicDistorsionScore(EvaluationScore):
Worst case: +inf
"""
NAME = 'thd'
NAME = 'thd'
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
self._input_frequency = None
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
self._input_frequency = None
def _Run(self, output_path):
self._CheckInputSignal()
def _Run(self, output_path):
self._CheckInputSignal()
self._LoadTestedSignal()
if self._tested_signal.channels != 1:
raise exceptions.EvaluationScoreException(
'unsupported number of channels')
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._tested_signal)
self._LoadTestedSignal()
if self._tested_signal.channels != 1:
raise exceptions.EvaluationScoreException(
'unsupported number of channels')
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._tested_signal)
# Init.
num_samples = len(samples)
duration = len(self._tested_signal) / 1000.0
scaling = 2.0 / num_samples
max_freq = self._tested_signal.frame_rate / 2
f0_freq = float(self._input_frequency)
t = np.linspace(0, duration, num_samples)
# Init.
num_samples = len(samples)
duration = len(self._tested_signal) / 1000.0
scaling = 2.0 / num_samples
max_freq = self._tested_signal.frame_rate / 2
f0_freq = float(self._input_frequency)
t = np.linspace(0, duration, num_samples)
# Analyze harmonics.
b_terms = []
n = 1
while f0_freq * n < max_freq:
x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
b_terms.append(np.sqrt(x_n**2 + y_n**2))
n += 1
# Analyze harmonics.
b_terms = []
n = 1
while f0_freq * n < max_freq:
x_n = np.sum(
samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
y_n = np.sum(
samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
b_terms.append(np.sqrt(x_n**2 + y_n**2))
n += 1
output_without_fundamental = samples - b_terms[0] * np.sin(
2.0 * np.pi * f0_freq * t)
distortion_and_noise = np.sqrt(np.sum(
output_without_fundamental**2) * np.pi * scaling)
output_without_fundamental = samples - b_terms[0] * np.sin(
2.0 * np.pi * f0_freq * t)
distortion_and_noise = np.sqrt(
np.sum(output_without_fundamental**2) * np.pi * scaling)
# TODO(alessiob): Fix or remove if not needed.
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
# TODO(alessiob): Fix or remove if not needed.
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
# TODO(alessiob): Check the range of |thd_plus_noise| and update the class
# docstring above if accordingly.
thd_plus_noise = distortion_and_noise / b_terms[0]
# TODO(alessiob): Check the range of |thd_plus_noise| and update the class
# docstring above if accordingly.
thd_plus_noise = distortion_and_noise / b_terms[0]
self._score = thd_plus_noise
self._SaveScore()
self._score = thd_plus_noise
self._SaveScore()
def _CheckInputSignal(self):
# Check input signal and get properties.
try:
if self._input_signal_metadata['signal'] != 'pure_tone':
raise exceptions.EvaluationScoreException(
'The THD score requires a pure tone as input signal')
self._input_frequency = self._input_signal_metadata['frequency']
if self._input_signal_metadata['test_data_gen_name'] != 'identity' or (
self._input_signal_metadata['test_data_gen_config'] != 'default'):
raise exceptions.EvaluationScoreException(
'The THD score cannot be used with any test data generator other '
'than "identity"')
except TypeError:
raise exceptions.EvaluationScoreException(
'The THD score requires an input signal with associated metadata')
except KeyError:
raise exceptions.EvaluationScoreException(
'Invalid input signal metadata to compute the THD score')
def _CheckInputSignal(self):
# Check input signal and get properties.
try:
if self._input_signal_metadata['signal'] != 'pure_tone':
raise exceptions.EvaluationScoreException(
'The THD score requires a pure tone as input signal')
self._input_frequency = self._input_signal_metadata['frequency']
if self._input_signal_metadata[
'test_data_gen_name'] != 'identity' or (
self._input_signal_metadata['test_data_gen_config'] !=
'default'):
raise exceptions.EvaluationScoreException(
'The THD score cannot be used with any test data generator other '
'than "identity"')
except TypeError:
raise exceptions.EvaluationScoreException(
'The THD score requires an input signal with associated metadata'
)
except KeyError:
raise exceptions.EvaluationScoreException(
'Invalid input signal metadata to compute the THD score')

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""EvaluationScore factory class.
"""
@ -16,22 +15,22 @@ from . import eval_scores
class EvaluationScoreWorkerFactory(object):
"""Factory class used to instantiate evaluation score workers.
"""Factory class used to instantiate evaluation score workers.
The ctor gets the parametrs that are used to instatiate the evaluation score
workers.
"""
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
self._score_filename_prefix = None
self._polqa_tool_bin_path = polqa_tool_bin_path
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
self._score_filename_prefix = None
self._polqa_tool_bin_path = polqa_tool_bin_path
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
def SetScoreFilenamePrefix(self, prefix):
self._score_filename_prefix = prefix
def SetScoreFilenamePrefix(self, prefix):
self._score_filename_prefix = prefix
def GetInstance(self, evaluation_score_class):
"""Creates an EvaluationScore instance given a class object.
def GetInstance(self, evaluation_score_class):
"""Creates an EvaluationScore instance given a class object.
Args:
evaluation_score_class: EvaluationScore class object (not an instance).
@ -39,17 +38,18 @@ class EvaluationScoreWorkerFactory(object):
Returns:
An EvaluationScore instance.
"""
if self._score_filename_prefix is None:
raise exceptions.InitializationException(
'The score file name prefix for evaluation score workers is not set')
logging.debug(
'factory producing a %s evaluation score', evaluation_score_class)
if self._score_filename_prefix is None:
raise exceptions.InitializationException(
'The score file name prefix for evaluation score workers is not set'
)
logging.debug('factory producing a %s evaluation score',
evaluation_score_class)
if evaluation_score_class == eval_scores.PolqaScore:
return eval_scores.PolqaScore(
self._score_filename_prefix, self._polqa_tool_bin_path)
elif evaluation_score_class == eval_scores.EchoMetric:
return eval_scores.EchoMetric(
self._score_filename_prefix, self._echo_metric_tool_bin_path)
else:
return evaluation_score_class(self._score_filename_prefix)
if evaluation_score_class == eval_scores.PolqaScore:
return eval_scores.PolqaScore(self._score_filename_prefix,
self._polqa_tool_bin_path)
elif evaluation_score_class == eval_scores.EchoMetric:
return eval_scores.EchoMetric(self._score_filename_prefix,
self._echo_metric_tool_bin_path)
else:
return evaluation_score_class(self._score_filename_prefix)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the eval_scores module.
"""
@ -23,111 +22,116 @@ from . import signal_processing
class TestEvalScores(unittest.TestCase):
"""Unit tests for the eval_scores module.
"""Unit tests for the eval_scores module.
"""
def setUp(self):
"""Create temporary output folder and two audio track files."""
self._output_path = tempfile.mkdtemp()
def setUp(self):
"""Create temporary output folder and two audio track files."""
self._output_path = tempfile.mkdtemp()
# Create fake reference and tested (i.e., APM output) audio track files.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_reference_signal = (
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
fake_tested_signal = (
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
# Create fake reference and tested (i.e., APM output) audio track files.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_reference_signal = (signal_processing.SignalProcessingUtils.
GenerateWhiteNoise(silence))
fake_tested_signal = (signal_processing.SignalProcessingUtils.
GenerateWhiteNoise(silence))
# Save fake audio tracks.
self._fake_reference_signal_filepath = os.path.join(
self._output_path, 'fake_ref.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_reference_signal_filepath, fake_reference_signal)
self._fake_tested_signal_filepath = os.path.join(
self._output_path, 'fake_test.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_tested_signal_filepath, fake_tested_signal)
# Save fake audio tracks.
self._fake_reference_signal_filepath = os.path.join(
self._output_path, 'fake_ref.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_reference_signal_filepath, fake_reference_signal)
self._fake_tested_signal_filepath = os.path.join(
self._output_path, 'fake_test.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_tested_signal_filepath, fake_tested_signal)
def tearDown(self):
"""Recursively delete temporary folder."""
shutil.rmtree(self._output_path)
def tearDown(self):
"""Recursively delete temporary folder."""
shutil.rmtree(self._output_path)
def testRegisteredClasses(self):
# Evaluation score names to exclude (tested separately).
exceptions = ['thd', 'echo_metric']
def testRegisteredClasses(self):
# Evaluation score names to exclude (tested separately).
exceptions = ['thd', 'echo_metric']
# Preliminary check.
self.assertTrue(os.path.exists(self._output_path))
# Preliminary check.
self.assertTrue(os.path.exists(self._output_path))
# Check that there is at least one registered evaluation score worker.
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Check that there is at least one registered evaluation score worker.
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance evaluation score workers factory with fake dependencies.
eval_score_workers_factory = (
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
echo_metric_tool_bin_path=None
))
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
# Instance evaluation score workers factory with fake dependencies.
eval_score_workers_factory = (
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
echo_metric_tool_bin_path=None))
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
# Try each registered evaluation score worker.
for eval_score_name in registered_classes:
if eval_score_name in exceptions:
continue
# Try each registered evaluation score worker.
for eval_score_name in registered_classes:
if eval_score_name in exceptions:
continue
# Instance evaluation score worker.
eval_score_worker = eval_score_workers_factory.GetInstance(
registered_classes[eval_score_name])
# Instance evaluation score worker.
eval_score_worker = eval_score_workers_factory.GetInstance(
registered_classes[eval_score_name])
# Set fake input metadata and reference and test file paths, then run.
eval_score_worker.SetReferenceSignalFilepath(
self._fake_reference_signal_filepath)
eval_score_worker.SetTestedSignalFilepath(
self._fake_tested_signal_filepath)
eval_score_worker.Run(self._output_path)
# Set fake input metadata and reference and test file paths, then run.
eval_score_worker.SetReferenceSignalFilepath(
self._fake_reference_signal_filepath)
eval_score_worker.SetTestedSignalFilepath(
self._fake_tested_signal_filepath)
eval_score_worker.Run(self._output_path)
# Check output.
score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
self.assertTrue(isinstance(score, float))
# Check output.
score = data_access.ScoreFile.Load(
eval_score_worker.output_filepath)
self.assertTrue(isinstance(score, float))
def testTotalHarmonicDistorsionScore(self):
# Init.
pure_tone_freq = 5000.0
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
eval_score_worker.SetInputSignalMetadata({
'signal': 'pure_tone',
'frequency': pure_tone_freq,
'test_data_gen_name': 'identity',
'test_data_gen_config': 'default',
})
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
def testTotalHarmonicDistorsionScore(self):
# Init.
pure_tone_freq = 5000.0
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
eval_score_worker.SetInputSignalMetadata({
'signal':
'pure_tone',
'frequency':
pure_tone_freq,
'test_data_gen_name':
'identity',
'test_data_gen_config':
'default',
})
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
# Create 3 test signals: pure tone, pure tone + white noise, white noise
# only.
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
template, pure_tone_freq)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
template)
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
pure_tone, white_noise)
# Create 3 test signals: pure tone, pure tone + white noise, white noise
# only.
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
template, pure_tone_freq)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
template)
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
pure_tone, white_noise)
# Compute scores for increasingly distorted pure tone signals.
scores = [None, None, None]
for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]):
# Save signal.
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
signal_processing.SignalProcessingUtils.SaveWav(
tmp_filepath, tested_signal)
# Compute scores for increasingly distorted pure tone signals.
scores = [None, None, None]
for index, tested_signal in enumerate(
[pure_tone, noisy_tone, white_noise]):
# Save signal.
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
signal_processing.SignalProcessingUtils.SaveWav(
tmp_filepath, tested_signal)
# Compute score.
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
eval_score_worker.Run(self._output_path)
scores[index] = eval_score_worker.score
# Compute score.
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
eval_score_worker.Run(self._output_path)
scores[index] = eval_score_worker.score
# Remove output file to avoid caching.
os.remove(eval_score_worker.output_filepath)
# Remove output file to avoid caching.
os.remove(eval_score_worker.output_filepath)
# Validate scores (lowest score with a pure tone).
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
# Validate scores (lowest score with a pure tone).
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Evaluator of the APM module.
"""
@ -13,17 +12,17 @@ import logging
class ApmModuleEvaluator(object):
"""APM evaluator class.
"""APM evaluator class.
"""
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def Run(cls, evaluation_score_workers, apm_input_metadata,
apm_output_filepath, reference_input_filepath,
render_input_filepath, output_path):
"""Runs the evaluation.
@classmethod
def Run(cls, evaluation_score_workers, apm_input_metadata,
apm_output_filepath, reference_input_filepath,
render_input_filepath, output_path):
"""Runs the evaluation.
Iterates over the given evaluation score workers.
@ -37,20 +36,22 @@ class ApmModuleEvaluator(object):
Returns:
A dict of evaluation score name and score pairs.
"""
# Init.
scores = {}
# Init.
scores = {}
for evaluation_score_worker in evaluation_score_workers:
logging.info(' computing <%s> score', evaluation_score_worker.NAME)
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
evaluation_score_worker.SetReferenceSignalFilepath(
reference_input_filepath)
evaluation_score_worker.SetTestedSignalFilepath(
apm_output_filepath)
evaluation_score_worker.SetRenderSignalFilepath(
render_input_filepath)
for evaluation_score_worker in evaluation_score_workers:
logging.info(' computing <%s> score',
evaluation_score_worker.NAME)
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
evaluation_score_worker.SetReferenceSignalFilepath(
reference_input_filepath)
evaluation_score_worker.SetTestedSignalFilepath(
apm_output_filepath)
evaluation_score_worker.SetRenderSignalFilepath(
render_input_filepath)
evaluation_score_worker.Run(output_path)
scores[evaluation_score_worker.NAME] = evaluation_score_worker.score
evaluation_score_worker.Run(output_path)
scores[
evaluation_score_worker.NAME] = evaluation_score_worker.score
return scores
return scores

View File

@ -5,42 +5,41 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Exception classes.
"""
class FileNotFoundError(Exception):
"""File not found exception.
"""File not found exception.
"""
pass
pass
class SignalProcessingException(Exception):
"""Signal processing exception.
"""Signal processing exception.
"""
pass
pass
class InputMixerException(Exception):
"""Input mixer exception.
"""Input mixer exception.
"""
pass
pass
class InputSignalCreatorException(Exception):
"""Input signal creator exception.
"""Input signal creator exception.
"""
pass
pass
class EvaluationScoreException(Exception):
"""Evaluation score exception.
"""Evaluation score exception.
"""
pass
pass
class InitializationException(Exception):
"""Initialization exception.
"""Initialization exception.
"""
pass
pass

View File

@ -14,58 +14,58 @@ import re
import sys
try:
import csscompressor
import csscompressor
except ImportError:
logging.critical('Cannot import the third-party Python package csscompressor')
sys.exit(1)
logging.critical(
'Cannot import the third-party Python package csscompressor')
sys.exit(1)
try:
import jsmin
import jsmin
except ImportError:
logging.critical('Cannot import the third-party Python package jsmin')
sys.exit(1)
logging.critical('Cannot import the third-party Python package jsmin')
sys.exit(1)
class HtmlExport(object):
"""HTML exporter class for APM quality scores."""
"""HTML exporter class for APM quality scores."""
_NEW_LINE = '\n'
_NEW_LINE = '\n'
# CSS and JS file paths.
_PATH = os.path.dirname(os.path.realpath(__file__))
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
_CSS_MINIFIED = True
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
_JS_MINIFIED = True
# CSS and JS file paths.
_PATH = os.path.dirname(os.path.realpath(__file__))
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
_CSS_MINIFIED = True
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
_JS_MINIFIED = True
def __init__(self, output_filepath):
self._scores_data_frame = None
self._output_filepath = output_filepath
def __init__(self, output_filepath):
self._scores_data_frame = None
self._output_filepath = output_filepath
def Export(self, scores_data_frame):
"""Exports scores into an HTML file.
def Export(self, scores_data_frame):
"""Exports scores into an HTML file.
Args:
scores_data_frame: DataFrame instance.
"""
self._scores_data_frame = scores_data_frame
html = ['<html>',
self._scores_data_frame = scores_data_frame
html = [
'<html>',
self._BuildHeader(),
('<script type="text/javascript">'
'(function () {'
'window.addEventListener(\'load\', function () {'
'var inspector = new AudioInspector();'
'});'
'(function () {'
'window.addEventListener(\'load\', function () {'
'var inspector = new AudioInspector();'
'});'
'})();'
'</script>'),
'<body>',
self._BuildBody(),
'</body>',
'</html>']
self._Save(self._output_filepath, self._NEW_LINE.join(html))
'</script>'), '<body>',
self._BuildBody(), '</body>', '</html>'
]
self._Save(self._output_filepath, self._NEW_LINE.join(html))
def _BuildHeader(self):
"""Builds the <head> section of the HTML file.
def _BuildHeader(self):
"""Builds the <head> section of the HTML file.
The header contains the page title and either embedded or linked CSS and JS
files.
@ -73,325 +73,349 @@ class HtmlExport(object):
Returns:
A string with <head>...</head> HTML.
"""
html = ['<head>', '<title>Results</title>']
html = ['<head>', '<title>Results</title>']
# Add Material Design hosted libs.
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
'css?family=Roboto:300,400,500,700" type="text/css">')
html.append('<link rel="stylesheet" href="https://fonts.googleapis.com/'
'icon?family=Material+Icons">')
html.append('<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
'material.indigo-pink.min.css">')
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
'material.min.js"></script>')
# Embed custom JavaScript and CSS files.
html.append('<script>')
with open(self._JS_FILEPATH) as f:
html.append(jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
f.read().rstrip()))
html.append('</script>')
html.append('<style>')
with open(self._CSS_FILEPATH) as f:
html.append(csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
f.read().rstrip()))
html.append('</style>')
html.append('</head>')
return self._NEW_LINE.join(html)
def _BuildBody(self):
"""Builds the content of the <body> section."""
score_names = self._scores_data_frame['eval_score_name'].drop_duplicates(
).values.tolist()
html = [
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
'mdl-layout--fixed-tabs">'),
'<header class="mdl-layout__header">',
'<div class="mdl-layout__header-row">',
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
self._output_filepath),
'</div>',
]
# Tab selectors.
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
for tab_index, score_name in enumerate(score_names):
is_active = tab_index == 0
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
'{}</a>'.format(tab_index,
' is-active' if is_active else '',
self._FormatName(score_name)))
html.append('</div>')
html.append('</header>')
html.append('<main class="mdl-layout__content" style="overflow-x: auto;">')
# Tabs content.
for tab_index, score_name in enumerate(score_names):
html.append('<section class="mdl-layout__tab-panel{}" '
'id="score-tab-{}">'.format(
' is-active' if is_active else '', tab_index))
html.append('<div class="page-content">')
html.append(self._BuildScoreTab(score_name, ('s{}'.format(tab_index),)))
html.append('</div>')
html.append('</section>')
html.append('</main>')
html.append('</div>')
# Add snackbar for notifications.
html.append(
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
'<div class="mdl-snackbar__text"></div>'
'<button type="button" class="mdl-snackbar__action"></button>'
'</div>')
return self._NEW_LINE.join(html)
def _BuildScoreTab(self, score_name, anchor_data):
"""Builds the content of a tab."""
# Find unique values.
scores = self._scores_data_frame[
self._scores_data_frame.eval_score_name == score_name]
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
test_data_gen_configs = sorted(self._FindUniqueTuples(
scores, ['test_data_gen', 'test_data_gen_params']))
html = [
'<div class="mdl-grid">',
'<div class="mdl-layout-spacer"></div>',
'<div class="mdl-cell mdl-cell--10-col">',
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
'style="width: 100%;">'),
]
# Header.
html.append('<thead><tr><th>APM config / Test data generator</th>')
for test_data_gen_info in test_data_gen_configs:
html.append('<th>{} {}</th>'.format(
self._FormatName(test_data_gen_info[0]), test_data_gen_info[1]))
html.append('</tr></thead>')
# Body.
html.append('<tbody>')
for apm_config in apm_configs:
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1])
# Add Material Design hosted libs.
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
'css?family=Roboto:300,400,500,700" type="text/css">')
html.append(
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.format(
dialog_id, self._BuildScoreTableCell(
score_name, test_data_gen_info[0], test_data_gen_info[1],
apm_config[0])))
html.append('</tr>')
html.append('</tbody>')
'<link rel="stylesheet" href="https://fonts.googleapis.com/'
'icon?family=Material+Icons">')
html.append(
'<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
'material.indigo-pink.min.css">')
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
'material.min.js"></script>')
html.append('</table></div><div class="mdl-layout-spacer"></div></div>')
# Embed custom JavaScript and CSS files.
html.append('<script>')
with open(self._JS_FILEPATH) as f:
html.append(
jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
f.read().rstrip()))
html.append('</script>')
html.append('<style>')
with open(self._CSS_FILEPATH) as f:
html.append(
csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
f.read().rstrip()))
html.append('</style>')
html.append(self._BuildScoreStatsInspectorDialogs(
score_name, apm_configs, test_data_gen_configs,
anchor_data))
html.append('</head>')
return self._NEW_LINE.join(html)
return self._NEW_LINE.join(html)
def _BuildScoreTableCell(self, score_name, test_data_gen,
test_data_gen_params, apm_config):
"""Builds the content of a table cell for a score table."""
scores = self._SliceDataForScoreTableCell(
score_name, apm_config, test_data_gen, test_data_gen_params)
stats = self._ComputeScoreStats(scores)
def _BuildBody(self):
"""Builds the content of the <body> section."""
score_names = self._scores_data_frame[
'eval_score_name'].drop_duplicates().values.tolist()
html = []
items_id_prefix = (
score_name + test_data_gen + test_data_gen_params + apm_config)
if stats['count'] == 1:
# Show the only available score.
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
item_id, scores['score'].mean()))
html.append('<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
'</div>'.format(item_id, 'single value'))
else:
# Show stats.
for stat_name in ['min', 'max', 'mean', 'std dev']:
item_id = hashlib.md5(
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
html.append('<div id="stats-{0}">{1:f}</div>'.format(
item_id, stats[stat_name]))
html.append('<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
'</div>'.format(item_id, stat_name))
html = [
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
'mdl-layout--fixed-tabs">'),
'<header class="mdl-layout__header">',
'<div class="mdl-layout__header-row">',
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
self._output_filepath),
'</div>',
]
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorDialogs(
self, score_name, apm_configs, test_data_gen_configs, anchor_data):
"""Builds a set of score stats inspector dialogs."""
html = []
for apm_config in apm_configs:
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0],
test_data_gen_info[0], test_data_gen_info[1])
html.append('<dialog class="mdl-dialog" id="{}" '
'style="width: 40%;">'.format(dialog_id))
# Content.
html.append('<div class="mdl-dialog__content">')
html.append('<h6><strong>APM config preset</strong>: {}<br/>'
'<strong>Test data generator</strong>: {} ({})</h6>'.format(
self._FormatName(apm_config[0]),
self._FormatName(test_data_gen_info[0]),
test_data_gen_info[1]))
html.append(self._BuildScoreStatsInspectorDialog(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1], anchor_data + (dialog_id,)))
# Tab selectors.
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
for tab_index, score_name in enumerate(score_names):
is_active = tab_index == 0
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
'{}</a>'.format(tab_index,
' is-active' if is_active else '',
self._FormatName(score_name)))
html.append('</div>')
# Actions.
html.append('<div class="mdl-dialog__actions">')
html.append('<button type="button" class="mdl-button" '
'onclick="closeScoreStatsInspector()">'
'Close</button>')
html.append('</header>')
html.append(
'<main class="mdl-layout__content" style="overflow-x: auto;">')
# Tabs content.
for tab_index, score_name in enumerate(score_names):
html.append('<section class="mdl-layout__tab-panel{}" '
'id="score-tab-{}">'.format(
' is-active' if is_active else '', tab_index))
html.append('<div class="page-content">')
html.append(
self._BuildScoreTab(score_name, ('s{}'.format(tab_index), )))
html.append('</div>')
html.append('</section>')
html.append('</main>')
html.append('</div>')
html.append('</dialog>')
# Add snackbar for notifications.
html.append(
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
'<div class="mdl-snackbar__text"></div>'
'<button type="button" class="mdl-snackbar__action"></button>'
'</div>')
return self._NEW_LINE.join(html)
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorDialog(
self, score_name, apm_config, test_data_gen, test_data_gen_params,
anchor_data):
"""Builds one score stats inspector dialog."""
scores = self._SliceDataForScoreTableCell(
score_name, apm_config, test_data_gen, test_data_gen_params)
def _BuildScoreTab(self, score_name, anchor_data):
"""Builds the content of a tab."""
# Find unique values.
scores = self._scores_data_frame[
self._scores_data_frame.eval_score_name == score_name]
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
test_data_gen_configs = sorted(
self._FindUniqueTuples(scores,
['test_data_gen', 'test_data_gen_params']))
capture_render_pairs = sorted(self._FindUniqueTuples(
scores, ['capture', 'render']))
echo_simulators = sorted(self._FindUniqueTuples(scores, ['echo_simulator']))
html = [
'<div class="mdl-grid">',
'<div class="mdl-layout-spacer"></div>',
'<div class="mdl-cell mdl-cell--10-col">',
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
'style="width: 100%;">'),
]
html = ['<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">']
# Header.
html.append('<thead><tr><th>APM config / Test data generator</th>')
for test_data_gen_info in test_data_gen_configs:
html.append('<th>{} {}</th>'.format(
self._FormatName(test_data_gen_info[0]),
test_data_gen_info[1]))
html.append('</tr></thead>')
# Header.
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
for echo_simulator in echo_simulators:
html.append('<th>' + self._FormatName(echo_simulator[0]) +'</th>')
html.append('</tr></thead>')
# Body.
html.append('<tbody>')
for apm_config in apm_configs:
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1])
html.append(
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.
format(
dialog_id,
self._BuildScoreTableCell(score_name,
test_data_gen_info[0],
test_data_gen_info[1],
apm_config[0])))
html.append('</tr>')
html.append('</tbody>')
# Body.
html.append('<tbody>')
for row, (capture, render) in enumerate(capture_render_pairs):
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
capture, render))
for col, echo_simulator in enumerate(echo_simulators):
score_tuple = self._SliceDataForScoreStatsTableCell(
scores, capture, render, echo_simulator[0])
cell_class = 'r{}c{}'.format(row, col)
html.append('<td class="single-score-cell {}">{}</td>'.format(
cell_class, self._BuildScoreStatsInspectorTableCell(
score_tuple, anchor_data + (cell_class,))))
html.append('</tr>')
html.append('</tbody>')
html.append(
'</table></div><div class="mdl-layout-spacer"></div></div>')
html.append('</table>')
html.append(
self._BuildScoreStatsInspectorDialogs(score_name, apm_configs,
test_data_gen_configs,
anchor_data))
# Placeholder for the audio inspector.
html.append('<div class="audio-inspector-placeholder"></div>')
return self._NEW_LINE.join(html)
return self._NEW_LINE.join(html)
def _BuildScoreTableCell(self, score_name, test_data_gen,
test_data_gen_params, apm_config):
"""Builds the content of a table cell for a score table."""
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
test_data_gen,
test_data_gen_params)
stats = self._ComputeScoreStats(scores)
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
"""Builds the content of a cell of a score stats inspector."""
anchor = '&'.join(anchor_data)
html = [('<div class="v">{}</div>'
'<button class="mdl-button mdl-js-button mdl-button--icon"'
' data-anchor="{}">'
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
'</button>').format(score_tuple.score, anchor)]
html = []
items_id_prefix = (score_name + test_data_gen + test_data_gen_params +
apm_config)
if stats['count'] == 1:
# Show the only available score.
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
item_id, scores['score'].mean()))
html.append(
'<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
'</div>'.format(item_id, 'single value'))
else:
# Show stats.
for stat_name in ['min', 'max', 'mean', 'std dev']:
item_id = hashlib.md5(
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
html.append('<div id="stats-{0}">{1:f}</div>'.format(
item_id, stats[stat_name]))
html.append(
'<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
'</div>'.format(item_id, stat_name))
# Add all the available file paths as hidden data.
for field_name in score_tuple.keys():
if field_name.endswith('_filepath'):
html.append('<input type="hidden" name="{}" value="{}">'.format(
field_name, score_tuple[field_name]))
return self._NEW_LINE.join(html)
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorDialogs(self, score_name, apm_configs,
test_data_gen_configs, anchor_data):
"""Builds a set of score stats inspector dialogs."""
html = []
for apm_config in apm_configs:
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1])
def _SliceDataForScoreTableCell(
self, score_name, apm_config, test_data_gen, test_data_gen_params):
"""Slices |self._scores_data_frame| to extract the data for a tab."""
masks = []
masks.append(self._scores_data_frame.eval_score_name == score_name)
masks.append(self._scores_data_frame.apm_config == apm_config)
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
masks.append(
self._scores_data_frame.test_data_gen_params == test_data_gen_params)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
return self._scores_data_frame[mask]
html.append('<dialog class="mdl-dialog" id="{}" '
'style="width: 40%;">'.format(dialog_id))
@classmethod
def _SliceDataForScoreStatsTableCell(
cls, scores, capture, render, echo_simulator):
"""Slices |scores| to extract the data for a tab."""
masks = []
# Content.
html.append('<div class="mdl-dialog__content">')
html.append(
'<h6><strong>APM config preset</strong>: {}<br/>'
'<strong>Test data generator</strong>: {} ({})</h6>'.
format(self._FormatName(apm_config[0]),
self._FormatName(test_data_gen_info[0]),
test_data_gen_info[1]))
html.append(
self._BuildScoreStatsInspectorDialog(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1], anchor_data + (dialog_id, )))
html.append('</div>')
masks.append(scores.capture == capture)
masks.append(scores.render == render)
masks.append(scores.echo_simulator == echo_simulator)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
# Actions.
html.append('<div class="mdl-dialog__actions">')
html.append('<button type="button" class="mdl-button" '
'onclick="closeScoreStatsInspector()">'
'Close</button>')
html.append('</div>')
sliced_data = scores[mask]
assert len(sliced_data) == 1, 'single score is expected'
return sliced_data.iloc[0]
html.append('</dialog>')
@classmethod
def _FindUniqueTuples(cls, data_frame, fields):
"""Slices |data_frame| to a list of fields and finds unique tuples."""
return data_frame[fields].drop_duplicates().values.tolist()
return self._NEW_LINE.join(html)
@classmethod
def _ComputeScoreStats(cls, data_frame):
"""Computes score stats."""
scores = data_frame['score']
return {
'count': scores.count(),
'min': scores.min(),
'max': scores.max(),
'mean': scores.mean(),
'std dev': scores.std(),
}
def _BuildScoreStatsInspectorDialog(self, score_name, apm_config,
test_data_gen, test_data_gen_params,
anchor_data):
"""Builds one score stats inspector dialog."""
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
test_data_gen,
test_data_gen_params)
@classmethod
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config, test_data_gen,
test_data_gen_params):
"""Assigns a unique name to a dialog."""
return 'score-stats-dialog-' + hashlib.md5(
'score-stats-inspector-{}-{}-{}-{}'.format(
score_name, apm_config, test_data_gen,
test_data_gen_params).encode('utf-8')).hexdigest()
capture_render_pairs = sorted(
self._FindUniqueTuples(scores, ['capture', 'render']))
echo_simulators = sorted(
self._FindUniqueTuples(scores, ['echo_simulator']))
@classmethod
def _Save(cls, output_filepath, html):
"""Writes the HTML file.
html = [
'<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">'
]
# Header.
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
for echo_simulator in echo_simulators:
html.append('<th>' + self._FormatName(echo_simulator[0]) + '</th>')
html.append('</tr></thead>')
# Body.
html.append('<tbody>')
for row, (capture, render) in enumerate(capture_render_pairs):
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
capture, render))
for col, echo_simulator in enumerate(echo_simulators):
score_tuple = self._SliceDataForScoreStatsTableCell(
scores, capture, render, echo_simulator[0])
cell_class = 'r{}c{}'.format(row, col)
html.append('<td class="single-score-cell {}">{}</td>'.format(
cell_class,
self._BuildScoreStatsInspectorTableCell(
score_tuple, anchor_data + (cell_class, ))))
html.append('</tr>')
html.append('</tbody>')
html.append('</table>')
# Placeholder for the audio inspector.
html.append('<div class="audio-inspector-placeholder"></div>')
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
"""Builds the content of a cell of a score stats inspector."""
anchor = '&'.join(anchor_data)
html = [('<div class="v">{}</div>'
'<button class="mdl-button mdl-js-button mdl-button--icon"'
' data-anchor="{}">'
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
'</button>').format(score_tuple.score, anchor)]
# Add all the available file paths as hidden data.
for field_name in score_tuple.keys():
if field_name.endswith('_filepath'):
html.append(
'<input type="hidden" name="{}" value="{}">'.format(
field_name, score_tuple[field_name]))
return self._NEW_LINE.join(html)
def _SliceDataForScoreTableCell(self, score_name, apm_config,
test_data_gen, test_data_gen_params):
"""Slices |self._scores_data_frame| to extract the data for a tab."""
masks = []
masks.append(self._scores_data_frame.eval_score_name == score_name)
masks.append(self._scores_data_frame.apm_config == apm_config)
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
masks.append(self._scores_data_frame.test_data_gen_params ==
test_data_gen_params)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
return self._scores_data_frame[mask]
@classmethod
def _SliceDataForScoreStatsTableCell(cls, scores, capture, render,
echo_simulator):
"""Slices |scores| to extract the data for a tab."""
masks = []
masks.append(scores.capture == capture)
masks.append(scores.render == render)
masks.append(scores.echo_simulator == echo_simulator)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
sliced_data = scores[mask]
assert len(sliced_data) == 1, 'single score is expected'
return sliced_data.iloc[0]
@classmethod
def _FindUniqueTuples(cls, data_frame, fields):
"""Slices |data_frame| to a list of fields and finds unique tuples."""
return data_frame[fields].drop_duplicates().values.tolist()
@classmethod
def _ComputeScoreStats(cls, data_frame):
"""Computes score stats."""
scores = data_frame['score']
return {
'count': scores.count(),
'min': scores.min(),
'max': scores.max(),
'mean': scores.mean(),
'std dev': scores.std(),
}
@classmethod
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config,
test_data_gen, test_data_gen_params):
"""Assigns a unique name to a dialog."""
return 'score-stats-dialog-' + hashlib.md5(
'score-stats-inspector-{}-{}-{}-{}'.format(
score_name, apm_config, test_data_gen,
test_data_gen_params).encode('utf-8')).hexdigest()
@classmethod
def _Save(cls, output_filepath, html):
"""Writes the HTML file.
Args:
output_filepath: output file path.
html: string with the HTML content.
"""
with open(output_filepath, 'w') as f:
f.write(html)
with open(output_filepath, 'w') as f:
f.write(html)
@classmethod
def _FormatName(cls, name):
"""Formats a name.
@classmethod
def _FormatName(cls, name):
"""Formats a name.
Args:
name: a string.
@ -399,4 +423,4 @@ class HtmlExport(object):
Returns:
A copy of name in which underscores and dashes are replaced with a space.
"""
return re.sub(r'[_\-]', ' ', name)
return re.sub(r'[_\-]', ' ', name)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the export module.
"""
@ -27,60 +26,61 @@ from . import test_data_generation_factory
class TestExport(unittest.TestCase):
"""Unit tests for the export module.
"""Unit tests for the export module.
"""
_CLEAN_TMP_OUTPUT = True
_CLEAN_TMP_OUTPUT = True
def setUp(self):
"""Creates temporary data to export."""
self._tmp_path = tempfile.mkdtemp()
def setUp(self):
"""Creates temporary data to export."""
self._tmp_path = tempfile.mkdtemp()
# Run a fake experiment to produce data to export.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
echo_metric_tool_bin_path=None
)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
simulator.Run(
config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=[
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
],
test_data_generator_names=['identity', 'white_noise'],
eval_score_names=['audio_level_peak', 'audio_level_mean'],
output_dir=self._tmp_path)
# Run a fake experiment to produce data to export.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
simulator.Run(
config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=[
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
],
test_data_generator_names=['identity', 'white_noise'],
eval_score_names=['audio_level_peak', 'audio_level_mean'],
output_dir=self._tmp_path)
# Export results.
p = collect_data.InstanceArgumentsParser()
args = p.parse_args(['--output_dir', self._tmp_path])
src_path = collect_data.ConstructSrcPath(args)
self._data_to_export = collect_data.FindScores(src_path, args)
# Export results.
p = collect_data.InstanceArgumentsParser()
args = p.parse_args(['--output_dir', self._tmp_path])
src_path = collect_data.ConstructSrcPath(args)
self._data_to_export = collect_data.FindScores(src_path, args)
def tearDown(self):
"""Recursively deletes temporary folders."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' + (
self._tmp_path))
def tearDown(self):
"""Recursively deletes temporary folders."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' +
(self._tmp_path))
def testCreateHtmlReport(self):
fn_out = os.path.join(self._tmp_path, 'results.html')
exporter = export.HtmlExport(fn_out)
exporter.Export(self._data_to_export)
def testCreateHtmlReport(self):
fn_out = os.path.join(self._tmp_path, 'results.html')
exporter = export.HtmlExport(fn_out)
exporter.Export(self._data_to_export)
document = pq.PyQuery(filename=fn_out)
self.assertIsInstance(document, pq.PyQuery)
# TODO(alessiob): Use PyQuery API to check the HTML file.
document = pq.PyQuery(filename=fn_out)
self.assertIsInstance(document, pq.PyQuery)
# TODO(alessiob): Use PyQuery API to check the HTML file.

View File

@ -16,62 +16,60 @@ import sys
import tempfile
try:
import numpy as np
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import signal_processing
class ExternalVad(object):
def __init__(self, path_to_binary, name):
"""Args:
class ExternalVad(object):
def __init__(self, path_to_binary, name):
"""Args:
path_to_binary: path to binary that accepts '-i <wav>', '-o
<float probabilities>'. There must be one float value per
10ms audio
name: a name to identify the external VAD. Used for saving
the output as extvad_output-<name>.
"""
self._path_to_binary = path_to_binary
self.name = name
assert os.path.exists(self._path_to_binary), (
self._path_to_binary)
self._vad_output = None
self._path_to_binary = path_to_binary
self.name = name
assert os.path.exists(self._path_to_binary), (self._path_to_binary)
self._vad_output = None
def Run(self, wav_file_path):
_signal = signal_processing.SignalProcessingUtils.LoadWav(wav_file_path)
if _signal.channels != 1:
raise NotImplementedError('Multiple-channel'
' annotations not implemented')
if _signal.frame_rate != 48000:
raise NotImplementedError('Frame rates '
'other than 48000 not implemented')
def Run(self, wav_file_path):
_signal = signal_processing.SignalProcessingUtils.LoadWav(
wav_file_path)
if _signal.channels != 1:
raise NotImplementedError('Multiple-channel'
' annotations not implemented')
if _signal.frame_rate != 48000:
raise NotImplementedError('Frame rates '
'other than 48000 not implemented')
tmp_path = tempfile.mkdtemp()
try:
output_file_path = os.path.join(
tmp_path, self.name + '_vad.tmp')
subprocess.call([
self._path_to_binary,
'-i', wav_file_path,
'-o', output_file_path
])
self._vad_output = np.fromfile(output_file_path, np.float32)
except Exception as e:
logging.error('Error while running the ' + self.name +
' VAD (' + e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
tmp_path = tempfile.mkdtemp()
try:
output_file_path = os.path.join(tmp_path, self.name + '_vad.tmp')
subprocess.call([
self._path_to_binary, '-i', wav_file_path, '-o',
output_file_path
])
self._vad_output = np.fromfile(output_file_path, np.float32)
except Exception as e:
logging.error('Error while running the ' + self.name + ' VAD (' +
e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
def GetVadOutput(self):
assert self._vad_output is not None
return self._vad_output
def GetVadOutput(self):
assert self._vad_output is not None
return self._vad_output
@classmethod
def ConstructVadDict(cls, vad_paths, vad_names):
external_vads = {}
for path, name in zip(vad_paths, vad_names):
external_vads[name] = ExternalVad(path, name)
return external_vads
@classmethod
def ConstructVadDict(cls, vad_paths, vad_names):
external_vads = {}
for path, name in zip(vad_paths, vad_names):
external_vads[name] = ExternalVad(path, name)
return external_vads

View File

@ -9,16 +9,17 @@
import argparse
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', required=True)
parser.add_argument('-o', required=True)
parser = argparse.ArgumentParser()
parser.add_argument('-i', required=True)
parser.add_argument('-o', required=True)
args = parser.parse_args()
args = parser.parse_args()
array = np.arange(100, dtype=np.float32)
array.tofile(open(args.o, 'w'))
array = np.arange(100, dtype=np.float32)
array.tofile(open(args.o, 'w'))
if __name__ == '__main__':
main()
main()

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Input mixer module.
"""
@ -17,24 +16,24 @@ from . import signal_processing
class ApmInputMixer(object):
"""Class to mix a set of audio segments down to the APM input."""
"""Class to mix a set of audio segments down to the APM input."""
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def HardClippingLogMessage(cls):
"""Returns the log message used when hard clipping is detected in the mix.
@classmethod
def HardClippingLogMessage(cls):
"""Returns the log message used when hard clipping is detected in the mix.
This method is mainly intended to be used by the unit tests.
"""
return cls._HARD_CLIPPING_LOG_MSG
return cls._HARD_CLIPPING_LOG_MSG
@classmethod
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
"""Mixes capture and echo.
@classmethod
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
"""Mixes capture and echo.
Creates the overall capture input for APM by mixing the "echo-free" capture
signal with the echo signal (e.g., echo simulated via the
@ -58,38 +57,41 @@ class ApmInputMixer(object):
Returns:
Path to the mix audio track file.
"""
if echo_filepath is None:
return capture_input_filepath
if echo_filepath is None:
return capture_input_filepath
# Build the mix output file name as a function of the echo file name.
# This ensures that if the internal parameters of the echo path simulator
# change, no erroneous cache hit occurs.
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
capture_input_file_name, _ = os.path.splitext(
os.path.split(capture_input_filepath)[1])
mix_filepath = os.path.join(output_path, 'mix_capture_{}_{}.wav'.format(
capture_input_file_name, echo_file_name))
# Build the mix output file name as a function of the echo file name.
# This ensures that if the internal parameters of the echo path simulator
# change, no erroneous cache hit occurs.
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
capture_input_file_name, _ = os.path.splitext(
os.path.split(capture_input_filepath)[1])
mix_filepath = os.path.join(
output_path,
'mix_capture_{}_{}.wav'.format(capture_input_file_name,
echo_file_name))
# Create the mix if not done yet.
mix = None
if not os.path.exists(mix_filepath):
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
capture_input_filepath)
echo = signal_processing.SignalProcessingUtils.LoadWav(echo_filepath)
# Create the mix if not done yet.
mix = None
if not os.path.exists(mix_filepath):
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
capture_input_filepath)
echo = signal_processing.SignalProcessingUtils.LoadWav(
echo_filepath)
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
signal_processing.SignalProcessingUtils.CountSamples(
echo_free_capture)):
raise exceptions.InputMixerException(
'echo cannot be shorter than capture')
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
signal_processing.SignalProcessingUtils.CountSamples(
echo_free_capture)):
raise exceptions.InputMixerException(
'echo cannot be shorter than capture')
mix = echo_free_capture.overlay(echo)
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
mix = echo_free_capture.overlay(echo)
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
# Check if hard clipping occurs.
if mix is None:
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
# Check if hard clipping occurs.
if mix is None:
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
return mix_filepath
return mix_filepath

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the input mixer module.
"""
@ -23,122 +22,119 @@ from . import signal_processing
class TestApmInputMixer(unittest.TestCase):
"""Unit tests for the ApmInputMixer class.
"""Unit tests for the ApmInputMixer class.
"""
# Audio track file names created in setUp().
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
# Audio track file names created in setUp().
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
# Target peak power level (dBFS) of each audio track file created in setUp().
# These values are hand-crafted in order to make saturation happen when
# capture and echo_2 are mixed and the contrary for capture and echo_1.
# None means that the power is not changed.
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
# Target peak power level (dBFS) of each audio track file created in setUp().
# These values are hand-crafted in order to make saturation happen when
# capture and echo_2 are mixed and the contrary for capture and echo_1.
# None means that the power is not changed.
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
# Audio track file durations in milliseconds.
_DURATIONS = [1000, 1000, 1000, 800, 1200]
# Audio track file durations in milliseconds.
_DURATIONS = [1000, 1000, 1000, 800, 1200]
_SAMPLE_RATE = 48000
_SAMPLE_RATE = 48000
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
# Create audio track files.
self._audio_tracks = {}
for filename, peak_power, duration in zip(
self._FILENAMES, self._MAX_PEAK_POWER_LEVELS, self._DURATIONS):
audio_track_filepath = os.path.join(self._tmp_path, '{}.wav'.format(
filename))
# Create audio track files.
self._audio_tracks = {}
for filename, peak_power, duration in zip(self._FILENAMES,
self._MAX_PEAK_POWER_LEVELS,
self._DURATIONS):
audio_track_filepath = os.path.join(self._tmp_path,
'{}.wav'.format(filename))
# Create a pure tone with the target peak power level.
template = signal_processing.SignalProcessingUtils.GenerateSilence(
duration=duration, sample_rate=self._SAMPLE_RATE)
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
template)
if peak_power is not None:
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
# Create a pure tone with the target peak power level.
template = signal_processing.SignalProcessingUtils.GenerateSilence(
duration=duration, sample_rate=self._SAMPLE_RATE)
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
template)
if peak_power is not None:
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
signal_processing.SignalProcessingUtils.SaveWav(
audio_track_filepath, signal)
self._audio_tracks[filename] = {
'filepath': audio_track_filepath,
'num_samples': signal_processing.SignalProcessingUtils.CountSamples(
signal)
}
signal_processing.SignalProcessingUtils.SaveWav(
audio_track_filepath, signal)
self._audio_tracks[filename] = {
'filepath':
audio_track_filepath,
'num_samples':
signal_processing.SignalProcessingUtils.CountSamples(signal)
}
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def testCheckMixSameDuration(self):
"""Checks the duration when mixing capture and echo with same duration."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
def testCheckMixSameDuration(self):
"""Checks the duration when mixing capture and echo with same duration."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(
self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
def testRejectShorterEcho(self):
"""Rejects echo signals that are shorter than the capture signal."""
try:
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['shorter']['filepath'])
self.fail('no exception raised')
except exceptions.InputMixerException:
pass
def testRejectShorterEcho(self):
"""Rejects echo signals that are shorter than the capture signal."""
try:
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['shorter']['filepath'])
self.fail('no exception raised')
except exceptions.InputMixerException:
pass
def testCheckMixDurationWithLongerEcho(self):
"""Checks the duration when mixing an echo longer than the capture."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['longer']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
def testCheckMixDurationWithLongerEcho(self):
"""Checks the duration when mixing an echo longer than the capture."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['longer']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(
self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
def testCheckOutputFileNamesConflict(self):
"""Checks that different echo files lead to different output file names."""
mix1_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix1_filepath))
def testCheckOutputFileNamesConflict(self):
"""Checks that different echo files lead to different output file names."""
mix1_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix1_filepath))
mix2_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
self.assertTrue(os.path.exists(mix2_filepath))
mix2_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
self.assertTrue(os.path.exists(mix2_filepath))
self.assertNotEqual(mix1_filepath, mix2_filepath)
self.assertNotEqual(mix1_filepath, mix2_filepath)
def testHardClippingLogExpected(self):
"""Checks that hard clipping warning is raised when occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
logging.warning.assert_called_once_with(
input_mixer.ApmInputMixer.HardClippingLogMessage())
def testHardClippingLogExpected(self):
"""Checks that hard clipping warning is raised when occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
logging.warning.assert_called_once_with(
input_mixer.ApmInputMixer.HardClippingLogMessage())
def testHardClippingLogNotExpected(self):
"""Checks that hard clipping warning is not raised when not occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertNotIn(
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
logging.warning.call_args_list)
def testHardClippingLogNotExpected(self):
"""Checks that hard clipping warning is not raised when not occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertNotIn(
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
logging.warning.call_args_list)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Input signal creator module.
"""
@ -14,12 +13,12 @@ from . import signal_processing
class InputSignalCreator(object):
"""Input signal creator class.
"""Input signal creator class.
"""
@classmethod
def Create(cls, name, raw_params):
"""Creates a input signal and its metadata.
@classmethod
def Create(cls, name, raw_params):
"""Creates a input signal and its metadata.
Args:
name: Input signal creator name.
@ -28,29 +27,30 @@ class InputSignalCreator(object):
Returns:
(AudioSegment, dict) tuple.
"""
try:
signal = {}
params = {}
try:
signal = {}
params = {}
if name == 'pure_tone':
params['frequency'] = float(raw_params[0])
params['duration'] = int(raw_params[1])
signal = cls._CreatePureTone(params['frequency'], params['duration'])
else:
raise exceptions.InputSignalCreatorException(
'Invalid input signal creator name')
if name == 'pure_tone':
params['frequency'] = float(raw_params[0])
params['duration'] = int(raw_params[1])
signal = cls._CreatePureTone(params['frequency'],
params['duration'])
else:
raise exceptions.InputSignalCreatorException(
'Invalid input signal creator name')
# Complete metadata.
params['signal'] = name
# Complete metadata.
params['signal'] = name
return signal, params
except (TypeError, AssertionError) as e:
raise exceptions.InputSignalCreatorException(
'Invalid signal creator parameters: {}'.format(e))
return signal, params
except (TypeError, AssertionError) as e:
raise exceptions.InputSignalCreatorException(
'Invalid signal creator parameters: {}'.format(e))
@classmethod
def _CreatePureTone(cls, frequency, duration):
"""
@classmethod
def _CreatePureTone(cls, frequency, duration):
"""
Generates a pure tone at 48000 Hz.
Args:
@ -60,8 +60,9 @@ class InputSignalCreator(object):
Returns:
AudioSegment instance.
"""
assert 0 < frequency <= 24000
assert duration > 0
template = signal_processing.SignalProcessingUtils.GenerateSilence(duration)
return signal_processing.SignalProcessingUtils.GeneratePureTone(
template, frequency)
assert 0 < frequency <= 24000
assert duration > 0
template = signal_processing.SignalProcessingUtils.GenerateSilence(
duration)
return signal_processing.SignalProcessingUtils.GeneratePureTone(
template, frequency)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Signal processing utility module.
"""
@ -16,44 +15,44 @@ import sys
import enum
try:
import numpy as np
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
try:
import pydub
import pydub.generators
import pydub
import pydub.generators
except ImportError:
logging.critical('Cannot import the third-party Python package pydub')
sys.exit(1)
logging.critical('Cannot import the third-party Python package pydub')
sys.exit(1)
try:
import scipy.signal
import scipy.fftpack
import scipy.signal
import scipy.fftpack
except ImportError:
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
from . import exceptions
class SignalProcessingUtils(object):
"""Collection of signal processing utilities.
"""Collection of signal processing utilities.
"""
@enum.unique
class MixPadding(enum.Enum):
NO_PADDING = 0
ZERO_PADDING = 1
LOOP = 2
@enum.unique
class MixPadding(enum.Enum):
NO_PADDING = 0
ZERO_PADDING = 1
LOOP = 2
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def LoadWav(cls, filepath, channels=1):
"""Loads wav file.
@classmethod
def LoadWav(cls, filepath, channels=1):
"""Loads wav file.
Args:
filepath: path to the wav audio track file to load.
@ -62,25 +61,26 @@ class SignalProcessingUtils(object):
Returns:
AudioSegment instance.
"""
if not os.path.exists(filepath):
logging.error('cannot find the <%s> audio track file', filepath)
raise exceptions.FileNotFoundError()
return pydub.AudioSegment.from_file(
filepath, format='wav', channels=channels)
if not os.path.exists(filepath):
logging.error('cannot find the <%s> audio track file', filepath)
raise exceptions.FileNotFoundError()
return pydub.AudioSegment.from_file(filepath,
format='wav',
channels=channels)
@classmethod
def SaveWav(cls, output_filepath, signal):
"""Saves wav file.
@classmethod
def SaveWav(cls, output_filepath, signal):
"""Saves wav file.
Args:
output_filepath: path to the wav audio track file to save.
signal: AudioSegment instance.
"""
return signal.export(output_filepath, format='wav')
return signal.export(output_filepath, format='wav')
@classmethod
def CountSamples(cls, signal):
"""Number of samples per channel.
@classmethod
def CountSamples(cls, signal):
"""Number of samples per channel.
Args:
signal: AudioSegment instance.
@ -88,14 +88,14 @@ class SignalProcessingUtils(object):
Returns:
An integer.
"""
number_of_samples = len(signal.get_array_of_samples())
assert signal.channels > 0
assert number_of_samples % signal.channels == 0
return number_of_samples / signal.channels
number_of_samples = len(signal.get_array_of_samples())
assert signal.channels > 0
assert number_of_samples % signal.channels == 0
return number_of_samples / signal.channels
@classmethod
def GenerateSilence(cls, duration=1000, sample_rate=48000):
"""Generates silence.
@classmethod
def GenerateSilence(cls, duration=1000, sample_rate=48000):
"""Generates silence.
This method can also be used to create a template AudioSegment instance.
A template can then be used with other Generate*() methods accepting an
@ -108,11 +108,11 @@ class SignalProcessingUtils(object):
Returns:
AudioSegment instance.
"""
return pydub.AudioSegment.silent(duration, sample_rate)
return pydub.AudioSegment.silent(duration, sample_rate)
@classmethod
def GeneratePureTone(cls, template, frequency=440.0):
"""Generates a pure tone.
@classmethod
def GeneratePureTone(cls, template, frequency=440.0):
"""Generates a pure tone.
The pure tone is generated with the same duration and in the same format of
the given template signal.
@ -124,21 +124,18 @@ class SignalProcessingUtils(object):
Return:
AudioSegment instance.
"""
if frequency > template.frame_rate >> 1:
raise exceptions.SignalProcessingException('Invalid frequency')
if frequency > template.frame_rate >> 1:
raise exceptions.SignalProcessingException('Invalid frequency')
generator = pydub.generators.Sine(
sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8,
freq=frequency)
generator = pydub.generators.Sine(sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8,
freq=frequency)
return generator.to_audio_segment(
duration=len(template),
volume=0.0)
return generator.to_audio_segment(duration=len(template), volume=0.0)
@classmethod
def GenerateWhiteNoise(cls, template):
"""Generates white noise.
@classmethod
def GenerateWhiteNoise(cls, template):
"""Generates white noise.
The white noise is generated with the same duration and in the same format
of the given template signal.
@ -149,33 +146,32 @@ class SignalProcessingUtils(object):
Return:
AudioSegment instance.
"""
generator = pydub.generators.WhiteNoise(
sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8)
return generator.to_audio_segment(
duration=len(template),
volume=0.0)
generator = pydub.generators.WhiteNoise(
sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8)
return generator.to_audio_segment(duration=len(template), volume=0.0)
@classmethod
def AudioSegmentToRawData(cls, signal):
samples = signal.get_array_of_samples()
if samples.typecode != 'h':
raise exceptions.SignalProcessingException('Unsupported samples type')
return np.array(signal.get_array_of_samples(), np.int16)
@classmethod
def AudioSegmentToRawData(cls, signal):
samples = signal.get_array_of_samples()
if samples.typecode != 'h':
raise exceptions.SignalProcessingException(
'Unsupported samples type')
return np.array(signal.get_array_of_samples(), np.int16)
@classmethod
def Fft(cls, signal, normalize=True):
if signal.channels != 1:
raise NotImplementedError('multiple-channel FFT not implemented')
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
if normalize:
x /= max(abs(np.max(x)), 1.0)
y = scipy.fftpack.fft(x)
return y[:len(y) / 2]
@classmethod
def Fft(cls, signal, normalize=True):
if signal.channels != 1:
raise NotImplementedError('multiple-channel FFT not implemented')
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
if normalize:
x /= max(abs(np.max(x)), 1.0)
y = scipy.fftpack.fft(x)
return y[:len(y) / 2]
@classmethod
def DetectHardClipping(cls, signal, threshold=2):
"""Detects hard clipping.
@classmethod
def DetectHardClipping(cls, signal, threshold=2):
"""Detects hard clipping.
Hard clipping is simply detected by counting samples that touch either the
lower or upper bound too many times in a row (according to |threshold|).
@ -189,32 +185,33 @@ class SignalProcessingUtils(object):
Returns:
True if hard clipping is detect, False otherwise.
"""
if signal.channels != 1:
raise NotImplementedError('multiple-channel clipping not implemented')
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
raise exceptions.SignalProcessingException(
'hard-clipping detection only supported for 16 bit samples')
samples = cls.AudioSegmentToRawData(signal)
if signal.channels != 1:
raise NotImplementedError(
'multiple-channel clipping not implemented')
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
raise exceptions.SignalProcessingException(
'hard-clipping detection only supported for 16 bit samples')
samples = cls.AudioSegmentToRawData(signal)
# Detect adjacent clipped samples.
samples_type_info = np.iinfo(samples.dtype)
mask_min = samples == samples_type_info.min
mask_max = samples == samples_type_info.max
# Detect adjacent clipped samples.
samples_type_info = np.iinfo(samples.dtype)
mask_min = samples == samples_type_info.min
mask_max = samples == samples_type_info.max
def HasLongSequence(vector, min_legth=threshold):
"""Returns True if there are one or more long sequences of True flags."""
seq_length = 0
for b in vector:
seq_length = seq_length + 1 if b else 0
if seq_length >= min_legth:
return True
return False
def HasLongSequence(vector, min_legth=threshold):
"""Returns True if there are one or more long sequences of True flags."""
seq_length = 0
for b in vector:
seq_length = seq_length + 1 if b else 0
if seq_length >= min_legth:
return True
return False
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
@classmethod
def ApplyImpulseResponse(cls, signal, impulse_response):
"""Applies an impulse response to a signal.
@classmethod
def ApplyImpulseResponse(cls, signal, impulse_response):
"""Applies an impulse response to a signal.
Args:
signal: AudioSegment instance.
@ -223,44 +220,48 @@ class SignalProcessingUtils(object):
Returns:
AudioSegment instance.
"""
# Get samples.
assert signal.channels == 1, (
'multiple-channel recordings not supported')
samples = signal.get_array_of_samples()
# Get samples.
assert signal.channels == 1, (
'multiple-channel recordings not supported')
samples = signal.get_array_of_samples()
# Convolve.
logging.info('applying %d order impulse response to a signal lasting %d ms',
len(impulse_response), len(signal))
convolved_samples = scipy.signal.fftconvolve(
in1=samples,
in2=impulse_response,
mode='full').astype(np.int16)
logging.info('convolution computed')
# Convolve.
logging.info(
'applying %d order impulse response to a signal lasting %d ms',
len(impulse_response), len(signal))
convolved_samples = scipy.signal.fftconvolve(in1=samples,
in2=impulse_response,
mode='full').astype(
np.int16)
logging.info('convolution computed')
# Cast.
convolved_samples = array.array(signal.array_type, convolved_samples)
# Cast.
convolved_samples = array.array(signal.array_type, convolved_samples)
# Verify.
logging.debug('signal length: %d samples', len(samples))
logging.debug('convolved signal length: %d samples', len(convolved_samples))
assert len(convolved_samples) > len(samples)
# Verify.
logging.debug('signal length: %d samples', len(samples))
logging.debug('convolved signal length: %d samples',
len(convolved_samples))
assert len(convolved_samples) > len(samples)
# Generate convolved signal AudioSegment instance.
convolved_signal = pydub.AudioSegment(
data=convolved_samples,
metadata={
'sample_width': signal.sample_width,
'frame_rate': signal.frame_rate,
'frame_width': signal.frame_width,
'channels': signal.channels,
})
assert len(convolved_signal) > len(signal)
# Generate convolved signal AudioSegment instance.
convolved_signal = pydub.AudioSegment(data=convolved_samples,
metadata={
'sample_width':
signal.sample_width,
'frame_rate':
signal.frame_rate,
'frame_width':
signal.frame_width,
'channels': signal.channels,
})
assert len(convolved_signal) > len(signal)
return convolved_signal
return convolved_signal
@classmethod
def Normalize(cls, signal):
"""Normalizes a signal.
@classmethod
def Normalize(cls, signal):
"""Normalizes a signal.
Args:
signal: AudioSegment instance.
@ -268,11 +269,11 @@ class SignalProcessingUtils(object):
Returns:
An AudioSegment instance.
"""
return signal.apply_gain(-signal.max_dBFS)
return signal.apply_gain(-signal.max_dBFS)
@classmethod
def Copy(cls, signal):
"""Makes a copy os a signal.
@classmethod
def Copy(cls, signal):
"""Makes a copy os a signal.
Args:
signal: AudioSegment instance.
@ -280,19 +281,21 @@ class SignalProcessingUtils(object):
Returns:
An AudioSegment instance.
"""
return pydub.AudioSegment(
data=signal.get_array_of_samples(),
metadata={
'sample_width': signal.sample_width,
'frame_rate': signal.frame_rate,
'frame_width': signal.frame_width,
'channels': signal.channels,
})
return pydub.AudioSegment(data=signal.get_array_of_samples(),
metadata={
'sample_width': signal.sample_width,
'frame_rate': signal.frame_rate,
'frame_width': signal.frame_width,
'channels': signal.channels,
})
@classmethod
def MixSignals(cls, signal, noise, target_snr=0.0,
pad_noise=MixPadding.NO_PADDING):
"""Mixes |signal| and |noise| with a target SNR.
@classmethod
def MixSignals(cls,
signal,
noise,
target_snr=0.0,
pad_noise=MixPadding.NO_PADDING):
"""Mixes |signal| and |noise| with a target SNR.
Mix |signal| and |noise| with a desired SNR by scaling |noise|.
If the target SNR is +/- infinite, a copy of signal/noise is returned.
@ -312,45 +315,45 @@ class SignalProcessingUtils(object):
Returns:
An AudioSegment instance.
"""
# Handle infinite target SNR.
if target_snr == -np.Inf:
# Return a copy of noise.
logging.warning('SNR = -Inf, returning noise')
return cls.Copy(noise)
elif target_snr == np.Inf:
# Return a copy of signal.
logging.warning('SNR = +Inf, returning signal')
return cls.Copy(signal)
# Handle infinite target SNR.
if target_snr == -np.Inf:
# Return a copy of noise.
logging.warning('SNR = -Inf, returning noise')
return cls.Copy(noise)
elif target_snr == np.Inf:
# Return a copy of signal.
logging.warning('SNR = +Inf, returning signal')
return cls.Copy(signal)
# Check signal and noise power.
signal_power = float(signal.dBFS)
noise_power = float(noise.dBFS)
if signal_power == -np.Inf:
logging.error('signal has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
if noise_power == -np.Inf:
logging.error('noise has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
# Check signal and noise power.
signal_power = float(signal.dBFS)
noise_power = float(noise.dBFS)
if signal_power == -np.Inf:
logging.error('signal has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
if noise_power == -np.Inf:
logging.error('noise has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
# Mix.
gain_db = signal_power - noise_power - target_snr
signal_duration = len(signal)
noise_duration = len(noise)
if signal_duration <= noise_duration:
# Ignore |pad_noise|, |noise| is truncated if longer that |signal|, the
# mix will have the same length of |signal|.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.NO_PADDING:
# |signal| is longer than |noise|, but no padding is applied to |noise|.
# Truncate |signal|.
return noise.overlay(signal, gain_during_overlay=gain_db)
elif pad_noise == cls.MixPadding.ZERO_PADDING:
# TODO(alessiob): Check that this works as expected.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.LOOP:
# |signal| is longer than |noise|, extend |noise| by looping.
return signal.overlay(noise.apply_gain(gain_db), loop=True)
else:
raise exceptions.SignalProcessingException('invalid padding type')
# Mix.
gain_db = signal_power - noise_power - target_snr
signal_duration = len(signal)
noise_duration = len(noise)
if signal_duration <= noise_duration:
# Ignore |pad_noise|, |noise| is truncated if longer that |signal|, the
# mix will have the same length of |signal|.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.NO_PADDING:
# |signal| is longer than |noise|, but no padding is applied to |noise|.
# Truncate |signal|.
return noise.overlay(signal, gain_during_overlay=gain_db)
elif pad_noise == cls.MixPadding.ZERO_PADDING:
# TODO(alessiob): Check that this works as expected.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.LOOP:
# |signal| is longer than |noise|, extend |noise| by looping.
return signal.overlay(noise.apply_gain(gain_db), loop=True)
else:
raise exceptions.SignalProcessingException('invalid padding type')

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the signal_processing module.
"""
@ -19,168 +18,166 @@ from . import signal_processing
class TestSignalProcessing(unittest.TestCase):
"""Unit tests for the signal_processing module.
"""Unit tests for the signal_processing module.
"""
def testMixSignals(self):
# Generate a template signal with which white noise can be generated.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
def testMixSignals(self):
# Generate a template signal with which white noise can be generated.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
# Generate two distinct AudioSegment instances with 1 second of white noise.
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
# Generate two distinct AudioSegment instances with 1 second of white noise.
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
# Extract samples.
signal_samples = signal.get_array_of_samples()
noise_samples = noise.get_array_of_samples()
# Extract samples.
signal_samples = signal.get_array_of_samples()
noise_samples = noise.get_array_of_samples()
# Test target SNR -Inf (noise expected).
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, -np.Inf)
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
# Test target SNR -Inf (noise expected).
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, -np.Inf)
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
# Test target SNR 0.0 (different data expected).
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, 0.0)
self.assertTrue(len(signal), len(mix_0)) # Check duration.
self.assertTrue(len(noise), len(mix_0))
mix_0_samples = mix_0.get_array_of_samples()
self.assertTrue(
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
self.assertTrue(
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
# Test target SNR 0.0 (different data expected).
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, 0.0)
self.assertTrue(len(signal), len(mix_0)) # Check duration.
self.assertTrue(len(noise), len(mix_0))
mix_0_samples = mix_0.get_array_of_samples()
self.assertTrue(
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
self.assertTrue(
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
# Test target SNR +Inf (signal expected).
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, np.Inf)
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
# Test target SNR +Inf (signal expected).
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, np.Inf)
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
def testMixSignalsMinInfPower(self):
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
def testMixSignalsMinInfPower(self):
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
signal, silence, 0.0)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
signal, silence, 0.0)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
silence, signal, 0.0)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
silence, signal, 0.0)
def testMixSignalNoiseDifferentLengths(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
def testMixSignalNoiseDifferentLengths(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
# When the signal is shorter than the noise, the mix length always equals
# that of the signal regardless of whether padding is applied.
# No noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
self.assertEqual(len(shorter), len(mix))
# When the signal is shorter than the noise, the mix length always equals
# that of the signal regardless of whether padding is applied.
# No noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
ZERO_PADDING)
self.assertEqual(len(shorter), len(mix))
# When the signal is longer than the noise, the mix length depends on
# whether padding is applied.
# No noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
self.assertEqual(len(longer), len(mix))
# When the signal is longer than the noise, the mix length depends on
# whether padding is applied.
# No noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
ZERO_PADDING)
self.assertEqual(len(longer), len(mix))
def testMixSignalNoisePaddingTypes(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
def testMixSignalNoisePaddingTypes(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
# Zero padding: expect pure tone only in 1-2s.
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
# Zero padding: expect pure tone only in 1-2s.
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
ZERO_PADDING)
# Loop: expect pure tone plus noise in 1-2s.
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
# Loop: expect pure tone plus noise in 1-2s.
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
def Energy(signal):
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
signal).astype(np.float32)
return np.sum(samples * samples)
def Energy(signal):
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
signal).astype(np.float32)
return np.sum(samples * samples)
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
e_mix_loop = Energy(mix_loop[-1000:])
self.assertLess(0, e_mix_zero_pad)
self.assertLess(e_mix_zero_pad, e_mix_loop)
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
e_mix_loop = Energy(mix_loop[-1000:])
self.assertLess(0, e_mix_zero_pad)
self.assertLess(e_mix_zero_pad, e_mix_loop)
def testMixSignalSnr(self):
# Test signals.
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
def testMixSignalSnr(self):
# Test signals.
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
def ToneAmplitudes(mix):
"""Returns the amplitude of the coefficients #16 and #192, which
def ToneAmplitudes(mix):
"""Returns the amplitude of the coefficients #16 and #192, which
correspond to the tones at 250 and 3k Hz respectively."""
mix_fft = np.absolute(signal_processing.SignalProcessingUtils.Fft(mix))
return mix_fft[16], mix_fft[192]
mix_fft = np.absolute(
signal_processing.SignalProcessingUtils.Fft(mix))
return mix_fft[16], mix_fft[192]
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low,
noise=tone_high,
target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low, noise=tone_high, target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high,
noise=tone_low,
target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high, noise=tone_low, target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low,
noise=tone_high,
target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low, noise=tone_high, target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high,
noise=tone_low,
target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high, noise=tone_low, target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""APM module simulator.
"""
@ -25,85 +24,93 @@ from . import test_data_generation
class ApmModuleSimulator(object):
"""Audio processing module (APM) simulator class.
"""Audio processing module (APM) simulator class.
"""
_TEST_DATA_GENERATOR_CLASSES = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
_TEST_DATA_GENERATOR_CLASSES = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
_PREFIX_APM_CONFIG = 'apmcfg-'
_PREFIX_CAPTURE = 'capture-'
_PREFIX_RENDER = 'render-'
_PREFIX_ECHO_SIMULATOR = 'echosim-'
_PREFIX_TEST_DATA_GEN = 'datagen-'
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
_PREFIX_SCORE = 'score-'
_PREFIX_APM_CONFIG = 'apmcfg-'
_PREFIX_CAPTURE = 'capture-'
_PREFIX_RENDER = 'render-'
_PREFIX_ECHO_SIMULATOR = 'echosim-'
_PREFIX_TEST_DATA_GEN = 'datagen-'
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
_PREFIX_SCORE = 'score-'
def __init__(self, test_data_generator_factory, evaluation_score_factory,
ap_wrapper, evaluator, external_vads=None):
if external_vads is None:
external_vads = {}
self._test_data_generator_factory = test_data_generator_factory
self._evaluation_score_factory = evaluation_score_factory
self._audioproc_wrapper = ap_wrapper
self._evaluator = evaluator
self._annotator = annotations.AudioAnnotationsExtractor(
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD |
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO |
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
external_vads
)
def __init__(self,
test_data_generator_factory,
evaluation_score_factory,
ap_wrapper,
evaluator,
external_vads=None):
if external_vads is None:
external_vads = {}
self._test_data_generator_factory = test_data_generator_factory
self._evaluation_score_factory = evaluation_score_factory
self._audioproc_wrapper = ap_wrapper
self._evaluator = evaluator
self._annotator = annotations.AudioAnnotationsExtractor(
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
external_vads)
# Init.
self._test_data_generator_factory.SetOutputDirectoryPrefix(
self._PREFIX_TEST_DATA_GEN_PARAMS)
self._evaluation_score_factory.SetScoreFilenamePrefix(
self._PREFIX_SCORE)
# Init.
self._test_data_generator_factory.SetOutputDirectoryPrefix(
self._PREFIX_TEST_DATA_GEN_PARAMS)
self._evaluation_score_factory.SetScoreFilenamePrefix(
self._PREFIX_SCORE)
# Properties for each run.
self._base_output_path = None
self._output_cache_path = None
self._test_data_generators = None
self._evaluation_score_workers = None
self._config_filepaths = None
self._capture_input_filepaths = None
self._render_input_filepaths = None
self._echo_path_simulator_class = None
# Properties for each run.
self._base_output_path = None
self._output_cache_path = None
self._test_data_generators = None
self._evaluation_score_workers = None
self._config_filepaths = None
self._capture_input_filepaths = None
self._render_input_filepaths = None
self._echo_path_simulator_class = None
@classmethod
def GetPrefixApmConfig(cls):
return cls._PREFIX_APM_CONFIG
@classmethod
def GetPrefixApmConfig(cls):
return cls._PREFIX_APM_CONFIG
@classmethod
def GetPrefixCapture(cls):
return cls._PREFIX_CAPTURE
@classmethod
def GetPrefixCapture(cls):
return cls._PREFIX_CAPTURE
@classmethod
def GetPrefixRender(cls):
return cls._PREFIX_RENDER
@classmethod
def GetPrefixRender(cls):
return cls._PREFIX_RENDER
@classmethod
def GetPrefixEchoSimulator(cls):
return cls._PREFIX_ECHO_SIMULATOR
@classmethod
def GetPrefixEchoSimulator(cls):
return cls._PREFIX_ECHO_SIMULATOR
@classmethod
def GetPrefixTestDataGenerator(cls):
return cls._PREFIX_TEST_DATA_GEN
@classmethod
def GetPrefixTestDataGenerator(cls):
return cls._PREFIX_TEST_DATA_GEN
@classmethod
def GetPrefixTestDataGeneratorParameters(cls):
return cls._PREFIX_TEST_DATA_GEN_PARAMS
@classmethod
def GetPrefixTestDataGeneratorParameters(cls):
return cls._PREFIX_TEST_DATA_GEN_PARAMS
@classmethod
def GetPrefixScore(cls):
return cls._PREFIX_SCORE
@classmethod
def GetPrefixScore(cls):
return cls._PREFIX_SCORE
def Run(self, config_filepaths, capture_input_filepaths,
test_data_generator_names, eval_score_names, output_dir,
render_input_filepaths=None, echo_path_simulator_name=(
echo_path_simulation.NoEchoPathSimulator.NAME)):
"""Runs the APM simulation.
def Run(self,
config_filepaths,
capture_input_filepaths,
test_data_generator_names,
eval_score_names,
output_dir,
render_input_filepaths=None,
echo_path_simulator_name=(
echo_path_simulation.NoEchoPathSimulator.NAME)):
"""Runs the APM simulation.
Initializes paths and required instances, then runs all the simulations.
The render input can be optionally added. If added, the number of capture
@ -120,132 +127,140 @@ class ApmModuleSimulator(object):
echo_path_simulator_name: name of the echo path simulator to use when
render input is provided.
"""
assert render_input_filepaths is None or (
len(capture_input_filepaths) == len(render_input_filepaths)), (
'render input set size not matching input set size')
assert render_input_filepaths is None or echo_path_simulator_name in (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
'invalid echo path simulator')
self._base_output_path = os.path.abspath(output_dir)
assert render_input_filepaths is None or (
len(capture_input_filepaths) == len(render_input_filepaths)), (
'render input set size not matching input set size')
assert render_input_filepaths is None or echo_path_simulator_name in (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
'invalid echo path simulator')
self._base_output_path = os.path.abspath(output_dir)
# Output path used to cache the data shared across simulations.
self._output_cache_path = os.path.join(self._base_output_path, '_cache')
# Output path used to cache the data shared across simulations.
self._output_cache_path = os.path.join(self._base_output_path,
'_cache')
# Instance test data generators.
self._test_data_generators = [self._test_data_generator_factory.GetInstance(
test_data_generators_class=(
self._TEST_DATA_GENERATOR_CLASSES[name])) for name in (
test_data_generator_names)]
# Instance test data generators.
self._test_data_generators = [
self._test_data_generator_factory.GetInstance(
test_data_generators_class=(
self._TEST_DATA_GENERATOR_CLASSES[name]))
for name in (test_data_generator_names)
]
# Instance evaluation score workers.
self._evaluation_score_workers = [
self._evaluation_score_factory.GetInstance(
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name]) for (
name) in eval_score_names]
# Instance evaluation score workers.
self._evaluation_score_workers = [
self._evaluation_score_factory.GetInstance(
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name])
for (name) in eval_score_names
]
# Set APM configuration file paths.
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
# Set APM configuration file paths.
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
# Set probing signal file paths.
if render_input_filepaths is None:
# Capture input only.
self._capture_input_filepaths = self._CreatePathsCollection(
capture_input_filepaths)
self._render_input_filepaths = None
else:
# Set both capture and render input signals.
self._SetTestInputSignalFilePaths(
capture_input_filepaths, render_input_filepaths)
# Set probing signal file paths.
if render_input_filepaths is None:
# Capture input only.
self._capture_input_filepaths = self._CreatePathsCollection(
capture_input_filepaths)
self._render_input_filepaths = None
else:
# Set both capture and render input signals.
self._SetTestInputSignalFilePaths(capture_input_filepaths,
render_input_filepaths)
# Set the echo path simulator class.
self._echo_path_simulator_class = (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES[
echo_path_simulator_name])
# Set the echo path simulator class.
self._echo_path_simulator_class = (
echo_path_simulation.EchoPathSimulator.
REGISTERED_CLASSES[echo_path_simulator_name])
self._SimulateAll()
self._SimulateAll()
def _SimulateAll(self):
"""Runs all the simulations.
def _SimulateAll(self):
"""Runs all the simulations.
Iterates over the combinations of APM configurations, probing signals, and
test data generators. This method is mainly responsible for the creation of
the cache and output directories required in order to call _Simulate().
"""
without_render_input = self._render_input_filepaths is None
without_render_input = self._render_input_filepaths is None
# Try different APM config files.
for config_name in self._config_filepaths:
config_filepath = self._config_filepaths[config_name]
# Try different APM config files.
for config_name in self._config_filepaths:
config_filepath = self._config_filepaths[config_name]
# Try different capture-render pairs.
for capture_input_name in self._capture_input_filepaths:
# Output path for the capture signal annotations.
capture_annotations_cache_path = os.path.join(
self._output_cache_path,
self._PREFIX_CAPTURE + capture_input_name)
data_access.MakeDirectory(capture_annotations_cache_path)
# Try different capture-render pairs.
for capture_input_name in self._capture_input_filepaths:
# Output path for the capture signal annotations.
capture_annotations_cache_path = os.path.join(
self._output_cache_path,
self._PREFIX_CAPTURE + capture_input_name)
data_access.MakeDirectory(capture_annotations_cache_path)
# Capture.
capture_input_filepath = self._capture_input_filepaths[
capture_input_name]
if not os.path.exists(capture_input_filepath):
# If the input signal file does not exist, try to create using the
# available input signal creators.
self._CreateInputSignal(capture_input_filepath)
assert os.path.exists(capture_input_filepath)
self._ExtractCaptureAnnotations(
capture_input_filepath, capture_annotations_cache_path)
# Capture.
capture_input_filepath = self._capture_input_filepaths[
capture_input_name]
if not os.path.exists(capture_input_filepath):
# If the input signal file does not exist, try to create using the
# available input signal creators.
self._CreateInputSignal(capture_input_filepath)
assert os.path.exists(capture_input_filepath)
self._ExtractCaptureAnnotations(
capture_input_filepath, capture_annotations_cache_path)
# Render and simulated echo path (optional).
render_input_filepath = None if without_render_input else (
self._render_input_filepaths[capture_input_name])
render_input_name = '(none)' if without_render_input else (
self._ExtractFileName(render_input_filepath))
echo_path_simulator = (
echo_path_simulation_factory.EchoPathSimulatorFactory.GetInstance(
self._echo_path_simulator_class, render_input_filepath))
# Render and simulated echo path (optional).
render_input_filepath = None if without_render_input else (
self._render_input_filepaths[capture_input_name])
render_input_name = '(none)' if without_render_input else (
self._ExtractFileName(render_input_filepath))
echo_path_simulator = (echo_path_simulation_factory.
EchoPathSimulatorFactory.GetInstance(
self._echo_path_simulator_class,
render_input_filepath))
# Try different test data generators.
for test_data_generators in self._test_data_generators:
logging.info('APM config preset: <%s>, capture: <%s>, render: <%s>,'
'test data generator: <%s>, echo simulator: <%s>',
config_name, capture_input_name, render_input_name,
test_data_generators.NAME, echo_path_simulator.NAME)
# Try different test data generators.
for test_data_generators in self._test_data_generators:
logging.info(
'APM config preset: <%s>, capture: <%s>, render: <%s>,'
'test data generator: <%s>, echo simulator: <%s>',
config_name, capture_input_name, render_input_name,
test_data_generators.NAME, echo_path_simulator.NAME)
# Output path for the generated test data.
test_data_cache_path = os.path.join(
capture_annotations_cache_path,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(test_data_cache_path)
logging.debug('test data cache path: <%s>', test_data_cache_path)
# Output path for the generated test data.
test_data_cache_path = os.path.join(
capture_annotations_cache_path,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(test_data_cache_path)
logging.debug('test data cache path: <%s>',
test_data_cache_path)
# Output path for the echo simulator and APM input mixer output.
echo_test_data_cache_path = os.path.join(
test_data_cache_path, 'echosim-{}'.format(
echo_path_simulator.NAME))
data_access.MakeDirectory(echo_test_data_cache_path)
logging.debug('echo test data cache path: <%s>',
echo_test_data_cache_path)
# Output path for the echo simulator and APM input mixer output.
echo_test_data_cache_path = os.path.join(
test_data_cache_path,
'echosim-{}'.format(echo_path_simulator.NAME))
data_access.MakeDirectory(echo_test_data_cache_path)
logging.debug('echo test data cache path: <%s>',
echo_test_data_cache_path)
# Full output path.
output_path = os.path.join(
self._base_output_path,
self._PREFIX_APM_CONFIG + config_name,
self._PREFIX_CAPTURE + capture_input_name,
self._PREFIX_RENDER + render_input_name,
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(output_path)
logging.debug('output path: <%s>', output_path)
# Full output path.
output_path = os.path.join(
self._base_output_path,
self._PREFIX_APM_CONFIG + config_name,
self._PREFIX_CAPTURE + capture_input_name,
self._PREFIX_RENDER + render_input_name,
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(output_path)
logging.debug('output path: <%s>', output_path)
self._Simulate(test_data_generators, capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path,
config_filepath, echo_path_simulator)
self._Simulate(test_data_generators,
capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path,
config_filepath, echo_path_simulator)
@staticmethod
def _CreateInputSignal(input_signal_filepath):
"""Creates a missing input signal file.
@staticmethod
def _CreateInputSignal(input_signal_filepath):
"""Creates a missing input signal file.
The file name is parsed to extract input signal creator and params. If a
creator is matched and the parameters are valid, a new signal is generated
@ -257,30 +272,33 @@ class ApmModuleSimulator(object):
Raises:
InputSignalCreatorException
"""
filename = os.path.splitext(os.path.split(input_signal_filepath)[-1])[0]
filename_parts = filename.split('-')
filename = os.path.splitext(
os.path.split(input_signal_filepath)[-1])[0]
filename_parts = filename.split('-')
if len(filename_parts) < 2:
raise exceptions.InputSignalCreatorException(
'Cannot parse input signal file name')
if len(filename_parts) < 2:
raise exceptions.InputSignalCreatorException(
'Cannot parse input signal file name')
signal, metadata = input_signal_creator.InputSignalCreator.Create(
filename_parts[0], filename_parts[1].split('_'))
signal, metadata = input_signal_creator.InputSignalCreator.Create(
filename_parts[0], filename_parts[1].split('_'))
signal_processing.SignalProcessingUtils.SaveWav(
input_signal_filepath, signal)
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
signal_processing.SignalProcessingUtils.SaveWav(
input_signal_filepath, signal)
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
def _ExtractCaptureAnnotations(self, input_filepath, output_path,
annotation_name=""):
self._annotator.Extract(input_filepath)
self._annotator.Save(output_path, annotation_name)
def _ExtractCaptureAnnotations(self,
input_filepath,
output_path,
annotation_name=""):
self._annotator.Extract(input_filepath)
self._annotator.Save(output_path, annotation_name)
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path, config_filepath,
echo_path_simulator):
"""Runs a single set of simulation.
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path, config_filepath,
echo_path_simulator):
"""Runs a single set of simulation.
Simulates a given combination of APM configuration, probing signal, and
test data generator. It iterates over the test data generator
@ -298,90 +316,92 @@ class ApmModuleSimulator(object):
config_filepath: APM configuration file to test.
echo_path_simulator: EchoPathSimulator instance.
"""
# Generate pairs of noisy input and reference signal files.
test_data_generators.Generate(
input_signal_filepath=clean_capture_input_filepath,
test_data_cache_path=test_data_cache_path,
base_output_path=output_path)
# Generate pairs of noisy input and reference signal files.
test_data_generators.Generate(
input_signal_filepath=clean_capture_input_filepath,
test_data_cache_path=test_data_cache_path,
base_output_path=output_path)
# Extract metadata linked to the clean input file (if any).
apm_input_metadata = None
try:
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
clean_capture_input_filepath)
except IOError as e:
apm_input_metadata = {}
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
apm_input_metadata['test_data_gen_config'] = None
# Extract metadata linked to the clean input file (if any).
apm_input_metadata = None
try:
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
clean_capture_input_filepath)
except IOError as e:
apm_input_metadata = {}
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
apm_input_metadata['test_data_gen_config'] = None
# For each test data pair, simulate a call and evaluate.
for config_name in test_data_generators.config_names:
logging.info(' - test data generator config: <%s>', config_name)
apm_input_metadata['test_data_gen_config'] = config_name
# For each test data pair, simulate a call and evaluate.
for config_name in test_data_generators.config_names:
logging.info(' - test data generator config: <%s>', config_name)
apm_input_metadata['test_data_gen_config'] = config_name
# Paths to the test data generator output.
# Note that the reference signal does not depend on the render input
# which is optional.
noisy_capture_input_filepath = (
test_data_generators.noisy_signal_filepaths[config_name])
reference_signal_filepath = (
test_data_generators.reference_signal_filepaths[config_name])
# Paths to the test data generator output.
# Note that the reference signal does not depend on the render input
# which is optional.
noisy_capture_input_filepath = (
test_data_generators.noisy_signal_filepaths[config_name])
reference_signal_filepath = (
test_data_generators.reference_signal_filepaths[config_name])
# Output path for the evaluation (e.g., APM output file).
evaluation_output_path = test_data_generators.apm_output_paths[
config_name]
# Output path for the evaluation (e.g., APM output file).
evaluation_output_path = test_data_generators.apm_output_paths[
config_name]
# Paths to the APM input signals.
echo_path_filepath = echo_path_simulator.Simulate(
echo_test_data_cache_path)
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
echo_test_data_cache_path, noisy_capture_input_filepath,
echo_path_filepath)
# Paths to the APM input signals.
echo_path_filepath = echo_path_simulator.Simulate(
echo_test_data_cache_path)
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
echo_test_data_cache_path, noisy_capture_input_filepath,
echo_path_filepath)
# Extract annotations for the APM input mix.
apm_input_basepath, apm_input_filename = os.path.split(
apm_input_filepath)
self._ExtractCaptureAnnotations(
apm_input_filepath, apm_input_basepath,
os.path.splitext(apm_input_filename)[0] + '-')
# Extract annotations for the APM input mix.
apm_input_basepath, apm_input_filename = os.path.split(
apm_input_filepath)
self._ExtractCaptureAnnotations(
apm_input_filepath, apm_input_basepath,
os.path.splitext(apm_input_filename)[0] + '-')
# Simulate a call using APM.
self._audioproc_wrapper.Run(
config_filepath=config_filepath,
capture_input_filepath=apm_input_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path)
# Simulate a call using APM.
self._audioproc_wrapper.Run(
config_filepath=config_filepath,
capture_input_filepath=apm_input_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path)
try:
# Evaluate.
self._evaluator.Run(
evaluation_score_workers=self._evaluation_score_workers,
apm_input_metadata=apm_input_metadata,
apm_output_filepath=self._audioproc_wrapper.output_filepath,
reference_input_filepath=reference_signal_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path,
)
try:
# Evaluate.
self._evaluator.Run(
evaluation_score_workers=self._evaluation_score_workers,
apm_input_metadata=apm_input_metadata,
apm_output_filepath=self._audioproc_wrapper.
output_filepath,
reference_input_filepath=reference_signal_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path,
)
# Save simulation metadata.
data_access.Metadata.SaveAudioTestDataPaths(
output_path=evaluation_output_path,
clean_capture_input_filepath=clean_capture_input_filepath,
echo_free_capture_filepath=noisy_capture_input_filepath,
echo_filepath=echo_path_filepath,
render_filepath=render_input_filepath,
capture_filepath=apm_input_filepath,
apm_output_filepath=self._audioproc_wrapper.output_filepath,
apm_reference_filepath=reference_signal_filepath,
apm_config_filepath=config_filepath,
)
except exceptions.EvaluationScoreException as e:
logging.warning('the evaluation failed: %s', e.message)
continue
# Save simulation metadata.
data_access.Metadata.SaveAudioTestDataPaths(
output_path=evaluation_output_path,
clean_capture_input_filepath=clean_capture_input_filepath,
echo_free_capture_filepath=noisy_capture_input_filepath,
echo_filepath=echo_path_filepath,
render_filepath=render_input_filepath,
capture_filepath=apm_input_filepath,
apm_output_filepath=self._audioproc_wrapper.
output_filepath,
apm_reference_filepath=reference_signal_filepath,
apm_config_filepath=config_filepath,
)
except exceptions.EvaluationScoreException as e:
logging.warning('the evaluation failed: %s', e.message)
continue
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
render_input_filepaths):
"""Sets input and render input file paths collections.
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
render_input_filepaths):
"""Sets input and render input file paths collections.
Pairs the input and render input files by storing the file paths into two
collections. The key is the file name of the input file.
@ -390,20 +410,20 @@ class ApmModuleSimulator(object):
capture_input_filepaths: list of file paths.
render_input_filepaths: list of file paths.
"""
self._capture_input_filepaths = {}
self._render_input_filepaths = {}
assert len(capture_input_filepaths) == len(render_input_filepaths)
for capture_input_filepath, render_input_filepath in zip(
capture_input_filepaths, render_input_filepaths):
name = self._ExtractFileName(capture_input_filepath)
self._capture_input_filepaths[name] = os.path.abspath(
capture_input_filepath)
self._render_input_filepaths[name] = os.path.abspath(
render_input_filepath)
self._capture_input_filepaths = {}
self._render_input_filepaths = {}
assert len(capture_input_filepaths) == len(render_input_filepaths)
for capture_input_filepath, render_input_filepath in zip(
capture_input_filepaths, render_input_filepaths):
name = self._ExtractFileName(capture_input_filepath)
self._capture_input_filepaths[name] = os.path.abspath(
capture_input_filepath)
self._render_input_filepaths[name] = os.path.abspath(
render_input_filepath)
@classmethod
def _CreatePathsCollection(cls, filepaths):
"""Creates a collection of file paths.
@classmethod
def _CreatePathsCollection(cls, filepaths):
"""Creates a collection of file paths.
Given a list of file paths, makes a collection with one item for each file
path. The value is absolute path, the key is the file name without
@ -415,12 +435,12 @@ class ApmModuleSimulator(object):
Returns:
A dict.
"""
filepaths_collection = {}
for filepath in filepaths:
name = cls._ExtractFileName(filepath)
filepaths_collection[name] = os.path.abspath(filepath)
return filepaths_collection
filepaths_collection = {}
for filepath in filepaths:
name = cls._ExtractFileName(filepath)
filepaths_collection[name] = os.path.abspath(filepath)
return filepaths_collection
@classmethod
def _ExtractFileName(cls, filepath):
return os.path.splitext(os.path.split(filepath)[-1])[0]
@classmethod
def _ExtractFileName(cls, filepath):
return os.path.splitext(os.path.split(filepath)[-1])[0]

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the simulation module.
"""
@ -28,177 +27,177 @@ from . import test_data_generation_factory
class TestApmModuleSimulator(unittest.TestCase):
"""Unit tests for the ApmModuleSimulator class.
"""Unit tests for the ApmModuleSimulator class.
"""
def setUp(self):
"""Create temporary folders and fake audio track."""
self._output_path = tempfile.mkdtemp()
self._tmp_path = tempfile.mkdtemp()
def setUp(self):
"""Create temporary folders and fake audio track."""
self._output_path = tempfile.mkdtemp()
self._tmp_path = tempfile.mkdtemp()
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_audio_track_path, fake_signal)
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._fake_audio_track_path = os.path.join(self._output_path,
'fake.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_audio_track_path, fake_signal)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._output_path)
shutil.rmtree(self._tmp_path)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._output_path)
shutil.rmtree(self._tmp_path)
def testSimulation(self):
# Instance dependencies to mock and inject.
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
evaluator = evaluation.ApmModuleEvaluator()
ap_wrapper.Run = mock.MagicMock(name='Run')
evaluator.Run = mock.MagicMock(name='Run')
def testSimulation(self):
# Instance dependencies to mock and inject.
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
evaluator = evaluation.ApmModuleEvaluator()
ap_wrapper.Run = mock.MagicMock(name='Run')
evaluator.Run = mock.MagicMock(name='Run')
# Instance non-mocked dependencies.
test_data_generator_factory = (
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False))
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(__file__), 'fake_polqa'),
echo_metric_tool_bin_path=None
)
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=test_data_generator_factory,
evaluation_score_factory=evaluation_score_factory,
ap_wrapper=ap_wrapper,
evaluator=evaluator,
external_vads={'fake': external_vad.ExternalVad(os.path.join(
os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')}
)
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [self._fake_audio_track_path]
test_data_generators = ['identity', 'white_noise']
eval_scores = ['audio_level_mean', 'polqa']
# Run all simulations.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=test_data_generators,
eval_score_names=eval_scores,
output_dir=self._output_path)
# Check.
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
# the client code (e.g., number of SNR pairs for the white noise test data
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
# is known; use that with assertEqual.
min_number_of_simulations = len(config_files) * len(input_files) * len(
test_data_generators)
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
min_number_of_simulations)
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
min_number_of_simulations)
def testInputSignalCreation(self):
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
# Instance non-mocked dependencies.
test_data_generator_factory = (
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(__file__), 'fake_polqa'),
echo_metric_tool_bin_path=None
)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
copy_with_identity=False))
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)
# Inexistent input files to be silently created.
input_files = [
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
]
self.assertFalse(any([os.path.exists(input_file) for input_file in (
input_files)]))
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=test_data_generator_factory,
evaluation_score_factory=evaluation_score_factory,
ap_wrapper=ap_wrapper,
evaluator=evaluator,
external_vads={
'fake':
external_vad.ExternalVad(
os.path.join(os.path.dirname(__file__),
'fake_external_vad.py'), 'fake')
})
# The input files are created during the simulation.
simulator.Run(
config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=['audio_level_peak'],
output_dir=self._output_path)
self.assertTrue(all([os.path.exists(input_file) for input_file in (
input_files)]))
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [self._fake_audio_track_path]
test_data_generators = ['identity', 'white_noise']
eval_scores = ['audio_level_mean', 'polqa']
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
logging.warning = mock.MagicMock(name='warning')
# Run all simulations.
simulator.Run(config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=test_data_generators,
eval_score_names=eval_scores,
output_dir=self._output_path)
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(__file__), 'fake_polqa'),
echo_metric_tool_bin_path=None
)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# Check.
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
# the client code (e.g., number of SNR pairs for the white noise test data
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
# is known; use that with assertEqual.
min_number_of_simulations = len(config_files) * len(input_files) * len(
test_data_generators)
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
min_number_of_simulations)
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
min_number_of_simulations)
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
eval_scores = ['thd']
def testInputSignalCreation(self):
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# Should work.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=eval_scores,
output_dir=self._output_path)
self.assertFalse(logging.warning.called)
# Inexistent input files to be silently created.
input_files = [
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
]
self.assertFalse(
any([os.path.exists(input_file) for input_file in (input_files)]))
# Warning expected.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['white_noise'], # Not allowed with THD.
eval_score_names=eval_scores,
output_dir=self._output_path)
logging.warning.assert_called_with('the evaluation failed: %s', (
'The THD score cannot be used with any test data generator other than '
'"identity"'))
# The input files are created during the simulation.
simulator.Run(config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=['audio_level_peak'],
output_dir=self._output_path)
self.assertTrue(
all([os.path.exists(input_file) for input_file in (input_files)]))
# # Init.
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
# input_signal_filepath = os.path.join(
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
logging.warning = mock.MagicMock(name='warning')
# # Check that the input signal is generated.
# self.assertFalse(os.path.exists(input_signal_filepath))
# generator.Generate(
# input_signal_filepath=input_signal_filepath,
# test_data_cache_path=self._test_data_cache_path,
# base_output_path=self._base_output_path)
# self.assertTrue(os.path.exists(input_signal_filepath))
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# # Check input signal properties.
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
# input_signal_filepath)
# self.assertEqual(1000, len(input_signal))
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
eval_scores = ['thd']
# Should work.
simulator.Run(config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=eval_scores,
output_dir=self._output_path)
self.assertFalse(logging.warning.called)
# Warning expected.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['white_noise'], # Not allowed with THD.
eval_score_names=eval_scores,
output_dir=self._output_path)
logging.warning.assert_called_with('the evaluation failed: %s', (
'The THD score cannot be used with any test data generator other than '
'"identity"'))
# # Init.
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
# input_signal_filepath = os.path.join(
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
# # Check that the input signal is generated.
# self.assertFalse(os.path.exists(input_signal_filepath))
# generator.Generate(
# input_signal_filepath=input_signal_filepath,
# test_data_cache_path=self._test_data_cache_path,
# base_output_path=self._base_output_path)
# self.assertTrue(os.path.exists(input_signal_filepath))
# # Check input signal properties.
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
# input_signal_filepath)
# self.assertEqual(1000, len(input_signal))

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Test data generators producing signals pairs intended to be used to
test the APM module. Each pair consists of a noisy input and a reference signal.
The former is used as APM input and it is generated by adding noise to a
@ -27,10 +26,10 @@ import shutil
import sys
try:
import scipy.io
import scipy.io
except ImportError:
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
from . import data_access
from . import exceptions
@ -38,7 +37,7 @@ from . import signal_processing
class TestDataGenerator(object):
"""Abstract class responsible for the generation of noisy signals.
"""Abstract class responsible for the generation of noisy signals.
Given a clean signal, it generates two streams named noisy signal and
reference. The former is the clean signal deteriorated by the noise source,
@ -50,24 +49,24 @@ class TestDataGenerator(object):
An test data generator generates one or more pairs.
"""
NAME = None
REGISTERED_CLASSES = {}
NAME = None
REGISTERED_CLASSES = {}
def __init__(self, output_directory_prefix):
self._output_directory_prefix = output_directory_prefix
# Init dictionaries with one entry for each test data generator
# configuration (e.g., different SNRs).
# Noisy audio track files (stored separately in a cache folder).
self._noisy_signal_filepaths = None
# Path to be used for the APM simulation output files.
self._apm_output_paths = None
# Reference audio track files (stored separately in a cache folder).
self._reference_signal_filepaths = None
self.Clear()
def __init__(self, output_directory_prefix):
self._output_directory_prefix = output_directory_prefix
# Init dictionaries with one entry for each test data generator
# configuration (e.g., different SNRs).
# Noisy audio track files (stored separately in a cache folder).
self._noisy_signal_filepaths = None
# Path to be used for the APM simulation output files.
self._apm_output_paths = None
# Reference audio track files (stored separately in a cache folder).
self._reference_signal_filepaths = None
self.Clear()
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers a TestDataGenerator implementation.
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers a TestDataGenerator implementation.
Decorator to automatically register the classes that extend
TestDataGenerator.
@ -77,28 +76,28 @@ class TestDataGenerator(object):
class IdentityGenerator(TestDataGenerator):
pass
"""
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
@property
def config_names(self):
return self._noisy_signal_filepaths.keys()
@property
def config_names(self):
return self._noisy_signal_filepaths.keys()
@property
def noisy_signal_filepaths(self):
return self._noisy_signal_filepaths
@property
def noisy_signal_filepaths(self):
return self._noisy_signal_filepaths
@property
def apm_output_paths(self):
return self._apm_output_paths
@property
def apm_output_paths(self):
return self._apm_output_paths
@property
def reference_signal_filepaths(self):
return self._reference_signal_filepaths
@property
def reference_signal_filepaths(self):
return self._reference_signal_filepaths
def Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
"""Generates a set of noisy input and reference audiotrack file pairs.
def Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Generates a set of noisy input and reference audiotrack file pairs.
This method initializes an empty set of pairs and calls the _Generate()
method implemented in a concrete class.
@ -109,26 +108,26 @@ class TestDataGenerator(object):
files.
base_output_path: base path where output is written.
"""
self.Clear()
self._Generate(
input_signal_filepath, test_data_cache_path, base_output_path)
self.Clear()
self._Generate(input_signal_filepath, test_data_cache_path,
base_output_path)
def Clear(self):
"""Clears the generated output path dictionaries.
def Clear(self):
"""Clears the generated output path dictionaries.
"""
self._noisy_signal_filepaths = {}
self._apm_output_paths = {}
self._reference_signal_filepaths = {}
self._noisy_signal_filepaths = {}
self._apm_output_paths = {}
self._reference_signal_filepaths = {}
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
"""Abstract method to be implemented in each concrete class.
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Abstract method to be implemented in each concrete class.
"""
raise NotImplementedError()
raise NotImplementedError()
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
snr_value_pairs):
"""Adds noisy-reference signal pairs.
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
snr_value_pairs):
"""Adds noisy-reference signal pairs.
Args:
base_output_path: noisy tracks base output path.
@ -136,22 +135,22 @@ class TestDataGenerator(object):
by noisy track name and SNR level.
snr_value_pairs: list of SNR pairs.
"""
for noise_track_name in noisy_mix_filepaths:
for snr_noisy, snr_refence in snr_value_pairs:
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
noise_track_name, snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[
noise_track_name][snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[
noise_track_name][snr_refence],
output_path=output_path)
for noise_track_name in noisy_mix_filepaths:
for snr_noisy, snr_refence in snr_value_pairs:
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
noise_track_name, snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[noise_track_name]
[snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[
noise_track_name][snr_refence],
output_path=output_path)
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
reference_signal_filepath, output_path):
"""Adds one noisy-reference signal pair.
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
reference_signal_filepath, output_path):
"""Adds one noisy-reference signal pair.
Args:
config_name: name of the APM configuration.
@ -159,264 +158,275 @@ class TestDataGenerator(object):
reference_signal_filepath: path to reference audio track file.
output_path: APM output path.
"""
assert config_name not in self._noisy_signal_filepaths
self._noisy_signal_filepaths[config_name] = os.path.abspath(
noisy_signal_filepath)
self._apm_output_paths[config_name] = os.path.abspath(output_path)
self._reference_signal_filepaths[config_name] = os.path.abspath(
reference_signal_filepath)
assert config_name not in self._noisy_signal_filepaths
self._noisy_signal_filepaths[config_name] = os.path.abspath(
noisy_signal_filepath)
self._apm_output_paths[config_name] = os.path.abspath(output_path)
self._reference_signal_filepaths[config_name] = os.path.abspath(
reference_signal_filepath)
def _MakeDir(self, base_output_path, test_data_generator_config_name):
output_path = os.path.join(
base_output_path,
self._output_directory_prefix + test_data_generator_config_name)
data_access.MakeDirectory(output_path)
return output_path
def _MakeDir(self, base_output_path, test_data_generator_config_name):
output_path = os.path.join(
base_output_path,
self._output_directory_prefix + test_data_generator_config_name)
data_access.MakeDirectory(output_path)
return output_path
@TestDataGenerator.RegisterClass
class IdentityTestDataGenerator(TestDataGenerator):
"""Generator that adds no noise.
"""Generator that adds no noise.
Both the noisy and the reference signals are the input signal.
"""
NAME = 'identity'
NAME = 'identity'
def __init__(self, output_directory_prefix, copy_with_identity):
TestDataGenerator.__init__(self, output_directory_prefix)
self._copy_with_identity = copy_with_identity
def __init__(self, output_directory_prefix, copy_with_identity):
TestDataGenerator.__init__(self, output_directory_prefix)
self._copy_with_identity = copy_with_identity
@property
def copy_with_identity(self):
return self._copy_with_identity
@property
def copy_with_identity(self):
return self._copy_with_identity
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
config_name = 'default'
output_path = self._MakeDir(base_output_path, config_name)
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
config_name = 'default'
output_path = self._MakeDir(base_output_path, config_name)
if self._copy_with_identity:
input_signal_filepath_new = os.path.join(
test_data_cache_path, os.path.split(input_signal_filepath)[1])
logging.info('copying ' + input_signal_filepath + ' to ' + (
input_signal_filepath_new))
shutil.copy(input_signal_filepath, input_signal_filepath_new)
input_signal_filepath = input_signal_filepath_new
if self._copy_with_identity:
input_signal_filepath_new = os.path.join(
test_data_cache_path,
os.path.split(input_signal_filepath)[1])
logging.info('copying ' + input_signal_filepath + ' to ' +
(input_signal_filepath_new))
shutil.copy(input_signal_filepath, input_signal_filepath_new)
input_signal_filepath = input_signal_filepath_new
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=input_signal_filepath,
reference_signal_filepath=input_signal_filepath,
output_path=output_path)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=input_signal_filepath,
reference_signal_filepath=input_signal_filepath,
output_path=output_path)
@TestDataGenerator.RegisterClass
class WhiteNoiseTestDataGenerator(TestDataGenerator):
"""Generator that adds white noise.
"""Generator that adds white noise.
"""
NAME = 'white_noise'
NAME = 'white_noise'
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Create the noise track.
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
input_signal)
# Create the noise track.
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
input_signal)
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths = {}
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths = {}
snr_values = set(
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Add file to the collection of mixes.
noisy_mix_filepaths[snr] = noisy_signal_filepath
# Add file to the collection of mixes.
noisy_mix_filepaths[snr] = noisy_signal_filepath
# Add all the noisy-reference signal pairs.
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
output_path=output_path)
# Add all the noisy-reference signal pairs.
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
output_path=output_path)
# TODO(alessiob): remove comment when class implemented.
# @TestDataGenerator.RegisterClass
class NarrowBandNoiseTestDataGenerator(TestDataGenerator):
"""Generator that adds narrow-band noise.
"""Generator that adds narrow-band noise.
"""
NAME = 'narrow_band_noise'
NAME = 'narrow_band_noise'
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
# TODO(alessiob): implement.
pass
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
# TODO(alessiob): implement.
pass
@TestDataGenerator.RegisterClass
class AdditiveNoiseTestDataGenerator(TestDataGenerator):
"""Generator that adds noise loops.
"""Generator that adds noise loops.
This generator uses all the wav files in a given path (default: noise_tracks/)
and mixes them to the clean speech with different target SNRs (hard-coded).
"""
NAME = 'additive_noise'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
NAME = 'additive_noise'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
DEFAULT_NOISE_TRACKS_PATH = os.path.join(
os.path.dirname(__file__), os.pardir, 'noise_tracks')
DEFAULT_NOISE_TRACKS_PATH = os.path.join(os.path.dirname(__file__),
os.pardir, 'noise_tracks')
# TODO(alessiob): Make the list of SNR pairs customizable.
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
# TODO(alessiob): Make the list of SNR pairs customizable.
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
def __init__(self, output_directory_prefix, noise_tracks_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._noise_tracks_path = noise_tracks_path
self._noise_tracks_file_names = [n for n in os.listdir(
self._noise_tracks_path) if n.lower().endswith('.wav')]
if len(self._noise_tracks_file_names) == 0:
raise exceptions.InitializationException(
'No wav files found in the noise tracks path %s' % (
self._noise_tracks_path))
def __init__(self, output_directory_prefix, noise_tracks_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._noise_tracks_path = noise_tracks_path
self._noise_tracks_file_names = [
n for n in os.listdir(self._noise_tracks_path)
if n.lower().endswith('.wav')
]
if len(self._noise_tracks_file_names) == 0:
raise exceptions.InitializationException(
'No wav files found in the noise tracks path %s' %
(self._noise_tracks_path))
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
"""Generates test data pairs using environmental noise.
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Generates test data pairs using environmental noise.
For each noise track and pair of SNR values, the following two audio tracks
are created: the noisy signal and the reference signal. The former is
obtained by mixing the (clean) input signal to the corresponding noise
track enforcing the target SNR.
"""
# Init.
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Init.
snr_values = set(
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
noisy_mix_filepaths = {}
for noise_track_filename in self._noise_tracks_file_names:
# Load the noise track.
noise_track_name, _ = os.path.splitext(noise_track_filename)
noise_track_filepath = os.path.join(
self._noise_tracks_path, noise_track_filename)
if not os.path.exists(noise_track_filepath):
logging.error('cannot find the <%s> noise track', noise_track_filename)
raise exceptions.FileNotFoundError()
noisy_mix_filepaths = {}
for noise_track_filename in self._noise_tracks_file_names:
# Load the noise track.
noise_track_name, _ = os.path.splitext(noise_track_filename)
noise_track_filepath = os.path.join(self._noise_tracks_path,
noise_track_filename)
if not os.path.exists(noise_track_filepath):
logging.error('cannot find the <%s> noise track',
noise_track_filename)
raise exceptions.FileNotFoundError()
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[noise_track_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(noise_track_name, snr))
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[noise_track_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
noise_track_name, snr))
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal,
noise_signal,
snr,
pad_noise=signal_processing.SignalProcessingUtils.
MixPadding.LOOP)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Add file to the collection of mixes.
noisy_mix_filepaths[noise_track_name][snr] = noisy_signal_filepath
# Add file to the collection of mixes.
noisy_mix_filepaths[noise_track_name][
snr] = noisy_signal_filepath
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(
base_output_path, noisy_mix_filepaths, self._SNR_VALUE_PAIRS)
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
self._SNR_VALUE_PAIRS)
@TestDataGenerator.RegisterClass
class ReverberationTestDataGenerator(TestDataGenerator):
"""Generator that adds reverberation noise.
"""Generator that adds reverberation noise.
TODO(alessiob): Make this class more generic since the impulse response can be
anything (not just reverberation); call it e.g.,
ConvolutionalNoiseTestDataGenerator.
"""
NAME = 'reverberation'
NAME = 'reverberation'
_IMPULSE_RESPONSES = {
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
}
_MAX_IMPULSE_RESPONSE_LENGTH = None
_IMPULSE_RESPONSES = {
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
}
_MAX_IMPULSE_RESPONSE_LENGTH = None
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 5 dB higher.
_SNR_VALUE_PAIRS = [
[3, 8], # Smallest noise.
[-3, 2], # Largest noise.
]
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 5 dB higher.
_SNR_VALUE_PAIRS = [
[3, 8], # Smallest noise.
[-3, 2], # Largest noise.
]
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
def __init__(self, output_directory_prefix, aechen_ir_database_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._aechen_ir_database_path = aechen_ir_database_path
def __init__(self, output_directory_prefix, aechen_ir_database_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._aechen_ir_database_path = aechen_ir_database_path
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
"""Generates test data pairs using reverberation noise.
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Generates test data pairs using reverberation noise.
For each impulse response, one noise track is created. For each impulse
response and pair of SNR values, the following 2 audio tracks are
@ -424,61 +434,64 @@ class ReverberationTestDataGenerator(TestDataGenerator):
obtained by mixing the (clean) input signal to the corresponding noise
track enforcing the target SNR.
"""
# Init.
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Init.
snr_values = set(
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
noisy_mix_filepaths = {}
for impulse_response_name in self._IMPULSE_RESPONSES:
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
impulse_response_name)
noise_track_filepath = os.path.join(
test_data_cache_path, noise_track_filename)
noise_signal = None
try:
# Load noise track.
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
except exceptions.FileNotFoundError:
# Generate noise track by applying the impulse response.
impulse_response_filepath = os.path.join(
self._aechen_ir_database_path,
self._IMPULSE_RESPONSES[impulse_response_name])
noise_signal = self._GenerateNoiseTrack(
noise_track_filepath, input_signal, impulse_response_filepath)
assert noise_signal is not None
noisy_mix_filepaths = {}
for impulse_response_name in self._IMPULSE_RESPONSES:
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
impulse_response_name)
noise_track_filepath = os.path.join(test_data_cache_path,
noise_track_filename)
noise_signal = None
try:
# Load noise track.
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
except exceptions.FileNotFoundError:
# Generate noise track by applying the impulse response.
impulse_response_filepath = os.path.join(
self._aechen_ir_database_path,
self._IMPULSE_RESPONSES[impulse_response_name])
noise_signal = self._GenerateNoiseTrack(
noise_track_filepath, input_signal,
impulse_response_filepath)
assert noise_signal is not None
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[impulse_response_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
impulse_response_name, snr))
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[impulse_response_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
impulse_response_name, snr))
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Add file to the collection of mixes.
noisy_mix_filepaths[impulse_response_name][snr] = noisy_signal_filepath
# Add file to the collection of mixes.
noisy_mix_filepaths[impulse_response_name][
snr] = noisy_signal_filepath
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
self._SNR_VALUE_PAIRS)
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
self._SNR_VALUE_PAIRS)
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
impulse_response_filepath):
"""Generates noise track.
"""Generates noise track.
Generate a signal by convolving input_signal with the impulse response in
impulse_response_filepath; then save to noise_track_filepath.
@ -491,21 +504,23 @@ class ReverberationTestDataGenerator(TestDataGenerator):
Returns:
AudioSegment instance.
"""
# Load impulse response.
data = scipy.io.loadmat(impulse_response_filepath)
impulse_response = data['h_air'].flatten()
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
logging.info('truncating impulse response from %d to %d samples',
len(impulse_response), self._MAX_IMPULSE_RESPONSE_LENGTH)
impulse_response = impulse_response[:self._MAX_IMPULSE_RESPONSE_LENGTH]
# Load impulse response.
data = scipy.io.loadmat(impulse_response_filepath)
impulse_response = data['h_air'].flatten()
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
logging.info('truncating impulse response from %d to %d samples',
len(impulse_response),
self._MAX_IMPULSE_RESPONSE_LENGTH)
impulse_response = impulse_response[:self.
_MAX_IMPULSE_RESPONSE_LENGTH]
# Apply impulse response.
processed_signal = (
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
input_signal, impulse_response))
# Apply impulse response.
processed_signal = (
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
input_signal, impulse_response))
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noise_track_filepath, processed_signal)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noise_track_filepath, processed_signal)
return processed_signal
return processed_signal

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""TestDataGenerator factory class.
"""
@ -16,15 +15,15 @@ from . import test_data_generation
class TestDataGeneratorFactory(object):
"""Factory class used to create test data generators.
"""Factory class used to create test data generators.
Usage: Create a factory passing parameters to the ctor with which the
generators will be produced.
"""
def __init__(self, aechen_ir_database_path, noise_tracks_path,
copy_with_identity):
"""Ctor.
def __init__(self, aechen_ir_database_path, noise_tracks_path,
copy_with_identity):
"""Ctor.
Args:
aechen_ir_database_path: Path to the Aechen Impulse Response database.
@ -32,16 +31,16 @@ class TestDataGeneratorFactory(object):
copy_with_identity: Flag indicating whether the identity generator has to
make copies of the clean speech input files.
"""
self._output_directory_prefix = None
self._aechen_ir_database_path = aechen_ir_database_path
self._noise_tracks_path = noise_tracks_path
self._copy_with_identity = copy_with_identity
self._output_directory_prefix = None
self._aechen_ir_database_path = aechen_ir_database_path
self._noise_tracks_path = noise_tracks_path
self._copy_with_identity = copy_with_identity
def SetOutputDirectoryPrefix(self, prefix):
self._output_directory_prefix = prefix
def SetOutputDirectoryPrefix(self, prefix):
self._output_directory_prefix = prefix
def GetInstance(self, test_data_generators_class):
"""Creates an TestDataGenerator instance given a class object.
def GetInstance(self, test_data_generators_class):
"""Creates an TestDataGenerator instance given a class object.
Args:
test_data_generators_class: TestDataGenerator class object (not an
@ -50,22 +49,23 @@ class TestDataGeneratorFactory(object):
Returns:
TestDataGenerator instance.
"""
if self._output_directory_prefix is None:
raise exceptions.InitializationException(
'The output directory prefix for test data generators is not set')
logging.debug('factory producing %s', test_data_generators_class)
if self._output_directory_prefix is None:
raise exceptions.InitializationException(
'The output directory prefix for test data generators is not set'
)
logging.debug('factory producing %s', test_data_generators_class)
if test_data_generators_class == (
test_data_generation.IdentityTestDataGenerator):
return test_data_generation.IdentityTestDataGenerator(
self._output_directory_prefix, self._copy_with_identity)
elif test_data_generators_class == (
test_data_generation.ReverberationTestDataGenerator):
return test_data_generation.ReverberationTestDataGenerator(
self._output_directory_prefix, self._aechen_ir_database_path)
elif test_data_generators_class == (
test_data_generation.AdditiveNoiseTestDataGenerator):
return test_data_generation.AdditiveNoiseTestDataGenerator(
self._output_directory_prefix, self._noise_tracks_path)
else:
return test_data_generators_class(self._output_directory_prefix)
if test_data_generators_class == (
test_data_generation.IdentityTestDataGenerator):
return test_data_generation.IdentityTestDataGenerator(
self._output_directory_prefix, self._copy_with_identity)
elif test_data_generators_class == (
test_data_generation.ReverberationTestDataGenerator):
return test_data_generation.ReverberationTestDataGenerator(
self._output_directory_prefix, self._aechen_ir_database_path)
elif test_data_generators_class == (
test_data_generation.AdditiveNoiseTestDataGenerator):
return test_data_generation.AdditiveNoiseTestDataGenerator(
self._output_directory_prefix, self._noise_tracks_path)
else:
return test_data_generators_class(self._output_directory_prefix)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the test_data_generation module.
"""
@ -23,141 +22,143 @@ from . import signal_processing
class TestTestDataGenerators(unittest.TestCase):
"""Unit tests for the test_data_generation module.
"""Unit tests for the test_data_generation module.
"""
def setUp(self):
"""Create temporary folders."""
self._base_output_path = tempfile.mkdtemp()
self._test_data_cache_path = tempfile.mkdtemp()
self._fake_air_db_path = tempfile.mkdtemp()
def setUp(self):
"""Create temporary folders."""
self._base_output_path = tempfile.mkdtemp()
self._test_data_cache_path = tempfile.mkdtemp()
self._fake_air_db_path = tempfile.mkdtemp()
# Fake AIR DB impulse responses.
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
# impulse responses. When changed, the coupling below between
# impulse_response_mat_file_names and
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
impulse_response_mat_file_names = [
'air_binaural_lecture_0_0_1.mat',
'air_binaural_booth_0_0_1.mat',
]
for impulse_response_mat_file_name in impulse_response_mat_file_names:
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
scipy.io.savemat(os.path.join(
self._fake_air_db_path, impulse_response_mat_file_name), data)
# Fake AIR DB impulse responses.
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
# impulse responses. When changed, the coupling below between
# impulse_response_mat_file_names and
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
impulse_response_mat_file_names = [
'air_binaural_lecture_0_0_1.mat',
'air_binaural_booth_0_0_1.mat',
]
for impulse_response_mat_file_name in impulse_response_mat_file_names:
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
scipy.io.savemat(
os.path.join(self._fake_air_db_path,
impulse_response_mat_file_name), data)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._base_output_path)
shutil.rmtree(self._test_data_cache_path)
shutil.rmtree(self._fake_air_db_path)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._base_output_path)
shutil.rmtree(self._test_data_cache_path)
shutil.rmtree(self._fake_air_db_path)
def testTestDataGenerators(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
def testTestDataGenerators(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
# Check that there is at least one registered test data generator.
registered_classes = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Check that there is at least one registered test data generator.
registered_classes = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance generators factory.
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=self._fake_air_db_path,
noise_tracks_path=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH,
copy_with_identity=False)
generators_factory.SetOutputDirectoryPrefix('datagen-')
# Instance generators factory.
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=self._fake_air_db_path,
noise_tracks_path=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH,
copy_with_identity=False)
generators_factory.SetOutputDirectoryPrefix('datagen-')
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(
os.getcwd(), 'probing_signals', 'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
# Load input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Load input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Try each registered test data generator.
for generator_name in registered_classes:
# Instance test data generator.
generator = generators_factory.GetInstance(
registered_classes[generator_name])
# Try each registered test data generator.
for generator_name in registered_classes:
# Instance test data generator.
generator = generators_factory.GetInstance(
registered_classes[generator_name])
# Generate the noisy input - reference pairs.
generator.Generate(
input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
# Generate the noisy input - reference pairs.
generator.Generate(input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
# Perform checks.
self._CheckGeneratedPairsListSizes(generator)
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
self._CheckGeneratedPairsOutputPaths(generator)
# Perform checks.
self._CheckGeneratedPairsListSizes(generator)
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
self._CheckGeneratedPairsOutputPaths(generator)
def testTestidentityDataGenerator(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
def testTestidentityDataGenerator(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(
os.getcwd(), 'probing_signals', 'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
def GetNoiseReferenceFilePaths(identity_generator):
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
reference_signal_filepaths = identity_generator.reference_signal_filepaths
assert noisy_signal_filepaths.keys() == reference_signal_filepaths.keys()
assert len(noisy_signal_filepaths.keys()) == 1
key = noisy_signal_filepaths.keys()[0]
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
def GetNoiseReferenceFilePaths(identity_generator):
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
reference_signal_filepaths = identity_generator.reference_signal_filepaths
assert noisy_signal_filepaths.keys(
) == reference_signal_filepaths.keys()
assert len(noisy_signal_filepaths.keys()) == 1
key = noisy_signal_filepaths.keys()[0]
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
# Test the |copy_with_identity| flag.
for copy_with_identity in [False, True]:
# Instance the generator through the factory.
factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='', noise_tracks_path='',
copy_with_identity=copy_with_identity)
factory.SetOutputDirectoryPrefix('datagen-')
generator = factory.GetInstance(
test_data_generation.IdentityTestDataGenerator)
# Check |copy_with_identity| is set correctly.
self.assertEqual(copy_with_identity, generator.copy_with_identity)
# Test the |copy_with_identity| flag.
for copy_with_identity in [False, True]:
# Instance the generator through the factory.
factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=copy_with_identity)
factory.SetOutputDirectoryPrefix('datagen-')
generator = factory.GetInstance(
test_data_generation.IdentityTestDataGenerator)
# Check |copy_with_identity| is set correctly.
self.assertEqual(copy_with_identity, generator.copy_with_identity)
# Generate test data and extract the paths to the noise and the reference
# files.
generator.Generate(
input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
noisy_signal_filepath, reference_signal_filepath = (
GetNoiseReferenceFilePaths(generator))
# Generate test data and extract the paths to the noise and the reference
# files.
generator.Generate(input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
noisy_signal_filepath, reference_signal_filepath = (
GetNoiseReferenceFilePaths(generator))
# Check that a copy is made if and only if |copy_with_identity| is True.
if copy_with_identity:
self.assertNotEqual(noisy_signal_filepath, input_signal_filepath)
self.assertNotEqual(reference_signal_filepath, input_signal_filepath)
else:
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
self.assertEqual(reference_signal_filepath, input_signal_filepath)
# Check that a copy is made if and only if |copy_with_identity| is True.
if copy_with_identity:
self.assertNotEqual(noisy_signal_filepath,
input_signal_filepath)
self.assertNotEqual(reference_signal_filepath,
input_signal_filepath)
else:
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
self.assertEqual(reference_signal_filepath,
input_signal_filepath)
def _CheckGeneratedPairsListSizes(self, generator):
config_names = generator.config_names
number_of_pairs = len(config_names)
self.assertEqual(number_of_pairs,
len(generator.noisy_signal_filepaths))
self.assertEqual(number_of_pairs,
len(generator.apm_output_paths))
self.assertEqual(number_of_pairs,
len(generator.reference_signal_filepaths))
def _CheckGeneratedPairsListSizes(self, generator):
config_names = generator.config_names
number_of_pairs = len(config_names)
self.assertEqual(number_of_pairs,
len(generator.noisy_signal_filepaths))
self.assertEqual(number_of_pairs, len(generator.apm_output_paths))
self.assertEqual(number_of_pairs,
len(generator.reference_signal_filepaths))
def _CheckGeneratedPairsSignalDurations(
self, generator, input_signal):
"""Checks duration of the generated signals.
def _CheckGeneratedPairsSignalDurations(self, generator, input_signal):
"""Checks duration of the generated signals.
Checks that the noisy input and the reference tracks are audio files
with duration equal to or greater than that of the input signal.
@ -166,41 +167,41 @@ class TestTestDataGenerators(unittest.TestCase):
generator: TestDataGenerator instance.
input_signal: AudioSegment instance.
"""
input_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
input_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
# Load the noisy input file.
noisy_signal_filepath = generator.noisy_signal_filepaths[
config_name]
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
noisy_signal_filepath)
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
# Load the noisy input file.
noisy_signal_filepath = generator.noisy_signal_filepaths[
config_name]
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
noisy_signal_filepath)
# Check noisy input signal length.
noisy_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(noisy_signal))
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
# Check noisy input signal length.
noisy_signal_length = (signal_processing.SignalProcessingUtils.
CountSamples(noisy_signal))
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
# Load the reference file.
reference_signal_filepath = generator.reference_signal_filepaths[
config_name]
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
reference_signal_filepath)
# Load the reference file.
reference_signal_filepath = generator.reference_signal_filepaths[
config_name]
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
reference_signal_filepath)
# Check noisy input signal length.
reference_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(
reference_signal))
self.assertGreaterEqual(reference_signal_length, input_signal_length)
# Check noisy input signal length.
reference_signal_length = (signal_processing.SignalProcessingUtils.
CountSamples(reference_signal))
self.assertGreaterEqual(reference_signal_length,
input_signal_length)
def _CheckGeneratedPairsOutputPaths(self, generator):
"""Checks that the output path created by the generator exists.
def _CheckGeneratedPairsOutputPaths(self, generator):
"""Checks that the output path created by the generator exists.
Args:
generator: TestDataGenerator instance.
"""
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
output_path = generator.apm_output_paths[config_name]
self.assertTrue(os.path.exists(output_path))
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
output_path = generator.apm_output_paths[config_name]
self.assertTrue(os.path.exists(output_path))