Reformat python files checked by pylint (part 1/2).
After recently changing .pylintrc (see [1]) we discovered that the presubmit check always checks all the python files when just one python file gets updated. This CL moves all these files one step closer to what the linter wants. Autogenerated with: # Added all the files under pylint control to ~/Desktop/to-reformat cat ~/Desktop/to-reformat | xargs sed -i '1i\\' git cl format --python --full This is part 1 out of 2. The second part will fix function names and will not be automated. [1] - https://webrtc-review.googlesource.com/c/src/+/186664 No-Presubmit: True Bug: webrtc:12114 Change-Id: Idfec4d759f209a2090440d0af2413a1ddc01b841 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/190980 Commit-Queue: Mirko Bonadei <mbonadei@webrtc.org> Reviewed-by: Karl Wiberg <kwiberg@webrtc.org> Cr-Commit-Position: refs/heads/master@{#32530}
This commit is contained in:
committed by
Commit Bot
parent
d3a3e9ef36
commit
8cc6695652
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Perform APM module quality assessment on one or more input files using one or
|
||||
more APM simulator configuration files and one or more test data generators.
|
||||
|
||||
@ -47,139 +46,172 @@ _POLQA_BIN_NAME = 'PolqaOem64'
|
||||
|
||||
|
||||
def _InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
"""Arguments parser factory.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
'Perform APM module quality assessment on one or more input files using '
|
||||
'one or more APM simulator configuration files and one or more '
|
||||
'test data generators.'))
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
'Perform APM module quality assessment on one or more input files using '
|
||||
'one or more APM simulator configuration files and one or more '
|
||||
'test data generators.'))
|
||||
|
||||
parser.add_argument('-c', '--config_files', nargs='+', required=False,
|
||||
help=('path to the configuration files defining the '
|
||||
'arguments with which the APM simulator tool is '
|
||||
'called'),
|
||||
default=[_DEFAULT_CONFIG_FILE])
|
||||
parser.add_argument('-c',
|
||||
'--config_files',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('path to the configuration files defining the '
|
||||
'arguments with which the APM simulator tool is '
|
||||
'called'),
|
||||
default=[_DEFAULT_CONFIG_FILE])
|
||||
|
||||
parser.add_argument('-i', '--capture_input_files', nargs='+', required=True,
|
||||
help='path to the capture input wav files (one or more)')
|
||||
parser.add_argument(
|
||||
'-i',
|
||||
'--capture_input_files',
|
||||
nargs='+',
|
||||
required=True,
|
||||
help='path to the capture input wav files (one or more)')
|
||||
|
||||
parser.add_argument('-r', '--render_input_files', nargs='+', required=False,
|
||||
help=('path to the render input wav files; either '
|
||||
'omitted or one file for each file in '
|
||||
'--capture_input_files (files will be paired by '
|
||||
'index)'), default=None)
|
||||
parser.add_argument('-r',
|
||||
'--render_input_files',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('path to the render input wav files; either '
|
||||
'omitted or one file for each file in '
|
||||
'--capture_input_files (files will be paired by '
|
||||
'index)'),
|
||||
default=None)
|
||||
|
||||
parser.add_argument('-p', '--echo_path_simulator', required=False,
|
||||
help=('custom echo path simulator name; required if '
|
||||
'--render_input_files is specified'),
|
||||
choices=_ECHO_PATH_SIMULATOR_NAMES,
|
||||
default=echo_path_simulation.NoEchoPathSimulator.NAME)
|
||||
parser.add_argument('-p',
|
||||
'--echo_path_simulator',
|
||||
required=False,
|
||||
help=('custom echo path simulator name; required if '
|
||||
'--render_input_files is specified'),
|
||||
choices=_ECHO_PATH_SIMULATOR_NAMES,
|
||||
default=echo_path_simulation.NoEchoPathSimulator.NAME)
|
||||
|
||||
parser.add_argument('-t', '--test_data_generators', nargs='+', required=False,
|
||||
help='custom list of test data generators to use',
|
||||
choices=_TEST_DATA_GENERATORS_NAMES,
|
||||
default=_TEST_DATA_GENERATORS_NAMES)
|
||||
parser.add_argument('-t',
|
||||
'--test_data_generators',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help='custom list of test data generators to use',
|
||||
choices=_TEST_DATA_GENERATORS_NAMES,
|
||||
default=_TEST_DATA_GENERATORS_NAMES)
|
||||
|
||||
parser.add_argument('--additive_noise_tracks_path', required=False,
|
||||
help='path to the wav files for the additive',
|
||||
default=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH)
|
||||
parser.add_argument('--additive_noise_tracks_path', required=False,
|
||||
help='path to the wav files for the additive',
|
||||
default=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH)
|
||||
|
||||
parser.add_argument('-e', '--eval_scores', nargs='+', required=False,
|
||||
help='custom list of evaluation scores to use',
|
||||
choices=_EVAL_SCORE_WORKER_NAMES,
|
||||
default=_EVAL_SCORE_WORKER_NAMES)
|
||||
parser.add_argument('-e',
|
||||
'--eval_scores',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help='custom list of evaluation scores to use',
|
||||
choices=_EVAL_SCORE_WORKER_NAMES,
|
||||
default=_EVAL_SCORE_WORKER_NAMES)
|
||||
|
||||
parser.add_argument('-o', '--output_dir', required=False,
|
||||
help=('base path to the output directory in which the '
|
||||
'output wav files and the evaluation outcomes '
|
||||
'are saved'),
|
||||
default='output')
|
||||
parser.add_argument('-o',
|
||||
'--output_dir',
|
||||
required=False,
|
||||
help=('base path to the output directory in which the '
|
||||
'output wav files and the evaluation outcomes '
|
||||
'are saved'),
|
||||
default='output')
|
||||
|
||||
parser.add_argument('--polqa_path', required=True,
|
||||
help='path to the POLQA tool')
|
||||
parser.add_argument('--polqa_path',
|
||||
required=True,
|
||||
help='path to the POLQA tool')
|
||||
|
||||
parser.add_argument('--air_db_path', required=True,
|
||||
help='path to the Aechen IR database')
|
||||
parser.add_argument('--air_db_path',
|
||||
required=True,
|
||||
help='path to the Aechen IR database')
|
||||
|
||||
parser.add_argument('--apm_sim_path', required=False,
|
||||
help='path to the APM simulator tool',
|
||||
default=audioproc_wrapper. \
|
||||
AudioProcWrapper. \
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
parser.add_argument('--apm_sim_path', required=False,
|
||||
help='path to the APM simulator tool',
|
||||
default=audioproc_wrapper. \
|
||||
AudioProcWrapper. \
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
|
||||
parser.add_argument('--echo_metric_tool_bin_path', required=False,
|
||||
help=('path to the echo metric binary '
|
||||
'(required for the echo eval score)'),
|
||||
default=None)
|
||||
parser.add_argument('--echo_metric_tool_bin_path',
|
||||
required=False,
|
||||
help=('path to the echo metric binary '
|
||||
'(required for the echo eval score)'),
|
||||
default=None)
|
||||
|
||||
parser.add_argument('--copy_with_identity_generator', required=False,
|
||||
help=('If true, the identity test data generator makes a '
|
||||
'copy of the clean speech input file.'),
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--copy_with_identity_generator',
|
||||
required=False,
|
||||
help=('If true, the identity test data generator makes a '
|
||||
'copy of the clean speech input file.'),
|
||||
default=False)
|
||||
|
||||
parser.add_argument('--external_vad_paths', nargs='+', required=False,
|
||||
help=('Paths to external VAD programs. Each must take'
|
||||
'\'-i <wav file> -o <output>\' inputs'), default=[])
|
||||
parser.add_argument('--external_vad_paths',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('Paths to external VAD programs. Each must take'
|
||||
'\'-i <wav file> -o <output>\' inputs'),
|
||||
default=[])
|
||||
|
||||
parser.add_argument('--external_vad_names', nargs='+', required=False,
|
||||
help=('Keys to the vad paths. Must be different and '
|
||||
'as many as the paths.'), default=[])
|
||||
parser.add_argument('--external_vad_names',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('Keys to the vad paths. Must be different and '
|
||||
'as many as the paths.'),
|
||||
default=[])
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
def _ValidateArguments(args, parser):
|
||||
if args.capture_input_files and args.render_input_files and (
|
||||
len(args.capture_input_files) != len(args.render_input_files)):
|
||||
parser.error('--render_input_files and --capture_input_files must be lists '
|
||||
'having the same length')
|
||||
sys.exit(1)
|
||||
if args.capture_input_files and args.render_input_files and (len(
|
||||
args.capture_input_files) != len(args.render_input_files)):
|
||||
parser.error(
|
||||
'--render_input_files and --capture_input_files must be lists '
|
||||
'having the same length')
|
||||
sys.exit(1)
|
||||
|
||||
if args.render_input_files and not args.echo_path_simulator:
|
||||
parser.error('when --render_input_files is set, --echo_path_simulator is '
|
||||
'also required')
|
||||
sys.exit(1)
|
||||
if args.render_input_files and not args.echo_path_simulator:
|
||||
parser.error(
|
||||
'when --render_input_files is set, --echo_path_simulator is '
|
||||
'also required')
|
||||
sys.exit(1)
|
||||
|
||||
if len(args.external_vad_names) != len(args.external_vad_paths):
|
||||
parser.error('If provided, --external_vad_paths and '
|
||||
'--external_vad_names must '
|
||||
'have the same number of arguments.')
|
||||
sys.exit(1)
|
||||
if len(args.external_vad_names) != len(args.external_vad_paths):
|
||||
parser.error('If provided, --external_vad_paths and '
|
||||
'--external_vad_names must '
|
||||
'have the same number of arguments.')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
# TODO(alessiob): level = logging.INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
_ValidateArguments(args, parser)
|
||||
# TODO(alessiob): level = logging.INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
_ValidateArguments(args, parser)
|
||||
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=args.air_db_path,
|
||||
noise_tracks_path=args.additive_noise_tracks_path,
|
||||
copy_with_identity=args.copy_with_identity_generator)),
|
||||
evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
|
||||
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path
|
||||
),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
|
||||
evaluator=evaluation.ApmModuleEvaluator(),
|
||||
external_vads=external_vad.ExternalVad.ConstructVadDict(
|
||||
args.external_vad_paths, args.external_vad_names))
|
||||
simulator.Run(
|
||||
config_filepaths=args.config_files,
|
||||
capture_input_filepaths=args.capture_input_files,
|
||||
render_input_filepaths=args.render_input_files,
|
||||
echo_path_simulator_name=args.echo_path_simulator,
|
||||
test_data_generator_names=args.test_data_generators,
|
||||
eval_score_names=args.eval_scores,
|
||||
output_dir=args.output_dir)
|
||||
sys.exit(0)
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=args.air_db_path,
|
||||
noise_tracks_path=args.additive_noise_tracks_path,
|
||||
copy_with_identity=args.copy_with_identity_generator)),
|
||||
evaluation_score_factory=eval_scores_factory.
|
||||
EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
|
||||
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
|
||||
evaluator=evaluation.ApmModuleEvaluator(),
|
||||
external_vads=external_vad.ExternalVad.ConstructVadDict(
|
||||
args.external_vad_paths, args.external_vad_names))
|
||||
simulator.Run(config_filepaths=args.config_files,
|
||||
capture_input_filepaths=args.capture_input_files,
|
||||
render_input_filepaths=args.render_input_files,
|
||||
echo_path_simulator_name=args.echo_path_simulator,
|
||||
test_data_generator_names=args.test_data_generators,
|
||||
eval_score_names=args.eval_scores,
|
||||
output_dir=args.output_dir)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Shows boxplots of given score for different values of selected
|
||||
parameters. Can be used to compare scores by audioproc_f flag.
|
||||
|
||||
@ -30,29 +29,37 @@ import quality_assessment.collect_data as collect_data
|
||||
|
||||
|
||||
def InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
"""Arguments parser factory.
|
||||
"""
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Shows boxplot of given score for different values of selected'
|
||||
'parameters. Can be used to compare scores by audioproc_f flag')
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Shows boxplot of given score for different values of selected'
|
||||
'parameters. Can be used to compare scores by audioproc_f flag')
|
||||
|
||||
parser.add_argument('-v', '--eval_score', required=True,
|
||||
help=('Score name for constructing boxplots'))
|
||||
parser.add_argument('-v',
|
||||
'--eval_score',
|
||||
required=True,
|
||||
help=('Score name for constructing boxplots'))
|
||||
|
||||
parser.add_argument('-n', '--config_dir', required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
parser.add_argument(
|
||||
'-n',
|
||||
'--config_dir',
|
||||
required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
|
||||
parser.add_argument('-z', '--params_to_plot', required=True,
|
||||
nargs='+', help=('audioproc_f parameter values'
|
||||
'by which to group scores (no leading dash)'))
|
||||
parser.add_argument('-z',
|
||||
'--params_to_plot',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help=('audioproc_f parameter values'
|
||||
'by which to group scores (no leading dash)'))
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
|
||||
"""Filters data on the values of one or more parameters.
|
||||
"""Filters data on the values of one or more parameters.
|
||||
|
||||
Args:
|
||||
data_frame: pandas.DataFrame of all used input data.
|
||||
@ -71,34 +78,36 @@ def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
|
||||
Returns: dictionary, key is a param value, result is all scores for
|
||||
that param value (see `filter_params` for explanation).
|
||||
"""
|
||||
results = collections.defaultdict(dict)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
results = collections.defaultdict(dict)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + '.json'))
|
||||
data_with_config = data_frame[data_frame.apm_config == config_name]
|
||||
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
|
||||
score_name]
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + '.json'))
|
||||
data_with_config = data_frame[data_frame.apm_config == config_name]
|
||||
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
|
||||
score_name]
|
||||
|
||||
# Exactly one of |params_to_plot| must match:
|
||||
(matching_param, ) = [x for x in filter_params if '-' + x in config_json]
|
||||
# Exactly one of |params_to_plot| must match:
|
||||
(matching_param, ) = [
|
||||
x for x in filter_params if '-' + x in config_json
|
||||
]
|
||||
|
||||
# Add scores for every track to the result.
|
||||
for capture_name in data_cell_scores.capture:
|
||||
result_score = float(data_cell_scores[data_cell_scores.capture ==
|
||||
capture_name].score)
|
||||
config_dict = results[config_json['-' + matching_param]]
|
||||
if capture_name not in config_dict:
|
||||
config_dict[capture_name] = {}
|
||||
# Add scores for every track to the result.
|
||||
for capture_name in data_cell_scores.capture:
|
||||
result_score = float(data_cell_scores[data_cell_scores.capture ==
|
||||
capture_name].score)
|
||||
config_dict = results[config_json['-' + matching_param]]
|
||||
if capture_name not in config_dict:
|
||||
config_dict[capture_name] = {}
|
||||
|
||||
config_dict[capture_name][matching_param] = result_score
|
||||
config_dict[capture_name][matching_param] = result_score
|
||||
|
||||
return results
|
||||
return results
|
||||
|
||||
|
||||
def _FlattenToScoresList(config_param_score_dict):
|
||||
"""Extracts a list of scores from input data structure.
|
||||
"""Extracts a list of scores from input data structure.
|
||||
|
||||
Args:
|
||||
config_param_score_dict: of the form {'capture_name':
|
||||
@ -107,40 +116,39 @@ def _FlattenToScoresList(config_param_score_dict):
|
||||
Returns: Plain list of all score value present in input data
|
||||
structure
|
||||
"""
|
||||
result = []
|
||||
for capture_name in config_param_score_dict:
|
||||
result += list(config_param_score_dict[capture_name].values())
|
||||
return result
|
||||
result = []
|
||||
for capture_name in config_param_score_dict:
|
||||
result += list(config_param_score_dict[capture_name].values())
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
|
||||
# Filter the data by `args.params_to_plot`
|
||||
scores_filtered = FilterScoresByParams(scores_data_frame,
|
||||
args.params_to_plot,
|
||||
args.eval_score,
|
||||
args.config_dir)
|
||||
# Filter the data by `args.params_to_plot`
|
||||
scores_filtered = FilterScoresByParams(scores_data_frame,
|
||||
args.params_to_plot,
|
||||
args.eval_score, args.config_dir)
|
||||
|
||||
data_list = sorted(scores_filtered.items())
|
||||
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
|
||||
data_labels = [x for (x, _) in data_list]
|
||||
data_list = sorted(scores_filtered.items())
|
||||
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
|
||||
data_labels = [x for (x, _) in data_list]
|
||||
|
||||
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
|
||||
axes.boxplot(data_values, labels=data_labels)
|
||||
axes.set_ylabel(args.eval_score)
|
||||
axes.set_xlabel('/'.join(args.params_to_plot))
|
||||
plt.show()
|
||||
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
|
||||
axes.boxplot(data_values, labels=data_labels)
|
||||
axes.set_ylabel(args.eval_score)
|
||||
axes.set_xlabel('/'.join(args.params_to_plot))
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Export the scores computed by the apm_quality_assessment.py script into an
|
||||
HTML file.
|
||||
"""
|
||||
@ -20,7 +19,7 @@ import quality_assessment.export as export
|
||||
|
||||
|
||||
def _BuildOutputFilename(filename_suffix):
|
||||
"""Builds the filename for the exported file.
|
||||
"""Builds the filename for the exported file.
|
||||
|
||||
Args:
|
||||
filename_suffix: suffix for the output file name.
|
||||
@ -28,34 +27,37 @@ def _BuildOutputFilename(filename_suffix):
|
||||
Returns:
|
||||
A string.
|
||||
"""
|
||||
if filename_suffix is None:
|
||||
return 'results.html'
|
||||
return 'results-{}.html'.format(filename_suffix)
|
||||
if filename_suffix is None:
|
||||
return 'results.html'
|
||||
return 'results-{}.html'.format(filename_suffix)
|
||||
|
||||
|
||||
def main():
|
||||
# Init.
|
||||
logging.basicConfig(level=logging.DEBUG) # TODO(alessio): INFO once debugged.
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.add_argument('-f', '--filename_suffix',
|
||||
help=('suffix of the exported file'))
|
||||
parser.description = ('Exports pre-computed APM module quality assessment '
|
||||
'results into HTML tables')
|
||||
args = parser.parse_args()
|
||||
# Init.
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG) # TODO(alessio): INFO once debugged.
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.add_argument('-f',
|
||||
'--filename_suffix',
|
||||
help=('suffix of the exported file'))
|
||||
parser.description = ('Exports pre-computed APM module quality assessment '
|
||||
'results into HTML tables')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
|
||||
# Export.
|
||||
output_filepath = os.path.join(args.output_dir, _BuildOutputFilename(
|
||||
args.filename_suffix))
|
||||
exporter = export.HtmlExport(output_filepath)
|
||||
exporter.Export(scores_data_frame)
|
||||
# Export.
|
||||
output_filepath = os.path.join(args.output_dir,
|
||||
_BuildOutputFilename(args.filename_suffix))
|
||||
exporter = export.HtmlExport(output_filepath)
|
||||
exporter.Export(scores_data_frame)
|
||||
|
||||
logging.info('output file successfully written in %s', output_filepath)
|
||||
sys.exit(0)
|
||||
logging.info('output file successfully written in %s', output_filepath)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Generate .json files with which the APM module can be tested using the
|
||||
apm_quality_assessment.py script and audioproc_f as APM simulator.
|
||||
"""
|
||||
@ -20,7 +19,7 @@ OUTPUT_PATH = os.path.abspath('apm_configs')
|
||||
|
||||
|
||||
def _GenerateDefaultOverridden(config_override):
|
||||
"""Generates one or more APM overriden configurations.
|
||||
"""Generates one or more APM overriden configurations.
|
||||
|
||||
For each item in config_override, it overrides the default configuration and
|
||||
writes a new APM configuration file.
|
||||
@ -45,54 +44,85 @@ def _GenerateDefaultOverridden(config_override):
|
||||
config_override: dict of APM configuration file names as keys; the values
|
||||
are dict instances encoding the audioproc_f flags.
|
||||
"""
|
||||
for config_filename in config_override:
|
||||
config = config_override[config_filename]
|
||||
config['-all_default'] = None
|
||||
for config_filename in config_override:
|
||||
config = config_override[config_filename]
|
||||
config['-all_default'] = None
|
||||
|
||||
config_filepath = os.path.join(OUTPUT_PATH, 'default-{}.json'.format(
|
||||
config_filename))
|
||||
logging.debug('config file <%s> | %s', config_filepath, config)
|
||||
config_filepath = os.path.join(
|
||||
OUTPUT_PATH, 'default-{}.json'.format(config_filename))
|
||||
logging.debug('config file <%s> | %s', config_filepath, config)
|
||||
|
||||
data_access.AudioProcConfigFile.Save(config_filepath, config)
|
||||
logging.info('config file created: <%s>', config_filepath)
|
||||
data_access.AudioProcConfigFile.Save(config_filepath, config)
|
||||
logging.info('config file created: <%s>', config_filepath)
|
||||
|
||||
|
||||
def _GenerateAllDefaultButOne():
|
||||
"""Disables the flags enabled by default one-by-one.
|
||||
"""Disables the flags enabled by default one-by-one.
|
||||
"""
|
||||
config_sets = {
|
||||
'no_AEC': {'-aec': 0,},
|
||||
'no_AGC': {'-agc': 0,},
|
||||
'no_HP_filter': {'-hpf': 0,},
|
||||
'no_level_estimator': {'-le': 0,},
|
||||
'no_noise_suppressor': {'-ns': 0,},
|
||||
'no_transient_suppressor': {'-ts': 0,},
|
||||
'no_vad': {'-vad': 0,},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
config_sets = {
|
||||
'no_AEC': {
|
||||
'-aec': 0,
|
||||
},
|
||||
'no_AGC': {
|
||||
'-agc': 0,
|
||||
},
|
||||
'no_HP_filter': {
|
||||
'-hpf': 0,
|
||||
},
|
||||
'no_level_estimator': {
|
||||
'-le': 0,
|
||||
},
|
||||
'no_noise_suppressor': {
|
||||
'-ns': 0,
|
||||
},
|
||||
'no_transient_suppressor': {
|
||||
'-ts': 0,
|
||||
},
|
||||
'no_vad': {
|
||||
'-vad': 0,
|
||||
},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
|
||||
|
||||
def _GenerateAllDefaultPlusOne():
|
||||
"""Enables the flags disabled by default one-by-one.
|
||||
"""Enables the flags disabled by default one-by-one.
|
||||
"""
|
||||
config_sets = {
|
||||
'with_AECM': {'-aec': 0, '-aecm': 1,}, # AEC and AECM are exclusive.
|
||||
'with_AGC_limiter': {'-agc_limiter': 1,},
|
||||
'with_AEC_delay_agnostic': {'-delay_agnostic': 1,},
|
||||
'with_drift_compensation': {'-drift_compensation': 1,},
|
||||
'with_residual_echo_detector': {'-ed': 1,},
|
||||
'with_AEC_extended_filter': {'-extended_filter': 1,},
|
||||
'with_LC': {'-lc': 1,},
|
||||
'with_refined_adaptive_filter': {'-refined_adaptive_filter': 1,},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
config_sets = {
|
||||
'with_AECM': {
|
||||
'-aec': 0,
|
||||
'-aecm': 1,
|
||||
}, # AEC and AECM are exclusive.
|
||||
'with_AGC_limiter': {
|
||||
'-agc_limiter': 1,
|
||||
},
|
||||
'with_AEC_delay_agnostic': {
|
||||
'-delay_agnostic': 1,
|
||||
},
|
||||
'with_drift_compensation': {
|
||||
'-drift_compensation': 1,
|
||||
},
|
||||
'with_residual_echo_detector': {
|
||||
'-ed': 1,
|
||||
},
|
||||
'with_AEC_extended_filter': {
|
||||
'-extended_filter': 1,
|
||||
},
|
||||
'with_LC': {
|
||||
'-lc': 1,
|
||||
},
|
||||
'with_refined_adaptive_filter': {
|
||||
'-refined_adaptive_filter': 1,
|
||||
},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
_GenerateAllDefaultPlusOne()
|
||||
_GenerateAllDefaultButOne()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
_GenerateAllDefaultPlusOne()
|
||||
_GenerateAllDefaultButOne()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Finds the APM configuration that maximizes a provided metric by
|
||||
parsing the output generated apm_quality_assessment.py.
|
||||
"""
|
||||
@ -20,33 +19,44 @@ import os
|
||||
import quality_assessment.data_access as data_access
|
||||
import quality_assessment.collect_data as collect_data
|
||||
|
||||
|
||||
def _InstanceArgumentsParser():
|
||||
"""Arguments parser factory. Extends the arguments from 'collect_data'
|
||||
"""Arguments parser factory. Extends the arguments from 'collect_data'
|
||||
with a few extra for selecting what parameters to optimize for.
|
||||
"""
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Rudimentary optimization of a function over different parameter'
|
||||
'combinations.')
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Rudimentary optimization of a function over different parameter'
|
||||
'combinations.')
|
||||
|
||||
parser.add_argument('-n', '--config_dir', required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
parser.add_argument(
|
||||
'-n',
|
||||
'--config_dir',
|
||||
required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
|
||||
parser.add_argument('-p', '--params', required=True, nargs='+',
|
||||
help=('parameters to parse from the config files in'
|
||||
'config_dir'))
|
||||
parser.add_argument('-p',
|
||||
'--params',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help=('parameters to parse from the config files in'
|
||||
'config_dir'))
|
||||
|
||||
parser.add_argument('-z', '--params_not_to_optimize', required=False,
|
||||
nargs='+', default=[],
|
||||
help=('parameters from `params` not to be optimized for'))
|
||||
parser.add_argument(
|
||||
'-z',
|
||||
'--params_not_to_optimize',
|
||||
required=False,
|
||||
nargs='+',
|
||||
default=[],
|
||||
help=('parameters from `params` not to be optimized for'))
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
def _ConfigurationAndScores(data_frame, params,
|
||||
params_not_to_optimize, config_dir):
|
||||
"""Returns a list of all configurations and scores.
|
||||
def _ConfigurationAndScores(data_frame, params, params_not_to_optimize,
|
||||
config_dir):
|
||||
"""Returns a list of all configurations and scores.
|
||||
|
||||
Args:
|
||||
data_frame: A pandas data frame with the scores and config name
|
||||
@ -72,47 +82,47 @@ def _ConfigurationAndScores(data_frame, params,
|
||||
param combinations for params in `params_not_to_optimize` and
|
||||
their scores.
|
||||
"""
|
||||
results = collections.defaultdict(list)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
score_names = data_frame['eval_score_name'].drop_duplicates().values.tolist()
|
||||
results = collections.defaultdict(list)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
score_names = data_frame['eval_score_name'].drop_duplicates(
|
||||
).values.tolist()
|
||||
|
||||
# Normalize the scores
|
||||
normalization_constants = {}
|
||||
for score_name in score_names:
|
||||
scores = data_frame[data_frame.eval_score_name == score_name].score
|
||||
normalization_constants[score_name] = max(scores)
|
||||
|
||||
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
|
||||
param_combination = collections.namedtuple("ParamCombination",
|
||||
params_to_optimize)
|
||||
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + ".json"))
|
||||
scores = {}
|
||||
data_cell = data_frame[data_frame.apm_config == config_name]
|
||||
# Normalize the scores
|
||||
normalization_constants = {}
|
||||
for score_name in score_names:
|
||||
data_cell_scores = data_cell[data_cell.eval_score_name ==
|
||||
score_name].score
|
||||
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
|
||||
scores[score_name] /= normalization_constants[score_name]
|
||||
scores = data_frame[data_frame.eval_score_name == score_name].score
|
||||
normalization_constants[score_name] = max(scores)
|
||||
|
||||
result = {'scores': scores, 'params': {}}
|
||||
config_optimize_params = {}
|
||||
for param in params:
|
||||
if param in params_to_optimize:
|
||||
config_optimize_params[param] = config_json['-' + param]
|
||||
else:
|
||||
result['params'][param] = config_json['-' + param]
|
||||
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
|
||||
param_combination = collections.namedtuple("ParamCombination",
|
||||
params_to_optimize)
|
||||
|
||||
current_param_combination = param_combination(
|
||||
**config_optimize_params)
|
||||
results[current_param_combination].append(result)
|
||||
return results
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + ".json"))
|
||||
scores = {}
|
||||
data_cell = data_frame[data_frame.apm_config == config_name]
|
||||
for score_name in score_names:
|
||||
data_cell_scores = data_cell[data_cell.eval_score_name ==
|
||||
score_name].score
|
||||
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
|
||||
scores[score_name] /= normalization_constants[score_name]
|
||||
|
||||
result = {'scores': scores, 'params': {}}
|
||||
config_optimize_params = {}
|
||||
for param in params:
|
||||
if param in params_to_optimize:
|
||||
config_optimize_params[param] = config_json['-' + param]
|
||||
else:
|
||||
result['params'][param] = config_json['-' + param]
|
||||
|
||||
current_param_combination = param_combination(**config_optimize_params)
|
||||
results[current_param_combination].append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _FindOptimalParameter(configs_and_scores, score_weighting):
|
||||
"""Finds the config producing the maximal score.
|
||||
"""Finds the config producing the maximal score.
|
||||
|
||||
Args:
|
||||
configs_and_scores: structure of the form returned by
|
||||
@ -127,53 +137,53 @@ def _FindOptimalParameter(configs_and_scores, score_weighting):
|
||||
to its scores.
|
||||
"""
|
||||
|
||||
min_score = float('+inf')
|
||||
best_params = None
|
||||
for config in configs_and_scores:
|
||||
scores_and_params = configs_and_scores[config]
|
||||
current_score = score_weighting(scores_and_params)
|
||||
if current_score < min_score:
|
||||
min_score = current_score
|
||||
best_params = config
|
||||
logging.debug("Score: %f", current_score)
|
||||
logging.debug("Config: %s", str(config))
|
||||
return best_params
|
||||
min_score = float('+inf')
|
||||
best_params = None
|
||||
for config in configs_and_scores:
|
||||
scores_and_params = configs_and_scores[config]
|
||||
current_score = score_weighting(scores_and_params)
|
||||
if current_score < min_score:
|
||||
min_score = current_score
|
||||
best_params = config
|
||||
logging.debug("Score: %f", current_score)
|
||||
logging.debug("Config: %s", str(config))
|
||||
return best_params
|
||||
|
||||
|
||||
def _ExampleWeighting(scores_and_configs):
|
||||
"""Example argument to `_FindOptimalParameter`
|
||||
"""Example argument to `_FindOptimalParameter`
|
||||
Args:
|
||||
scores_and_configs: a list of configs and scores, in the form
|
||||
described in _FindOptimalParameter
|
||||
Returns:
|
||||
numeric value, the sum of all scores
|
||||
"""
|
||||
res = 0
|
||||
for score_config in scores_and_configs:
|
||||
res += sum(score_config['scores'].values())
|
||||
return res
|
||||
res = 0
|
||||
for score_config in scores_and_configs:
|
||||
res += sum(score_config['scores'].values())
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug('Src path <%s>', src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
all_scores = _ConfigurationAndScores(scores_data_frame,
|
||||
args.params,
|
||||
args.params_not_to_optimize,
|
||||
args.config_dir)
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug('Src path <%s>', src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
all_scores = _ConfigurationAndScores(scores_data_frame, args.params,
|
||||
args.params_not_to_optimize,
|
||||
args.config_dir)
|
||||
|
||||
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
|
||||
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
|
||||
|
||||
logging.info('Optimal parameter combination: <%s>', opt_param)
|
||||
logging.info('It\'s score values: <%s>', all_scores[opt_param])
|
||||
|
||||
logging.info('Optimal parameter combination: <%s>', opt_param)
|
||||
logging.info('It\'s score values: <%s>', all_scores[opt_param])
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the apm_quality_assessment module.
|
||||
"""
|
||||
|
||||
@ -16,13 +15,14 @@ import mock
|
||||
|
||||
import apm_quality_assessment
|
||||
|
||||
|
||||
class TestSimulationScript(unittest.TestCase):
|
||||
"""Unit tests for the apm_quality_assessment module.
|
||||
"""Unit tests for the apm_quality_assessment module.
|
||||
"""
|
||||
|
||||
def testMain(self):
|
||||
# Exit with error code if no arguments are passed.
|
||||
with self.assertRaises(SystemExit) as cm, mock.patch.object(
|
||||
sys, 'argv', ['apm_quality_assessment.py']):
|
||||
apm_quality_assessment.main()
|
||||
self.assertGreater(cm.exception.code, 0)
|
||||
def testMain(self):
|
||||
# Exit with error code if no arguments are passed.
|
||||
with self.assertRaises(SystemExit) as cm, mock.patch.object(
|
||||
sys, 'argv', ['apm_quality_assessment.py']):
|
||||
apm_quality_assessment.main()
|
||||
self.assertGreater(cm.exception.code, 0)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Extraction of annotations from audio files.
|
||||
"""
|
||||
|
||||
@ -19,10 +18,10 @@ import sys
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import external_vad
|
||||
from . import exceptions
|
||||
@ -30,262 +29,268 @@ from . import signal_processing
|
||||
|
||||
|
||||
class AudioAnnotationsExtractor(object):
|
||||
"""Extracts annotations from audio files.
|
||||
"""Extracts annotations from audio files.
|
||||
"""
|
||||
|
||||
class VadType(object):
|
||||
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
|
||||
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
|
||||
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
|
||||
class VadType(object):
|
||||
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
|
||||
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
|
||||
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
|
||||
|
||||
def __init__(self, value):
|
||||
if (not isinstance(value, int)) or not 0 <= value <= 7:
|
||||
raise exceptions.InitializationException(
|
||||
'Invalid vad type: ' + value)
|
||||
self._value = value
|
||||
def __init__(self, value):
|
||||
if (not isinstance(value, int)) or not 0 <= value <= 7:
|
||||
raise exceptions.InitializationException('Invalid vad type: ' +
|
||||
value)
|
||||
self._value = value
|
||||
|
||||
def Contains(self, vad_type):
|
||||
return self._value | vad_type == self._value
|
||||
def Contains(self, vad_type):
|
||||
return self._value | vad_type == self._value
|
||||
|
||||
def __str__(self):
|
||||
vads = []
|
||||
if self.Contains(self.ENERGY_THRESHOLD):
|
||||
vads.append("energy")
|
||||
if self.Contains(self.WEBRTC_COMMON_AUDIO):
|
||||
vads.append("common_audio")
|
||||
if self.Contains(self.WEBRTC_APM):
|
||||
vads.append("apm")
|
||||
return "VadType({})".format(", ".join(vads))
|
||||
def __str__(self):
|
||||
vads = []
|
||||
if self.Contains(self.ENERGY_THRESHOLD):
|
||||
vads.append("energy")
|
||||
if self.Contains(self.WEBRTC_COMMON_AUDIO):
|
||||
vads.append("common_audio")
|
||||
if self.Contains(self.WEBRTC_APM):
|
||||
vads.append("apm")
|
||||
return "VadType({})".format(", ".join(vads))
|
||||
|
||||
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
|
||||
|
||||
# Level estimation params.
|
||||
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
|
||||
_LEVEL_FRAME_SIZE_MS = 1.0
|
||||
# The time constants in ms indicate the time it takes for the level estimate
|
||||
# to go down/up by 1 db if the signal is zero.
|
||||
_LEVEL_ATTACK_MS = 5.0
|
||||
_LEVEL_DECAY_MS = 20.0
|
||||
|
||||
# VAD params.
|
||||
_VAD_THRESHOLD = 1
|
||||
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(
|
||||
os.path.abspath(__file__)), os.pardir, os.pardir)
|
||||
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
|
||||
|
||||
_VAD_WEBRTC_APM_PATH = os.path.join(
|
||||
_VAD_WEBRTC_PATH, 'apm_vad')
|
||||
|
||||
def __init__(self, vad_type, external_vads=None):
|
||||
self._signal = None
|
||||
self._level = None
|
||||
self._level_frame_size = None
|
||||
self._common_audio_vad = None
|
||||
self._energy_vad = None
|
||||
self._apm_vad_probs = None
|
||||
self._apm_vad_rms = None
|
||||
self._vad_frame_size = None
|
||||
self._vad_frame_size_ms = None
|
||||
self._c_attack = None
|
||||
self._c_decay = None
|
||||
|
||||
self._vad_type = self.VadType(vad_type)
|
||||
logging.info('VADs used for annotations: ' + str(self._vad_type))
|
||||
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._external_vads = external_vads
|
||||
|
||||
assert len(self._external_vads) == len(external_vads), (
|
||||
'The external VAD names must be unique.')
|
||||
for vad in external_vads.values():
|
||||
if not isinstance(vad, external_vad.ExternalVad):
|
||||
raise exceptions.InitializationException(
|
||||
'Invalid vad type: ' + str(type(vad)))
|
||||
logging.info('External VAD used for annotation: ' +
|
||||
str(vad.name))
|
||||
|
||||
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH
|
||||
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
|
||||
self._VAD_WEBRTC_APM_PATH
|
||||
|
||||
@classmethod
|
||||
def GetOutputFileNameTemplate(cls):
|
||||
return cls._OUTPUT_FILENAME_TEMPLATE
|
||||
|
||||
def GetLevel(self):
|
||||
return self._level
|
||||
|
||||
def GetLevelFrameSize(self):
|
||||
return self._level_frame_size
|
||||
|
||||
@classmethod
|
||||
def GetLevelFrameSizeMs(cls):
|
||||
return cls._LEVEL_FRAME_SIZE_MS
|
||||
|
||||
def GetVadOutput(self, vad_type):
|
||||
if vad_type == self.VadType.ENERGY_THRESHOLD:
|
||||
return self._energy_vad
|
||||
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
|
||||
return self._common_audio_vad
|
||||
elif vad_type == self.VadType.WEBRTC_APM:
|
||||
return (self._apm_vad_probs, self._apm_vad_rms)
|
||||
else:
|
||||
raise exceptions.InitializationException(
|
||||
'Invalid vad type: ' + vad_type)
|
||||
|
||||
def GetVadFrameSize(self):
|
||||
return self._vad_frame_size
|
||||
|
||||
def GetVadFrameSizeMs(self):
|
||||
return self._vad_frame_size_ms
|
||||
|
||||
def Extract(self, filepath):
|
||||
# Load signal.
|
||||
self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath)
|
||||
if self._signal.channels != 1:
|
||||
raise NotImplementedError('Multiple-channel annotations not implemented')
|
||||
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
|
||||
|
||||
# Level estimation params.
|
||||
self._level_frame_size = int(self._signal.frame_rate / 1000 * (
|
||||
self._LEVEL_FRAME_SIZE_MS))
|
||||
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION ** (
|
||||
self._LEVEL_FRAME_SIZE_MS / self._LEVEL_ATTACK_MS))
|
||||
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION ** (
|
||||
self._LEVEL_FRAME_SIZE_MS / self._LEVEL_DECAY_MS))
|
||||
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
|
||||
_LEVEL_FRAME_SIZE_MS = 1.0
|
||||
# The time constants in ms indicate the time it takes for the level estimate
|
||||
# to go down/up by 1 db if the signal is zero.
|
||||
_LEVEL_ATTACK_MS = 5.0
|
||||
_LEVEL_DECAY_MS = 20.0
|
||||
|
||||
# Compute level.
|
||||
self._LevelEstimation()
|
||||
# VAD params.
|
||||
_VAD_THRESHOLD = 1
|
||||
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
os.pardir, os.pardir)
|
||||
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
|
||||
|
||||
# Ideal VAD output, it requires clean speech with high SNR as input.
|
||||
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
|
||||
# Naive VAD based on level thresholding.
|
||||
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
|
||||
self._energy_vad = np.uint8(self._level > vad_threshold)
|
||||
self._vad_frame_size = self._level_frame_size
|
||||
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
|
||||
# WebRTC common_audio/ VAD.
|
||||
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
|
||||
# WebRTC modules/audio_processing/ VAD.
|
||||
self._RunWebRtcApmVad(filepath)
|
||||
for extvad_name in self._external_vads:
|
||||
self._external_vads[extvad_name].Run(filepath)
|
||||
_VAD_WEBRTC_APM_PATH = os.path.join(_VAD_WEBRTC_PATH, 'apm_vad')
|
||||
|
||||
def Save(self, output_path, annotation_name=""):
|
||||
ext_kwargs = {'extvad_conf-' + ext_vad:
|
||||
self._external_vads[ext_vad].GetVadOutput()
|
||||
for ext_vad in self._external_vads}
|
||||
np.savez_compressed(
|
||||
file=os.path.join(
|
||||
def __init__(self, vad_type, external_vads=None):
|
||||
self._signal = None
|
||||
self._level = None
|
||||
self._level_frame_size = None
|
||||
self._common_audio_vad = None
|
||||
self._energy_vad = None
|
||||
self._apm_vad_probs = None
|
||||
self._apm_vad_rms = None
|
||||
self._vad_frame_size = None
|
||||
self._vad_frame_size_ms = None
|
||||
self._c_attack = None
|
||||
self._c_decay = None
|
||||
|
||||
self._vad_type = self.VadType(vad_type)
|
||||
logging.info('VADs used for annotations: ' + str(self._vad_type))
|
||||
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._external_vads = external_vads
|
||||
|
||||
assert len(self._external_vads) == len(external_vads), (
|
||||
'The external VAD names must be unique.')
|
||||
for vad in external_vads.values():
|
||||
if not isinstance(vad, external_vad.ExternalVad):
|
||||
raise exceptions.InitializationException('Invalid vad type: ' +
|
||||
str(type(vad)))
|
||||
logging.info('External VAD used for annotation: ' + str(vad.name))
|
||||
|
||||
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH
|
||||
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
|
||||
self._VAD_WEBRTC_APM_PATH
|
||||
|
||||
@classmethod
|
||||
def GetOutputFileNameTemplate(cls):
|
||||
return cls._OUTPUT_FILENAME_TEMPLATE
|
||||
|
||||
def GetLevel(self):
|
||||
return self._level
|
||||
|
||||
def GetLevelFrameSize(self):
|
||||
return self._level_frame_size
|
||||
|
||||
@classmethod
|
||||
def GetLevelFrameSizeMs(cls):
|
||||
return cls._LEVEL_FRAME_SIZE_MS
|
||||
|
||||
def GetVadOutput(self, vad_type):
|
||||
if vad_type == self.VadType.ENERGY_THRESHOLD:
|
||||
return self._energy_vad
|
||||
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
|
||||
return self._common_audio_vad
|
||||
elif vad_type == self.VadType.WEBRTC_APM:
|
||||
return (self._apm_vad_probs, self._apm_vad_rms)
|
||||
else:
|
||||
raise exceptions.InitializationException('Invalid vad type: ' +
|
||||
vad_type)
|
||||
|
||||
def GetVadFrameSize(self):
|
||||
return self._vad_frame_size
|
||||
|
||||
def GetVadFrameSizeMs(self):
|
||||
return self._vad_frame_size_ms
|
||||
|
||||
def Extract(self, filepath):
|
||||
# Load signal.
|
||||
self._signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
filepath)
|
||||
if self._signal.channels != 1:
|
||||
raise NotImplementedError(
|
||||
'Multiple-channel annotations not implemented')
|
||||
|
||||
# Level estimation params.
|
||||
self._level_frame_size = int(self._signal.frame_rate / 1000 *
|
||||
(self._LEVEL_FRAME_SIZE_MS))
|
||||
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
|
||||
self._LEVEL_ATTACK_MS))
|
||||
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
|
||||
self._LEVEL_DECAY_MS))
|
||||
|
||||
# Compute level.
|
||||
self._LevelEstimation()
|
||||
|
||||
# Ideal VAD output, it requires clean speech with high SNR as input.
|
||||
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
|
||||
# Naive VAD based on level thresholding.
|
||||
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
|
||||
self._energy_vad = np.uint8(self._level > vad_threshold)
|
||||
self._vad_frame_size = self._level_frame_size
|
||||
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
|
||||
# WebRTC common_audio/ VAD.
|
||||
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
|
||||
# WebRTC modules/audio_processing/ VAD.
|
||||
self._RunWebRtcApmVad(filepath)
|
||||
for extvad_name in self._external_vads:
|
||||
self._external_vads[extvad_name].Run(filepath)
|
||||
|
||||
def Save(self, output_path, annotation_name=""):
|
||||
ext_kwargs = {
|
||||
'extvad_conf-' + ext_vad:
|
||||
self._external_vads[ext_vad].GetVadOutput()
|
||||
for ext_vad in self._external_vads
|
||||
}
|
||||
np.savez_compressed(file=os.path.join(
|
||||
output_path,
|
||||
self.GetOutputFileNameTemplate().format(annotation_name)),
|
||||
level=self._level,
|
||||
level_frame_size=self._level_frame_size,
|
||||
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
|
||||
vad_output=self._common_audio_vad,
|
||||
vad_energy_output=self._energy_vad,
|
||||
vad_frame_size=self._vad_frame_size,
|
||||
vad_frame_size_ms=self._vad_frame_size_ms,
|
||||
vad_probs=self._apm_vad_probs,
|
||||
vad_rms=self._apm_vad_rms,
|
||||
**ext_kwargs
|
||||
)
|
||||
level=self._level,
|
||||
level_frame_size=self._level_frame_size,
|
||||
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
|
||||
vad_output=self._common_audio_vad,
|
||||
vad_energy_output=self._energy_vad,
|
||||
vad_frame_size=self._vad_frame_size,
|
||||
vad_frame_size_ms=self._vad_frame_size_ms,
|
||||
vad_probs=self._apm_vad_probs,
|
||||
vad_rms=self._apm_vad_rms,
|
||||
**ext_kwargs)
|
||||
|
||||
def _LevelEstimation(self):
|
||||
# Read samples.
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._signal).astype(np.float32) / 32768.0
|
||||
num_frames = len(samples) // self._level_frame_size
|
||||
num_samples = num_frames * self._level_frame_size
|
||||
def _LevelEstimation(self):
|
||||
# Read samples.
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._signal).astype(np.float32) / 32768.0
|
||||
num_frames = len(samples) // self._level_frame_size
|
||||
num_samples = num_frames * self._level_frame_size
|
||||
|
||||
# Envelope.
|
||||
self._level = np.max(np.reshape(np.abs(samples[:num_samples]), (
|
||||
num_frames, self._level_frame_size)), axis=1)
|
||||
assert len(self._level) == num_frames
|
||||
# Envelope.
|
||||
self._level = np.max(np.reshape(np.abs(samples[:num_samples]),
|
||||
(num_frames, self._level_frame_size)),
|
||||
axis=1)
|
||||
assert len(self._level) == num_frames
|
||||
|
||||
# Envelope smoothing.
|
||||
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
|
||||
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
|
||||
for i in range(1, num_frames):
|
||||
self._level[i] = smooth(
|
||||
self._level[i], self._level[i - 1], self._c_attack if (
|
||||
self._level[i] > self._level[i - 1]) else self._c_decay)
|
||||
# Envelope smoothing.
|
||||
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
|
||||
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
|
||||
for i in range(1, num_frames):
|
||||
self._level[i] = smooth(
|
||||
self._level[i], self._level[i - 1], self._c_attack if
|
||||
(self._level[i] > self._level[i - 1]) else self._c_decay)
|
||||
|
||||
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
|
||||
self._common_audio_vad = None
|
||||
self._vad_frame_size = None
|
||||
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
|
||||
self._common_audio_vad = None
|
||||
self._vad_frame_size = None
|
||||
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path = os.path.join(
|
||||
tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp')
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path = os.path.join(
|
||||
tmp_path,
|
||||
os.path.split(wav_file_path)[1] + '_vad.tmp')
|
||||
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH,
|
||||
'-i', wav_file_path,
|
||||
'-o', output_file_path
|
||||
], cwd=self._VAD_WEBRTC_PATH)
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH, '-i', wav_file_path, '-o',
|
||||
output_file_path
|
||||
],
|
||||
cwd=self._VAD_WEBRTC_PATH)
|
||||
|
||||
# Read bytes.
|
||||
with open(output_file_path, 'rb') as f:
|
||||
raw_data = f.read()
|
||||
# Read bytes.
|
||||
with open(output_file_path, 'rb') as f:
|
||||
raw_data = f.read()
|
||||
|
||||
# Parse side information.
|
||||
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
|
||||
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
|
||||
assert self._vad_frame_size_ms in [10, 20, 30]
|
||||
extra_bits = struct.unpack('B', raw_data[-1])[0]
|
||||
assert 0 <= extra_bits <= 8
|
||||
# Parse side information.
|
||||
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
|
||||
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
|
||||
assert self._vad_frame_size_ms in [10, 20, 30]
|
||||
extra_bits = struct.unpack('B', raw_data[-1])[0]
|
||||
assert 0 <= extra_bits <= 8
|
||||
|
||||
# Init VAD vector.
|
||||
num_bytes = len(raw_data)
|
||||
num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte.
|
||||
self._common_audio_vad = np.zeros(num_frames, np.uint8)
|
||||
# Init VAD vector.
|
||||
num_bytes = len(raw_data)
|
||||
num_frames = 8 * (num_bytes -
|
||||
2) - extra_bits # 8 frames for each byte.
|
||||
self._common_audio_vad = np.zeros(num_frames, np.uint8)
|
||||
|
||||
# Read VAD decisions.
|
||||
for i, byte in enumerate(raw_data[1:-1]):
|
||||
byte = struct.unpack('B', byte)[0]
|
||||
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
|
||||
self._common_audio_vad[i * 8 + j] = int(byte & 1)
|
||||
byte = byte >> 1
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC VAD (' + e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
# Read VAD decisions.
|
||||
for i, byte in enumerate(raw_data[1:-1]):
|
||||
byte = struct.unpack('B', byte)[0]
|
||||
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
|
||||
self._common_audio_vad[i * 8 + j] = int(byte & 1)
|
||||
byte = byte >> 1
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC VAD (' + e.message +
|
||||
')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
|
||||
def _RunWebRtcApmVad(self, wav_file_path):
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path_probs = os.path.join(
|
||||
tmp_path, os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
|
||||
output_file_path_rms = os.path.join(
|
||||
tmp_path, os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
|
||||
def _RunWebRtcApmVad(self, wav_file_path):
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path_probs = os.path.join(
|
||||
tmp_path,
|
||||
os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
|
||||
output_file_path_rms = os.path.join(
|
||||
tmp_path,
|
||||
os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
|
||||
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_APM_PATH,
|
||||
'-i', wav_file_path,
|
||||
'-o_probs', output_file_path_probs,
|
||||
'-o_rms', output_file_path_rms
|
||||
], cwd=self._VAD_WEBRTC_PATH)
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_APM_PATH, '-i', wav_file_path, '-o_probs',
|
||||
output_file_path_probs, '-o_rms', output_file_path_rms
|
||||
],
|
||||
cwd=self._VAD_WEBRTC_PATH)
|
||||
|
||||
# Parse annotations.
|
||||
self._apm_vad_probs = np.fromfile(output_file_path_probs, np.double)
|
||||
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
|
||||
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
|
||||
# Parse annotations.
|
||||
self._apm_vad_probs = np.fromfile(output_file_path_probs,
|
||||
np.double)
|
||||
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
|
||||
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
|
||||
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC APM VAD (' +
|
||||
e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC APM VAD (' +
|
||||
e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the annotations module.
|
||||
"""
|
||||
|
||||
@ -25,133 +24,137 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestAnnotationsExtraction(unittest.TestCase):
|
||||
"""Unit tests for the annotations module.
|
||||
"""Unit tests for the annotations module.
|
||||
"""
|
||||
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
_DEBUG_PLOT_VAD = False
|
||||
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
|
||||
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD |
|
||||
_VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO |
|
||||
_VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
_DEBUG_PLOT_VAD = False
|
||||
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
|
||||
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD
|
||||
| _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO
|
||||
| _VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folder."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
|
||||
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
|
||||
'pure_tone', [440, 1000])
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._wav_file_path, pure_tone)
|
||||
self._sample_rate = pure_tone.frame_rate
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folder."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
|
||||
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
|
||||
'pure_tone', [440, 1000])
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._wav_file_path, pure_tone)
|
||||
self._sample_rate = pure_tone.frame_rate
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' +
|
||||
(self._tmp_path))
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' + (
|
||||
self._tmp_path))
|
||||
def testFrameSizes(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
samples_to_ms = lambda n, sr: 1000 * n // sr
|
||||
self.assertEqual(
|
||||
samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
|
||||
e.GetLevelFrameSizeMs())
|
||||
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
|
||||
e.GetVadFrameSizeMs())
|
||||
|
||||
def testFrameSizes(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
samples_to_ms = lambda n, sr: 1000 * n // sr
|
||||
self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
|
||||
e.GetLevelFrameSizeMs())
|
||||
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
|
||||
e.GetVadFrameSizeMs())
|
||||
def testVoiceActivityDetectors(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
|
||||
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
|
||||
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
|
||||
e.Extract(self._wav_file_path)
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(
|
||||
self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_output)) / len(vad_output), 0.95)
|
||||
|
||||
def testVoiceActivityDetectors(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
|
||||
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
|
||||
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
|
||||
e.Extract(self._wav_file_path)
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
|
||||
0.95)
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(
|
||||
self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_output)) / len(vad_output), 0.95)
|
||||
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
|
||||
0.95)
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
(vad_probs,
|
||||
vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
self.assertGreater(len(vad_probs), 0)
|
||||
self.assertGreater(len(vad_rms), 0)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_probs)) / len(vad_probs), 0.5)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_rms)) / len(vad_rms), 20000)
|
||||
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
(vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
self.assertGreater(len(vad_probs), 0)
|
||||
self.assertGreater(len(vad_rms), 0)
|
||||
self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs),
|
||||
0.5)
|
||||
self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000)
|
||||
if self._DEBUG_PLOT_VAD:
|
||||
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
|
||||
num_frames).astype(np.float32) * frame_size_ms / 1000.0
|
||||
level = e.GetLevel()
|
||||
t_level = frame_times_s(num_frames=len(level),
|
||||
frame_size_ms=e.GetLevelFrameSizeMs())
|
||||
t_vad = frame_times_s(num_frames=len(vad_output),
|
||||
frame_size_ms=e.GetVadFrameSizeMs())
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure()
|
||||
plt.hold(True)
|
||||
plt.plot(t_level, level)
|
||||
plt.plot(t_vad, vad_output * np.max(level), '.')
|
||||
plt.show()
|
||||
|
||||
if self._DEBUG_PLOT_VAD:
|
||||
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
|
||||
num_frames).astype(np.float32) * frame_size_ms / 1000.0
|
||||
level = e.GetLevel()
|
||||
t_level = frame_times_s(
|
||||
num_frames=len(level),
|
||||
frame_size_ms=e.GetLevelFrameSizeMs())
|
||||
t_vad = frame_times_s(
|
||||
num_frames=len(vad_output),
|
||||
frame_size_ms=e.GetVadFrameSizeMs())
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure()
|
||||
plt.hold(True)
|
||||
plt.plot(t_level, level)
|
||||
plt.plot(t_vad, vad_output * np.max(level), '.')
|
||||
plt.show()
|
||||
def testSaveLoad(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, "fake-annotation")
|
||||
|
||||
def testSaveLoad(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, "fake-annotation")
|
||||
data = np.load(
|
||||
os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
np.testing.assert_array_equal(e.GetLevel(), data['level'])
|
||||
self.assertEqual(np.float32, data['level'].dtype)
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
|
||||
data['vad_energy_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
|
||||
data['vad_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0],
|
||||
data['vad_probs'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1],
|
||||
data['vad_rms'])
|
||||
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_probs'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_rms'].dtype)
|
||||
|
||||
data = np.load(os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
np.testing.assert_array_equal(e.GetLevel(), data['level'])
|
||||
self.assertEqual(np.float32, data['level'].dtype)
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
|
||||
data['vad_energy_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
|
||||
data['vad_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms'])
|
||||
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_probs'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_rms'].dtype)
|
||||
def testEmptyExternalShouldNotCrash(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
|
||||
annotations.AudioAnnotationsExtractor(vad_type_value, {})
|
||||
|
||||
def testEmptyExternalShouldNotCrash(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
|
||||
annotations.AudioAnnotationsExtractor(vad_type_value, {})
|
||||
def testFakeExternalSaveLoad(self):
|
||||
def FakeExternalFactory():
|
||||
return external_vad.ExternalVad(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'fake_external_vad.py'), 'fake')
|
||||
|
||||
def testFakeExternalSaveLoad(self):
|
||||
def FakeExternalFactory():
|
||||
return external_vad.ExternalVad(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'),
|
||||
'fake'
|
||||
)
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
|
||||
e = annotations.AudioAnnotationsExtractor(
|
||||
vad_type_value,
|
||||
{'fake': FakeExternalFactory()})
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, annotation_name="fake-annotation")
|
||||
data = np.load(os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
|
||||
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
|
||||
data['extvad_conf-fake'])
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
|
||||
e = annotations.AudioAnnotationsExtractor(
|
||||
vad_type_value, {'fake': FakeExternalFactory()})
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, annotation_name="fake-annotation")
|
||||
data = np.load(
|
||||
os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
|
||||
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
|
||||
data['extvad_conf-fake'])
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Class implementing a wrapper for APM simulators.
|
||||
"""
|
||||
|
||||
@ -19,33 +18,36 @@ from . import exceptions
|
||||
|
||||
|
||||
class AudioProcWrapper(object):
|
||||
"""Wrapper for APM simulators.
|
||||
"""Wrapper for APM simulators.
|
||||
"""
|
||||
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(os.path.join(
|
||||
os.pardir, 'audioproc_f'))
|
||||
OUTPUT_FILENAME = 'output.wav'
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(
|
||||
os.path.join(os.pardir, 'audioproc_f'))
|
||||
OUTPUT_FILENAME = 'output.wav'
|
||||
|
||||
def __init__(self, simulator_bin_path):
|
||||
"""Ctor.
|
||||
def __init__(self, simulator_bin_path):
|
||||
"""Ctor.
|
||||
|
||||
Args:
|
||||
simulator_bin_path: path to the APM simulator binary.
|
||||
"""
|
||||
self._simulator_bin_path = simulator_bin_path
|
||||
self._config = None
|
||||
self._output_signal_filepath = None
|
||||
self._simulator_bin_path = simulator_bin_path
|
||||
self._config = None
|
||||
self._output_signal_filepath = None
|
||||
|
||||
# Profiler instance to measure running time.
|
||||
self._profiler = cProfile.Profile()
|
||||
# Profiler instance to measure running time.
|
||||
self._profiler = cProfile.Profile()
|
||||
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_signal_filepath
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_signal_filepath
|
||||
|
||||
def Run(self, config_filepath, capture_input_filepath, output_path,
|
||||
render_input_filepath=None):
|
||||
"""Runs APM simulator.
|
||||
def Run(self,
|
||||
config_filepath,
|
||||
capture_input_filepath,
|
||||
output_path,
|
||||
render_input_filepath=None):
|
||||
"""Runs APM simulator.
|
||||
|
||||
Args:
|
||||
config_filepath: path to the configuration file specifying the arguments
|
||||
@ -56,41 +58,43 @@ class AudioProcWrapper(object):
|
||||
render_input_filepath: path to the render audio track input file (aka
|
||||
reverse or far-end).
|
||||
"""
|
||||
# Init.
|
||||
self._output_signal_filepath = os.path.join(
|
||||
output_path, self.OUTPUT_FILENAME)
|
||||
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
|
||||
# Init.
|
||||
self._output_signal_filepath = os.path.join(output_path,
|
||||
self.OUTPUT_FILENAME)
|
||||
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
|
||||
|
||||
# Skip if the output has already been generated.
|
||||
if os.path.exists(self._output_signal_filepath) and os.path.exists(
|
||||
profiling_stats_filepath):
|
||||
return
|
||||
# Skip if the output has already been generated.
|
||||
if os.path.exists(self._output_signal_filepath) and os.path.exists(
|
||||
profiling_stats_filepath):
|
||||
return
|
||||
|
||||
# Load configuration.
|
||||
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
|
||||
# Load configuration.
|
||||
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
|
||||
|
||||
# Set remaining parameters.
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
raise exceptions.FileNotFoundError('cannot find capture input file')
|
||||
self._config['-i'] = capture_input_filepath
|
||||
self._config['-o'] = self._output_signal_filepath
|
||||
if render_input_filepath is not None:
|
||||
if not os.path.exists(render_input_filepath):
|
||||
raise exceptions.FileNotFoundError('cannot find render input file')
|
||||
self._config['-ri'] = render_input_filepath
|
||||
# Set remaining parameters.
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
raise exceptions.FileNotFoundError(
|
||||
'cannot find capture input file')
|
||||
self._config['-i'] = capture_input_filepath
|
||||
self._config['-o'] = self._output_signal_filepath
|
||||
if render_input_filepath is not None:
|
||||
if not os.path.exists(render_input_filepath):
|
||||
raise exceptions.FileNotFoundError(
|
||||
'cannot find render input file')
|
||||
self._config['-ri'] = render_input_filepath
|
||||
|
||||
# Build arguments list.
|
||||
args = [self._simulator_bin_path]
|
||||
for param_name in self._config:
|
||||
args.append(param_name)
|
||||
if self._config[param_name] is not None:
|
||||
args.append(str(self._config[param_name]))
|
||||
logging.debug(' '.join(args))
|
||||
# Build arguments list.
|
||||
args = [self._simulator_bin_path]
|
||||
for param_name in self._config:
|
||||
args.append(param_name)
|
||||
if self._config[param_name] is not None:
|
||||
args.append(str(self._config[param_name]))
|
||||
logging.debug(' '.join(args))
|
||||
|
||||
# Run.
|
||||
self._profiler.enable()
|
||||
subprocess.call(args)
|
||||
self._profiler.disable()
|
||||
# Run.
|
||||
self._profiler.enable()
|
||||
subprocess.call(args)
|
||||
self._profiler.disable()
|
||||
|
||||
# Save profiling stats.
|
||||
self._profiler.dump_stats(profiling_stats_filepath)
|
||||
# Save profiling stats.
|
||||
self._profiler.dump_stats(profiling_stats_filepath)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Imports a filtered subset of the scores and configurations computed
|
||||
by apm_quality_assessment.py into a pandas data frame.
|
||||
"""
|
||||
@ -18,71 +17,88 @@ import re
|
||||
import sys
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package pandas')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package pandas')
|
||||
sys.exit(1)
|
||||
|
||||
from . import data_access as data_access
|
||||
from . import simulation as sim
|
||||
|
||||
# Compiled regular expressions used to extract score descriptors.
|
||||
RE_CONFIG_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixApmConfig() + r'(.+)')
|
||||
RE_CAPTURE_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixCapture() + r'(.+)')
|
||||
RE_RENDER_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
|
||||
RE_ECHO_SIM_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + r'(.+)')
|
||||
RE_CONFIG_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixApmConfig() +
|
||||
r'(.+)')
|
||||
RE_CAPTURE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixCapture() +
|
||||
r'(.+)')
|
||||
RE_RENDER_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
|
||||
RE_ECHO_SIM_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixEchoSimulator() +
|
||||
r'(.+)')
|
||||
RE_TEST_DATA_GEN_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + r'(.+)')
|
||||
RE_TEST_DATA_GEN_PARAMS = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + r'(.+)')
|
||||
RE_SCORE_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixScore() + r'(.+)(\..+)')
|
||||
RE_SCORE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixScore() +
|
||||
r'(.+)(\..+)')
|
||||
|
||||
|
||||
def InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
"""Arguments parser factory.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
'Override this description in a user script by changing'
|
||||
' `parser.description` of the returned parser.'))
|
||||
parser = argparse.ArgumentParser(
|
||||
description=('Override this description in a user script by changing'
|
||||
' `parser.description` of the returned parser.'))
|
||||
|
||||
parser.add_argument('-o', '--output_dir', required=True,
|
||||
help=('the same base path used with the '
|
||||
'apm_quality_assessment tool'))
|
||||
parser.add_argument('-o',
|
||||
'--output_dir',
|
||||
required=True,
|
||||
help=('the same base path used with the '
|
||||
'apm_quality_assessment tool'))
|
||||
|
||||
parser.add_argument('-c', '--config_names', type=re.compile,
|
||||
help=('regular expression to filter the APM configuration'
|
||||
' names'))
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the APM configuration'
|
||||
' names'))
|
||||
|
||||
parser.add_argument('-i', '--capture_names', type=re.compile,
|
||||
help=('regular expression to filter the capture signal '
|
||||
'names'))
|
||||
parser.add_argument(
|
||||
'-i',
|
||||
'--capture_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the capture signal '
|
||||
'names'))
|
||||
|
||||
parser.add_argument('-r', '--render_names', type=re.compile,
|
||||
help=('regular expression to filter the render signal '
|
||||
'names'))
|
||||
parser.add_argument('-r',
|
||||
'--render_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the render signal '
|
||||
'names'))
|
||||
|
||||
parser.add_argument('-e', '--echo_simulator_names', type=re.compile,
|
||||
help=('regular expression to filter the echo simulator '
|
||||
'names'))
|
||||
parser.add_argument(
|
||||
'-e',
|
||||
'--echo_simulator_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the echo simulator '
|
||||
'names'))
|
||||
|
||||
parser.add_argument('-t', '--test_data_generators', type=re.compile,
|
||||
help=('regular expression to filter the test data '
|
||||
'generator names'))
|
||||
parser.add_argument('-t',
|
||||
'--test_data_generators',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the test data '
|
||||
'generator names'))
|
||||
|
||||
parser.add_argument('-s', '--eval_scores', type=re.compile,
|
||||
help=('regular expression to filter the evaluation score '
|
||||
'names'))
|
||||
parser.add_argument(
|
||||
'-s',
|
||||
'--eval_scores',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the evaluation score '
|
||||
'names'))
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
def _GetScoreDescriptors(score_filepath):
|
||||
"""Extracts a score descriptor from the given score file path.
|
||||
"""Extracts a score descriptor from the given score file path.
|
||||
|
||||
Args:
|
||||
score_filepath: path to the score file.
|
||||
@ -92,23 +108,23 @@ def _GetScoreDescriptors(score_filepath):
|
||||
render audio track name, echo simulator name, test data generator name,
|
||||
test data generator parameters as string, evaluation score name).
|
||||
"""
|
||||
fields = score_filepath.split(os.sep)[-7:]
|
||||
extract_name = lambda index, reg_expr: (
|
||||
reg_expr.match(fields[index]).groups(0)[0])
|
||||
return (
|
||||
extract_name(0, RE_CONFIG_NAME),
|
||||
extract_name(1, RE_CAPTURE_NAME),
|
||||
extract_name(2, RE_RENDER_NAME),
|
||||
extract_name(3, RE_ECHO_SIM_NAME),
|
||||
extract_name(4, RE_TEST_DATA_GEN_NAME),
|
||||
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
|
||||
extract_name(6, RE_SCORE_NAME),
|
||||
)
|
||||
fields = score_filepath.split(os.sep)[-7:]
|
||||
extract_name = lambda index, reg_expr: (reg_expr.match(fields[index]).
|
||||
groups(0)[0])
|
||||
return (
|
||||
extract_name(0, RE_CONFIG_NAME),
|
||||
extract_name(1, RE_CAPTURE_NAME),
|
||||
extract_name(2, RE_RENDER_NAME),
|
||||
extract_name(3, RE_ECHO_SIM_NAME),
|
||||
extract_name(4, RE_TEST_DATA_GEN_NAME),
|
||||
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
|
||||
extract_name(6, RE_SCORE_NAME),
|
||||
)
|
||||
|
||||
|
||||
def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
|
||||
test_data_gen_name, score_name, args):
|
||||
"""Decides whether excluding a score.
|
||||
"""Decides whether excluding a score.
|
||||
|
||||
A set of optional regular expressions in args is used to determine if the
|
||||
score should be excluded (depending on its |*_name| descriptors).
|
||||
@ -125,27 +141,27 @@ def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
|
||||
Returns:
|
||||
A boolean.
|
||||
"""
|
||||
value_regexpr_pairs = [
|
||||
(config_name, args.config_names),
|
||||
(capture_name, args.capture_names),
|
||||
(render_name, args.render_names),
|
||||
(echo_simulator_name, args.echo_simulator_names),
|
||||
(test_data_gen_name, args.test_data_generators),
|
||||
(score_name, args.eval_scores),
|
||||
]
|
||||
value_regexpr_pairs = [
|
||||
(config_name, args.config_names),
|
||||
(capture_name, args.capture_names),
|
||||
(render_name, args.render_names),
|
||||
(echo_simulator_name, args.echo_simulator_names),
|
||||
(test_data_gen_name, args.test_data_generators),
|
||||
(score_name, args.eval_scores),
|
||||
]
|
||||
|
||||
# Score accepted if each value matches the corresponding regular expression.
|
||||
for value, regexpr in value_regexpr_pairs:
|
||||
if regexpr is None:
|
||||
continue
|
||||
if not regexpr.match(value):
|
||||
return True
|
||||
# Score accepted if each value matches the corresponding regular expression.
|
||||
for value, regexpr in value_regexpr_pairs:
|
||||
if regexpr is None:
|
||||
continue
|
||||
if not regexpr.match(value):
|
||||
return True
|
||||
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def FindScores(src_path, args):
|
||||
"""Given a search path, find scores and return a DataFrame object.
|
||||
"""Given a search path, find scores and return a DataFrame object.
|
||||
|
||||
Args:
|
||||
src_path: Search path pattern.
|
||||
@ -154,89 +170,74 @@ def FindScores(src_path, args):
|
||||
Returns:
|
||||
A DataFrame object.
|
||||
"""
|
||||
# Get scores.
|
||||
scores = []
|
||||
for score_filepath in glob.iglob(src_path):
|
||||
# Extract score descriptor fields from the path.
|
||||
(config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
test_data_gen_params,
|
||||
score_name) = _GetScoreDescriptors(score_filepath)
|
||||
# Get scores.
|
||||
scores = []
|
||||
for score_filepath in glob.iglob(src_path):
|
||||
# Extract score descriptor fields from the path.
|
||||
(config_name, capture_name, render_name, echo_simulator_name,
|
||||
test_data_gen_name, test_data_gen_params,
|
||||
score_name) = _GetScoreDescriptors(score_filepath)
|
||||
|
||||
# Ignore the score if required.
|
||||
if _ExcludeScore(
|
||||
config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
score_name,
|
||||
args):
|
||||
logging.info(
|
||||
'ignored score: %s %s %s %s %s %s',
|
||||
config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
score_name)
|
||||
continue
|
||||
# Ignore the score if required.
|
||||
if _ExcludeScore(config_name, capture_name, render_name,
|
||||
echo_simulator_name, test_data_gen_name, score_name,
|
||||
args):
|
||||
logging.info('ignored score: %s %s %s %s %s %s', config_name,
|
||||
capture_name, render_name, echo_simulator_name,
|
||||
test_data_gen_name, score_name)
|
||||
continue
|
||||
|
||||
# Read metadata and score.
|
||||
metadata = data_access.Metadata.LoadAudioTestDataPaths(
|
||||
os.path.split(score_filepath)[0])
|
||||
score = data_access.ScoreFile.Load(score_filepath)
|
||||
# Read metadata and score.
|
||||
metadata = data_access.Metadata.LoadAudioTestDataPaths(
|
||||
os.path.split(score_filepath)[0])
|
||||
score = data_access.ScoreFile.Load(score_filepath)
|
||||
|
||||
# Add a score with its descriptor fields.
|
||||
scores.append((
|
||||
metadata['clean_capture_input_filepath'],
|
||||
metadata['echo_free_capture_filepath'],
|
||||
metadata['echo_filepath'],
|
||||
metadata['render_filepath'],
|
||||
metadata['capture_filepath'],
|
||||
metadata['apm_output_filepath'],
|
||||
metadata['apm_reference_filepath'],
|
||||
config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
test_data_gen_params,
|
||||
score_name,
|
||||
score,
|
||||
))
|
||||
# Add a score with its descriptor fields.
|
||||
scores.append((
|
||||
metadata['clean_capture_input_filepath'],
|
||||
metadata['echo_free_capture_filepath'],
|
||||
metadata['echo_filepath'],
|
||||
metadata['render_filepath'],
|
||||
metadata['capture_filepath'],
|
||||
metadata['apm_output_filepath'],
|
||||
metadata['apm_reference_filepath'],
|
||||
config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
test_data_gen_params,
|
||||
score_name,
|
||||
score,
|
||||
))
|
||||
|
||||
return pd.DataFrame(
|
||||
data=scores,
|
||||
columns=(
|
||||
'clean_capture_input_filepath',
|
||||
'echo_free_capture_filepath',
|
||||
'echo_filepath',
|
||||
'render_filepath',
|
||||
'capture_filepath',
|
||||
'apm_output_filepath',
|
||||
'apm_reference_filepath',
|
||||
'apm_config',
|
||||
'capture',
|
||||
'render',
|
||||
'echo_simulator',
|
||||
'test_data_gen',
|
||||
'test_data_gen_params',
|
||||
'eval_score_name',
|
||||
'score',
|
||||
))
|
||||
return pd.DataFrame(data=scores,
|
||||
columns=(
|
||||
'clean_capture_input_filepath',
|
||||
'echo_free_capture_filepath',
|
||||
'echo_filepath',
|
||||
'render_filepath',
|
||||
'capture_filepath',
|
||||
'apm_output_filepath',
|
||||
'apm_reference_filepath',
|
||||
'apm_config',
|
||||
'capture',
|
||||
'render',
|
||||
'echo_simulator',
|
||||
'test_data_gen',
|
||||
'test_data_gen_params',
|
||||
'eval_score_name',
|
||||
'score',
|
||||
))
|
||||
|
||||
|
||||
def ConstructSrcPath(args):
|
||||
return os.path.join(
|
||||
args.output_dir,
|
||||
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixRender() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixScore() + '*')
|
||||
return os.path.join(
|
||||
args.output_dir,
|
||||
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixRender() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixScore() + '*')
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Data access utility functions and classes.
|
||||
"""
|
||||
|
||||
@ -14,29 +13,29 @@ import os
|
||||
|
||||
|
||||
def MakeDirectory(path):
|
||||
"""Makes a directory recursively without rising exceptions if existing.
|
||||
"""Makes a directory recursively without rising exceptions if existing.
|
||||
|
||||
Args:
|
||||
path: path to the directory to be created.
|
||||
"""
|
||||
if os.path.exists(path):
|
||||
return
|
||||
os.makedirs(path)
|
||||
if os.path.exists(path):
|
||||
return
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
class Metadata(object):
|
||||
"""Data access class to save and load metadata.
|
||||
"""Data access class to save and load metadata.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
_GENERIC_METADATA_SUFFIX = '.mdata'
|
||||
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
|
||||
_GENERIC_METADATA_SUFFIX = '.mdata'
|
||||
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
|
||||
|
||||
@classmethod
|
||||
def LoadFileMetadata(cls, filepath):
|
||||
"""Loads generic metadata linked to a file.
|
||||
@classmethod
|
||||
def LoadFileMetadata(cls, filepath):
|
||||
"""Loads generic metadata linked to a file.
|
||||
|
||||
Args:
|
||||
filepath: path to the metadata file to read.
|
||||
@ -44,23 +43,23 @@ class Metadata(object):
|
||||
Returns:
|
||||
A dict.
|
||||
"""
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
|
||||
return json.load(f)
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
|
||||
return json.load(f)
|
||||
|
||||
@classmethod
|
||||
def SaveFileMetadata(cls, filepath, metadata):
|
||||
"""Saves generic metadata linked to a file.
|
||||
@classmethod
|
||||
def SaveFileMetadata(cls, filepath, metadata):
|
||||
"""Saves generic metadata linked to a file.
|
||||
|
||||
Args:
|
||||
filepath: path to the metadata file to write.
|
||||
metadata: a dict.
|
||||
"""
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
|
||||
json.dump(metadata, f)
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
@classmethod
|
||||
def LoadAudioTestDataPaths(cls, metadata_path):
|
||||
"""Loads the input and the reference audio track paths.
|
||||
@classmethod
|
||||
def LoadAudioTestDataPaths(cls, metadata_path):
|
||||
"""Loads the input and the reference audio track paths.
|
||||
|
||||
Args:
|
||||
metadata_path: path to the directory containing the metadata file.
|
||||
@ -68,14 +67,14 @@ class Metadata(object):
|
||||
Returns:
|
||||
Tuple with the paths to the input and output audio tracks.
|
||||
"""
|
||||
metadata_filepath = os.path.join(
|
||||
metadata_path, cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(metadata_filepath) as f:
|
||||
return json.load(f)
|
||||
metadata_filepath = os.path.join(metadata_path,
|
||||
cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(metadata_filepath) as f:
|
||||
return json.load(f)
|
||||
|
||||
@classmethod
|
||||
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
|
||||
"""Saves the input and the reference audio track paths.
|
||||
@classmethod
|
||||
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
|
||||
"""Saves the input and the reference audio track paths.
|
||||
|
||||
Args:
|
||||
output_path: path to the directory containing the metadata file.
|
||||
@ -83,23 +82,24 @@ class Metadata(object):
|
||||
Keyword Args:
|
||||
filepaths: collection of audio track file paths to save.
|
||||
"""
|
||||
output_filepath = os.path.join(output_path, cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(output_filepath, 'w') as f:
|
||||
json.dump(filepaths, f)
|
||||
output_filepath = os.path.join(output_path,
|
||||
cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(output_filepath, 'w') as f:
|
||||
json.dump(filepaths, f)
|
||||
|
||||
|
||||
class AudioProcConfigFile(object):
|
||||
"""Data access to load/save APM simulator argument lists.
|
||||
"""Data access to load/save APM simulator argument lists.
|
||||
|
||||
The arguments stored in the config files are used to control the APM flags.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a configuration file for an APM simulator.
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a configuration file for an APM simulator.
|
||||
|
||||
Args:
|
||||
filepath: path to the configuration file.
|
||||
@ -107,31 +107,31 @@ class AudioProcConfigFile(object):
|
||||
Returns:
|
||||
A dict containing the configuration.
|
||||
"""
|
||||
with open(filepath) as f:
|
||||
return json.load(f)
|
||||
with open(filepath) as f:
|
||||
return json.load(f)
|
||||
|
||||
@classmethod
|
||||
def Save(cls, filepath, config):
|
||||
"""Saves a configuration file for an APM simulator.
|
||||
@classmethod
|
||||
def Save(cls, filepath, config):
|
||||
"""Saves a configuration file for an APM simulator.
|
||||
|
||||
Args:
|
||||
filepath: path to the configuration file.
|
||||
config: a dict containing the configuration.
|
||||
"""
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(config, f)
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(config, f)
|
||||
|
||||
|
||||
class ScoreFile(object):
|
||||
"""Data access class to save and load float scalar scores.
|
||||
"""Data access class to save and load float scalar scores.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a score from file.
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a score from file.
|
||||
|
||||
Args:
|
||||
filepath: path to the score file.
|
||||
@ -139,16 +139,16 @@ class ScoreFile(object):
|
||||
Returns:
|
||||
A float encoding the score.
|
||||
"""
|
||||
with open(filepath) as f:
|
||||
return float(f.readline().strip())
|
||||
with open(filepath) as f:
|
||||
return float(f.readline().strip())
|
||||
|
||||
@classmethod
|
||||
def Save(cls, filepath, score):
|
||||
"""Saves a score into a file.
|
||||
@classmethod
|
||||
def Save(cls, filepath, score):
|
||||
"""Saves a score into a file.
|
||||
|
||||
Args:
|
||||
filepath: path to the score file.
|
||||
score: float encoding the score.
|
||||
"""
|
||||
with open(filepath, 'w') as f:
|
||||
f.write('{0:f}\n'.format(score))
|
||||
with open(filepath, 'w') as f:
|
||||
f.write('{0:f}\n'.format(score))
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Echo path simulation module.
|
||||
"""
|
||||
|
||||
@ -16,21 +15,21 @@ from . import signal_processing
|
||||
|
||||
|
||||
class EchoPathSimulator(object):
|
||||
"""Abstract class for the echo path simulators.
|
||||
"""Abstract class for the echo path simulators.
|
||||
|
||||
In general, an echo path simulator is a function of the render signal and
|
||||
simulates the propagation of the latter into the microphone (e.g., due to
|
||||
mechanical or electrical paths).
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def Simulate(self, output_path):
|
||||
"""Creates the echo signal and stores it in an audio file (abstract method).
|
||||
def Simulate(self, output_path):
|
||||
"""Creates the echo signal and stores it in an audio file (abstract method).
|
||||
|
||||
Args:
|
||||
output_path: Path in which any output can be saved.
|
||||
@ -38,11 +37,11 @@ class EchoPathSimulator(object):
|
||||
Returns:
|
||||
Path to the generated audio track file or None if no echo is present.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EchoPathSimulator implementation.
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EchoPathSimulator implementation.
|
||||
|
||||
Decorator to automatically register the classes that extend
|
||||
EchoPathSimulator.
|
||||
@ -52,85 +51,86 @@ class EchoPathSimulator(object):
|
||||
class NoEchoPathSimulator(EchoPathSimulator):
|
||||
pass
|
||||
"""
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class NoEchoPathSimulator(EchoPathSimulator):
|
||||
"""Simulates absence of echo."""
|
||||
"""Simulates absence of echo."""
|
||||
|
||||
NAME = 'noecho'
|
||||
NAME = 'noecho'
|
||||
|
||||
def __init__(self):
|
||||
EchoPathSimulator.__init__(self)
|
||||
def __init__(self):
|
||||
EchoPathSimulator.__init__(self)
|
||||
|
||||
def Simulate(self, output_path):
|
||||
return None
|
||||
def Simulate(self, output_path):
|
||||
return None
|
||||
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class LinearEchoPathSimulator(EchoPathSimulator):
|
||||
"""Simulates linear echo path.
|
||||
"""Simulates linear echo path.
|
||||
|
||||
This class applies a given impulse response to the render input and then it
|
||||
sums the signal to the capture input signal.
|
||||
"""
|
||||
|
||||
NAME = 'linear'
|
||||
NAME = 'linear'
|
||||
|
||||
def __init__(self, render_input_filepath, impulse_response):
|
||||
"""
|
||||
def __init__(self, render_input_filepath, impulse_response):
|
||||
"""
|
||||
Args:
|
||||
render_input_filepath: Render audio track file.
|
||||
impulse_response: list or numpy vector of float values.
|
||||
"""
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
self._impulse_response = impulse_response
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
self._impulse_response = impulse_response
|
||||
|
||||
def Simulate(self, output_path):
|
||||
"""Simulates linear echo path."""
|
||||
# Form the file name with a hash of the impulse response.
|
||||
impulse_response_hash = hashlib.sha256(
|
||||
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
|
||||
echo_filepath = os.path.join(output_path, 'linear_echo_{}.wav'.format(
|
||||
impulse_response_hash))
|
||||
def Simulate(self, output_path):
|
||||
"""Simulates linear echo path."""
|
||||
# Form the file name with a hash of the impulse response.
|
||||
impulse_response_hash = hashlib.sha256(
|
||||
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
|
||||
echo_filepath = os.path.join(
|
||||
output_path, 'linear_echo_{}.wav'.format(impulse_response_hash))
|
||||
|
||||
# If the simulated echo audio track file does not exists, create it.
|
||||
if not os.path.exists(echo_filepath):
|
||||
render = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._render_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
render, self._impulse_response)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(echo_filepath, echo)
|
||||
# If the simulated echo audio track file does not exists, create it.
|
||||
if not os.path.exists(echo_filepath):
|
||||
render = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._render_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
render, self._impulse_response)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
echo_filepath, echo)
|
||||
|
||||
return echo_filepath
|
||||
return echo_filepath
|
||||
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class RecordedEchoPathSimulator(EchoPathSimulator):
|
||||
"""Uses recorded echo.
|
||||
"""Uses recorded echo.
|
||||
|
||||
This class uses the clean capture input file name to build the file name of
|
||||
the corresponding recording containing echo (a predefined suffix is used).
|
||||
Such a file is expected to be already existing.
|
||||
"""
|
||||
|
||||
NAME = 'recorded'
|
||||
NAME = 'recorded'
|
||||
|
||||
_FILE_NAME_SUFFIX = '_echo'
|
||||
_FILE_NAME_SUFFIX = '_echo'
|
||||
|
||||
def __init__(self, render_input_filepath):
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
def __init__(self, render_input_filepath):
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
|
||||
def Simulate(self, output_path):
|
||||
"""Uses recorded echo path."""
|
||||
path, file_name_ext = os.path.split(self._render_input_filepath)
|
||||
file_name, file_ext = os.path.splitext(file_name_ext)
|
||||
echo_filepath = os.path.join(path, '{}{}{}'.format(
|
||||
file_name, self._FILE_NAME_SUFFIX, file_ext))
|
||||
assert os.path.exists(echo_filepath), (
|
||||
'cannot find the echo audio track file {}'.format(echo_filepath))
|
||||
return echo_filepath
|
||||
def Simulate(self, output_path):
|
||||
"""Uses recorded echo path."""
|
||||
path, file_name_ext = os.path.split(self._render_input_filepath)
|
||||
file_name, file_ext = os.path.splitext(file_name_ext)
|
||||
echo_filepath = os.path.join(
|
||||
path, '{}{}{}'.format(file_name, self._FILE_NAME_SUFFIX, file_ext))
|
||||
assert os.path.exists(echo_filepath), (
|
||||
'cannot find the echo audio track file {}'.format(echo_filepath))
|
||||
return echo_filepath
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Echo path simulation factory module.
|
||||
"""
|
||||
|
||||
@ -16,16 +15,16 @@ from . import echo_path_simulation
|
||||
|
||||
class EchoPathSimulatorFactory(object):
|
||||
|
||||
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
|
||||
# realistic impulse response.
|
||||
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0]*(20 * 48) + [0.15])
|
||||
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
|
||||
# realistic impulse response.
|
||||
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0] * (20 * 48) + [0.15])
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
|
||||
"""Creates an EchoPathSimulator instance given a class object.
|
||||
@classmethod
|
||||
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
|
||||
"""Creates an EchoPathSimulator instance given a class object.
|
||||
|
||||
Args:
|
||||
echo_path_simulator_class: EchoPathSimulator class object (not an
|
||||
@ -35,14 +34,15 @@ class EchoPathSimulatorFactory(object):
|
||||
Returns:
|
||||
An EchoPathSimulator instance.
|
||||
"""
|
||||
assert render_input_filepath is not None or (
|
||||
echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator)
|
||||
assert render_input_filepath is not None or (
|
||||
echo_path_simulator_class ==
|
||||
echo_path_simulation.NoEchoPathSimulator)
|
||||
|
||||
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
|
||||
return echo_path_simulation.NoEchoPathSimulator()
|
||||
elif echo_path_simulator_class == (
|
||||
echo_path_simulation.LinearEchoPathSimulator):
|
||||
return echo_path_simulation.LinearEchoPathSimulator(
|
||||
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
|
||||
else:
|
||||
return echo_path_simulator_class(render_input_filepath)
|
||||
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
|
||||
return echo_path_simulation.NoEchoPathSimulator()
|
||||
elif echo_path_simulator_class == (
|
||||
echo_path_simulation.LinearEchoPathSimulator):
|
||||
return echo_path_simulation.LinearEchoPathSimulator(
|
||||
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
|
||||
else:
|
||||
return echo_path_simulator_class(render_input_filepath)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the echo path simulation module.
|
||||
"""
|
||||
|
||||
@ -22,60 +21,62 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestEchoPathSimulators(unittest.TestCase):
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
# Create and save white noise.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._audio_track_num_samples = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
|
||||
self._audio_track_filepath = os.path.join(self._tmp_path, 'white_noise.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._audio_track_filepath, white_noise)
|
||||
# Create and save white noise.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._audio_track_num_samples = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
|
||||
self._audio_track_filepath = os.path.join(self._tmp_path,
|
||||
'white_noise.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._audio_track_filepath, white_noise)
|
||||
|
||||
# Make a copy the white noise audio track file; it will be used by
|
||||
# echo_path_simulation.RecordedEchoPathSimulator.
|
||||
shutil.copy(self._audio_track_filepath, os.path.join(
|
||||
self._tmp_path, 'white_noise_echo.wav'))
|
||||
# Make a copy the white noise audio track file; it will be used by
|
||||
# echo_path_simulation.RecordedEchoPathSimulator.
|
||||
shutil.copy(self._audio_track_filepath,
|
||||
os.path.join(self._tmp_path, 'white_noise_echo.wav'))
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
|
||||
def testRegisteredClasses(self):
|
||||
# Check that there is at least one registered echo path simulator.
|
||||
registered_classes = (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
def testRegisteredClasses(self):
|
||||
# Check that there is at least one registered echo path simulator.
|
||||
registered_classes = (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance factory.
|
||||
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
|
||||
# Instance factory.
|
||||
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
|
||||
|
||||
# Try each registered echo path simulator.
|
||||
for echo_path_simulator_name in registered_classes:
|
||||
simulator = factory.GetInstance(
|
||||
echo_path_simulator_class=registered_classes[
|
||||
echo_path_simulator_name],
|
||||
render_input_filepath=self._audio_track_filepath)
|
||||
# Try each registered echo path simulator.
|
||||
for echo_path_simulator_name in registered_classes:
|
||||
simulator = factory.GetInstance(
|
||||
echo_path_simulator_class=registered_classes[
|
||||
echo_path_simulator_name],
|
||||
render_input_filepath=self._audio_track_filepath)
|
||||
|
||||
echo_filepath = simulator.Simulate(self._tmp_path)
|
||||
if echo_filepath is None:
|
||||
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
|
||||
echo_path_simulator_name)
|
||||
# No other tests in this case.
|
||||
continue
|
||||
echo_filepath = simulator.Simulate(self._tmp_path)
|
||||
if echo_filepath is None:
|
||||
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
|
||||
echo_path_simulator_name)
|
||||
# No other tests in this case.
|
||||
continue
|
||||
|
||||
# Check that the echo audio track file exists and its length is greater or
|
||||
# equal to that of the render audio track.
|
||||
self.assertTrue(os.path.exists(echo_filepath))
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(echo_filepath)
|
||||
self.assertGreaterEqual(
|
||||
signal_processing.SignalProcessingUtils.CountSamples(echo),
|
||||
self._audio_track_num_samples)
|
||||
# Check that the echo audio track file exists and its length is greater or
|
||||
# equal to that of the render audio track.
|
||||
self.assertTrue(os.path.exists(echo_filepath))
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
echo_filepath)
|
||||
self.assertGreaterEqual(
|
||||
signal_processing.SignalProcessingUtils.CountSamples(echo),
|
||||
self._audio_track_num_samples)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Evaluation score abstract class and implementations.
|
||||
"""
|
||||
|
||||
@ -17,10 +16,10 @@ import subprocess
|
||||
import sys
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import data_access
|
||||
from . import exceptions
|
||||
@ -29,23 +28,23 @@ from . import signal_processing
|
||||
|
||||
class EvaluationScore(object):
|
||||
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
self._score_filename_prefix = score_filename_prefix
|
||||
self._input_signal_metadata = None
|
||||
self._reference_signal = None
|
||||
self._reference_signal_filepath = None
|
||||
self._tested_signal = None
|
||||
self._tested_signal_filepath = None
|
||||
self._output_filepath = None
|
||||
self._score = None
|
||||
self._render_signal_filepath = None
|
||||
def __init__(self, score_filename_prefix):
|
||||
self._score_filename_prefix = score_filename_prefix
|
||||
self._input_signal_metadata = None
|
||||
self._reference_signal = None
|
||||
self._reference_signal_filepath = None
|
||||
self._tested_signal = None
|
||||
self._tested_signal_filepath = None
|
||||
self._output_filepath = None
|
||||
self._score = None
|
||||
self._render_signal_filepath = None
|
||||
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EvaluationScore implementation.
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EvaluationScore implementation.
|
||||
|
||||
Decorator to automatically register the classes that extend EvaluationScore.
|
||||
Example usage:
|
||||
@ -54,91 +53,90 @@ class EvaluationScore(object):
|
||||
class AudioLevelScore(EvaluationScore):
|
||||
pass
|
||||
"""
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_filepath
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_filepath
|
||||
|
||||
@property
|
||||
def score(self):
|
||||
return self._score
|
||||
@property
|
||||
def score(self):
|
||||
return self._score
|
||||
|
||||
def SetInputSignalMetadata(self, metadata):
|
||||
"""Sets input signal metadata.
|
||||
def SetInputSignalMetadata(self, metadata):
|
||||
"""Sets input signal metadata.
|
||||
|
||||
Args:
|
||||
metadata: dict instance.
|
||||
"""
|
||||
self._input_signal_metadata = metadata
|
||||
self._input_signal_metadata = metadata
|
||||
|
||||
def SetReferenceSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as reference signal.
|
||||
def SetReferenceSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as reference signal.
|
||||
|
||||
Args:
|
||||
filepath: path to the reference audio track.
|
||||
"""
|
||||
self._reference_signal_filepath = filepath
|
||||
self._reference_signal_filepath = filepath
|
||||
|
||||
def SetTestedSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as test signal.
|
||||
def SetTestedSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as test signal.
|
||||
|
||||
Args:
|
||||
filepath: path to the test audio track.
|
||||
"""
|
||||
self._tested_signal_filepath = filepath
|
||||
self._tested_signal_filepath = filepath
|
||||
|
||||
def SetRenderSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as render signal.
|
||||
def SetRenderSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as render signal.
|
||||
|
||||
Args:
|
||||
filepath: path to the test audio track.
|
||||
"""
|
||||
self._render_signal_filepath = filepath
|
||||
self._render_signal_filepath = filepath
|
||||
|
||||
def Run(self, output_path):
|
||||
"""Extracts the score for the set test data pair.
|
||||
def Run(self, output_path):
|
||||
"""Extracts the score for the set test data pair.
|
||||
|
||||
Args:
|
||||
output_path: path to the directory where the output is written.
|
||||
"""
|
||||
self._output_filepath = os.path.join(
|
||||
output_path, self._score_filename_prefix + self.NAME + '.txt')
|
||||
try:
|
||||
# If the score has already been computed, load.
|
||||
self._LoadScore()
|
||||
logging.debug('score found and loaded')
|
||||
except IOError:
|
||||
# Compute the score.
|
||||
logging.debug('score not found, compute')
|
||||
self._Run(output_path)
|
||||
self._output_filepath = os.path.join(
|
||||
output_path, self._score_filename_prefix + self.NAME + '.txt')
|
||||
try:
|
||||
# If the score has already been computed, load.
|
||||
self._LoadScore()
|
||||
logging.debug('score found and loaded')
|
||||
except IOError:
|
||||
# Compute the score.
|
||||
logging.debug('score not found, compute')
|
||||
self._Run(output_path)
|
||||
|
||||
def _Run(self, output_path):
|
||||
# Abstract method.
|
||||
raise NotImplementedError()
|
||||
def _Run(self, output_path):
|
||||
# Abstract method.
|
||||
raise NotImplementedError()
|
||||
|
||||
def _LoadReferenceSignal(self):
|
||||
assert self._reference_signal_filepath is not None
|
||||
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._reference_signal_filepath)
|
||||
def _LoadReferenceSignal(self):
|
||||
assert self._reference_signal_filepath is not None
|
||||
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._reference_signal_filepath)
|
||||
|
||||
def _LoadTestedSignal(self):
|
||||
assert self._tested_signal_filepath is not None
|
||||
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._tested_signal_filepath)
|
||||
def _LoadTestedSignal(self):
|
||||
assert self._tested_signal_filepath is not None
|
||||
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._tested_signal_filepath)
|
||||
|
||||
def _LoadScore(self):
|
||||
return data_access.ScoreFile.Load(self._output_filepath)
|
||||
|
||||
def _LoadScore(self):
|
||||
return data_access.ScoreFile.Load(self._output_filepath)
|
||||
|
||||
def _SaveScore(self):
|
||||
return data_access.ScoreFile.Save(self._output_filepath, self._score)
|
||||
def _SaveScore(self):
|
||||
return data_access.ScoreFile.Save(self._output_filepath, self._score)
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class AudioLevelPeakScore(EvaluationScore):
|
||||
"""Peak audio level score.
|
||||
"""Peak audio level score.
|
||||
|
||||
Defined as the difference between the peak audio level of the tested and
|
||||
the reference signals.
|
||||
@ -148,21 +146,21 @@ class AudioLevelPeakScore(EvaluationScore):
|
||||
Worst case: +/-inf dB
|
||||
"""
|
||||
|
||||
NAME = 'audio_level_peak'
|
||||
NAME = 'audio_level_peak'
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
|
||||
self._SaveScore()
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
|
||||
self._SaveScore()
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class MeanAudioLevelScore(EvaluationScore):
|
||||
"""Mean audio level score.
|
||||
"""Mean audio level score.
|
||||
|
||||
Defined as the difference between the mean audio level of the tested and
|
||||
the reference signals.
|
||||
@ -172,29 +170,30 @@ class MeanAudioLevelScore(EvaluationScore):
|
||||
Worst case: +/-inf dB
|
||||
"""
|
||||
|
||||
NAME = 'audio_level_mean'
|
||||
NAME = 'audio_level_mean'
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
|
||||
dbfs_diffs_sum = 0.0
|
||||
seconds = min(len(self._tested_signal), len(self._reference_signal)) // 1000
|
||||
for t in range(seconds):
|
||||
t0 = t * seconds
|
||||
t1 = t0 + seconds
|
||||
dbfs_diffs_sum += (
|
||||
self._tested_signal[t0:t1].dBFS - self._reference_signal[t0:t1].dBFS)
|
||||
self._score = dbfs_diffs_sum / float(seconds)
|
||||
self._SaveScore()
|
||||
dbfs_diffs_sum = 0.0
|
||||
seconds = min(len(self._tested_signal), len(
|
||||
self._reference_signal)) // 1000
|
||||
for t in range(seconds):
|
||||
t0 = t * seconds
|
||||
t1 = t0 + seconds
|
||||
dbfs_diffs_sum += (self._tested_signal[t0:t1].dBFS -
|
||||
self._reference_signal[t0:t1].dBFS)
|
||||
self._score = dbfs_diffs_sum / float(seconds)
|
||||
self._SaveScore()
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class EchoMetric(EvaluationScore):
|
||||
"""Echo score.
|
||||
"""Echo score.
|
||||
|
||||
Proportion of detected echo.
|
||||
|
||||
@ -203,46 +202,47 @@ class EchoMetric(EvaluationScore):
|
||||
Worst case: 1
|
||||
"""
|
||||
|
||||
NAME = 'echo_metric'
|
||||
NAME = 'echo_metric'
|
||||
|
||||
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
# POLQA binary file path.
|
||||
self._echo_detector_bin_filepath = echo_detector_bin_filepath
|
||||
if not os.path.exists(self._echo_detector_bin_filepath):
|
||||
logging.error('cannot find EchoMetric tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
# POLQA binary file path.
|
||||
self._echo_detector_bin_filepath = echo_detector_bin_filepath
|
||||
if not os.path.exists(self._echo_detector_bin_filepath):
|
||||
logging.error('cannot find EchoMetric tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
self._echo_detector_bin_path, _ = os.path.split(
|
||||
self._echo_detector_bin_filepath)
|
||||
self._echo_detector_bin_path, _ = os.path.split(
|
||||
self._echo_detector_bin_filepath)
|
||||
|
||||
def _Run(self, output_path):
|
||||
echo_detector_out_filepath = os.path.join(output_path, 'echo_detector.out')
|
||||
if os.path.exists(echo_detector_out_filepath):
|
||||
os.unlink(echo_detector_out_filepath)
|
||||
def _Run(self, output_path):
|
||||
echo_detector_out_filepath = os.path.join(output_path,
|
||||
'echo_detector.out')
|
||||
if os.path.exists(echo_detector_out_filepath):
|
||||
os.unlink(echo_detector_out_filepath)
|
||||
|
||||
logging.debug("Render signal filepath: %s", self._render_signal_filepath)
|
||||
if not os.path.exists(self._render_signal_filepath):
|
||||
logging.error("Render input required for evaluating the echo metric.")
|
||||
logging.debug("Render signal filepath: %s",
|
||||
self._render_signal_filepath)
|
||||
if not os.path.exists(self._render_signal_filepath):
|
||||
logging.error(
|
||||
"Render input required for evaluating the echo metric.")
|
||||
|
||||
args = [
|
||||
self._echo_detector_bin_filepath,
|
||||
'--output_file', echo_detector_out_filepath,
|
||||
'--',
|
||||
'-i', self._tested_signal_filepath,
|
||||
'-ri', self._render_signal_filepath
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._echo_detector_bin_path)
|
||||
args = [
|
||||
self._echo_detector_bin_filepath, '--output_file',
|
||||
echo_detector_out_filepath, '--', '-i',
|
||||
self._tested_signal_filepath, '-ri', self._render_signal_filepath
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._echo_detector_bin_path)
|
||||
|
||||
# Parse Echo detector tool output and extract the score.
|
||||
self._score = self._ParseOutputFile(echo_detector_out_filepath)
|
||||
self._SaveScore()
|
||||
# Parse Echo detector tool output and extract the score.
|
||||
self._score = self._ParseOutputFile(echo_detector_out_filepath)
|
||||
self._SaveScore()
|
||||
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, echo_metric_file_path):
|
||||
"""
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, echo_metric_file_path):
|
||||
"""
|
||||
Parses the POLQA tool output formatted as a table ('-t' option).
|
||||
|
||||
Args:
|
||||
@ -251,12 +251,13 @@ class EchoMetric(EvaluationScore):
|
||||
Returns:
|
||||
The score as a number in [0, 1].
|
||||
"""
|
||||
with open(echo_metric_file_path) as f:
|
||||
return float(f.read())
|
||||
with open(echo_metric_file_path) as f:
|
||||
return float(f.read())
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class PolqaScore(EvaluationScore):
|
||||
"""POLQA score.
|
||||
"""POLQA score.
|
||||
|
||||
See http://www.polqa.info/.
|
||||
|
||||
@ -265,44 +266,51 @@ class PolqaScore(EvaluationScore):
|
||||
Worst case: 1.0
|
||||
"""
|
||||
|
||||
NAME = 'polqa'
|
||||
NAME = 'polqa'
|
||||
|
||||
def __init__(self, score_filename_prefix, polqa_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
def __init__(self, score_filename_prefix, polqa_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
# POLQA binary file path.
|
||||
self._polqa_bin_filepath = polqa_bin_filepath
|
||||
if not os.path.exists(self._polqa_bin_filepath):
|
||||
logging.error('cannot find POLQA tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
# POLQA binary file path.
|
||||
self._polqa_bin_filepath = polqa_bin_filepath
|
||||
if not os.path.exists(self._polqa_bin_filepath):
|
||||
logging.error('cannot find POLQA tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
# Path to the POLQA directory with binary and license files.
|
||||
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
|
||||
# Path to the POLQA directory with binary and license files.
|
||||
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
|
||||
|
||||
def _Run(self, output_path):
|
||||
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
|
||||
if os.path.exists(polqa_out_filepath):
|
||||
os.unlink(polqa_out_filepath)
|
||||
def _Run(self, output_path):
|
||||
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
|
||||
if os.path.exists(polqa_out_filepath):
|
||||
os.unlink(polqa_out_filepath)
|
||||
|
||||
args = [
|
||||
self._polqa_bin_filepath, '-t', '-q', '-Overwrite',
|
||||
'-Ref', self._reference_signal_filepath,
|
||||
'-Test', self._tested_signal_filepath,
|
||||
'-LC', 'NB',
|
||||
'-Out', polqa_out_filepath,
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._polqa_tool_path)
|
||||
args = [
|
||||
self._polqa_bin_filepath,
|
||||
'-t',
|
||||
'-q',
|
||||
'-Overwrite',
|
||||
'-Ref',
|
||||
self._reference_signal_filepath,
|
||||
'-Test',
|
||||
self._tested_signal_filepath,
|
||||
'-LC',
|
||||
'NB',
|
||||
'-Out',
|
||||
polqa_out_filepath,
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._polqa_tool_path)
|
||||
|
||||
# Parse POLQA tool output and extract the score.
|
||||
polqa_output = self._ParseOutputFile(polqa_out_filepath)
|
||||
self._score = float(polqa_output['PolqaScore'])
|
||||
# Parse POLQA tool output and extract the score.
|
||||
polqa_output = self._ParseOutputFile(polqa_out_filepath)
|
||||
self._score = float(polqa_output['PolqaScore'])
|
||||
|
||||
self._SaveScore()
|
||||
self._SaveScore()
|
||||
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, polqa_out_filepath):
|
||||
"""
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, polqa_out_filepath):
|
||||
"""
|
||||
Parses the POLQA tool output formatted as a table ('-t' option).
|
||||
|
||||
Args:
|
||||
@ -311,29 +319,32 @@ class PolqaScore(EvaluationScore):
|
||||
Returns:
|
||||
A dict.
|
||||
"""
|
||||
data = []
|
||||
with open(polqa_out_filepath) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if len(line) == 0 or line.startswith('*'):
|
||||
# Ignore comments.
|
||||
continue
|
||||
# Read fields.
|
||||
data.append(re.split(r'\t+', line))
|
||||
data = []
|
||||
with open(polqa_out_filepath) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if len(line) == 0 or line.startswith('*'):
|
||||
# Ignore comments.
|
||||
continue
|
||||
# Read fields.
|
||||
data.append(re.split(r'\t+', line))
|
||||
|
||||
# Two rows expected (header and values).
|
||||
assert len(data) == 2, 'Cannot parse POLQA output'
|
||||
number_of_fields = len(data[0])
|
||||
assert number_of_fields == len(data[1])
|
||||
# Two rows expected (header and values).
|
||||
assert len(data) == 2, 'Cannot parse POLQA output'
|
||||
number_of_fields = len(data[0])
|
||||
assert number_of_fields == len(data[1])
|
||||
|
||||
# Build and return a dictionary with field names (header) as keys and the
|
||||
# corresponding field values as values.
|
||||
return {data[0][index]: data[1][index] for index in range(number_of_fields)}
|
||||
# Build and return a dictionary with field names (header) as keys and the
|
||||
# corresponding field values as values.
|
||||
return {
|
||||
data[0][index]: data[1][index]
|
||||
for index in range(number_of_fields)
|
||||
}
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class TotalHarmonicDistorsionScore(EvaluationScore):
|
||||
"""Total harmonic distorsion plus noise score.
|
||||
"""Total harmonic distorsion plus noise score.
|
||||
|
||||
Total harmonic distorsion plus noise score.
|
||||
See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
|
||||
@ -343,69 +354,74 @@ class TotalHarmonicDistorsionScore(EvaluationScore):
|
||||
Worst case: +inf
|
||||
"""
|
||||
|
||||
NAME = 'thd'
|
||||
NAME = 'thd'
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
self._input_frequency = None
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
self._input_frequency = None
|
||||
|
||||
def _Run(self, output_path):
|
||||
self._CheckInputSignal()
|
||||
def _Run(self, output_path):
|
||||
self._CheckInputSignal()
|
||||
|
||||
self._LoadTestedSignal()
|
||||
if self._tested_signal.channels != 1:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'unsupported number of channels')
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._tested_signal)
|
||||
self._LoadTestedSignal()
|
||||
if self._tested_signal.channels != 1:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'unsupported number of channels')
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._tested_signal)
|
||||
|
||||
# Init.
|
||||
num_samples = len(samples)
|
||||
duration = len(self._tested_signal) / 1000.0
|
||||
scaling = 2.0 / num_samples
|
||||
max_freq = self._tested_signal.frame_rate / 2
|
||||
f0_freq = float(self._input_frequency)
|
||||
t = np.linspace(0, duration, num_samples)
|
||||
# Init.
|
||||
num_samples = len(samples)
|
||||
duration = len(self._tested_signal) / 1000.0
|
||||
scaling = 2.0 / num_samples
|
||||
max_freq = self._tested_signal.frame_rate / 2
|
||||
f0_freq = float(self._input_frequency)
|
||||
t = np.linspace(0, duration, num_samples)
|
||||
|
||||
# Analyze harmonics.
|
||||
b_terms = []
|
||||
n = 1
|
||||
while f0_freq * n < max_freq:
|
||||
x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
b_terms.append(np.sqrt(x_n**2 + y_n**2))
|
||||
n += 1
|
||||
# Analyze harmonics.
|
||||
b_terms = []
|
||||
n = 1
|
||||
while f0_freq * n < max_freq:
|
||||
x_n = np.sum(
|
||||
samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
y_n = np.sum(
|
||||
samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
b_terms.append(np.sqrt(x_n**2 + y_n**2))
|
||||
n += 1
|
||||
|
||||
output_without_fundamental = samples - b_terms[0] * np.sin(
|
||||
2.0 * np.pi * f0_freq * t)
|
||||
distortion_and_noise = np.sqrt(np.sum(
|
||||
output_without_fundamental**2) * np.pi * scaling)
|
||||
output_without_fundamental = samples - b_terms[0] * np.sin(
|
||||
2.0 * np.pi * f0_freq * t)
|
||||
distortion_and_noise = np.sqrt(
|
||||
np.sum(output_without_fundamental**2) * np.pi * scaling)
|
||||
|
||||
# TODO(alessiob): Fix or remove if not needed.
|
||||
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
|
||||
# TODO(alessiob): Fix or remove if not needed.
|
||||
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
|
||||
|
||||
# TODO(alessiob): Check the range of |thd_plus_noise| and update the class
|
||||
# docstring above if accordingly.
|
||||
thd_plus_noise = distortion_and_noise / b_terms[0]
|
||||
# TODO(alessiob): Check the range of |thd_plus_noise| and update the class
|
||||
# docstring above if accordingly.
|
||||
thd_plus_noise = distortion_and_noise / b_terms[0]
|
||||
|
||||
self._score = thd_plus_noise
|
||||
self._SaveScore()
|
||||
self._score = thd_plus_noise
|
||||
self._SaveScore()
|
||||
|
||||
def _CheckInputSignal(self):
|
||||
# Check input signal and get properties.
|
||||
try:
|
||||
if self._input_signal_metadata['signal'] != 'pure_tone':
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires a pure tone as input signal')
|
||||
self._input_frequency = self._input_signal_metadata['frequency']
|
||||
if self._input_signal_metadata['test_data_gen_name'] != 'identity' or (
|
||||
self._input_signal_metadata['test_data_gen_config'] != 'default'):
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score cannot be used with any test data generator other '
|
||||
'than "identity"')
|
||||
except TypeError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires an input signal with associated metadata')
|
||||
except KeyError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'Invalid input signal metadata to compute the THD score')
|
||||
def _CheckInputSignal(self):
|
||||
# Check input signal and get properties.
|
||||
try:
|
||||
if self._input_signal_metadata['signal'] != 'pure_tone':
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires a pure tone as input signal')
|
||||
self._input_frequency = self._input_signal_metadata['frequency']
|
||||
if self._input_signal_metadata[
|
||||
'test_data_gen_name'] != 'identity' or (
|
||||
self._input_signal_metadata['test_data_gen_config'] !=
|
||||
'default'):
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score cannot be used with any test data generator other '
|
||||
'than "identity"')
|
||||
except TypeError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires an input signal with associated metadata'
|
||||
)
|
||||
except KeyError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'Invalid input signal metadata to compute the THD score')
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""EvaluationScore factory class.
|
||||
"""
|
||||
|
||||
@ -16,22 +15,22 @@ from . import eval_scores
|
||||
|
||||
|
||||
class EvaluationScoreWorkerFactory(object):
|
||||
"""Factory class used to instantiate evaluation score workers.
|
||||
"""Factory class used to instantiate evaluation score workers.
|
||||
|
||||
The ctor gets the parametrs that are used to instatiate the evaluation score
|
||||
workers.
|
||||
"""
|
||||
|
||||
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
|
||||
self._score_filename_prefix = None
|
||||
self._polqa_tool_bin_path = polqa_tool_bin_path
|
||||
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
|
||||
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
|
||||
self._score_filename_prefix = None
|
||||
self._polqa_tool_bin_path = polqa_tool_bin_path
|
||||
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
|
||||
|
||||
def SetScoreFilenamePrefix(self, prefix):
|
||||
self._score_filename_prefix = prefix
|
||||
def SetScoreFilenamePrefix(self, prefix):
|
||||
self._score_filename_prefix = prefix
|
||||
|
||||
def GetInstance(self, evaluation_score_class):
|
||||
"""Creates an EvaluationScore instance given a class object.
|
||||
def GetInstance(self, evaluation_score_class):
|
||||
"""Creates an EvaluationScore instance given a class object.
|
||||
|
||||
Args:
|
||||
evaluation_score_class: EvaluationScore class object (not an instance).
|
||||
@ -39,17 +38,18 @@ class EvaluationScoreWorkerFactory(object):
|
||||
Returns:
|
||||
An EvaluationScore instance.
|
||||
"""
|
||||
if self._score_filename_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The score file name prefix for evaluation score workers is not set')
|
||||
logging.debug(
|
||||
'factory producing a %s evaluation score', evaluation_score_class)
|
||||
if self._score_filename_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The score file name prefix for evaluation score workers is not set'
|
||||
)
|
||||
logging.debug('factory producing a %s evaluation score',
|
||||
evaluation_score_class)
|
||||
|
||||
if evaluation_score_class == eval_scores.PolqaScore:
|
||||
return eval_scores.PolqaScore(
|
||||
self._score_filename_prefix, self._polqa_tool_bin_path)
|
||||
elif evaluation_score_class == eval_scores.EchoMetric:
|
||||
return eval_scores.EchoMetric(
|
||||
self._score_filename_prefix, self._echo_metric_tool_bin_path)
|
||||
else:
|
||||
return evaluation_score_class(self._score_filename_prefix)
|
||||
if evaluation_score_class == eval_scores.PolqaScore:
|
||||
return eval_scores.PolqaScore(self._score_filename_prefix,
|
||||
self._polqa_tool_bin_path)
|
||||
elif evaluation_score_class == eval_scores.EchoMetric:
|
||||
return eval_scores.EchoMetric(self._score_filename_prefix,
|
||||
self._echo_metric_tool_bin_path)
|
||||
else:
|
||||
return evaluation_score_class(self._score_filename_prefix)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
@ -23,111 +22,116 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestEvalScores(unittest.TestCase):
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary output folder and two audio track files."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Create temporary output folder and two audio track files."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
|
||||
# Create fake reference and tested (i.e., APM output) audio track files.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_reference_signal = (
|
||||
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
|
||||
fake_tested_signal = (
|
||||
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
|
||||
# Create fake reference and tested (i.e., APM output) audio track files.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_reference_signal = (signal_processing.SignalProcessingUtils.
|
||||
GenerateWhiteNoise(silence))
|
||||
fake_tested_signal = (signal_processing.SignalProcessingUtils.
|
||||
GenerateWhiteNoise(silence))
|
||||
|
||||
# Save fake audio tracks.
|
||||
self._fake_reference_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_ref.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_reference_signal_filepath, fake_reference_signal)
|
||||
self._fake_tested_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_test.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_tested_signal_filepath, fake_tested_signal)
|
||||
# Save fake audio tracks.
|
||||
self._fake_reference_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_ref.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_reference_signal_filepath, fake_reference_signal)
|
||||
self._fake_tested_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_test.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_tested_signal_filepath, fake_tested_signal)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
shutil.rmtree(self._output_path)
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
shutil.rmtree(self._output_path)
|
||||
|
||||
def testRegisteredClasses(self):
|
||||
# Evaluation score names to exclude (tested separately).
|
||||
exceptions = ['thd', 'echo_metric']
|
||||
def testRegisteredClasses(self):
|
||||
# Evaluation score names to exclude (tested separately).
|
||||
exceptions = ['thd', 'echo_metric']
|
||||
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._output_path))
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._output_path))
|
||||
|
||||
# Check that there is at least one registered evaluation score worker.
|
||||
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
# Check that there is at least one registered evaluation score worker.
|
||||
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance evaluation score workers factory with fake dependencies.
|
||||
eval_score_workers_factory = (
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
))
|
||||
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
|
||||
# Instance evaluation score workers factory with fake dependencies.
|
||||
eval_score_workers_factory = (
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None))
|
||||
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
|
||||
|
||||
# Try each registered evaluation score worker.
|
||||
for eval_score_name in registered_classes:
|
||||
if eval_score_name in exceptions:
|
||||
continue
|
||||
# Try each registered evaluation score worker.
|
||||
for eval_score_name in registered_classes:
|
||||
if eval_score_name in exceptions:
|
||||
continue
|
||||
|
||||
# Instance evaluation score worker.
|
||||
eval_score_worker = eval_score_workers_factory.GetInstance(
|
||||
registered_classes[eval_score_name])
|
||||
# Instance evaluation score worker.
|
||||
eval_score_worker = eval_score_workers_factory.GetInstance(
|
||||
registered_classes[eval_score_name])
|
||||
|
||||
# Set fake input metadata and reference and test file paths, then run.
|
||||
eval_score_worker.SetReferenceSignalFilepath(
|
||||
self._fake_reference_signal_filepath)
|
||||
eval_score_worker.SetTestedSignalFilepath(
|
||||
self._fake_tested_signal_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
# Set fake input metadata and reference and test file paths, then run.
|
||||
eval_score_worker.SetReferenceSignalFilepath(
|
||||
self._fake_reference_signal_filepath)
|
||||
eval_score_worker.SetTestedSignalFilepath(
|
||||
self._fake_tested_signal_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
|
||||
# Check output.
|
||||
score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
|
||||
self.assertTrue(isinstance(score, float))
|
||||
# Check output.
|
||||
score = data_access.ScoreFile.Load(
|
||||
eval_score_worker.output_filepath)
|
||||
self.assertTrue(isinstance(score, float))
|
||||
|
||||
def testTotalHarmonicDistorsionScore(self):
|
||||
# Init.
|
||||
pure_tone_freq = 5000.0
|
||||
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
|
||||
eval_score_worker.SetInputSignalMetadata({
|
||||
'signal': 'pure_tone',
|
||||
'frequency': pure_tone_freq,
|
||||
'test_data_gen_name': 'identity',
|
||||
'test_data_gen_config': 'default',
|
||||
})
|
||||
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
def testTotalHarmonicDistorsionScore(self):
|
||||
# Init.
|
||||
pure_tone_freq = 5000.0
|
||||
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
|
||||
eval_score_worker.SetInputSignalMetadata({
|
||||
'signal':
|
||||
'pure_tone',
|
||||
'frequency':
|
||||
pure_tone_freq,
|
||||
'test_data_gen_name':
|
||||
'identity',
|
||||
'test_data_gen_config':
|
||||
'default',
|
||||
})
|
||||
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
|
||||
# Create 3 test signals: pure tone, pure tone + white noise, white noise
|
||||
# only.
|
||||
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, pure_tone_freq)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
template)
|
||||
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
pure_tone, white_noise)
|
||||
# Create 3 test signals: pure tone, pure tone + white noise, white noise
|
||||
# only.
|
||||
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, pure_tone_freq)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
template)
|
||||
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
pure_tone, white_noise)
|
||||
|
||||
# Compute scores for increasingly distorted pure tone signals.
|
||||
scores = [None, None, None]
|
||||
for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]):
|
||||
# Save signal.
|
||||
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
tmp_filepath, tested_signal)
|
||||
# Compute scores for increasingly distorted pure tone signals.
|
||||
scores = [None, None, None]
|
||||
for index, tested_signal in enumerate(
|
||||
[pure_tone, noisy_tone, white_noise]):
|
||||
# Save signal.
|
||||
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
tmp_filepath, tested_signal)
|
||||
|
||||
# Compute score.
|
||||
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
scores[index] = eval_score_worker.score
|
||||
# Compute score.
|
||||
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
scores[index] = eval_score_worker.score
|
||||
|
||||
# Remove output file to avoid caching.
|
||||
os.remove(eval_score_worker.output_filepath)
|
||||
# Remove output file to avoid caching.
|
||||
os.remove(eval_score_worker.output_filepath)
|
||||
|
||||
# Validate scores (lowest score with a pure tone).
|
||||
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
|
||||
# Validate scores (lowest score with a pure tone).
|
||||
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Evaluator of the APM module.
|
||||
"""
|
||||
|
||||
@ -13,17 +12,17 @@ import logging
|
||||
|
||||
|
||||
class ApmModuleEvaluator(object):
|
||||
"""APM evaluator class.
|
||||
"""APM evaluator class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def Run(cls, evaluation_score_workers, apm_input_metadata,
|
||||
apm_output_filepath, reference_input_filepath,
|
||||
render_input_filepath, output_path):
|
||||
"""Runs the evaluation.
|
||||
@classmethod
|
||||
def Run(cls, evaluation_score_workers, apm_input_metadata,
|
||||
apm_output_filepath, reference_input_filepath,
|
||||
render_input_filepath, output_path):
|
||||
"""Runs the evaluation.
|
||||
|
||||
Iterates over the given evaluation score workers.
|
||||
|
||||
@ -37,20 +36,22 @@ class ApmModuleEvaluator(object):
|
||||
Returns:
|
||||
A dict of evaluation score name and score pairs.
|
||||
"""
|
||||
# Init.
|
||||
scores = {}
|
||||
# Init.
|
||||
scores = {}
|
||||
|
||||
for evaluation_score_worker in evaluation_score_workers:
|
||||
logging.info(' computing <%s> score', evaluation_score_worker.NAME)
|
||||
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
|
||||
evaluation_score_worker.SetReferenceSignalFilepath(
|
||||
reference_input_filepath)
|
||||
evaluation_score_worker.SetTestedSignalFilepath(
|
||||
apm_output_filepath)
|
||||
evaluation_score_worker.SetRenderSignalFilepath(
|
||||
render_input_filepath)
|
||||
for evaluation_score_worker in evaluation_score_workers:
|
||||
logging.info(' computing <%s> score',
|
||||
evaluation_score_worker.NAME)
|
||||
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
|
||||
evaluation_score_worker.SetReferenceSignalFilepath(
|
||||
reference_input_filepath)
|
||||
evaluation_score_worker.SetTestedSignalFilepath(
|
||||
apm_output_filepath)
|
||||
evaluation_score_worker.SetRenderSignalFilepath(
|
||||
render_input_filepath)
|
||||
|
||||
evaluation_score_worker.Run(output_path)
|
||||
scores[evaluation_score_worker.NAME] = evaluation_score_worker.score
|
||||
evaluation_score_worker.Run(output_path)
|
||||
scores[
|
||||
evaluation_score_worker.NAME] = evaluation_score_worker.score
|
||||
|
||||
return scores
|
||||
return scores
|
||||
|
||||
@ -5,42 +5,41 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Exception classes.
|
||||
"""
|
||||
|
||||
|
||||
class FileNotFoundError(Exception):
|
||||
"""File not found exception.
|
||||
"""File not found exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class SignalProcessingException(Exception):
|
||||
"""Signal processing exception.
|
||||
"""Signal processing exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class InputMixerException(Exception):
|
||||
"""Input mixer exception.
|
||||
"""Input mixer exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class InputSignalCreatorException(Exception):
|
||||
"""Input signal creator exception.
|
||||
"""Input signal creator exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class EvaluationScoreException(Exception):
|
||||
"""Evaluation score exception.
|
||||
"""Evaluation score exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class InitializationException(Exception):
|
||||
"""Initialization exception.
|
||||
"""Initialization exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -14,58 +14,58 @@ import re
|
||||
import sys
|
||||
|
||||
try:
|
||||
import csscompressor
|
||||
import csscompressor
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package csscompressor')
|
||||
sys.exit(1)
|
||||
logging.critical(
|
||||
'Cannot import the third-party Python package csscompressor')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import jsmin
|
||||
import jsmin
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package jsmin')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package jsmin')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class HtmlExport(object):
|
||||
"""HTML exporter class for APM quality scores."""
|
||||
"""HTML exporter class for APM quality scores."""
|
||||
|
||||
_NEW_LINE = '\n'
|
||||
_NEW_LINE = '\n'
|
||||
|
||||
# CSS and JS file paths.
|
||||
_PATH = os.path.dirname(os.path.realpath(__file__))
|
||||
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
|
||||
_CSS_MINIFIED = True
|
||||
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
|
||||
_JS_MINIFIED = True
|
||||
# CSS and JS file paths.
|
||||
_PATH = os.path.dirname(os.path.realpath(__file__))
|
||||
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
|
||||
_CSS_MINIFIED = True
|
||||
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
|
||||
_JS_MINIFIED = True
|
||||
|
||||
def __init__(self, output_filepath):
|
||||
self._scores_data_frame = None
|
||||
self._output_filepath = output_filepath
|
||||
def __init__(self, output_filepath):
|
||||
self._scores_data_frame = None
|
||||
self._output_filepath = output_filepath
|
||||
|
||||
def Export(self, scores_data_frame):
|
||||
"""Exports scores into an HTML file.
|
||||
def Export(self, scores_data_frame):
|
||||
"""Exports scores into an HTML file.
|
||||
|
||||
Args:
|
||||
scores_data_frame: DataFrame instance.
|
||||
"""
|
||||
self._scores_data_frame = scores_data_frame
|
||||
html = ['<html>',
|
||||
self._scores_data_frame = scores_data_frame
|
||||
html = [
|
||||
'<html>',
|
||||
self._BuildHeader(),
|
||||
('<script type="text/javascript">'
|
||||
'(function () {'
|
||||
'window.addEventListener(\'load\', function () {'
|
||||
'var inspector = new AudioInspector();'
|
||||
'});'
|
||||
'(function () {'
|
||||
'window.addEventListener(\'load\', function () {'
|
||||
'var inspector = new AudioInspector();'
|
||||
'});'
|
||||
'})();'
|
||||
'</script>'),
|
||||
'<body>',
|
||||
self._BuildBody(),
|
||||
'</body>',
|
||||
'</html>']
|
||||
self._Save(self._output_filepath, self._NEW_LINE.join(html))
|
||||
'</script>'), '<body>',
|
||||
self._BuildBody(), '</body>', '</html>'
|
||||
]
|
||||
self._Save(self._output_filepath, self._NEW_LINE.join(html))
|
||||
|
||||
def _BuildHeader(self):
|
||||
"""Builds the <head> section of the HTML file.
|
||||
def _BuildHeader(self):
|
||||
"""Builds the <head> section of the HTML file.
|
||||
|
||||
The header contains the page title and either embedded or linked CSS and JS
|
||||
files.
|
||||
@ -73,325 +73,349 @@ class HtmlExport(object):
|
||||
Returns:
|
||||
A string with <head>...</head> HTML.
|
||||
"""
|
||||
html = ['<head>', '<title>Results</title>']
|
||||
html = ['<head>', '<title>Results</title>']
|
||||
|
||||
# Add Material Design hosted libs.
|
||||
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
|
||||
'css?family=Roboto:300,400,500,700" type="text/css">')
|
||||
html.append('<link rel="stylesheet" href="https://fonts.googleapis.com/'
|
||||
'icon?family=Material+Icons">')
|
||||
html.append('<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
|
||||
'material.indigo-pink.min.css">')
|
||||
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
|
||||
'material.min.js"></script>')
|
||||
|
||||
# Embed custom JavaScript and CSS files.
|
||||
html.append('<script>')
|
||||
with open(self._JS_FILEPATH) as f:
|
||||
html.append(jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</script>')
|
||||
html.append('<style>')
|
||||
with open(self._CSS_FILEPATH) as f:
|
||||
html.append(csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</style>')
|
||||
|
||||
html.append('</head>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildBody(self):
|
||||
"""Builds the content of the <body> section."""
|
||||
score_names = self._scores_data_frame['eval_score_name'].drop_duplicates(
|
||||
).values.tolist()
|
||||
|
||||
html = [
|
||||
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
|
||||
'mdl-layout--fixed-tabs">'),
|
||||
'<header class="mdl-layout__header">',
|
||||
'<div class="mdl-layout__header-row">',
|
||||
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
|
||||
self._output_filepath),
|
||||
'</div>',
|
||||
]
|
||||
|
||||
# Tab selectors.
|
||||
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
is_active = tab_index == 0
|
||||
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
|
||||
'{}</a>'.format(tab_index,
|
||||
' is-active' if is_active else '',
|
||||
self._FormatName(score_name)))
|
||||
html.append('</div>')
|
||||
|
||||
html.append('</header>')
|
||||
html.append('<main class="mdl-layout__content" style="overflow-x: auto;">')
|
||||
|
||||
# Tabs content.
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
html.append('<section class="mdl-layout__tab-panel{}" '
|
||||
'id="score-tab-{}">'.format(
|
||||
' is-active' if is_active else '', tab_index))
|
||||
html.append('<div class="page-content">')
|
||||
html.append(self._BuildScoreTab(score_name, ('s{}'.format(tab_index),)))
|
||||
html.append('</div>')
|
||||
html.append('</section>')
|
||||
|
||||
html.append('</main>')
|
||||
html.append('</div>')
|
||||
|
||||
# Add snackbar for notifications.
|
||||
html.append(
|
||||
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
|
||||
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
|
||||
'<div class="mdl-snackbar__text"></div>'
|
||||
'<button type="button" class="mdl-snackbar__action"></button>'
|
||||
'</div>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreTab(self, score_name, anchor_data):
|
||||
"""Builds the content of a tab."""
|
||||
# Find unique values.
|
||||
scores = self._scores_data_frame[
|
||||
self._scores_data_frame.eval_score_name == score_name]
|
||||
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
|
||||
test_data_gen_configs = sorted(self._FindUniqueTuples(
|
||||
scores, ['test_data_gen', 'test_data_gen_params']))
|
||||
|
||||
html = [
|
||||
'<div class="mdl-grid">',
|
||||
'<div class="mdl-layout-spacer"></div>',
|
||||
'<div class="mdl-cell mdl-cell--10-col">',
|
||||
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
|
||||
'style="width: 100%;">'),
|
||||
]
|
||||
|
||||
# Header.
|
||||
html.append('<thead><tr><th>APM config / Test data generator</th>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
html.append('<th>{} {}</th>'.format(
|
||||
self._FormatName(test_data_gen_info[0]), test_data_gen_info[1]))
|
||||
html.append('</tr></thead>')
|
||||
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for apm_config in apm_configs:
|
||||
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1])
|
||||
# Add Material Design hosted libs.
|
||||
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
|
||||
'css?family=Roboto:300,400,500,700" type="text/css">')
|
||||
html.append(
|
||||
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.format(
|
||||
dialog_id, self._BuildScoreTableCell(
|
||||
score_name, test_data_gen_info[0], test_data_gen_info[1],
|
||||
apm_config[0])))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
'<link rel="stylesheet" href="https://fonts.googleapis.com/'
|
||||
'icon?family=Material+Icons">')
|
||||
html.append(
|
||||
'<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
|
||||
'material.indigo-pink.min.css">')
|
||||
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
|
||||
'material.min.js"></script>')
|
||||
|
||||
html.append('</table></div><div class="mdl-layout-spacer"></div></div>')
|
||||
# Embed custom JavaScript and CSS files.
|
||||
html.append('<script>')
|
||||
with open(self._JS_FILEPATH) as f:
|
||||
html.append(
|
||||
jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</script>')
|
||||
html.append('<style>')
|
||||
with open(self._CSS_FILEPATH) as f:
|
||||
html.append(
|
||||
csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</style>')
|
||||
|
||||
html.append(self._BuildScoreStatsInspectorDialogs(
|
||||
score_name, apm_configs, test_data_gen_configs,
|
||||
anchor_data))
|
||||
html.append('</head>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreTableCell(self, score_name, test_data_gen,
|
||||
test_data_gen_params, apm_config):
|
||||
"""Builds the content of a table cell for a score table."""
|
||||
scores = self._SliceDataForScoreTableCell(
|
||||
score_name, apm_config, test_data_gen, test_data_gen_params)
|
||||
stats = self._ComputeScoreStats(scores)
|
||||
def _BuildBody(self):
|
||||
"""Builds the content of the <body> section."""
|
||||
score_names = self._scores_data_frame[
|
||||
'eval_score_name'].drop_duplicates().values.tolist()
|
||||
|
||||
html = []
|
||||
items_id_prefix = (
|
||||
score_name + test_data_gen + test_data_gen_params + apm_config)
|
||||
if stats['count'] == 1:
|
||||
# Show the only available score.
|
||||
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
|
||||
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
|
||||
item_id, scores['score'].mean()))
|
||||
html.append('<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
|
||||
'</div>'.format(item_id, 'single value'))
|
||||
else:
|
||||
# Show stats.
|
||||
for stat_name in ['min', 'max', 'mean', 'std dev']:
|
||||
item_id = hashlib.md5(
|
||||
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
|
||||
html.append('<div id="stats-{0}">{1:f}</div>'.format(
|
||||
item_id, stats[stat_name]))
|
||||
html.append('<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
|
||||
'</div>'.format(item_id, stat_name))
|
||||
html = [
|
||||
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
|
||||
'mdl-layout--fixed-tabs">'),
|
||||
'<header class="mdl-layout__header">',
|
||||
'<div class="mdl-layout__header-row">',
|
||||
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
|
||||
self._output_filepath),
|
||||
'</div>',
|
||||
]
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreStatsInspectorDialogs(
|
||||
self, score_name, apm_configs, test_data_gen_configs, anchor_data):
|
||||
"""Builds a set of score stats inspector dialogs."""
|
||||
html = []
|
||||
for apm_config in apm_configs:
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0],
|
||||
test_data_gen_info[0], test_data_gen_info[1])
|
||||
|
||||
html.append('<dialog class="mdl-dialog" id="{}" '
|
||||
'style="width: 40%;">'.format(dialog_id))
|
||||
|
||||
# Content.
|
||||
html.append('<div class="mdl-dialog__content">')
|
||||
html.append('<h6><strong>APM config preset</strong>: {}<br/>'
|
||||
'<strong>Test data generator</strong>: {} ({})</h6>'.format(
|
||||
self._FormatName(apm_config[0]),
|
||||
self._FormatName(test_data_gen_info[0]),
|
||||
test_data_gen_info[1]))
|
||||
html.append(self._BuildScoreStatsInspectorDialog(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1], anchor_data + (dialog_id,)))
|
||||
# Tab selectors.
|
||||
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
is_active = tab_index == 0
|
||||
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
|
||||
'{}</a>'.format(tab_index,
|
||||
' is-active' if is_active else '',
|
||||
self._FormatName(score_name)))
|
||||
html.append('</div>')
|
||||
|
||||
# Actions.
|
||||
html.append('<div class="mdl-dialog__actions">')
|
||||
html.append('<button type="button" class="mdl-button" '
|
||||
'onclick="closeScoreStatsInspector()">'
|
||||
'Close</button>')
|
||||
html.append('</header>')
|
||||
html.append(
|
||||
'<main class="mdl-layout__content" style="overflow-x: auto;">')
|
||||
|
||||
# Tabs content.
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
html.append('<section class="mdl-layout__tab-panel{}" '
|
||||
'id="score-tab-{}">'.format(
|
||||
' is-active' if is_active else '', tab_index))
|
||||
html.append('<div class="page-content">')
|
||||
html.append(
|
||||
self._BuildScoreTab(score_name, ('s{}'.format(tab_index), )))
|
||||
html.append('</div>')
|
||||
html.append('</section>')
|
||||
|
||||
html.append('</main>')
|
||||
html.append('</div>')
|
||||
|
||||
html.append('</dialog>')
|
||||
# Add snackbar for notifications.
|
||||
html.append(
|
||||
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
|
||||
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
|
||||
'<div class="mdl-snackbar__text"></div>'
|
||||
'<button type="button" class="mdl-snackbar__action"></button>'
|
||||
'</div>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreStatsInspectorDialog(
|
||||
self, score_name, apm_config, test_data_gen, test_data_gen_params,
|
||||
anchor_data):
|
||||
"""Builds one score stats inspector dialog."""
|
||||
scores = self._SliceDataForScoreTableCell(
|
||||
score_name, apm_config, test_data_gen, test_data_gen_params)
|
||||
def _BuildScoreTab(self, score_name, anchor_data):
|
||||
"""Builds the content of a tab."""
|
||||
# Find unique values.
|
||||
scores = self._scores_data_frame[
|
||||
self._scores_data_frame.eval_score_name == score_name]
|
||||
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
|
||||
test_data_gen_configs = sorted(
|
||||
self._FindUniqueTuples(scores,
|
||||
['test_data_gen', 'test_data_gen_params']))
|
||||
|
||||
capture_render_pairs = sorted(self._FindUniqueTuples(
|
||||
scores, ['capture', 'render']))
|
||||
echo_simulators = sorted(self._FindUniqueTuples(scores, ['echo_simulator']))
|
||||
html = [
|
||||
'<div class="mdl-grid">',
|
||||
'<div class="mdl-layout-spacer"></div>',
|
||||
'<div class="mdl-cell mdl-cell--10-col">',
|
||||
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
|
||||
'style="width: 100%;">'),
|
||||
]
|
||||
|
||||
html = ['<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">']
|
||||
# Header.
|
||||
html.append('<thead><tr><th>APM config / Test data generator</th>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
html.append('<th>{} {}</th>'.format(
|
||||
self._FormatName(test_data_gen_info[0]),
|
||||
test_data_gen_info[1]))
|
||||
html.append('</tr></thead>')
|
||||
|
||||
# Header.
|
||||
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
|
||||
for echo_simulator in echo_simulators:
|
||||
html.append('<th>' + self._FormatName(echo_simulator[0]) +'</th>')
|
||||
html.append('</tr></thead>')
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for apm_config in apm_configs:
|
||||
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1])
|
||||
html.append(
|
||||
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.
|
||||
format(
|
||||
dialog_id,
|
||||
self._BuildScoreTableCell(score_name,
|
||||
test_data_gen_info[0],
|
||||
test_data_gen_info[1],
|
||||
apm_config[0])))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for row, (capture, render) in enumerate(capture_render_pairs):
|
||||
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
|
||||
capture, render))
|
||||
for col, echo_simulator in enumerate(echo_simulators):
|
||||
score_tuple = self._SliceDataForScoreStatsTableCell(
|
||||
scores, capture, render, echo_simulator[0])
|
||||
cell_class = 'r{}c{}'.format(row, col)
|
||||
html.append('<td class="single-score-cell {}">{}</td>'.format(
|
||||
cell_class, self._BuildScoreStatsInspectorTableCell(
|
||||
score_tuple, anchor_data + (cell_class,))))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
html.append(
|
||||
'</table></div><div class="mdl-layout-spacer"></div></div>')
|
||||
|
||||
html.append('</table>')
|
||||
html.append(
|
||||
self._BuildScoreStatsInspectorDialogs(score_name, apm_configs,
|
||||
test_data_gen_configs,
|
||||
anchor_data))
|
||||
|
||||
# Placeholder for the audio inspector.
|
||||
html.append('<div class="audio-inspector-placeholder"></div>')
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
def _BuildScoreTableCell(self, score_name, test_data_gen,
|
||||
test_data_gen_params, apm_config):
|
||||
"""Builds the content of a table cell for a score table."""
|
||||
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
|
||||
test_data_gen,
|
||||
test_data_gen_params)
|
||||
stats = self._ComputeScoreStats(scores)
|
||||
|
||||
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
|
||||
"""Builds the content of a cell of a score stats inspector."""
|
||||
anchor = '&'.join(anchor_data)
|
||||
html = [('<div class="v">{}</div>'
|
||||
'<button class="mdl-button mdl-js-button mdl-button--icon"'
|
||||
' data-anchor="{}">'
|
||||
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
|
||||
'</button>').format(score_tuple.score, anchor)]
|
||||
html = []
|
||||
items_id_prefix = (score_name + test_data_gen + test_data_gen_params +
|
||||
apm_config)
|
||||
if stats['count'] == 1:
|
||||
# Show the only available score.
|
||||
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
|
||||
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
|
||||
item_id, scores['score'].mean()))
|
||||
html.append(
|
||||
'<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
|
||||
'</div>'.format(item_id, 'single value'))
|
||||
else:
|
||||
# Show stats.
|
||||
for stat_name in ['min', 'max', 'mean', 'std dev']:
|
||||
item_id = hashlib.md5(
|
||||
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
|
||||
html.append('<div id="stats-{0}">{1:f}</div>'.format(
|
||||
item_id, stats[stat_name]))
|
||||
html.append(
|
||||
'<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
|
||||
'</div>'.format(item_id, stat_name))
|
||||
|
||||
# Add all the available file paths as hidden data.
|
||||
for field_name in score_tuple.keys():
|
||||
if field_name.endswith('_filepath'):
|
||||
html.append('<input type="hidden" name="{}" value="{}">'.format(
|
||||
field_name, score_tuple[field_name]))
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
def _BuildScoreStatsInspectorDialogs(self, score_name, apm_configs,
|
||||
test_data_gen_configs, anchor_data):
|
||||
"""Builds a set of score stats inspector dialogs."""
|
||||
html = []
|
||||
for apm_config in apm_configs:
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1])
|
||||
|
||||
def _SliceDataForScoreTableCell(
|
||||
self, score_name, apm_config, test_data_gen, test_data_gen_params):
|
||||
"""Slices |self._scores_data_frame| to extract the data for a tab."""
|
||||
masks = []
|
||||
masks.append(self._scores_data_frame.eval_score_name == score_name)
|
||||
masks.append(self._scores_data_frame.apm_config == apm_config)
|
||||
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
|
||||
masks.append(
|
||||
self._scores_data_frame.test_data_gen_params == test_data_gen_params)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
return self._scores_data_frame[mask]
|
||||
html.append('<dialog class="mdl-dialog" id="{}" '
|
||||
'style="width: 40%;">'.format(dialog_id))
|
||||
|
||||
@classmethod
|
||||
def _SliceDataForScoreStatsTableCell(
|
||||
cls, scores, capture, render, echo_simulator):
|
||||
"""Slices |scores| to extract the data for a tab."""
|
||||
masks = []
|
||||
# Content.
|
||||
html.append('<div class="mdl-dialog__content">')
|
||||
html.append(
|
||||
'<h6><strong>APM config preset</strong>: {}<br/>'
|
||||
'<strong>Test data generator</strong>: {} ({})</h6>'.
|
||||
format(self._FormatName(apm_config[0]),
|
||||
self._FormatName(test_data_gen_info[0]),
|
||||
test_data_gen_info[1]))
|
||||
html.append(
|
||||
self._BuildScoreStatsInspectorDialog(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1], anchor_data + (dialog_id, )))
|
||||
html.append('</div>')
|
||||
|
||||
masks.append(scores.capture == capture)
|
||||
masks.append(scores.render == render)
|
||||
masks.append(scores.echo_simulator == echo_simulator)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
# Actions.
|
||||
html.append('<div class="mdl-dialog__actions">')
|
||||
html.append('<button type="button" class="mdl-button" '
|
||||
'onclick="closeScoreStatsInspector()">'
|
||||
'Close</button>')
|
||||
html.append('</div>')
|
||||
|
||||
sliced_data = scores[mask]
|
||||
assert len(sliced_data) == 1, 'single score is expected'
|
||||
return sliced_data.iloc[0]
|
||||
html.append('</dialog>')
|
||||
|
||||
@classmethod
|
||||
def _FindUniqueTuples(cls, data_frame, fields):
|
||||
"""Slices |data_frame| to a list of fields and finds unique tuples."""
|
||||
return data_frame[fields].drop_duplicates().values.tolist()
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
@classmethod
|
||||
def _ComputeScoreStats(cls, data_frame):
|
||||
"""Computes score stats."""
|
||||
scores = data_frame['score']
|
||||
return {
|
||||
'count': scores.count(),
|
||||
'min': scores.min(),
|
||||
'max': scores.max(),
|
||||
'mean': scores.mean(),
|
||||
'std dev': scores.std(),
|
||||
}
|
||||
def _BuildScoreStatsInspectorDialog(self, score_name, apm_config,
|
||||
test_data_gen, test_data_gen_params,
|
||||
anchor_data):
|
||||
"""Builds one score stats inspector dialog."""
|
||||
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
|
||||
test_data_gen,
|
||||
test_data_gen_params)
|
||||
|
||||
@classmethod
|
||||
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config, test_data_gen,
|
||||
test_data_gen_params):
|
||||
"""Assigns a unique name to a dialog."""
|
||||
return 'score-stats-dialog-' + hashlib.md5(
|
||||
'score-stats-inspector-{}-{}-{}-{}'.format(
|
||||
score_name, apm_config, test_data_gen,
|
||||
test_data_gen_params).encode('utf-8')).hexdigest()
|
||||
capture_render_pairs = sorted(
|
||||
self._FindUniqueTuples(scores, ['capture', 'render']))
|
||||
echo_simulators = sorted(
|
||||
self._FindUniqueTuples(scores, ['echo_simulator']))
|
||||
|
||||
@classmethod
|
||||
def _Save(cls, output_filepath, html):
|
||||
"""Writes the HTML file.
|
||||
html = [
|
||||
'<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">'
|
||||
]
|
||||
|
||||
# Header.
|
||||
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
|
||||
for echo_simulator in echo_simulators:
|
||||
html.append('<th>' + self._FormatName(echo_simulator[0]) + '</th>')
|
||||
html.append('</tr></thead>')
|
||||
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for row, (capture, render) in enumerate(capture_render_pairs):
|
||||
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
|
||||
capture, render))
|
||||
for col, echo_simulator in enumerate(echo_simulators):
|
||||
score_tuple = self._SliceDataForScoreStatsTableCell(
|
||||
scores, capture, render, echo_simulator[0])
|
||||
cell_class = 'r{}c{}'.format(row, col)
|
||||
html.append('<td class="single-score-cell {}">{}</td>'.format(
|
||||
cell_class,
|
||||
self._BuildScoreStatsInspectorTableCell(
|
||||
score_tuple, anchor_data + (cell_class, ))))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
|
||||
html.append('</table>')
|
||||
|
||||
# Placeholder for the audio inspector.
|
||||
html.append('<div class="audio-inspector-placeholder"></div>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
|
||||
"""Builds the content of a cell of a score stats inspector."""
|
||||
anchor = '&'.join(anchor_data)
|
||||
html = [('<div class="v">{}</div>'
|
||||
'<button class="mdl-button mdl-js-button mdl-button--icon"'
|
||||
' data-anchor="{}">'
|
||||
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
|
||||
'</button>').format(score_tuple.score, anchor)]
|
||||
|
||||
# Add all the available file paths as hidden data.
|
||||
for field_name in score_tuple.keys():
|
||||
if field_name.endswith('_filepath'):
|
||||
html.append(
|
||||
'<input type="hidden" name="{}" value="{}">'.format(
|
||||
field_name, score_tuple[field_name]))
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _SliceDataForScoreTableCell(self, score_name, apm_config,
|
||||
test_data_gen, test_data_gen_params):
|
||||
"""Slices |self._scores_data_frame| to extract the data for a tab."""
|
||||
masks = []
|
||||
masks.append(self._scores_data_frame.eval_score_name == score_name)
|
||||
masks.append(self._scores_data_frame.apm_config == apm_config)
|
||||
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
|
||||
masks.append(self._scores_data_frame.test_data_gen_params ==
|
||||
test_data_gen_params)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
return self._scores_data_frame[mask]
|
||||
|
||||
@classmethod
|
||||
def _SliceDataForScoreStatsTableCell(cls, scores, capture, render,
|
||||
echo_simulator):
|
||||
"""Slices |scores| to extract the data for a tab."""
|
||||
masks = []
|
||||
|
||||
masks.append(scores.capture == capture)
|
||||
masks.append(scores.render == render)
|
||||
masks.append(scores.echo_simulator == echo_simulator)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
|
||||
sliced_data = scores[mask]
|
||||
assert len(sliced_data) == 1, 'single score is expected'
|
||||
return sliced_data.iloc[0]
|
||||
|
||||
@classmethod
|
||||
def _FindUniqueTuples(cls, data_frame, fields):
|
||||
"""Slices |data_frame| to a list of fields and finds unique tuples."""
|
||||
return data_frame[fields].drop_duplicates().values.tolist()
|
||||
|
||||
@classmethod
|
||||
def _ComputeScoreStats(cls, data_frame):
|
||||
"""Computes score stats."""
|
||||
scores = data_frame['score']
|
||||
return {
|
||||
'count': scores.count(),
|
||||
'min': scores.min(),
|
||||
'max': scores.max(),
|
||||
'mean': scores.mean(),
|
||||
'std dev': scores.std(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config,
|
||||
test_data_gen, test_data_gen_params):
|
||||
"""Assigns a unique name to a dialog."""
|
||||
return 'score-stats-dialog-' + hashlib.md5(
|
||||
'score-stats-inspector-{}-{}-{}-{}'.format(
|
||||
score_name, apm_config, test_data_gen,
|
||||
test_data_gen_params).encode('utf-8')).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _Save(cls, output_filepath, html):
|
||||
"""Writes the HTML file.
|
||||
|
||||
Args:
|
||||
output_filepath: output file path.
|
||||
html: string with the HTML content.
|
||||
"""
|
||||
with open(output_filepath, 'w') as f:
|
||||
f.write(html)
|
||||
with open(output_filepath, 'w') as f:
|
||||
f.write(html)
|
||||
|
||||
@classmethod
|
||||
def _FormatName(cls, name):
|
||||
"""Formats a name.
|
||||
@classmethod
|
||||
def _FormatName(cls, name):
|
||||
"""Formats a name.
|
||||
|
||||
Args:
|
||||
name: a string.
|
||||
@ -399,4 +423,4 @@ class HtmlExport(object):
|
||||
Returns:
|
||||
A copy of name in which underscores and dashes are replaced with a space.
|
||||
"""
|
||||
return re.sub(r'[_\-]', ' ', name)
|
||||
return re.sub(r'[_\-]', ' ', name)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the export module.
|
||||
"""
|
||||
|
||||
@ -27,60 +26,61 @@ from . import test_data_generation_factory
|
||||
|
||||
|
||||
class TestExport(unittest.TestCase):
|
||||
"""Unit tests for the export module.
|
||||
"""Unit tests for the export module.
|
||||
"""
|
||||
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
|
||||
def setUp(self):
|
||||
"""Creates temporary data to export."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Creates temporary data to export."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
# Run a fake experiment to produce data to export.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
simulator.Run(
|
||||
config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=[
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
|
||||
],
|
||||
test_data_generator_names=['identity', 'white_noise'],
|
||||
eval_score_names=['audio_level_peak', 'audio_level_mean'],
|
||||
output_dir=self._tmp_path)
|
||||
# Run a fake experiment to produce data to export.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
simulator.Run(
|
||||
config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=[
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
|
||||
],
|
||||
test_data_generator_names=['identity', 'white_noise'],
|
||||
eval_score_names=['audio_level_peak', 'audio_level_mean'],
|
||||
output_dir=self._tmp_path)
|
||||
|
||||
# Export results.
|
||||
p = collect_data.InstanceArgumentsParser()
|
||||
args = p.parse_args(['--output_dir', self._tmp_path])
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
self._data_to_export = collect_data.FindScores(src_path, args)
|
||||
# Export results.
|
||||
p = collect_data.InstanceArgumentsParser()
|
||||
args = p.parse_args(['--output_dir', self._tmp_path])
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
self._data_to_export = collect_data.FindScores(src_path, args)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' + (
|
||||
self._tmp_path))
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' +
|
||||
(self._tmp_path))
|
||||
|
||||
def testCreateHtmlReport(self):
|
||||
fn_out = os.path.join(self._tmp_path, 'results.html')
|
||||
exporter = export.HtmlExport(fn_out)
|
||||
exporter.Export(self._data_to_export)
|
||||
def testCreateHtmlReport(self):
|
||||
fn_out = os.path.join(self._tmp_path, 'results.html')
|
||||
exporter = export.HtmlExport(fn_out)
|
||||
exporter.Export(self._data_to_export)
|
||||
|
||||
document = pq.PyQuery(filename=fn_out)
|
||||
self.assertIsInstance(document, pq.PyQuery)
|
||||
# TODO(alessiob): Use PyQuery API to check the HTML file.
|
||||
document = pq.PyQuery(filename=fn_out)
|
||||
self.assertIsInstance(document, pq.PyQuery)
|
||||
# TODO(alessiob): Use PyQuery API to check the HTML file.
|
||||
|
||||
@ -16,62 +16,60 @@ import sys
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import signal_processing
|
||||
|
||||
class ExternalVad(object):
|
||||
|
||||
def __init__(self, path_to_binary, name):
|
||||
"""Args:
|
||||
class ExternalVad(object):
|
||||
def __init__(self, path_to_binary, name):
|
||||
"""Args:
|
||||
path_to_binary: path to binary that accepts '-i <wav>', '-o
|
||||
<float probabilities>'. There must be one float value per
|
||||
10ms audio
|
||||
name: a name to identify the external VAD. Used for saving
|
||||
the output as extvad_output-<name>.
|
||||
"""
|
||||
self._path_to_binary = path_to_binary
|
||||
self.name = name
|
||||
assert os.path.exists(self._path_to_binary), (
|
||||
self._path_to_binary)
|
||||
self._vad_output = None
|
||||
self._path_to_binary = path_to_binary
|
||||
self.name = name
|
||||
assert os.path.exists(self._path_to_binary), (self._path_to_binary)
|
||||
self._vad_output = None
|
||||
|
||||
def Run(self, wav_file_path):
|
||||
_signal = signal_processing.SignalProcessingUtils.LoadWav(wav_file_path)
|
||||
if _signal.channels != 1:
|
||||
raise NotImplementedError('Multiple-channel'
|
||||
' annotations not implemented')
|
||||
if _signal.frame_rate != 48000:
|
||||
raise NotImplementedError('Frame rates '
|
||||
'other than 48000 not implemented')
|
||||
def Run(self, wav_file_path):
|
||||
_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
wav_file_path)
|
||||
if _signal.channels != 1:
|
||||
raise NotImplementedError('Multiple-channel'
|
||||
' annotations not implemented')
|
||||
if _signal.frame_rate != 48000:
|
||||
raise NotImplementedError('Frame rates '
|
||||
'other than 48000 not implemented')
|
||||
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
try:
|
||||
output_file_path = os.path.join(
|
||||
tmp_path, self.name + '_vad.tmp')
|
||||
subprocess.call([
|
||||
self._path_to_binary,
|
||||
'-i', wav_file_path,
|
||||
'-o', output_file_path
|
||||
])
|
||||
self._vad_output = np.fromfile(output_file_path, np.float32)
|
||||
except Exception as e:
|
||||
logging.error('Error while running the ' + self.name +
|
||||
' VAD (' + e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
try:
|
||||
output_file_path = os.path.join(tmp_path, self.name + '_vad.tmp')
|
||||
subprocess.call([
|
||||
self._path_to_binary, '-i', wav_file_path, '-o',
|
||||
output_file_path
|
||||
])
|
||||
self._vad_output = np.fromfile(output_file_path, np.float32)
|
||||
except Exception as e:
|
||||
logging.error('Error while running the ' + self.name + ' VAD (' +
|
||||
e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
|
||||
def GetVadOutput(self):
|
||||
assert self._vad_output is not None
|
||||
return self._vad_output
|
||||
def GetVadOutput(self):
|
||||
assert self._vad_output is not None
|
||||
return self._vad_output
|
||||
|
||||
@classmethod
|
||||
def ConstructVadDict(cls, vad_paths, vad_names):
|
||||
external_vads = {}
|
||||
for path, name in zip(vad_paths, vad_names):
|
||||
external_vads[name] = ExternalVad(path, name)
|
||||
return external_vads
|
||||
@classmethod
|
||||
def ConstructVadDict(cls, vad_paths, vad_names):
|
||||
external_vads = {}
|
||||
for path, name in zip(vad_paths, vad_names):
|
||||
external_vads[name] = ExternalVad(path, name)
|
||||
return external_vads
|
||||
|
||||
@ -9,16 +9,17 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', required=True)
|
||||
parser.add_argument('-o', required=True)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', required=True)
|
||||
parser.add_argument('-o', required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
array = np.arange(100, dtype=np.float32)
|
||||
array.tofile(open(args.o, 'w'))
|
||||
array = np.arange(100, dtype=np.float32)
|
||||
array.tofile(open(args.o, 'w'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Input mixer module.
|
||||
"""
|
||||
|
||||
@ -17,24 +16,24 @@ from . import signal_processing
|
||||
|
||||
|
||||
class ApmInputMixer(object):
|
||||
"""Class to mix a set of audio segments down to the APM input."""
|
||||
"""Class to mix a set of audio segments down to the APM input."""
|
||||
|
||||
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
|
||||
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def HardClippingLogMessage(cls):
|
||||
"""Returns the log message used when hard clipping is detected in the mix.
|
||||
@classmethod
|
||||
def HardClippingLogMessage(cls):
|
||||
"""Returns the log message used when hard clipping is detected in the mix.
|
||||
|
||||
This method is mainly intended to be used by the unit tests.
|
||||
"""
|
||||
return cls._HARD_CLIPPING_LOG_MSG
|
||||
return cls._HARD_CLIPPING_LOG_MSG
|
||||
|
||||
@classmethod
|
||||
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
|
||||
"""Mixes capture and echo.
|
||||
@classmethod
|
||||
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
|
||||
"""Mixes capture and echo.
|
||||
|
||||
Creates the overall capture input for APM by mixing the "echo-free" capture
|
||||
signal with the echo signal (e.g., echo simulated via the
|
||||
@ -58,38 +57,41 @@ class ApmInputMixer(object):
|
||||
Returns:
|
||||
Path to the mix audio track file.
|
||||
"""
|
||||
if echo_filepath is None:
|
||||
return capture_input_filepath
|
||||
if echo_filepath is None:
|
||||
return capture_input_filepath
|
||||
|
||||
# Build the mix output file name as a function of the echo file name.
|
||||
# This ensures that if the internal parameters of the echo path simulator
|
||||
# change, no erroneous cache hit occurs.
|
||||
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
|
||||
capture_input_file_name, _ = os.path.splitext(
|
||||
os.path.split(capture_input_filepath)[1])
|
||||
mix_filepath = os.path.join(output_path, 'mix_capture_{}_{}.wav'.format(
|
||||
capture_input_file_name, echo_file_name))
|
||||
# Build the mix output file name as a function of the echo file name.
|
||||
# This ensures that if the internal parameters of the echo path simulator
|
||||
# change, no erroneous cache hit occurs.
|
||||
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
|
||||
capture_input_file_name, _ = os.path.splitext(
|
||||
os.path.split(capture_input_filepath)[1])
|
||||
mix_filepath = os.path.join(
|
||||
output_path,
|
||||
'mix_capture_{}_{}.wav'.format(capture_input_file_name,
|
||||
echo_file_name))
|
||||
|
||||
# Create the mix if not done yet.
|
||||
mix = None
|
||||
if not os.path.exists(mix_filepath):
|
||||
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
capture_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(echo_filepath)
|
||||
# Create the mix if not done yet.
|
||||
mix = None
|
||||
if not os.path.exists(mix_filepath):
|
||||
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
capture_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
echo_filepath)
|
||||
|
||||
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(
|
||||
echo_free_capture)):
|
||||
raise exceptions.InputMixerException(
|
||||
'echo cannot be shorter than capture')
|
||||
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(
|
||||
echo_free_capture)):
|
||||
raise exceptions.InputMixerException(
|
||||
'echo cannot be shorter than capture')
|
||||
|
||||
mix = echo_free_capture.overlay(echo)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
|
||||
mix = echo_free_capture.overlay(echo)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
|
||||
|
||||
# Check if hard clipping occurs.
|
||||
if mix is None:
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
|
||||
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
|
||||
# Check if hard clipping occurs.
|
||||
if mix is None:
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
|
||||
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
|
||||
|
||||
return mix_filepath
|
||||
return mix_filepath
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the input mixer module.
|
||||
"""
|
||||
|
||||
@ -23,122 +22,119 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestApmInputMixer(unittest.TestCase):
|
||||
"""Unit tests for the ApmInputMixer class.
|
||||
"""Unit tests for the ApmInputMixer class.
|
||||
"""
|
||||
|
||||
# Audio track file names created in setUp().
|
||||
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
|
||||
# Audio track file names created in setUp().
|
||||
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
|
||||
|
||||
# Target peak power level (dBFS) of each audio track file created in setUp().
|
||||
# These values are hand-crafted in order to make saturation happen when
|
||||
# capture and echo_2 are mixed and the contrary for capture and echo_1.
|
||||
# None means that the power is not changed.
|
||||
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
|
||||
# Target peak power level (dBFS) of each audio track file created in setUp().
|
||||
# These values are hand-crafted in order to make saturation happen when
|
||||
# capture and echo_2 are mixed and the contrary for capture and echo_1.
|
||||
# None means that the power is not changed.
|
||||
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
|
||||
|
||||
# Audio track file durations in milliseconds.
|
||||
_DURATIONS = [1000, 1000, 1000, 800, 1200]
|
||||
# Audio track file durations in milliseconds.
|
||||
_DURATIONS = [1000, 1000, 1000, 800, 1200]
|
||||
|
||||
_SAMPLE_RATE = 48000
|
||||
_SAMPLE_RATE = 48000
|
||||
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
# Create audio track files.
|
||||
self._audio_tracks = {}
|
||||
for filename, peak_power, duration in zip(
|
||||
self._FILENAMES, self._MAX_PEAK_POWER_LEVELS, self._DURATIONS):
|
||||
audio_track_filepath = os.path.join(self._tmp_path, '{}.wav'.format(
|
||||
filename))
|
||||
# Create audio track files.
|
||||
self._audio_tracks = {}
|
||||
for filename, peak_power, duration in zip(self._FILENAMES,
|
||||
self._MAX_PEAK_POWER_LEVELS,
|
||||
self._DURATIONS):
|
||||
audio_track_filepath = os.path.join(self._tmp_path,
|
||||
'{}.wav'.format(filename))
|
||||
|
||||
# Create a pure tone with the target peak power level.
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(
|
||||
duration=duration, sample_rate=self._SAMPLE_RATE)
|
||||
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template)
|
||||
if peak_power is not None:
|
||||
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
|
||||
# Create a pure tone with the target peak power level.
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(
|
||||
duration=duration, sample_rate=self._SAMPLE_RATE)
|
||||
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template)
|
||||
if peak_power is not None:
|
||||
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
|
||||
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
audio_track_filepath, signal)
|
||||
self._audio_tracks[filename] = {
|
||||
'filepath': audio_track_filepath,
|
||||
'num_samples': signal_processing.SignalProcessingUtils.CountSamples(
|
||||
signal)
|
||||
}
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
audio_track_filepath, signal)
|
||||
self._audio_tracks[filename] = {
|
||||
'filepath':
|
||||
audio_track_filepath,
|
||||
'num_samples':
|
||||
signal_processing.SignalProcessingUtils.CountSamples(signal)
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
|
||||
def testCheckMixSameDuration(self):
|
||||
"""Checks the duration when mixing capture and echo with same duration."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
def testCheckMixSameDuration(self):
|
||||
"""Checks the duration when mixing capture and echo with same duration."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(
|
||||
self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
|
||||
def testRejectShorterEcho(self):
|
||||
"""Rejects echo signals that are shorter than the capture signal."""
|
||||
try:
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['shorter']['filepath'])
|
||||
self.fail('no exception raised')
|
||||
except exceptions.InputMixerException:
|
||||
pass
|
||||
def testRejectShorterEcho(self):
|
||||
"""Rejects echo signals that are shorter than the capture signal."""
|
||||
try:
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['shorter']['filepath'])
|
||||
self.fail('no exception raised')
|
||||
except exceptions.InputMixerException:
|
||||
pass
|
||||
|
||||
def testCheckMixDurationWithLongerEcho(self):
|
||||
"""Checks the duration when mixing an echo longer than the capture."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['longer']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
def testCheckMixDurationWithLongerEcho(self):
|
||||
"""Checks the duration when mixing an echo longer than the capture."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['longer']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(
|
||||
self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
|
||||
def testCheckOutputFileNamesConflict(self):
|
||||
"""Checks that different echo files lead to different output file names."""
|
||||
mix1_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix1_filepath))
|
||||
def testCheckOutputFileNamesConflict(self):
|
||||
"""Checks that different echo files lead to different output file names."""
|
||||
mix1_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix1_filepath))
|
||||
|
||||
mix2_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix2_filepath))
|
||||
mix2_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix2_filepath))
|
||||
|
||||
self.assertNotEqual(mix1_filepath, mix2_filepath)
|
||||
self.assertNotEqual(mix1_filepath, mix2_filepath)
|
||||
|
||||
def testHardClippingLogExpected(self):
|
||||
"""Checks that hard clipping warning is raised when occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
logging.warning.assert_called_once_with(
|
||||
input_mixer.ApmInputMixer.HardClippingLogMessage())
|
||||
def testHardClippingLogExpected(self):
|
||||
"""Checks that hard clipping warning is raised when occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
logging.warning.assert_called_once_with(
|
||||
input_mixer.ApmInputMixer.HardClippingLogMessage())
|
||||
|
||||
def testHardClippingLogNotExpected(self):
|
||||
"""Checks that hard clipping warning is not raised when not occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertNotIn(
|
||||
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
|
||||
logging.warning.call_args_list)
|
||||
def testHardClippingLogNotExpected(self):
|
||||
"""Checks that hard clipping warning is not raised when not occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertNotIn(
|
||||
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
|
||||
logging.warning.call_args_list)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Input signal creator module.
|
||||
"""
|
||||
|
||||
@ -14,12 +13,12 @@ from . import signal_processing
|
||||
|
||||
|
||||
class InputSignalCreator(object):
|
||||
"""Input signal creator class.
|
||||
"""Input signal creator class.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def Create(cls, name, raw_params):
|
||||
"""Creates a input signal and its metadata.
|
||||
@classmethod
|
||||
def Create(cls, name, raw_params):
|
||||
"""Creates a input signal and its metadata.
|
||||
|
||||
Args:
|
||||
name: Input signal creator name.
|
||||
@ -28,29 +27,30 @@ class InputSignalCreator(object):
|
||||
Returns:
|
||||
(AudioSegment, dict) tuple.
|
||||
"""
|
||||
try:
|
||||
signal = {}
|
||||
params = {}
|
||||
try:
|
||||
signal = {}
|
||||
params = {}
|
||||
|
||||
if name == 'pure_tone':
|
||||
params['frequency'] = float(raw_params[0])
|
||||
params['duration'] = int(raw_params[1])
|
||||
signal = cls._CreatePureTone(params['frequency'], params['duration'])
|
||||
else:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid input signal creator name')
|
||||
if name == 'pure_tone':
|
||||
params['frequency'] = float(raw_params[0])
|
||||
params['duration'] = int(raw_params[1])
|
||||
signal = cls._CreatePureTone(params['frequency'],
|
||||
params['duration'])
|
||||
else:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid input signal creator name')
|
||||
|
||||
# Complete metadata.
|
||||
params['signal'] = name
|
||||
# Complete metadata.
|
||||
params['signal'] = name
|
||||
|
||||
return signal, params
|
||||
except (TypeError, AssertionError) as e:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid signal creator parameters: {}'.format(e))
|
||||
return signal, params
|
||||
except (TypeError, AssertionError) as e:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid signal creator parameters: {}'.format(e))
|
||||
|
||||
@classmethod
|
||||
def _CreatePureTone(cls, frequency, duration):
|
||||
"""
|
||||
@classmethod
|
||||
def _CreatePureTone(cls, frequency, duration):
|
||||
"""
|
||||
Generates a pure tone at 48000 Hz.
|
||||
|
||||
Args:
|
||||
@ -60,8 +60,9 @@ class InputSignalCreator(object):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
assert 0 < frequency <= 24000
|
||||
assert duration > 0
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(duration)
|
||||
return signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, frequency)
|
||||
assert 0 < frequency <= 24000
|
||||
assert duration > 0
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(
|
||||
duration)
|
||||
return signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, frequency)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Signal processing utility module.
|
||||
"""
|
||||
|
||||
@ -16,44 +15,44 @@ import sys
|
||||
import enum
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import pydub
|
||||
import pydub.generators
|
||||
import pydub
|
||||
import pydub.generators
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package pydub')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package pydub')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import scipy.signal
|
||||
import scipy.fftpack
|
||||
import scipy.signal
|
||||
import scipy.fftpack
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import exceptions
|
||||
|
||||
|
||||
class SignalProcessingUtils(object):
|
||||
"""Collection of signal processing utilities.
|
||||
"""Collection of signal processing utilities.
|
||||
"""
|
||||
|
||||
@enum.unique
|
||||
class MixPadding(enum.Enum):
|
||||
NO_PADDING = 0
|
||||
ZERO_PADDING = 1
|
||||
LOOP = 2
|
||||
@enum.unique
|
||||
class MixPadding(enum.Enum):
|
||||
NO_PADDING = 0
|
||||
ZERO_PADDING = 1
|
||||
LOOP = 2
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def LoadWav(cls, filepath, channels=1):
|
||||
"""Loads wav file.
|
||||
@classmethod
|
||||
def LoadWav(cls, filepath, channels=1):
|
||||
"""Loads wav file.
|
||||
|
||||
Args:
|
||||
filepath: path to the wav audio track file to load.
|
||||
@ -62,25 +61,26 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
if not os.path.exists(filepath):
|
||||
logging.error('cannot find the <%s> audio track file', filepath)
|
||||
raise exceptions.FileNotFoundError()
|
||||
return pydub.AudioSegment.from_file(
|
||||
filepath, format='wav', channels=channels)
|
||||
if not os.path.exists(filepath):
|
||||
logging.error('cannot find the <%s> audio track file', filepath)
|
||||
raise exceptions.FileNotFoundError()
|
||||
return pydub.AudioSegment.from_file(filepath,
|
||||
format='wav',
|
||||
channels=channels)
|
||||
|
||||
@classmethod
|
||||
def SaveWav(cls, output_filepath, signal):
|
||||
"""Saves wav file.
|
||||
@classmethod
|
||||
def SaveWav(cls, output_filepath, signal):
|
||||
"""Saves wav file.
|
||||
|
||||
Args:
|
||||
output_filepath: path to the wav audio track file to save.
|
||||
signal: AudioSegment instance.
|
||||
"""
|
||||
return signal.export(output_filepath, format='wav')
|
||||
return signal.export(output_filepath, format='wav')
|
||||
|
||||
@classmethod
|
||||
def CountSamples(cls, signal):
|
||||
"""Number of samples per channel.
|
||||
@classmethod
|
||||
def CountSamples(cls, signal):
|
||||
"""Number of samples per channel.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
@ -88,14 +88,14 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
An integer.
|
||||
"""
|
||||
number_of_samples = len(signal.get_array_of_samples())
|
||||
assert signal.channels > 0
|
||||
assert number_of_samples % signal.channels == 0
|
||||
return number_of_samples / signal.channels
|
||||
number_of_samples = len(signal.get_array_of_samples())
|
||||
assert signal.channels > 0
|
||||
assert number_of_samples % signal.channels == 0
|
||||
return number_of_samples / signal.channels
|
||||
|
||||
@classmethod
|
||||
def GenerateSilence(cls, duration=1000, sample_rate=48000):
|
||||
"""Generates silence.
|
||||
@classmethod
|
||||
def GenerateSilence(cls, duration=1000, sample_rate=48000):
|
||||
"""Generates silence.
|
||||
|
||||
This method can also be used to create a template AudioSegment instance.
|
||||
A template can then be used with other Generate*() methods accepting an
|
||||
@ -108,11 +108,11 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
return pydub.AudioSegment.silent(duration, sample_rate)
|
||||
return pydub.AudioSegment.silent(duration, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def GeneratePureTone(cls, template, frequency=440.0):
|
||||
"""Generates a pure tone.
|
||||
@classmethod
|
||||
def GeneratePureTone(cls, template, frequency=440.0):
|
||||
"""Generates a pure tone.
|
||||
|
||||
The pure tone is generated with the same duration and in the same format of
|
||||
the given template signal.
|
||||
@ -124,21 +124,18 @@ class SignalProcessingUtils(object):
|
||||
Return:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
if frequency > template.frame_rate >> 1:
|
||||
raise exceptions.SignalProcessingException('Invalid frequency')
|
||||
if frequency > template.frame_rate >> 1:
|
||||
raise exceptions.SignalProcessingException('Invalid frequency')
|
||||
|
||||
generator = pydub.generators.Sine(
|
||||
sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8,
|
||||
freq=frequency)
|
||||
generator = pydub.generators.Sine(sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8,
|
||||
freq=frequency)
|
||||
|
||||
return generator.to_audio_segment(
|
||||
duration=len(template),
|
||||
volume=0.0)
|
||||
return generator.to_audio_segment(duration=len(template), volume=0.0)
|
||||
|
||||
@classmethod
|
||||
def GenerateWhiteNoise(cls, template):
|
||||
"""Generates white noise.
|
||||
@classmethod
|
||||
def GenerateWhiteNoise(cls, template):
|
||||
"""Generates white noise.
|
||||
|
||||
The white noise is generated with the same duration and in the same format
|
||||
of the given template signal.
|
||||
@ -149,33 +146,32 @@ class SignalProcessingUtils(object):
|
||||
Return:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
generator = pydub.generators.WhiteNoise(
|
||||
sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8)
|
||||
return generator.to_audio_segment(
|
||||
duration=len(template),
|
||||
volume=0.0)
|
||||
generator = pydub.generators.WhiteNoise(
|
||||
sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8)
|
||||
return generator.to_audio_segment(duration=len(template), volume=0.0)
|
||||
|
||||
@classmethod
|
||||
def AudioSegmentToRawData(cls, signal):
|
||||
samples = signal.get_array_of_samples()
|
||||
if samples.typecode != 'h':
|
||||
raise exceptions.SignalProcessingException('Unsupported samples type')
|
||||
return np.array(signal.get_array_of_samples(), np.int16)
|
||||
@classmethod
|
||||
def AudioSegmentToRawData(cls, signal):
|
||||
samples = signal.get_array_of_samples()
|
||||
if samples.typecode != 'h':
|
||||
raise exceptions.SignalProcessingException(
|
||||
'Unsupported samples type')
|
||||
return np.array(signal.get_array_of_samples(), np.int16)
|
||||
|
||||
@classmethod
|
||||
def Fft(cls, signal, normalize=True):
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError('multiple-channel FFT not implemented')
|
||||
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
|
||||
if normalize:
|
||||
x /= max(abs(np.max(x)), 1.0)
|
||||
y = scipy.fftpack.fft(x)
|
||||
return y[:len(y) / 2]
|
||||
@classmethod
|
||||
def Fft(cls, signal, normalize=True):
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError('multiple-channel FFT not implemented')
|
||||
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
|
||||
if normalize:
|
||||
x /= max(abs(np.max(x)), 1.0)
|
||||
y = scipy.fftpack.fft(x)
|
||||
return y[:len(y) / 2]
|
||||
|
||||
@classmethod
|
||||
def DetectHardClipping(cls, signal, threshold=2):
|
||||
"""Detects hard clipping.
|
||||
@classmethod
|
||||
def DetectHardClipping(cls, signal, threshold=2):
|
||||
"""Detects hard clipping.
|
||||
|
||||
Hard clipping is simply detected by counting samples that touch either the
|
||||
lower or upper bound too many times in a row (according to |threshold|).
|
||||
@ -189,32 +185,33 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
True if hard clipping is detect, False otherwise.
|
||||
"""
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError('multiple-channel clipping not implemented')
|
||||
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
|
||||
raise exceptions.SignalProcessingException(
|
||||
'hard-clipping detection only supported for 16 bit samples')
|
||||
samples = cls.AudioSegmentToRawData(signal)
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError(
|
||||
'multiple-channel clipping not implemented')
|
||||
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
|
||||
raise exceptions.SignalProcessingException(
|
||||
'hard-clipping detection only supported for 16 bit samples')
|
||||
samples = cls.AudioSegmentToRawData(signal)
|
||||
|
||||
# Detect adjacent clipped samples.
|
||||
samples_type_info = np.iinfo(samples.dtype)
|
||||
mask_min = samples == samples_type_info.min
|
||||
mask_max = samples == samples_type_info.max
|
||||
# Detect adjacent clipped samples.
|
||||
samples_type_info = np.iinfo(samples.dtype)
|
||||
mask_min = samples == samples_type_info.min
|
||||
mask_max = samples == samples_type_info.max
|
||||
|
||||
def HasLongSequence(vector, min_legth=threshold):
|
||||
"""Returns True if there are one or more long sequences of True flags."""
|
||||
seq_length = 0
|
||||
for b in vector:
|
||||
seq_length = seq_length + 1 if b else 0
|
||||
if seq_length >= min_legth:
|
||||
return True
|
||||
return False
|
||||
def HasLongSequence(vector, min_legth=threshold):
|
||||
"""Returns True if there are one or more long sequences of True flags."""
|
||||
seq_length = 0
|
||||
for b in vector:
|
||||
seq_length = seq_length + 1 if b else 0
|
||||
if seq_length >= min_legth:
|
||||
return True
|
||||
return False
|
||||
|
||||
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
|
||||
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
|
||||
|
||||
@classmethod
|
||||
def ApplyImpulseResponse(cls, signal, impulse_response):
|
||||
"""Applies an impulse response to a signal.
|
||||
@classmethod
|
||||
def ApplyImpulseResponse(cls, signal, impulse_response):
|
||||
"""Applies an impulse response to a signal.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
@ -223,44 +220,48 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
# Get samples.
|
||||
assert signal.channels == 1, (
|
||||
'multiple-channel recordings not supported')
|
||||
samples = signal.get_array_of_samples()
|
||||
# Get samples.
|
||||
assert signal.channels == 1, (
|
||||
'multiple-channel recordings not supported')
|
||||
samples = signal.get_array_of_samples()
|
||||
|
||||
# Convolve.
|
||||
logging.info('applying %d order impulse response to a signal lasting %d ms',
|
||||
len(impulse_response), len(signal))
|
||||
convolved_samples = scipy.signal.fftconvolve(
|
||||
in1=samples,
|
||||
in2=impulse_response,
|
||||
mode='full').astype(np.int16)
|
||||
logging.info('convolution computed')
|
||||
# Convolve.
|
||||
logging.info(
|
||||
'applying %d order impulse response to a signal lasting %d ms',
|
||||
len(impulse_response), len(signal))
|
||||
convolved_samples = scipy.signal.fftconvolve(in1=samples,
|
||||
in2=impulse_response,
|
||||
mode='full').astype(
|
||||
np.int16)
|
||||
logging.info('convolution computed')
|
||||
|
||||
# Cast.
|
||||
convolved_samples = array.array(signal.array_type, convolved_samples)
|
||||
# Cast.
|
||||
convolved_samples = array.array(signal.array_type, convolved_samples)
|
||||
|
||||
# Verify.
|
||||
logging.debug('signal length: %d samples', len(samples))
|
||||
logging.debug('convolved signal length: %d samples', len(convolved_samples))
|
||||
assert len(convolved_samples) > len(samples)
|
||||
# Verify.
|
||||
logging.debug('signal length: %d samples', len(samples))
|
||||
logging.debug('convolved signal length: %d samples',
|
||||
len(convolved_samples))
|
||||
assert len(convolved_samples) > len(samples)
|
||||
|
||||
# Generate convolved signal AudioSegment instance.
|
||||
convolved_signal = pydub.AudioSegment(
|
||||
data=convolved_samples,
|
||||
metadata={
|
||||
'sample_width': signal.sample_width,
|
||||
'frame_rate': signal.frame_rate,
|
||||
'frame_width': signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
assert len(convolved_signal) > len(signal)
|
||||
# Generate convolved signal AudioSegment instance.
|
||||
convolved_signal = pydub.AudioSegment(data=convolved_samples,
|
||||
metadata={
|
||||
'sample_width':
|
||||
signal.sample_width,
|
||||
'frame_rate':
|
||||
signal.frame_rate,
|
||||
'frame_width':
|
||||
signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
assert len(convolved_signal) > len(signal)
|
||||
|
||||
return convolved_signal
|
||||
return convolved_signal
|
||||
|
||||
@classmethod
|
||||
def Normalize(cls, signal):
|
||||
"""Normalizes a signal.
|
||||
@classmethod
|
||||
def Normalize(cls, signal):
|
||||
"""Normalizes a signal.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
@ -268,11 +269,11 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
An AudioSegment instance.
|
||||
"""
|
||||
return signal.apply_gain(-signal.max_dBFS)
|
||||
return signal.apply_gain(-signal.max_dBFS)
|
||||
|
||||
@classmethod
|
||||
def Copy(cls, signal):
|
||||
"""Makes a copy os a signal.
|
||||
@classmethod
|
||||
def Copy(cls, signal):
|
||||
"""Makes a copy os a signal.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
@ -280,19 +281,21 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
An AudioSegment instance.
|
||||
"""
|
||||
return pydub.AudioSegment(
|
||||
data=signal.get_array_of_samples(),
|
||||
metadata={
|
||||
'sample_width': signal.sample_width,
|
||||
'frame_rate': signal.frame_rate,
|
||||
'frame_width': signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
return pydub.AudioSegment(data=signal.get_array_of_samples(),
|
||||
metadata={
|
||||
'sample_width': signal.sample_width,
|
||||
'frame_rate': signal.frame_rate,
|
||||
'frame_width': signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def MixSignals(cls, signal, noise, target_snr=0.0,
|
||||
pad_noise=MixPadding.NO_PADDING):
|
||||
"""Mixes |signal| and |noise| with a target SNR.
|
||||
@classmethod
|
||||
def MixSignals(cls,
|
||||
signal,
|
||||
noise,
|
||||
target_snr=0.0,
|
||||
pad_noise=MixPadding.NO_PADDING):
|
||||
"""Mixes |signal| and |noise| with a target SNR.
|
||||
|
||||
Mix |signal| and |noise| with a desired SNR by scaling |noise|.
|
||||
If the target SNR is +/- infinite, a copy of signal/noise is returned.
|
||||
@ -312,45 +315,45 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
An AudioSegment instance.
|
||||
"""
|
||||
# Handle infinite target SNR.
|
||||
if target_snr == -np.Inf:
|
||||
# Return a copy of noise.
|
||||
logging.warning('SNR = -Inf, returning noise')
|
||||
return cls.Copy(noise)
|
||||
elif target_snr == np.Inf:
|
||||
# Return a copy of signal.
|
||||
logging.warning('SNR = +Inf, returning signal')
|
||||
return cls.Copy(signal)
|
||||
# Handle infinite target SNR.
|
||||
if target_snr == -np.Inf:
|
||||
# Return a copy of noise.
|
||||
logging.warning('SNR = -Inf, returning noise')
|
||||
return cls.Copy(noise)
|
||||
elif target_snr == np.Inf:
|
||||
# Return a copy of signal.
|
||||
logging.warning('SNR = +Inf, returning signal')
|
||||
return cls.Copy(signal)
|
||||
|
||||
# Check signal and noise power.
|
||||
signal_power = float(signal.dBFS)
|
||||
noise_power = float(noise.dBFS)
|
||||
if signal_power == -np.Inf:
|
||||
logging.error('signal has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
if noise_power == -np.Inf:
|
||||
logging.error('noise has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
# Check signal and noise power.
|
||||
signal_power = float(signal.dBFS)
|
||||
noise_power = float(noise.dBFS)
|
||||
if signal_power == -np.Inf:
|
||||
logging.error('signal has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
if noise_power == -np.Inf:
|
||||
logging.error('noise has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
|
||||
# Mix.
|
||||
gain_db = signal_power - noise_power - target_snr
|
||||
signal_duration = len(signal)
|
||||
noise_duration = len(noise)
|
||||
if signal_duration <= noise_duration:
|
||||
# Ignore |pad_noise|, |noise| is truncated if longer that |signal|, the
|
||||
# mix will have the same length of |signal|.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.NO_PADDING:
|
||||
# |signal| is longer than |noise|, but no padding is applied to |noise|.
|
||||
# Truncate |signal|.
|
||||
return noise.overlay(signal, gain_during_overlay=gain_db)
|
||||
elif pad_noise == cls.MixPadding.ZERO_PADDING:
|
||||
# TODO(alessiob): Check that this works as expected.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.LOOP:
|
||||
# |signal| is longer than |noise|, extend |noise| by looping.
|
||||
return signal.overlay(noise.apply_gain(gain_db), loop=True)
|
||||
else:
|
||||
raise exceptions.SignalProcessingException('invalid padding type')
|
||||
# Mix.
|
||||
gain_db = signal_power - noise_power - target_snr
|
||||
signal_duration = len(signal)
|
||||
noise_duration = len(noise)
|
||||
if signal_duration <= noise_duration:
|
||||
# Ignore |pad_noise|, |noise| is truncated if longer that |signal|, the
|
||||
# mix will have the same length of |signal|.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.NO_PADDING:
|
||||
# |signal| is longer than |noise|, but no padding is applied to |noise|.
|
||||
# Truncate |signal|.
|
||||
return noise.overlay(signal, gain_during_overlay=gain_db)
|
||||
elif pad_noise == cls.MixPadding.ZERO_PADDING:
|
||||
# TODO(alessiob): Check that this works as expected.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.LOOP:
|
||||
# |signal| is longer than |noise|, extend |noise| by looping.
|
||||
return signal.overlay(noise.apply_gain(gain_db), loop=True)
|
||||
else:
|
||||
raise exceptions.SignalProcessingException('invalid padding type')
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the signal_processing module.
|
||||
"""
|
||||
|
||||
@ -19,168 +18,166 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestSignalProcessing(unittest.TestCase):
|
||||
"""Unit tests for the signal_processing module.
|
||||
"""Unit tests for the signal_processing module.
|
||||
"""
|
||||
|
||||
def testMixSignals(self):
|
||||
# Generate a template signal with which white noise can be generated.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
def testMixSignals(self):
|
||||
# Generate a template signal with which white noise can be generated.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
|
||||
# Generate two distinct AudioSegment instances with 1 second of white noise.
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
# Generate two distinct AudioSegment instances with 1 second of white noise.
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
|
||||
# Extract samples.
|
||||
signal_samples = signal.get_array_of_samples()
|
||||
noise_samples = noise.get_array_of_samples()
|
||||
# Extract samples.
|
||||
signal_samples = signal.get_array_of_samples()
|
||||
noise_samples = noise.get_array_of_samples()
|
||||
|
||||
# Test target SNR -Inf (noise expected).
|
||||
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, -np.Inf)
|
||||
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
|
||||
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
|
||||
# Test target SNR -Inf (noise expected).
|
||||
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, -np.Inf)
|
||||
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
|
||||
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
|
||||
|
||||
# Test target SNR 0.0 (different data expected).
|
||||
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, 0.0)
|
||||
self.assertTrue(len(signal), len(mix_0)) # Check duration.
|
||||
self.assertTrue(len(noise), len(mix_0))
|
||||
mix_0_samples = mix_0.get_array_of_samples()
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
|
||||
# Test target SNR 0.0 (different data expected).
|
||||
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, 0.0)
|
||||
self.assertTrue(len(signal), len(mix_0)) # Check duration.
|
||||
self.assertTrue(len(noise), len(mix_0))
|
||||
mix_0_samples = mix_0.get_array_of_samples()
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
|
||||
|
||||
# Test target SNR +Inf (signal expected).
|
||||
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, np.Inf)
|
||||
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
|
||||
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
|
||||
# Test target SNR +Inf (signal expected).
|
||||
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, np.Inf)
|
||||
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
|
||||
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
|
||||
|
||||
def testMixSignalsMinInfPower(self):
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
def testMixSignalsMinInfPower(self):
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, silence, 0.0)
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, silence, 0.0)
|
||||
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
silence, signal, 0.0)
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
silence, signal, 0.0)
|
||||
|
||||
def testMixSignalNoiseDifferentLengths(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
|
||||
def testMixSignalNoiseDifferentLengths(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
|
||||
|
||||
# When the signal is shorter than the noise, the mix length always equals
|
||||
# that of the signal regardless of whether padding is applied.
|
||||
# No noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# When the signal is shorter than the noise, the mix length always equals
|
||||
# that of the signal regardless of whether padding is applied.
|
||||
# No noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
ZERO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
|
||||
# When the signal is longer than the noise, the mix length depends on
|
||||
# whether padding is applied.
|
||||
# No noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
|
||||
self.assertEqual(len(longer), len(mix))
|
||||
# When the signal is longer than the noise, the mix length depends on
|
||||
# whether padding is applied.
|
||||
# No noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
ZERO_PADDING)
|
||||
self.assertEqual(len(longer), len(mix))
|
||||
|
||||
def testMixSignalNoisePaddingTypes(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
|
||||
def testMixSignalNoisePaddingTypes(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
|
||||
|
||||
# Zero padding: expect pure tone only in 1-2s.
|
||||
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
|
||||
# Zero padding: expect pure tone only in 1-2s.
|
||||
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
ZERO_PADDING)
|
||||
|
||||
# Loop: expect pure tone plus noise in 1-2s.
|
||||
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
|
||||
# Loop: expect pure tone plus noise in 1-2s.
|
||||
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
|
||||
|
||||
def Energy(signal):
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
signal).astype(np.float32)
|
||||
return np.sum(samples * samples)
|
||||
def Energy(signal):
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
signal).astype(np.float32)
|
||||
return np.sum(samples * samples)
|
||||
|
||||
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
|
||||
e_mix_loop = Energy(mix_loop[-1000:])
|
||||
self.assertLess(0, e_mix_zero_pad)
|
||||
self.assertLess(e_mix_zero_pad, e_mix_loop)
|
||||
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
|
||||
e_mix_loop = Energy(mix_loop[-1000:])
|
||||
self.assertLess(0, e_mix_zero_pad)
|
||||
self.assertLess(e_mix_zero_pad, e_mix_loop)
|
||||
|
||||
def testMixSignalSnr(self):
|
||||
# Test signals.
|
||||
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
|
||||
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
|
||||
def testMixSignalSnr(self):
|
||||
# Test signals.
|
||||
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
|
||||
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
|
||||
|
||||
def ToneAmplitudes(mix):
|
||||
"""Returns the amplitude of the coefficients #16 and #192, which
|
||||
def ToneAmplitudes(mix):
|
||||
"""Returns the amplitude of the coefficients #16 and #192, which
|
||||
correspond to the tones at 250 and 3k Hz respectively."""
|
||||
mix_fft = np.absolute(signal_processing.SignalProcessingUtils.Fft(mix))
|
||||
return mix_fft[16], mix_fft[192]
|
||||
mix_fft = np.absolute(
|
||||
signal_processing.SignalProcessingUtils.Fft(mix))
|
||||
return mix_fft[16], mix_fft[192]
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low,
|
||||
noise=tone_high,
|
||||
target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low, noise=tone_high, target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high,
|
||||
noise=tone_low,
|
||||
target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high, noise=tone_low, target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low,
|
||||
noise=tone_high,
|
||||
target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low, noise=tone_high, target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high,
|
||||
noise=tone_low,
|
||||
target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high, noise=tone_low, target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""APM module simulator.
|
||||
"""
|
||||
|
||||
@ -25,85 +24,93 @@ from . import test_data_generation
|
||||
|
||||
|
||||
class ApmModuleSimulator(object):
|
||||
"""Audio processing module (APM) simulator class.
|
||||
"""Audio processing module (APM) simulator class.
|
||||
"""
|
||||
|
||||
_TEST_DATA_GENERATOR_CLASSES = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
_TEST_DATA_GENERATOR_CLASSES = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
|
||||
_PREFIX_APM_CONFIG = 'apmcfg-'
|
||||
_PREFIX_CAPTURE = 'capture-'
|
||||
_PREFIX_RENDER = 'render-'
|
||||
_PREFIX_ECHO_SIMULATOR = 'echosim-'
|
||||
_PREFIX_TEST_DATA_GEN = 'datagen-'
|
||||
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
|
||||
_PREFIX_SCORE = 'score-'
|
||||
_PREFIX_APM_CONFIG = 'apmcfg-'
|
||||
_PREFIX_CAPTURE = 'capture-'
|
||||
_PREFIX_RENDER = 'render-'
|
||||
_PREFIX_ECHO_SIMULATOR = 'echosim-'
|
||||
_PREFIX_TEST_DATA_GEN = 'datagen-'
|
||||
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
|
||||
_PREFIX_SCORE = 'score-'
|
||||
|
||||
def __init__(self, test_data_generator_factory, evaluation_score_factory,
|
||||
ap_wrapper, evaluator, external_vads=None):
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._test_data_generator_factory = test_data_generator_factory
|
||||
self._evaluation_score_factory = evaluation_score_factory
|
||||
self._audioproc_wrapper = ap_wrapper
|
||||
self._evaluator = evaluator
|
||||
self._annotator = annotations.AudioAnnotationsExtractor(
|
||||
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD |
|
||||
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO |
|
||||
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
|
||||
external_vads
|
||||
)
|
||||
def __init__(self,
|
||||
test_data_generator_factory,
|
||||
evaluation_score_factory,
|
||||
ap_wrapper,
|
||||
evaluator,
|
||||
external_vads=None):
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._test_data_generator_factory = test_data_generator_factory
|
||||
self._evaluation_score_factory = evaluation_score_factory
|
||||
self._audioproc_wrapper = ap_wrapper
|
||||
self._evaluator = evaluator
|
||||
self._annotator = annotations.AudioAnnotationsExtractor(
|
||||
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD
|
||||
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO
|
||||
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
|
||||
external_vads)
|
||||
|
||||
# Init.
|
||||
self._test_data_generator_factory.SetOutputDirectoryPrefix(
|
||||
self._PREFIX_TEST_DATA_GEN_PARAMS)
|
||||
self._evaluation_score_factory.SetScoreFilenamePrefix(
|
||||
self._PREFIX_SCORE)
|
||||
# Init.
|
||||
self._test_data_generator_factory.SetOutputDirectoryPrefix(
|
||||
self._PREFIX_TEST_DATA_GEN_PARAMS)
|
||||
self._evaluation_score_factory.SetScoreFilenamePrefix(
|
||||
self._PREFIX_SCORE)
|
||||
|
||||
# Properties for each run.
|
||||
self._base_output_path = None
|
||||
self._output_cache_path = None
|
||||
self._test_data_generators = None
|
||||
self._evaluation_score_workers = None
|
||||
self._config_filepaths = None
|
||||
self._capture_input_filepaths = None
|
||||
self._render_input_filepaths = None
|
||||
self._echo_path_simulator_class = None
|
||||
# Properties for each run.
|
||||
self._base_output_path = None
|
||||
self._output_cache_path = None
|
||||
self._test_data_generators = None
|
||||
self._evaluation_score_workers = None
|
||||
self._config_filepaths = None
|
||||
self._capture_input_filepaths = None
|
||||
self._render_input_filepaths = None
|
||||
self._echo_path_simulator_class = None
|
||||
|
||||
@classmethod
|
||||
def GetPrefixApmConfig(cls):
|
||||
return cls._PREFIX_APM_CONFIG
|
||||
@classmethod
|
||||
def GetPrefixApmConfig(cls):
|
||||
return cls._PREFIX_APM_CONFIG
|
||||
|
||||
@classmethod
|
||||
def GetPrefixCapture(cls):
|
||||
return cls._PREFIX_CAPTURE
|
||||
@classmethod
|
||||
def GetPrefixCapture(cls):
|
||||
return cls._PREFIX_CAPTURE
|
||||
|
||||
@classmethod
|
||||
def GetPrefixRender(cls):
|
||||
return cls._PREFIX_RENDER
|
||||
@classmethod
|
||||
def GetPrefixRender(cls):
|
||||
return cls._PREFIX_RENDER
|
||||
|
||||
@classmethod
|
||||
def GetPrefixEchoSimulator(cls):
|
||||
return cls._PREFIX_ECHO_SIMULATOR
|
||||
@classmethod
|
||||
def GetPrefixEchoSimulator(cls):
|
||||
return cls._PREFIX_ECHO_SIMULATOR
|
||||
|
||||
@classmethod
|
||||
def GetPrefixTestDataGenerator(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN
|
||||
@classmethod
|
||||
def GetPrefixTestDataGenerator(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN
|
||||
|
||||
@classmethod
|
||||
def GetPrefixTestDataGeneratorParameters(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN_PARAMS
|
||||
@classmethod
|
||||
def GetPrefixTestDataGeneratorParameters(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN_PARAMS
|
||||
|
||||
@classmethod
|
||||
def GetPrefixScore(cls):
|
||||
return cls._PREFIX_SCORE
|
||||
@classmethod
|
||||
def GetPrefixScore(cls):
|
||||
return cls._PREFIX_SCORE
|
||||
|
||||
def Run(self, config_filepaths, capture_input_filepaths,
|
||||
test_data_generator_names, eval_score_names, output_dir,
|
||||
render_input_filepaths=None, echo_path_simulator_name=(
|
||||
echo_path_simulation.NoEchoPathSimulator.NAME)):
|
||||
"""Runs the APM simulation.
|
||||
def Run(self,
|
||||
config_filepaths,
|
||||
capture_input_filepaths,
|
||||
test_data_generator_names,
|
||||
eval_score_names,
|
||||
output_dir,
|
||||
render_input_filepaths=None,
|
||||
echo_path_simulator_name=(
|
||||
echo_path_simulation.NoEchoPathSimulator.NAME)):
|
||||
"""Runs the APM simulation.
|
||||
|
||||
Initializes paths and required instances, then runs all the simulations.
|
||||
The render input can be optionally added. If added, the number of capture
|
||||
@ -120,132 +127,140 @@ class ApmModuleSimulator(object):
|
||||
echo_path_simulator_name: name of the echo path simulator to use when
|
||||
render input is provided.
|
||||
"""
|
||||
assert render_input_filepaths is None or (
|
||||
len(capture_input_filepaths) == len(render_input_filepaths)), (
|
||||
'render input set size not matching input set size')
|
||||
assert render_input_filepaths is None or echo_path_simulator_name in (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
|
||||
'invalid echo path simulator')
|
||||
self._base_output_path = os.path.abspath(output_dir)
|
||||
assert render_input_filepaths is None or (
|
||||
len(capture_input_filepaths) == len(render_input_filepaths)), (
|
||||
'render input set size not matching input set size')
|
||||
assert render_input_filepaths is None or echo_path_simulator_name in (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
|
||||
'invalid echo path simulator')
|
||||
self._base_output_path = os.path.abspath(output_dir)
|
||||
|
||||
# Output path used to cache the data shared across simulations.
|
||||
self._output_cache_path = os.path.join(self._base_output_path, '_cache')
|
||||
# Output path used to cache the data shared across simulations.
|
||||
self._output_cache_path = os.path.join(self._base_output_path,
|
||||
'_cache')
|
||||
|
||||
# Instance test data generators.
|
||||
self._test_data_generators = [self._test_data_generator_factory.GetInstance(
|
||||
test_data_generators_class=(
|
||||
self._TEST_DATA_GENERATOR_CLASSES[name])) for name in (
|
||||
test_data_generator_names)]
|
||||
# Instance test data generators.
|
||||
self._test_data_generators = [
|
||||
self._test_data_generator_factory.GetInstance(
|
||||
test_data_generators_class=(
|
||||
self._TEST_DATA_GENERATOR_CLASSES[name]))
|
||||
for name in (test_data_generator_names)
|
||||
]
|
||||
|
||||
# Instance evaluation score workers.
|
||||
self._evaluation_score_workers = [
|
||||
self._evaluation_score_factory.GetInstance(
|
||||
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name]) for (
|
||||
name) in eval_score_names]
|
||||
# Instance evaluation score workers.
|
||||
self._evaluation_score_workers = [
|
||||
self._evaluation_score_factory.GetInstance(
|
||||
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name])
|
||||
for (name) in eval_score_names
|
||||
]
|
||||
|
||||
# Set APM configuration file paths.
|
||||
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
|
||||
# Set APM configuration file paths.
|
||||
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
|
||||
|
||||
# Set probing signal file paths.
|
||||
if render_input_filepaths is None:
|
||||
# Capture input only.
|
||||
self._capture_input_filepaths = self._CreatePathsCollection(
|
||||
capture_input_filepaths)
|
||||
self._render_input_filepaths = None
|
||||
else:
|
||||
# Set both capture and render input signals.
|
||||
self._SetTestInputSignalFilePaths(
|
||||
capture_input_filepaths, render_input_filepaths)
|
||||
# Set probing signal file paths.
|
||||
if render_input_filepaths is None:
|
||||
# Capture input only.
|
||||
self._capture_input_filepaths = self._CreatePathsCollection(
|
||||
capture_input_filepaths)
|
||||
self._render_input_filepaths = None
|
||||
else:
|
||||
# Set both capture and render input signals.
|
||||
self._SetTestInputSignalFilePaths(capture_input_filepaths,
|
||||
render_input_filepaths)
|
||||
|
||||
# Set the echo path simulator class.
|
||||
self._echo_path_simulator_class = (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES[
|
||||
echo_path_simulator_name])
|
||||
# Set the echo path simulator class.
|
||||
self._echo_path_simulator_class = (
|
||||
echo_path_simulation.EchoPathSimulator.
|
||||
REGISTERED_CLASSES[echo_path_simulator_name])
|
||||
|
||||
self._SimulateAll()
|
||||
self._SimulateAll()
|
||||
|
||||
def _SimulateAll(self):
|
||||
"""Runs all the simulations.
|
||||
def _SimulateAll(self):
|
||||
"""Runs all the simulations.
|
||||
|
||||
Iterates over the combinations of APM configurations, probing signals, and
|
||||
test data generators. This method is mainly responsible for the creation of
|
||||
the cache and output directories required in order to call _Simulate().
|
||||
"""
|
||||
without_render_input = self._render_input_filepaths is None
|
||||
without_render_input = self._render_input_filepaths is None
|
||||
|
||||
# Try different APM config files.
|
||||
for config_name in self._config_filepaths:
|
||||
config_filepath = self._config_filepaths[config_name]
|
||||
# Try different APM config files.
|
||||
for config_name in self._config_filepaths:
|
||||
config_filepath = self._config_filepaths[config_name]
|
||||
|
||||
# Try different capture-render pairs.
|
||||
for capture_input_name in self._capture_input_filepaths:
|
||||
# Output path for the capture signal annotations.
|
||||
capture_annotations_cache_path = os.path.join(
|
||||
self._output_cache_path,
|
||||
self._PREFIX_CAPTURE + capture_input_name)
|
||||
data_access.MakeDirectory(capture_annotations_cache_path)
|
||||
# Try different capture-render pairs.
|
||||
for capture_input_name in self._capture_input_filepaths:
|
||||
# Output path for the capture signal annotations.
|
||||
capture_annotations_cache_path = os.path.join(
|
||||
self._output_cache_path,
|
||||
self._PREFIX_CAPTURE + capture_input_name)
|
||||
data_access.MakeDirectory(capture_annotations_cache_path)
|
||||
|
||||
# Capture.
|
||||
capture_input_filepath = self._capture_input_filepaths[
|
||||
capture_input_name]
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
# If the input signal file does not exist, try to create using the
|
||||
# available input signal creators.
|
||||
self._CreateInputSignal(capture_input_filepath)
|
||||
assert os.path.exists(capture_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
capture_input_filepath, capture_annotations_cache_path)
|
||||
# Capture.
|
||||
capture_input_filepath = self._capture_input_filepaths[
|
||||
capture_input_name]
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
# If the input signal file does not exist, try to create using the
|
||||
# available input signal creators.
|
||||
self._CreateInputSignal(capture_input_filepath)
|
||||
assert os.path.exists(capture_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
capture_input_filepath, capture_annotations_cache_path)
|
||||
|
||||
# Render and simulated echo path (optional).
|
||||
render_input_filepath = None if without_render_input else (
|
||||
self._render_input_filepaths[capture_input_name])
|
||||
render_input_name = '(none)' if without_render_input else (
|
||||
self._ExtractFileName(render_input_filepath))
|
||||
echo_path_simulator = (
|
||||
echo_path_simulation_factory.EchoPathSimulatorFactory.GetInstance(
|
||||
self._echo_path_simulator_class, render_input_filepath))
|
||||
# Render and simulated echo path (optional).
|
||||
render_input_filepath = None if without_render_input else (
|
||||
self._render_input_filepaths[capture_input_name])
|
||||
render_input_name = '(none)' if without_render_input else (
|
||||
self._ExtractFileName(render_input_filepath))
|
||||
echo_path_simulator = (echo_path_simulation_factory.
|
||||
EchoPathSimulatorFactory.GetInstance(
|
||||
self._echo_path_simulator_class,
|
||||
render_input_filepath))
|
||||
|
||||
# Try different test data generators.
|
||||
for test_data_generators in self._test_data_generators:
|
||||
logging.info('APM config preset: <%s>, capture: <%s>, render: <%s>,'
|
||||
'test data generator: <%s>, echo simulator: <%s>',
|
||||
config_name, capture_input_name, render_input_name,
|
||||
test_data_generators.NAME, echo_path_simulator.NAME)
|
||||
# Try different test data generators.
|
||||
for test_data_generators in self._test_data_generators:
|
||||
logging.info(
|
||||
'APM config preset: <%s>, capture: <%s>, render: <%s>,'
|
||||
'test data generator: <%s>, echo simulator: <%s>',
|
||||
config_name, capture_input_name, render_input_name,
|
||||
test_data_generators.NAME, echo_path_simulator.NAME)
|
||||
|
||||
# Output path for the generated test data.
|
||||
test_data_cache_path = os.path.join(
|
||||
capture_annotations_cache_path,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(test_data_cache_path)
|
||||
logging.debug('test data cache path: <%s>', test_data_cache_path)
|
||||
# Output path for the generated test data.
|
||||
test_data_cache_path = os.path.join(
|
||||
capture_annotations_cache_path,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(test_data_cache_path)
|
||||
logging.debug('test data cache path: <%s>',
|
||||
test_data_cache_path)
|
||||
|
||||
# Output path for the echo simulator and APM input mixer output.
|
||||
echo_test_data_cache_path = os.path.join(
|
||||
test_data_cache_path, 'echosim-{}'.format(
|
||||
echo_path_simulator.NAME))
|
||||
data_access.MakeDirectory(echo_test_data_cache_path)
|
||||
logging.debug('echo test data cache path: <%s>',
|
||||
echo_test_data_cache_path)
|
||||
# Output path for the echo simulator and APM input mixer output.
|
||||
echo_test_data_cache_path = os.path.join(
|
||||
test_data_cache_path,
|
||||
'echosim-{}'.format(echo_path_simulator.NAME))
|
||||
data_access.MakeDirectory(echo_test_data_cache_path)
|
||||
logging.debug('echo test data cache path: <%s>',
|
||||
echo_test_data_cache_path)
|
||||
|
||||
# Full output path.
|
||||
output_path = os.path.join(
|
||||
self._base_output_path,
|
||||
self._PREFIX_APM_CONFIG + config_name,
|
||||
self._PREFIX_CAPTURE + capture_input_name,
|
||||
self._PREFIX_RENDER + render_input_name,
|
||||
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(output_path)
|
||||
logging.debug('output path: <%s>', output_path)
|
||||
# Full output path.
|
||||
output_path = os.path.join(
|
||||
self._base_output_path,
|
||||
self._PREFIX_APM_CONFIG + config_name,
|
||||
self._PREFIX_CAPTURE + capture_input_name,
|
||||
self._PREFIX_RENDER + render_input_name,
|
||||
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(output_path)
|
||||
logging.debug('output path: <%s>', output_path)
|
||||
|
||||
self._Simulate(test_data_generators, capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path,
|
||||
config_filepath, echo_path_simulator)
|
||||
self._Simulate(test_data_generators,
|
||||
capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path,
|
||||
config_filepath, echo_path_simulator)
|
||||
|
||||
@staticmethod
|
||||
def _CreateInputSignal(input_signal_filepath):
|
||||
"""Creates a missing input signal file.
|
||||
@staticmethod
|
||||
def _CreateInputSignal(input_signal_filepath):
|
||||
"""Creates a missing input signal file.
|
||||
|
||||
The file name is parsed to extract input signal creator and params. If a
|
||||
creator is matched and the parameters are valid, a new signal is generated
|
||||
@ -257,30 +272,33 @@ class ApmModuleSimulator(object):
|
||||
Raises:
|
||||
InputSignalCreatorException
|
||||
"""
|
||||
filename = os.path.splitext(os.path.split(input_signal_filepath)[-1])[0]
|
||||
filename_parts = filename.split('-')
|
||||
filename = os.path.splitext(
|
||||
os.path.split(input_signal_filepath)[-1])[0]
|
||||
filename_parts = filename.split('-')
|
||||
|
||||
if len(filename_parts) < 2:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Cannot parse input signal file name')
|
||||
if len(filename_parts) < 2:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Cannot parse input signal file name')
|
||||
|
||||
signal, metadata = input_signal_creator.InputSignalCreator.Create(
|
||||
filename_parts[0], filename_parts[1].split('_'))
|
||||
signal, metadata = input_signal_creator.InputSignalCreator.Create(
|
||||
filename_parts[0], filename_parts[1].split('_'))
|
||||
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
input_signal_filepath, signal)
|
||||
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
input_signal_filepath, signal)
|
||||
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
|
||||
|
||||
def _ExtractCaptureAnnotations(self, input_filepath, output_path,
|
||||
annotation_name=""):
|
||||
self._annotator.Extract(input_filepath)
|
||||
self._annotator.Save(output_path, annotation_name)
|
||||
def _ExtractCaptureAnnotations(self,
|
||||
input_filepath,
|
||||
output_path,
|
||||
annotation_name=""):
|
||||
self._annotator.Extract(input_filepath)
|
||||
self._annotator.Save(output_path, annotation_name)
|
||||
|
||||
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path, config_filepath,
|
||||
echo_path_simulator):
|
||||
"""Runs a single set of simulation.
|
||||
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path, config_filepath,
|
||||
echo_path_simulator):
|
||||
"""Runs a single set of simulation.
|
||||
|
||||
Simulates a given combination of APM configuration, probing signal, and
|
||||
test data generator. It iterates over the test data generator
|
||||
@ -298,90 +316,92 @@ class ApmModuleSimulator(object):
|
||||
config_filepath: APM configuration file to test.
|
||||
echo_path_simulator: EchoPathSimulator instance.
|
||||
"""
|
||||
# Generate pairs of noisy input and reference signal files.
|
||||
test_data_generators.Generate(
|
||||
input_signal_filepath=clean_capture_input_filepath,
|
||||
test_data_cache_path=test_data_cache_path,
|
||||
base_output_path=output_path)
|
||||
# Generate pairs of noisy input and reference signal files.
|
||||
test_data_generators.Generate(
|
||||
input_signal_filepath=clean_capture_input_filepath,
|
||||
test_data_cache_path=test_data_cache_path,
|
||||
base_output_path=output_path)
|
||||
|
||||
# Extract metadata linked to the clean input file (if any).
|
||||
apm_input_metadata = None
|
||||
try:
|
||||
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
|
||||
clean_capture_input_filepath)
|
||||
except IOError as e:
|
||||
apm_input_metadata = {}
|
||||
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
|
||||
apm_input_metadata['test_data_gen_config'] = None
|
||||
# Extract metadata linked to the clean input file (if any).
|
||||
apm_input_metadata = None
|
||||
try:
|
||||
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
|
||||
clean_capture_input_filepath)
|
||||
except IOError as e:
|
||||
apm_input_metadata = {}
|
||||
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
|
||||
apm_input_metadata['test_data_gen_config'] = None
|
||||
|
||||
# For each test data pair, simulate a call and evaluate.
|
||||
for config_name in test_data_generators.config_names:
|
||||
logging.info(' - test data generator config: <%s>', config_name)
|
||||
apm_input_metadata['test_data_gen_config'] = config_name
|
||||
# For each test data pair, simulate a call and evaluate.
|
||||
for config_name in test_data_generators.config_names:
|
||||
logging.info(' - test data generator config: <%s>', config_name)
|
||||
apm_input_metadata['test_data_gen_config'] = config_name
|
||||
|
||||
# Paths to the test data generator output.
|
||||
# Note that the reference signal does not depend on the render input
|
||||
# which is optional.
|
||||
noisy_capture_input_filepath = (
|
||||
test_data_generators.noisy_signal_filepaths[config_name])
|
||||
reference_signal_filepath = (
|
||||
test_data_generators.reference_signal_filepaths[config_name])
|
||||
# Paths to the test data generator output.
|
||||
# Note that the reference signal does not depend on the render input
|
||||
# which is optional.
|
||||
noisy_capture_input_filepath = (
|
||||
test_data_generators.noisy_signal_filepaths[config_name])
|
||||
reference_signal_filepath = (
|
||||
test_data_generators.reference_signal_filepaths[config_name])
|
||||
|
||||
# Output path for the evaluation (e.g., APM output file).
|
||||
evaluation_output_path = test_data_generators.apm_output_paths[
|
||||
config_name]
|
||||
# Output path for the evaluation (e.g., APM output file).
|
||||
evaluation_output_path = test_data_generators.apm_output_paths[
|
||||
config_name]
|
||||
|
||||
# Paths to the APM input signals.
|
||||
echo_path_filepath = echo_path_simulator.Simulate(
|
||||
echo_test_data_cache_path)
|
||||
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
echo_test_data_cache_path, noisy_capture_input_filepath,
|
||||
echo_path_filepath)
|
||||
# Paths to the APM input signals.
|
||||
echo_path_filepath = echo_path_simulator.Simulate(
|
||||
echo_test_data_cache_path)
|
||||
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
echo_test_data_cache_path, noisy_capture_input_filepath,
|
||||
echo_path_filepath)
|
||||
|
||||
# Extract annotations for the APM input mix.
|
||||
apm_input_basepath, apm_input_filename = os.path.split(
|
||||
apm_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
apm_input_filepath, apm_input_basepath,
|
||||
os.path.splitext(apm_input_filename)[0] + '-')
|
||||
# Extract annotations for the APM input mix.
|
||||
apm_input_basepath, apm_input_filename = os.path.split(
|
||||
apm_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
apm_input_filepath, apm_input_basepath,
|
||||
os.path.splitext(apm_input_filename)[0] + '-')
|
||||
|
||||
# Simulate a call using APM.
|
||||
self._audioproc_wrapper.Run(
|
||||
config_filepath=config_filepath,
|
||||
capture_input_filepath=apm_input_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path)
|
||||
# Simulate a call using APM.
|
||||
self._audioproc_wrapper.Run(
|
||||
config_filepath=config_filepath,
|
||||
capture_input_filepath=apm_input_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path)
|
||||
|
||||
try:
|
||||
# Evaluate.
|
||||
self._evaluator.Run(
|
||||
evaluation_score_workers=self._evaluation_score_workers,
|
||||
apm_input_metadata=apm_input_metadata,
|
||||
apm_output_filepath=self._audioproc_wrapper.output_filepath,
|
||||
reference_input_filepath=reference_signal_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path,
|
||||
)
|
||||
try:
|
||||
# Evaluate.
|
||||
self._evaluator.Run(
|
||||
evaluation_score_workers=self._evaluation_score_workers,
|
||||
apm_input_metadata=apm_input_metadata,
|
||||
apm_output_filepath=self._audioproc_wrapper.
|
||||
output_filepath,
|
||||
reference_input_filepath=reference_signal_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path,
|
||||
)
|
||||
|
||||
# Save simulation metadata.
|
||||
data_access.Metadata.SaveAudioTestDataPaths(
|
||||
output_path=evaluation_output_path,
|
||||
clean_capture_input_filepath=clean_capture_input_filepath,
|
||||
echo_free_capture_filepath=noisy_capture_input_filepath,
|
||||
echo_filepath=echo_path_filepath,
|
||||
render_filepath=render_input_filepath,
|
||||
capture_filepath=apm_input_filepath,
|
||||
apm_output_filepath=self._audioproc_wrapper.output_filepath,
|
||||
apm_reference_filepath=reference_signal_filepath,
|
||||
apm_config_filepath=config_filepath,
|
||||
)
|
||||
except exceptions.EvaluationScoreException as e:
|
||||
logging.warning('the evaluation failed: %s', e.message)
|
||||
continue
|
||||
# Save simulation metadata.
|
||||
data_access.Metadata.SaveAudioTestDataPaths(
|
||||
output_path=evaluation_output_path,
|
||||
clean_capture_input_filepath=clean_capture_input_filepath,
|
||||
echo_free_capture_filepath=noisy_capture_input_filepath,
|
||||
echo_filepath=echo_path_filepath,
|
||||
render_filepath=render_input_filepath,
|
||||
capture_filepath=apm_input_filepath,
|
||||
apm_output_filepath=self._audioproc_wrapper.
|
||||
output_filepath,
|
||||
apm_reference_filepath=reference_signal_filepath,
|
||||
apm_config_filepath=config_filepath,
|
||||
)
|
||||
except exceptions.EvaluationScoreException as e:
|
||||
logging.warning('the evaluation failed: %s', e.message)
|
||||
continue
|
||||
|
||||
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
|
||||
render_input_filepaths):
|
||||
"""Sets input and render input file paths collections.
|
||||
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
|
||||
render_input_filepaths):
|
||||
"""Sets input and render input file paths collections.
|
||||
|
||||
Pairs the input and render input files by storing the file paths into two
|
||||
collections. The key is the file name of the input file.
|
||||
@ -390,20 +410,20 @@ class ApmModuleSimulator(object):
|
||||
capture_input_filepaths: list of file paths.
|
||||
render_input_filepaths: list of file paths.
|
||||
"""
|
||||
self._capture_input_filepaths = {}
|
||||
self._render_input_filepaths = {}
|
||||
assert len(capture_input_filepaths) == len(render_input_filepaths)
|
||||
for capture_input_filepath, render_input_filepath in zip(
|
||||
capture_input_filepaths, render_input_filepaths):
|
||||
name = self._ExtractFileName(capture_input_filepath)
|
||||
self._capture_input_filepaths[name] = os.path.abspath(
|
||||
capture_input_filepath)
|
||||
self._render_input_filepaths[name] = os.path.abspath(
|
||||
render_input_filepath)
|
||||
self._capture_input_filepaths = {}
|
||||
self._render_input_filepaths = {}
|
||||
assert len(capture_input_filepaths) == len(render_input_filepaths)
|
||||
for capture_input_filepath, render_input_filepath in zip(
|
||||
capture_input_filepaths, render_input_filepaths):
|
||||
name = self._ExtractFileName(capture_input_filepath)
|
||||
self._capture_input_filepaths[name] = os.path.abspath(
|
||||
capture_input_filepath)
|
||||
self._render_input_filepaths[name] = os.path.abspath(
|
||||
render_input_filepath)
|
||||
|
||||
@classmethod
|
||||
def _CreatePathsCollection(cls, filepaths):
|
||||
"""Creates a collection of file paths.
|
||||
@classmethod
|
||||
def _CreatePathsCollection(cls, filepaths):
|
||||
"""Creates a collection of file paths.
|
||||
|
||||
Given a list of file paths, makes a collection with one item for each file
|
||||
path. The value is absolute path, the key is the file name without
|
||||
@ -415,12 +435,12 @@ class ApmModuleSimulator(object):
|
||||
Returns:
|
||||
A dict.
|
||||
"""
|
||||
filepaths_collection = {}
|
||||
for filepath in filepaths:
|
||||
name = cls._ExtractFileName(filepath)
|
||||
filepaths_collection[name] = os.path.abspath(filepath)
|
||||
return filepaths_collection
|
||||
filepaths_collection = {}
|
||||
for filepath in filepaths:
|
||||
name = cls._ExtractFileName(filepath)
|
||||
filepaths_collection[name] = os.path.abspath(filepath)
|
||||
return filepaths_collection
|
||||
|
||||
@classmethod
|
||||
def _ExtractFileName(cls, filepath):
|
||||
return os.path.splitext(os.path.split(filepath)[-1])[0]
|
||||
@classmethod
|
||||
def _ExtractFileName(cls, filepath):
|
||||
return os.path.splitext(os.path.split(filepath)[-1])[0]
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the simulation module.
|
||||
"""
|
||||
|
||||
@ -28,177 +27,177 @@ from . import test_data_generation_factory
|
||||
|
||||
|
||||
class TestApmModuleSimulator(unittest.TestCase):
|
||||
"""Unit tests for the ApmModuleSimulator class.
|
||||
"""Unit tests for the ApmModuleSimulator class.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folders and fake audio track."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Create temporary folders and fake audio track."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_audio_track_path, fake_signal)
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._fake_audio_track_path = os.path.join(self._output_path,
|
||||
'fake.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_audio_track_path, fake_signal)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._output_path)
|
||||
shutil.rmtree(self._tmp_path)
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._output_path)
|
||||
shutil.rmtree(self._tmp_path)
|
||||
|
||||
def testSimulation(self):
|
||||
# Instance dependencies to mock and inject.
|
||||
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
evaluator = evaluation.ApmModuleEvaluator()
|
||||
ap_wrapper.Run = mock.MagicMock(name='Run')
|
||||
evaluator.Run = mock.MagicMock(name='Run')
|
||||
def testSimulation(self):
|
||||
# Instance dependencies to mock and inject.
|
||||
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
evaluator = evaluation.ApmModuleEvaluator()
|
||||
ap_wrapper.Run = mock.MagicMock(name='Run')
|
||||
evaluator.Run = mock.MagicMock(name='Run')
|
||||
|
||||
# Instance non-mocked dependencies.
|
||||
test_data_generator_factory = (
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False))
|
||||
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(__file__), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
)
|
||||
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=test_data_generator_factory,
|
||||
evaluation_score_factory=evaluation_score_factory,
|
||||
ap_wrapper=ap_wrapper,
|
||||
evaluator=evaluator,
|
||||
external_vads={'fake': external_vad.ExternalVad(os.path.join(
|
||||
os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')}
|
||||
)
|
||||
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [self._fake_audio_track_path]
|
||||
test_data_generators = ['identity', 'white_noise']
|
||||
eval_scores = ['audio_level_mean', 'polqa']
|
||||
|
||||
# Run all simulations.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=test_data_generators,
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
|
||||
# Check.
|
||||
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
|
||||
# the client code (e.g., number of SNR pairs for the white noise test data
|
||||
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
|
||||
# is known; use that with assertEqual.
|
||||
min_number_of_simulations = len(config_files) * len(input_files) * len(
|
||||
test_data_generators)
|
||||
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
|
||||
def testInputSignalCreation(self):
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
# Instance non-mocked dependencies.
|
||||
test_data_generator_factory = (
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(__file__), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
copy_with_identity=False))
|
||||
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)
|
||||
|
||||
# Inexistent input files to be silently created.
|
||||
input_files = [
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
|
||||
]
|
||||
self.assertFalse(any([os.path.exists(input_file) for input_file in (
|
||||
input_files)]))
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=test_data_generator_factory,
|
||||
evaluation_score_factory=evaluation_score_factory,
|
||||
ap_wrapper=ap_wrapper,
|
||||
evaluator=evaluator,
|
||||
external_vads={
|
||||
'fake':
|
||||
external_vad.ExternalVad(
|
||||
os.path.join(os.path.dirname(__file__),
|
||||
'fake_external_vad.py'), 'fake')
|
||||
})
|
||||
|
||||
# The input files are created during the simulation.
|
||||
simulator.Run(
|
||||
config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=['audio_level_peak'],
|
||||
output_dir=self._output_path)
|
||||
self.assertTrue(all([os.path.exists(input_file) for input_file in (
|
||||
input_files)]))
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [self._fake_audio_track_path]
|
||||
test_data_generators = ['identity', 'white_noise']
|
||||
eval_scores = ['audio_level_mean', 'polqa']
|
||||
|
||||
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
# Run all simulations.
|
||||
simulator.Run(config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=test_data_generators,
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(__file__), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
# Check.
|
||||
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
|
||||
# the client code (e.g., number of SNR pairs for the white noise test data
|
||||
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
|
||||
# is known; use that with assertEqual.
|
||||
min_number_of_simulations = len(config_files) * len(input_files) * len(
|
||||
test_data_generators)
|
||||
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
|
||||
eval_scores = ['thd']
|
||||
def testInputSignalCreation(self):
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
|
||||
# Should work.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
self.assertFalse(logging.warning.called)
|
||||
# Inexistent input files to be silently created.
|
||||
input_files = [
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
|
||||
]
|
||||
self.assertFalse(
|
||||
any([os.path.exists(input_file) for input_file in (input_files)]))
|
||||
|
||||
# Warning expected.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['white_noise'], # Not allowed with THD.
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
logging.warning.assert_called_with('the evaluation failed: %s', (
|
||||
'The THD score cannot be used with any test data generator other than '
|
||||
'"identity"'))
|
||||
# The input files are created during the simulation.
|
||||
simulator.Run(config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=['audio_level_peak'],
|
||||
output_dir=self._output_path)
|
||||
self.assertTrue(
|
||||
all([os.path.exists(input_file) for input_file in (input_files)]))
|
||||
|
||||
# # Init.
|
||||
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
|
||||
# input_signal_filepath = os.path.join(
|
||||
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
|
||||
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
|
||||
# # Check that the input signal is generated.
|
||||
# self.assertFalse(os.path.exists(input_signal_filepath))
|
||||
# generator.Generate(
|
||||
# input_signal_filepath=input_signal_filepath,
|
||||
# test_data_cache_path=self._test_data_cache_path,
|
||||
# base_output_path=self._base_output_path)
|
||||
# self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
|
||||
# # Check input signal properties.
|
||||
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
# input_signal_filepath)
|
||||
# self.assertEqual(1000, len(input_signal))
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
|
||||
eval_scores = ['thd']
|
||||
|
||||
# Should work.
|
||||
simulator.Run(config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
self.assertFalse(logging.warning.called)
|
||||
|
||||
# Warning expected.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['white_noise'], # Not allowed with THD.
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
logging.warning.assert_called_with('the evaluation failed: %s', (
|
||||
'The THD score cannot be used with any test data generator other than '
|
||||
'"identity"'))
|
||||
|
||||
# # Init.
|
||||
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
|
||||
# input_signal_filepath = os.path.join(
|
||||
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
|
||||
|
||||
# # Check that the input signal is generated.
|
||||
# self.assertFalse(os.path.exists(input_signal_filepath))
|
||||
# generator.Generate(
|
||||
# input_signal_filepath=input_signal_filepath,
|
||||
# test_data_cache_path=self._test_data_cache_path,
|
||||
# base_output_path=self._base_output_path)
|
||||
# self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
|
||||
# # Check input signal properties.
|
||||
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
# input_signal_filepath)
|
||||
# self.assertEqual(1000, len(input_signal))
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Test data generators producing signals pairs intended to be used to
|
||||
test the APM module. Each pair consists of a noisy input and a reference signal.
|
||||
The former is used as APM input and it is generated by adding noise to a
|
||||
@ -27,10 +26,10 @@ import shutil
|
||||
import sys
|
||||
|
||||
try:
|
||||
import scipy.io
|
||||
import scipy.io
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import data_access
|
||||
from . import exceptions
|
||||
@ -38,7 +37,7 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestDataGenerator(object):
|
||||
"""Abstract class responsible for the generation of noisy signals.
|
||||
"""Abstract class responsible for the generation of noisy signals.
|
||||
|
||||
Given a clean signal, it generates two streams named noisy signal and
|
||||
reference. The former is the clean signal deteriorated by the noise source,
|
||||
@ -50,24 +49,24 @@ class TestDataGenerator(object):
|
||||
An test data generator generates one or more pairs.
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
|
||||
def __init__(self, output_directory_prefix):
|
||||
self._output_directory_prefix = output_directory_prefix
|
||||
# Init dictionaries with one entry for each test data generator
|
||||
# configuration (e.g., different SNRs).
|
||||
# Noisy audio track files (stored separately in a cache folder).
|
||||
self._noisy_signal_filepaths = None
|
||||
# Path to be used for the APM simulation output files.
|
||||
self._apm_output_paths = None
|
||||
# Reference audio track files (stored separately in a cache folder).
|
||||
self._reference_signal_filepaths = None
|
||||
self.Clear()
|
||||
def __init__(self, output_directory_prefix):
|
||||
self._output_directory_prefix = output_directory_prefix
|
||||
# Init dictionaries with one entry for each test data generator
|
||||
# configuration (e.g., different SNRs).
|
||||
# Noisy audio track files (stored separately in a cache folder).
|
||||
self._noisy_signal_filepaths = None
|
||||
# Path to be used for the APM simulation output files.
|
||||
self._apm_output_paths = None
|
||||
# Reference audio track files (stored separately in a cache folder).
|
||||
self._reference_signal_filepaths = None
|
||||
self.Clear()
|
||||
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers a TestDataGenerator implementation.
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers a TestDataGenerator implementation.
|
||||
|
||||
Decorator to automatically register the classes that extend
|
||||
TestDataGenerator.
|
||||
@ -77,28 +76,28 @@ class TestDataGenerator(object):
|
||||
class IdentityGenerator(TestDataGenerator):
|
||||
pass
|
||||
"""
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
|
||||
@property
|
||||
def config_names(self):
|
||||
return self._noisy_signal_filepaths.keys()
|
||||
@property
|
||||
def config_names(self):
|
||||
return self._noisy_signal_filepaths.keys()
|
||||
|
||||
@property
|
||||
def noisy_signal_filepaths(self):
|
||||
return self._noisy_signal_filepaths
|
||||
@property
|
||||
def noisy_signal_filepaths(self):
|
||||
return self._noisy_signal_filepaths
|
||||
|
||||
@property
|
||||
def apm_output_paths(self):
|
||||
return self._apm_output_paths
|
||||
@property
|
||||
def apm_output_paths(self):
|
||||
return self._apm_output_paths
|
||||
|
||||
@property
|
||||
def reference_signal_filepaths(self):
|
||||
return self._reference_signal_filepaths
|
||||
@property
|
||||
def reference_signal_filepaths(self):
|
||||
return self._reference_signal_filepaths
|
||||
|
||||
def Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
"""Generates a set of noisy input and reference audiotrack file pairs.
|
||||
def Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Generates a set of noisy input and reference audiotrack file pairs.
|
||||
|
||||
This method initializes an empty set of pairs and calls the _Generate()
|
||||
method implemented in a concrete class.
|
||||
@ -109,26 +108,26 @@ class TestDataGenerator(object):
|
||||
files.
|
||||
base_output_path: base path where output is written.
|
||||
"""
|
||||
self.Clear()
|
||||
self._Generate(
|
||||
input_signal_filepath, test_data_cache_path, base_output_path)
|
||||
self.Clear()
|
||||
self._Generate(input_signal_filepath, test_data_cache_path,
|
||||
base_output_path)
|
||||
|
||||
def Clear(self):
|
||||
"""Clears the generated output path dictionaries.
|
||||
def Clear(self):
|
||||
"""Clears the generated output path dictionaries.
|
||||
"""
|
||||
self._noisy_signal_filepaths = {}
|
||||
self._apm_output_paths = {}
|
||||
self._reference_signal_filepaths = {}
|
||||
self._noisy_signal_filepaths = {}
|
||||
self._apm_output_paths = {}
|
||||
self._reference_signal_filepaths = {}
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
"""Abstract method to be implemented in each concrete class.
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Abstract method to be implemented in each concrete class.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
|
||||
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
|
||||
snr_value_pairs):
|
||||
"""Adds noisy-reference signal pairs.
|
||||
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
|
||||
snr_value_pairs):
|
||||
"""Adds noisy-reference signal pairs.
|
||||
|
||||
Args:
|
||||
base_output_path: noisy tracks base output path.
|
||||
@ -136,22 +135,22 @@ class TestDataGenerator(object):
|
||||
by noisy track name and SNR level.
|
||||
snr_value_pairs: list of SNR pairs.
|
||||
"""
|
||||
for noise_track_name in noisy_mix_filepaths:
|
||||
for snr_noisy, snr_refence in snr_value_pairs:
|
||||
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
|
||||
noise_track_name, snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[
|
||||
noise_track_name][snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[
|
||||
noise_track_name][snr_refence],
|
||||
output_path=output_path)
|
||||
for noise_track_name in noisy_mix_filepaths:
|
||||
for snr_noisy, snr_refence in snr_value_pairs:
|
||||
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
|
||||
noise_track_name, snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[noise_track_name]
|
||||
[snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[
|
||||
noise_track_name][snr_refence],
|
||||
output_path=output_path)
|
||||
|
||||
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
|
||||
reference_signal_filepath, output_path):
|
||||
"""Adds one noisy-reference signal pair.
|
||||
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
|
||||
reference_signal_filepath, output_path):
|
||||
"""Adds one noisy-reference signal pair.
|
||||
|
||||
Args:
|
||||
config_name: name of the APM configuration.
|
||||
@ -159,264 +158,275 @@ class TestDataGenerator(object):
|
||||
reference_signal_filepath: path to reference audio track file.
|
||||
output_path: APM output path.
|
||||
"""
|
||||
assert config_name not in self._noisy_signal_filepaths
|
||||
self._noisy_signal_filepaths[config_name] = os.path.abspath(
|
||||
noisy_signal_filepath)
|
||||
self._apm_output_paths[config_name] = os.path.abspath(output_path)
|
||||
self._reference_signal_filepaths[config_name] = os.path.abspath(
|
||||
reference_signal_filepath)
|
||||
assert config_name not in self._noisy_signal_filepaths
|
||||
self._noisy_signal_filepaths[config_name] = os.path.abspath(
|
||||
noisy_signal_filepath)
|
||||
self._apm_output_paths[config_name] = os.path.abspath(output_path)
|
||||
self._reference_signal_filepaths[config_name] = os.path.abspath(
|
||||
reference_signal_filepath)
|
||||
|
||||
def _MakeDir(self, base_output_path, test_data_generator_config_name):
|
||||
output_path = os.path.join(
|
||||
base_output_path,
|
||||
self._output_directory_prefix + test_data_generator_config_name)
|
||||
data_access.MakeDirectory(output_path)
|
||||
return output_path
|
||||
def _MakeDir(self, base_output_path, test_data_generator_config_name):
|
||||
output_path = os.path.join(
|
||||
base_output_path,
|
||||
self._output_directory_prefix + test_data_generator_config_name)
|
||||
data_access.MakeDirectory(output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class IdentityTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds no noise.
|
||||
"""Generator that adds no noise.
|
||||
|
||||
Both the noisy and the reference signals are the input signal.
|
||||
"""
|
||||
|
||||
NAME = 'identity'
|
||||
NAME = 'identity'
|
||||
|
||||
def __init__(self, output_directory_prefix, copy_with_identity):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._copy_with_identity = copy_with_identity
|
||||
def __init__(self, output_directory_prefix, copy_with_identity):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._copy_with_identity = copy_with_identity
|
||||
|
||||
@property
|
||||
def copy_with_identity(self):
|
||||
return self._copy_with_identity
|
||||
@property
|
||||
def copy_with_identity(self):
|
||||
return self._copy_with_identity
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
config_name = 'default'
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
config_name = 'default'
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
|
||||
if self._copy_with_identity:
|
||||
input_signal_filepath_new = os.path.join(
|
||||
test_data_cache_path, os.path.split(input_signal_filepath)[1])
|
||||
logging.info('copying ' + input_signal_filepath + ' to ' + (
|
||||
input_signal_filepath_new))
|
||||
shutil.copy(input_signal_filepath, input_signal_filepath_new)
|
||||
input_signal_filepath = input_signal_filepath_new
|
||||
if self._copy_with_identity:
|
||||
input_signal_filepath_new = os.path.join(
|
||||
test_data_cache_path,
|
||||
os.path.split(input_signal_filepath)[1])
|
||||
logging.info('copying ' + input_signal_filepath + ' to ' +
|
||||
(input_signal_filepath_new))
|
||||
shutil.copy(input_signal_filepath, input_signal_filepath_new)
|
||||
input_signal_filepath = input_signal_filepath_new
|
||||
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=input_signal_filepath,
|
||||
reference_signal_filepath=input_signal_filepath,
|
||||
output_path=output_path)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=input_signal_filepath,
|
||||
reference_signal_filepath=input_signal_filepath,
|
||||
output_path=output_path)
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class WhiteNoiseTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds white noise.
|
||||
"""Generator that adds white noise.
|
||||
"""
|
||||
|
||||
NAME = 'white_noise'
|
||||
NAME = 'white_noise'
|
||||
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
|
||||
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
# Create the noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
input_signal)
|
||||
# Create the noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
input_signal)
|
||||
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths = {}
|
||||
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths = {}
|
||||
snr_values = set(
|
||||
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
|
||||
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[snr] = noisy_signal_filepath
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[snr] = noisy_signal_filepath
|
||||
|
||||
# Add all the noisy-reference signal pairs.
|
||||
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
|
||||
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
|
||||
output_path=output_path)
|
||||
# Add all the noisy-reference signal pairs.
|
||||
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
|
||||
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
|
||||
output_path=output_path)
|
||||
|
||||
|
||||
# TODO(alessiob): remove comment when class implemented.
|
||||
# @TestDataGenerator.RegisterClass
|
||||
class NarrowBandNoiseTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds narrow-band noise.
|
||||
"""Generator that adds narrow-band noise.
|
||||
"""
|
||||
|
||||
NAME = 'narrow_band_noise'
|
||||
NAME = 'narrow_band_noise'
|
||||
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
# TODO(alessiob): implement.
|
||||
pass
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
# TODO(alessiob): implement.
|
||||
pass
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class AdditiveNoiseTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds noise loops.
|
||||
"""Generator that adds noise loops.
|
||||
|
||||
This generator uses all the wav files in a given path (default: noise_tracks/)
|
||||
and mixes them to the clean speech with different target SNRs (hard-coded).
|
||||
"""
|
||||
|
||||
NAME = 'additive_noise'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
NAME = 'additive_noise'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
|
||||
DEFAULT_NOISE_TRACKS_PATH = os.path.join(
|
||||
os.path.dirname(__file__), os.pardir, 'noise_tracks')
|
||||
DEFAULT_NOISE_TRACKS_PATH = os.path.join(os.path.dirname(__file__),
|
||||
os.pardir, 'noise_tracks')
|
||||
|
||||
# TODO(alessiob): Make the list of SNR pairs customizable.
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
# TODO(alessiob): Make the list of SNR pairs customizable.
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
|
||||
def __init__(self, output_directory_prefix, noise_tracks_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._noise_tracks_file_names = [n for n in os.listdir(
|
||||
self._noise_tracks_path) if n.lower().endswith('.wav')]
|
||||
if len(self._noise_tracks_file_names) == 0:
|
||||
raise exceptions.InitializationException(
|
||||
'No wav files found in the noise tracks path %s' % (
|
||||
self._noise_tracks_path))
|
||||
def __init__(self, output_directory_prefix, noise_tracks_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._noise_tracks_file_names = [
|
||||
n for n in os.listdir(self._noise_tracks_path)
|
||||
if n.lower().endswith('.wav')
|
||||
]
|
||||
if len(self._noise_tracks_file_names) == 0:
|
||||
raise exceptions.InitializationException(
|
||||
'No wav files found in the noise tracks path %s' %
|
||||
(self._noise_tracks_path))
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
"""Generates test data pairs using environmental noise.
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Generates test data pairs using environmental noise.
|
||||
|
||||
For each noise track and pair of SNR values, the following two audio tracks
|
||||
are created: the noisy signal and the reference signal. The former is
|
||||
obtained by mixing the (clean) input signal to the corresponding noise
|
||||
track enforcing the target SNR.
|
||||
"""
|
||||
# Init.
|
||||
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
# Init.
|
||||
snr_values = set(
|
||||
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
noisy_mix_filepaths = {}
|
||||
for noise_track_filename in self._noise_tracks_file_names:
|
||||
# Load the noise track.
|
||||
noise_track_name, _ = os.path.splitext(noise_track_filename)
|
||||
noise_track_filepath = os.path.join(
|
||||
self._noise_tracks_path, noise_track_filename)
|
||||
if not os.path.exists(noise_track_filepath):
|
||||
logging.error('cannot find the <%s> noise track', noise_track_filename)
|
||||
raise exceptions.FileNotFoundError()
|
||||
noisy_mix_filepaths = {}
|
||||
for noise_track_filename in self._noise_tracks_file_names:
|
||||
# Load the noise track.
|
||||
noise_track_name, _ = os.path.splitext(noise_track_filename)
|
||||
noise_track_filepath = os.path.join(self._noise_tracks_path,
|
||||
noise_track_filename)
|
||||
if not os.path.exists(noise_track_filepath):
|
||||
logging.error('cannot find the <%s> noise track',
|
||||
noise_track_filename)
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[noise_track_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(noise_track_name, snr))
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[noise_track_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
|
||||
noise_track_name, snr))
|
||||
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal,
|
||||
noise_signal,
|
||||
snr,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.
|
||||
MixPadding.LOOP)
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[noise_track_name][snr] = noisy_signal_filepath
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[noise_track_name][
|
||||
snr] = noisy_signal_filepath
|
||||
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(
|
||||
base_output_path, noisy_mix_filepaths, self._SNR_VALUE_PAIRS)
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
|
||||
self._SNR_VALUE_PAIRS)
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class ReverberationTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds reverberation noise.
|
||||
"""Generator that adds reverberation noise.
|
||||
|
||||
TODO(alessiob): Make this class more generic since the impulse response can be
|
||||
anything (not just reverberation); call it e.g.,
|
||||
ConvolutionalNoiseTestDataGenerator.
|
||||
"""
|
||||
|
||||
NAME = 'reverberation'
|
||||
NAME = 'reverberation'
|
||||
|
||||
_IMPULSE_RESPONSES = {
|
||||
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
|
||||
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
|
||||
}
|
||||
_MAX_IMPULSE_RESPONSE_LENGTH = None
|
||||
_IMPULSE_RESPONSES = {
|
||||
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
|
||||
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
|
||||
}
|
||||
_MAX_IMPULSE_RESPONSE_LENGTH = None
|
||||
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 5 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[3, 8], # Smallest noise.
|
||||
[-3, 2], # Largest noise.
|
||||
]
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 5 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[3, 8], # Smallest noise.
|
||||
[-3, 2], # Largest noise.
|
||||
]
|
||||
|
||||
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
|
||||
def __init__(self, output_directory_prefix, aechen_ir_database_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
def __init__(self, output_directory_prefix, aechen_ir_database_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
"""Generates test data pairs using reverberation noise.
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Generates test data pairs using reverberation noise.
|
||||
|
||||
For each impulse response, one noise track is created. For each impulse
|
||||
response and pair of SNR values, the following 2 audio tracks are
|
||||
@ -424,61 +434,64 @@ class ReverberationTestDataGenerator(TestDataGenerator):
|
||||
obtained by mixing the (clean) input signal to the corresponding noise
|
||||
track enforcing the target SNR.
|
||||
"""
|
||||
# Init.
|
||||
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
# Init.
|
||||
snr_values = set(
|
||||
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
noisy_mix_filepaths = {}
|
||||
for impulse_response_name in self._IMPULSE_RESPONSES:
|
||||
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name)
|
||||
noise_track_filepath = os.path.join(
|
||||
test_data_cache_path, noise_track_filename)
|
||||
noise_signal = None
|
||||
try:
|
||||
# Load noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
except exceptions.FileNotFoundError:
|
||||
# Generate noise track by applying the impulse response.
|
||||
impulse_response_filepath = os.path.join(
|
||||
self._aechen_ir_database_path,
|
||||
self._IMPULSE_RESPONSES[impulse_response_name])
|
||||
noise_signal = self._GenerateNoiseTrack(
|
||||
noise_track_filepath, input_signal, impulse_response_filepath)
|
||||
assert noise_signal is not None
|
||||
noisy_mix_filepaths = {}
|
||||
for impulse_response_name in self._IMPULSE_RESPONSES:
|
||||
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name)
|
||||
noise_track_filepath = os.path.join(test_data_cache_path,
|
||||
noise_track_filename)
|
||||
noise_signal = None
|
||||
try:
|
||||
# Load noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
except exceptions.FileNotFoundError:
|
||||
# Generate noise track by applying the impulse response.
|
||||
impulse_response_filepath = os.path.join(
|
||||
self._aechen_ir_database_path,
|
||||
self._IMPULSE_RESPONSES[impulse_response_name])
|
||||
noise_signal = self._GenerateNoiseTrack(
|
||||
noise_track_filepath, input_signal,
|
||||
impulse_response_filepath)
|
||||
assert noise_signal is not None
|
||||
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[impulse_response_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name, snr))
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[impulse_response_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name, snr))
|
||||
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[impulse_response_name][snr] = noisy_signal_filepath
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[impulse_response_name][
|
||||
snr] = noisy_signal_filepath
|
||||
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
|
||||
self._SNR_VALUE_PAIRS)
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
|
||||
self._SNR_VALUE_PAIRS)
|
||||
|
||||
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
|
||||
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
|
||||
impulse_response_filepath):
|
||||
"""Generates noise track.
|
||||
"""Generates noise track.
|
||||
|
||||
Generate a signal by convolving input_signal with the impulse response in
|
||||
impulse_response_filepath; then save to noise_track_filepath.
|
||||
@ -491,21 +504,23 @@ class ReverberationTestDataGenerator(TestDataGenerator):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
# Load impulse response.
|
||||
data = scipy.io.loadmat(impulse_response_filepath)
|
||||
impulse_response = data['h_air'].flatten()
|
||||
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
|
||||
logging.info('truncating impulse response from %d to %d samples',
|
||||
len(impulse_response), self._MAX_IMPULSE_RESPONSE_LENGTH)
|
||||
impulse_response = impulse_response[:self._MAX_IMPULSE_RESPONSE_LENGTH]
|
||||
# Load impulse response.
|
||||
data = scipy.io.loadmat(impulse_response_filepath)
|
||||
impulse_response = data['h_air'].flatten()
|
||||
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
|
||||
logging.info('truncating impulse response from %d to %d samples',
|
||||
len(impulse_response),
|
||||
self._MAX_IMPULSE_RESPONSE_LENGTH)
|
||||
impulse_response = impulse_response[:self.
|
||||
_MAX_IMPULSE_RESPONSE_LENGTH]
|
||||
|
||||
# Apply impulse response.
|
||||
processed_signal = (
|
||||
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
input_signal, impulse_response))
|
||||
# Apply impulse response.
|
||||
processed_signal = (
|
||||
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
input_signal, impulse_response))
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noise_track_filepath, processed_signal)
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noise_track_filepath, processed_signal)
|
||||
|
||||
return processed_signal
|
||||
return processed_signal
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""TestDataGenerator factory class.
|
||||
"""
|
||||
|
||||
@ -16,15 +15,15 @@ from . import test_data_generation
|
||||
|
||||
|
||||
class TestDataGeneratorFactory(object):
|
||||
"""Factory class used to create test data generators.
|
||||
"""Factory class used to create test data generators.
|
||||
|
||||
Usage: Create a factory passing parameters to the ctor with which the
|
||||
generators will be produced.
|
||||
"""
|
||||
|
||||
def __init__(self, aechen_ir_database_path, noise_tracks_path,
|
||||
copy_with_identity):
|
||||
"""Ctor.
|
||||
def __init__(self, aechen_ir_database_path, noise_tracks_path,
|
||||
copy_with_identity):
|
||||
"""Ctor.
|
||||
|
||||
Args:
|
||||
aechen_ir_database_path: Path to the Aechen Impulse Response database.
|
||||
@ -32,16 +31,16 @@ class TestDataGeneratorFactory(object):
|
||||
copy_with_identity: Flag indicating whether the identity generator has to
|
||||
make copies of the clean speech input files.
|
||||
"""
|
||||
self._output_directory_prefix = None
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._copy_with_identity = copy_with_identity
|
||||
self._output_directory_prefix = None
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._copy_with_identity = copy_with_identity
|
||||
|
||||
def SetOutputDirectoryPrefix(self, prefix):
|
||||
self._output_directory_prefix = prefix
|
||||
def SetOutputDirectoryPrefix(self, prefix):
|
||||
self._output_directory_prefix = prefix
|
||||
|
||||
def GetInstance(self, test_data_generators_class):
|
||||
"""Creates an TestDataGenerator instance given a class object.
|
||||
def GetInstance(self, test_data_generators_class):
|
||||
"""Creates an TestDataGenerator instance given a class object.
|
||||
|
||||
Args:
|
||||
test_data_generators_class: TestDataGenerator class object (not an
|
||||
@ -50,22 +49,23 @@ class TestDataGeneratorFactory(object):
|
||||
Returns:
|
||||
TestDataGenerator instance.
|
||||
"""
|
||||
if self._output_directory_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The output directory prefix for test data generators is not set')
|
||||
logging.debug('factory producing %s', test_data_generators_class)
|
||||
if self._output_directory_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The output directory prefix for test data generators is not set'
|
||||
)
|
||||
logging.debug('factory producing %s', test_data_generators_class)
|
||||
|
||||
if test_data_generators_class == (
|
||||
test_data_generation.IdentityTestDataGenerator):
|
||||
return test_data_generation.IdentityTestDataGenerator(
|
||||
self._output_directory_prefix, self._copy_with_identity)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.ReverberationTestDataGenerator):
|
||||
return test_data_generation.ReverberationTestDataGenerator(
|
||||
self._output_directory_prefix, self._aechen_ir_database_path)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.AdditiveNoiseTestDataGenerator):
|
||||
return test_data_generation.AdditiveNoiseTestDataGenerator(
|
||||
self._output_directory_prefix, self._noise_tracks_path)
|
||||
else:
|
||||
return test_data_generators_class(self._output_directory_prefix)
|
||||
if test_data_generators_class == (
|
||||
test_data_generation.IdentityTestDataGenerator):
|
||||
return test_data_generation.IdentityTestDataGenerator(
|
||||
self._output_directory_prefix, self._copy_with_identity)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.ReverberationTestDataGenerator):
|
||||
return test_data_generation.ReverberationTestDataGenerator(
|
||||
self._output_directory_prefix, self._aechen_ir_database_path)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.AdditiveNoiseTestDataGenerator):
|
||||
return test_data_generation.AdditiveNoiseTestDataGenerator(
|
||||
self._output_directory_prefix, self._noise_tracks_path)
|
||||
else:
|
||||
return test_data_generators_class(self._output_directory_prefix)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the test_data_generation module.
|
||||
"""
|
||||
|
||||
@ -23,141 +22,143 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestTestDataGenerators(unittest.TestCase):
|
||||
"""Unit tests for the test_data_generation module.
|
||||
"""Unit tests for the test_data_generation module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folders."""
|
||||
self._base_output_path = tempfile.mkdtemp()
|
||||
self._test_data_cache_path = tempfile.mkdtemp()
|
||||
self._fake_air_db_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Create temporary folders."""
|
||||
self._base_output_path = tempfile.mkdtemp()
|
||||
self._test_data_cache_path = tempfile.mkdtemp()
|
||||
self._fake_air_db_path = tempfile.mkdtemp()
|
||||
|
||||
# Fake AIR DB impulse responses.
|
||||
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
|
||||
# impulse responses. When changed, the coupling below between
|
||||
# impulse_response_mat_file_names and
|
||||
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
|
||||
impulse_response_mat_file_names = [
|
||||
'air_binaural_lecture_0_0_1.mat',
|
||||
'air_binaural_booth_0_0_1.mat',
|
||||
]
|
||||
for impulse_response_mat_file_name in impulse_response_mat_file_names:
|
||||
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
|
||||
scipy.io.savemat(os.path.join(
|
||||
self._fake_air_db_path, impulse_response_mat_file_name), data)
|
||||
# Fake AIR DB impulse responses.
|
||||
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
|
||||
# impulse responses. When changed, the coupling below between
|
||||
# impulse_response_mat_file_names and
|
||||
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
|
||||
impulse_response_mat_file_names = [
|
||||
'air_binaural_lecture_0_0_1.mat',
|
||||
'air_binaural_booth_0_0_1.mat',
|
||||
]
|
||||
for impulse_response_mat_file_name in impulse_response_mat_file_names:
|
||||
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
|
||||
scipy.io.savemat(
|
||||
os.path.join(self._fake_air_db_path,
|
||||
impulse_response_mat_file_name), data)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._base_output_path)
|
||||
shutil.rmtree(self._test_data_cache_path)
|
||||
shutil.rmtree(self._fake_air_db_path)
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._base_output_path)
|
||||
shutil.rmtree(self._test_data_cache_path)
|
||||
shutil.rmtree(self._fake_air_db_path)
|
||||
|
||||
def testTestDataGenerators(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
def testTestDataGenerators(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
|
||||
# Check that there is at least one registered test data generator.
|
||||
registered_classes = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
# Check that there is at least one registered test data generator.
|
||||
registered_classes = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance generators factory.
|
||||
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=self._fake_air_db_path,
|
||||
noise_tracks_path=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH,
|
||||
copy_with_identity=False)
|
||||
generators_factory.SetOutputDirectoryPrefix('datagen-')
|
||||
# Instance generators factory.
|
||||
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=self._fake_air_db_path,
|
||||
noise_tracks_path=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH,
|
||||
copy_with_identity=False)
|
||||
generators_factory.SetOutputDirectoryPrefix('datagen-')
|
||||
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(
|
||||
os.getcwd(), 'probing_signals', 'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
|
||||
'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
|
||||
# Load input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
# Load input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
# Try each registered test data generator.
|
||||
for generator_name in registered_classes:
|
||||
# Instance test data generator.
|
||||
generator = generators_factory.GetInstance(
|
||||
registered_classes[generator_name])
|
||||
# Try each registered test data generator.
|
||||
for generator_name in registered_classes:
|
||||
# Instance test data generator.
|
||||
generator = generators_factory.GetInstance(
|
||||
registered_classes[generator_name])
|
||||
|
||||
# Generate the noisy input - reference pairs.
|
||||
generator.Generate(
|
||||
input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
# Generate the noisy input - reference pairs.
|
||||
generator.Generate(input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
|
||||
# Perform checks.
|
||||
self._CheckGeneratedPairsListSizes(generator)
|
||||
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
|
||||
self._CheckGeneratedPairsOutputPaths(generator)
|
||||
# Perform checks.
|
||||
self._CheckGeneratedPairsListSizes(generator)
|
||||
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
|
||||
self._CheckGeneratedPairsOutputPaths(generator)
|
||||
|
||||
def testTestidentityDataGenerator(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
def testTestidentityDataGenerator(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(
|
||||
os.getcwd(), 'probing_signals', 'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
|
||||
'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
|
||||
def GetNoiseReferenceFilePaths(identity_generator):
|
||||
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
|
||||
reference_signal_filepaths = identity_generator.reference_signal_filepaths
|
||||
assert noisy_signal_filepaths.keys() == reference_signal_filepaths.keys()
|
||||
assert len(noisy_signal_filepaths.keys()) == 1
|
||||
key = noisy_signal_filepaths.keys()[0]
|
||||
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
|
||||
def GetNoiseReferenceFilePaths(identity_generator):
|
||||
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
|
||||
reference_signal_filepaths = identity_generator.reference_signal_filepaths
|
||||
assert noisy_signal_filepaths.keys(
|
||||
) == reference_signal_filepaths.keys()
|
||||
assert len(noisy_signal_filepaths.keys()) == 1
|
||||
key = noisy_signal_filepaths.keys()[0]
|
||||
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
|
||||
|
||||
# Test the |copy_with_identity| flag.
|
||||
for copy_with_identity in [False, True]:
|
||||
# Instance the generator through the factory.
|
||||
factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='', noise_tracks_path='',
|
||||
copy_with_identity=copy_with_identity)
|
||||
factory.SetOutputDirectoryPrefix('datagen-')
|
||||
generator = factory.GetInstance(
|
||||
test_data_generation.IdentityTestDataGenerator)
|
||||
# Check |copy_with_identity| is set correctly.
|
||||
self.assertEqual(copy_with_identity, generator.copy_with_identity)
|
||||
# Test the |copy_with_identity| flag.
|
||||
for copy_with_identity in [False, True]:
|
||||
# Instance the generator through the factory.
|
||||
factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=copy_with_identity)
|
||||
factory.SetOutputDirectoryPrefix('datagen-')
|
||||
generator = factory.GetInstance(
|
||||
test_data_generation.IdentityTestDataGenerator)
|
||||
# Check |copy_with_identity| is set correctly.
|
||||
self.assertEqual(copy_with_identity, generator.copy_with_identity)
|
||||
|
||||
# Generate test data and extract the paths to the noise and the reference
|
||||
# files.
|
||||
generator.Generate(
|
||||
input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
noisy_signal_filepath, reference_signal_filepath = (
|
||||
GetNoiseReferenceFilePaths(generator))
|
||||
# Generate test data and extract the paths to the noise and the reference
|
||||
# files.
|
||||
generator.Generate(input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
noisy_signal_filepath, reference_signal_filepath = (
|
||||
GetNoiseReferenceFilePaths(generator))
|
||||
|
||||
# Check that a copy is made if and only if |copy_with_identity| is True.
|
||||
if copy_with_identity:
|
||||
self.assertNotEqual(noisy_signal_filepath, input_signal_filepath)
|
||||
self.assertNotEqual(reference_signal_filepath, input_signal_filepath)
|
||||
else:
|
||||
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
|
||||
self.assertEqual(reference_signal_filepath, input_signal_filepath)
|
||||
# Check that a copy is made if and only if |copy_with_identity| is True.
|
||||
if copy_with_identity:
|
||||
self.assertNotEqual(noisy_signal_filepath,
|
||||
input_signal_filepath)
|
||||
self.assertNotEqual(reference_signal_filepath,
|
||||
input_signal_filepath)
|
||||
else:
|
||||
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
|
||||
self.assertEqual(reference_signal_filepath,
|
||||
input_signal_filepath)
|
||||
|
||||
def _CheckGeneratedPairsListSizes(self, generator):
|
||||
config_names = generator.config_names
|
||||
number_of_pairs = len(config_names)
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.noisy_signal_filepaths))
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.apm_output_paths))
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.reference_signal_filepaths))
|
||||
def _CheckGeneratedPairsListSizes(self, generator):
|
||||
config_names = generator.config_names
|
||||
number_of_pairs = len(config_names)
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.noisy_signal_filepaths))
|
||||
self.assertEqual(number_of_pairs, len(generator.apm_output_paths))
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.reference_signal_filepaths))
|
||||
|
||||
def _CheckGeneratedPairsSignalDurations(
|
||||
self, generator, input_signal):
|
||||
"""Checks duration of the generated signals.
|
||||
def _CheckGeneratedPairsSignalDurations(self, generator, input_signal):
|
||||
"""Checks duration of the generated signals.
|
||||
|
||||
Checks that the noisy input and the reference tracks are audio files
|
||||
with duration equal to or greater than that of the input signal.
|
||||
@ -166,41 +167,41 @@ class TestTestDataGenerators(unittest.TestCase):
|
||||
generator: TestDataGenerator instance.
|
||||
input_signal: AudioSegment instance.
|
||||
"""
|
||||
input_signal_length = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
|
||||
input_signal_length = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
|
||||
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
# Load the noisy input file.
|
||||
noisy_signal_filepath = generator.noisy_signal_filepaths[
|
||||
config_name]
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noisy_signal_filepath)
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
# Load the noisy input file.
|
||||
noisy_signal_filepath = generator.noisy_signal_filepaths[
|
||||
config_name]
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noisy_signal_filepath)
|
||||
|
||||
# Check noisy input signal length.
|
||||
noisy_signal_length = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(noisy_signal))
|
||||
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
|
||||
# Check noisy input signal length.
|
||||
noisy_signal_length = (signal_processing.SignalProcessingUtils.
|
||||
CountSamples(noisy_signal))
|
||||
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
|
||||
|
||||
# Load the reference file.
|
||||
reference_signal_filepath = generator.reference_signal_filepaths[
|
||||
config_name]
|
||||
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
reference_signal_filepath)
|
||||
# Load the reference file.
|
||||
reference_signal_filepath = generator.reference_signal_filepaths[
|
||||
config_name]
|
||||
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
reference_signal_filepath)
|
||||
|
||||
# Check noisy input signal length.
|
||||
reference_signal_length = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(
|
||||
reference_signal))
|
||||
self.assertGreaterEqual(reference_signal_length, input_signal_length)
|
||||
# Check noisy input signal length.
|
||||
reference_signal_length = (signal_processing.SignalProcessingUtils.
|
||||
CountSamples(reference_signal))
|
||||
self.assertGreaterEqual(reference_signal_length,
|
||||
input_signal_length)
|
||||
|
||||
def _CheckGeneratedPairsOutputPaths(self, generator):
|
||||
"""Checks that the output path created by the generator exists.
|
||||
def _CheckGeneratedPairsOutputPaths(self, generator):
|
||||
"""Checks that the output path created by the generator exists.
|
||||
|
||||
Args:
|
||||
generator: TestDataGenerator instance.
|
||||
"""
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
output_path = generator.apm_output_paths[config_name]
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
output_path = generator.apm_output_paths[config_name]
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
|
||||
Reference in New Issue
Block a user