WebRTC VAD wrapper for APM-QA

Alternative VAD based on the existing one in WebRTC.
It is used to extract VAD annotations in APM-QA.

TBR=

Bug: webrtc:7494
Change-Id: I6af412742f804631ad4f3ba3ccf71a30d74de984
Reviewed-on: https://webrtc-review.googlesource.com/14553
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#20404}
This commit is contained in:
Alessio Bazzica
2017-10-24 09:56:49 +02:00
committed by Commit Bot
parent ef48df9aeb
commit 330bf4076e
5 changed files with 263 additions and 59 deletions

View File

@ -102,6 +102,7 @@ group("unit_tests") {
":fake_polqa", ":fake_polqa",
":lib_unit_tests", ":lib_unit_tests",
":scripts_unit_tests", ":scripts_unit_tests",
":vad",
] ]
} }
@ -118,6 +119,17 @@ rtc_executable("fake_polqa") {
] ]
} }
rtc_executable("vad") {
sources = [
"quality_assessment/vad.cc",
]
deps = [
"../../../..:webrtc_common",
"../../../../common_audio",
"../../../../rtc_base:rtc_base_approved",
]
}
copy("lib_unit_tests") { copy("lib_unit_tests") {
testonly = true testonly = true
sources = [ sources = [

View File

@ -10,9 +10,14 @@
""" """
from __future__ import division from __future__ import division
import enum
import logging import logging
import os import os
import shutil
import struct
import subprocess
import sys import sys
import tempfile
try: try:
import numpy as np import numpy as np
@ -20,6 +25,7 @@ except ImportError:
logging.critical('Cannot import the third-party Python package numpy') logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1) sys.exit(1)
from . import exceptions
from . import signal_processing from . import signal_processing
@ -27,9 +33,12 @@ class AudioAnnotationsExtractor(object):
"""Extracts annotations from audio files. """Extracts annotations from audio files.
""" """
_LEVEL_FILENAME = 'level.npy' @enum.unique
_VAD_FILENAME = 'vad.npy' class VadType(enum.Enum):
_SPEECH_LEVEL_FILENAME = 'speech_level.npy' ENERGY_THRESHOLD = 0 # TODO(alessiob): Consider switching to P56 standard.
WEBRTC = 1
_OUTPUT_FILENAME = 'annotations.npz'
# Level estimation params. # Level estimation params.
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0) _ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
@ -41,36 +50,50 @@ class AudioAnnotationsExtractor(object):
# VAD params. # VAD params.
_VAD_THRESHOLD = 1 _VAD_THRESHOLD = 1
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(
os.path.abspath(__file__)), os.pardir, os.pardir)
_VAD_WEBRTC_BIN_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
def __init__(self): def __init__(self, vad_type):
self._signal = None self._signal = None
self._level = None self._level = None
self._vad = None
self._speech_level = None
self._level_frame_size = None self._level_frame_size = None
self._vad_output = None
self._vad_frame_size = None
self._vad_frame_size_ms = None
self._c_attack = None self._c_attack = None
self._c_decay = None self._c_decay = None
@classmethod self._vad_type = vad_type
def GetLevelFileName(cls): if self._vad_type not in self.VadType:
return cls._LEVEL_FILENAME raise exceptions.InitializationException(
'Invalid vad type: ' + self._vad_type)
logging.info('VAD used for annotations: ' + str(self._vad_type))
assert os.path.exists(self._VAD_WEBRTC_BIN_PATH), self._VAD_WEBRTC_BIN_PATH
@classmethod @classmethod
def GetVadFileName(cls): def GetOutputFileName(cls):
return cls._VAD_FILENAME return cls._OUTPUT_FILENAME
@classmethod
def GetSpeechLevelFileName(cls):
return cls._SPEECH_LEVEL_FILENAME
def GetLevel(self): def GetLevel(self):
return self._level return self._level
def GetVad(self): def GetLevelFrameSize(self):
return self._vad return self._level_frame_size
def GetSpeechLevel(self): @classmethod
return self._speech_level def GetLevelFrameSizeMs(cls):
return cls._LEVEL_FRAME_SIZE_MS
def GetVadOutput(self):
return self._vad_output
def GetVadFrameSize(self):
return self._vad_frame_size
def GetVadFrameSizeMs(self):
return self._vad_frame_size_ms
def Extract(self, filepath): def Extract(self, filepath):
# Load signal. # Load signal.
@ -78,7 +101,7 @@ class AudioAnnotationsExtractor(object):
if self._signal.channels != 1: if self._signal.channels != 1:
raise NotImplementedError('multiple-channel annotations not implemented') raise NotImplementedError('multiple-channel annotations not implemented')
# level estimation params. # Level estimation params.
self._level_frame_size = int(self._signal.frame_rate / 1000 * ( self._level_frame_size = int(self._signal.frame_rate / 1000 * (
self._LEVEL_FRAME_SIZE_MS)) self._LEVEL_FRAME_SIZE_MS))
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else ( self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
@ -91,26 +114,26 @@ class AudioAnnotationsExtractor(object):
# Compute level. # Compute level.
self._LevelEstimation() self._LevelEstimation()
# Naive VAD based on level thresholding. It assumes ideal clean speech # Ideal VAD output, it requires clean speech with high SNR as input.
# with high SNR. if self._vad_type == self.VadType.ENERGY_THRESHOLD:
# TODO(alessiob): Maybe replace with a VAD based on stationary-noise # Naive VAD based on level thresholding.
# detection. vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD) self._vad_output = np.uint8(self._level > vad_threshold)
self._vad = np.uint8(self._level > vad_threshold) self._vad_frame_size = self._level_frame_size
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
# Speech level based on VAD output. elif self._vad_type == self.VadType.WEBRTC:
self._speech_level = self._level * self._vad # WebRTC VAD.
self._RunWebRtcVad(filepath, self._signal.frame_rate)
# Expand to one value per sample.
self._level = np.repeat(self._level, self._level_frame_size)
self._vad = np.repeat(self._vad, self._level_frame_size)
self._speech_level = np.repeat(self._speech_level, self._level_frame_size)
def Save(self, output_path): def Save(self, output_path):
np.save(os.path.join(output_path, self._LEVEL_FILENAME), self._level) np.savez_compressed(
np.save(os.path.join(output_path, self._VAD_FILENAME), self._vad) file=os.path.join(output_path, self._OUTPUT_FILENAME),
np.save(os.path.join(output_path, self._SPEECH_LEVEL_FILENAME), level=self._level,
self._speech_level) level_frame_size=self._level_frame_size,
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
vad_output=self._vad_output,
vad_frame_size=self._vad_frame_size,
vad_frame_size_ms=self._vad_frame_size_ms)
def _LevelEstimation(self): def _LevelEstimation(self):
# Read samples. # Read samples.
@ -132,4 +155,47 @@ class AudioAnnotationsExtractor(object):
self._level[i], self._level[i - 1], self._c_attack if ( self._level[i], self._level[i - 1], self._c_attack if (
self._level[i] > self._level[i - 1]) else self._c_decay) self._level[i] > self._level[i - 1]) else self._c_decay)
return self._level def _RunWebRtcVad(self, wav_file_path, sample_rate):
self._vad_output = None
self._vad_frame_size = None
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path = os.path.join(
tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp')
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_BIN_PATH,
'-i', wav_file_path,
'-o', output_file_path
], cwd=self._VAD_WEBRTC_PATH)
# Read bytes.
with open(output_file_path, 'rb') as f:
raw_data = f.read()
# Parse side information.
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
assert self._vad_frame_size_ms in [10, 20, 30]
extra_bits = struct.unpack('B', raw_data[-1])[0]
assert 0 <= extra_bits <= 8
# Init VAD vector.
num_bytes = len(raw_data)
num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte.
self._vad_output = np.zeros(num_frames, np.uint8)
# Read VAD decisions.
for i, byte in enumerate(raw_data[1:-1]):
byte = struct.unpack('B', byte)[0]
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
self._vad_output[i * 8 + j] = int(byte & 1)
byte = byte >> 1
except Exception as e:
logging.error('Error while running the WebRTC VAD (' + e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)

View File

@ -9,6 +9,7 @@
"""Unit tests for the annotations module. """Unit tests for the annotations module.
""" """
from __future__ import division
import logging import logging
import os import os
import shutil import shutil
@ -27,6 +28,7 @@ class TestAnnotationsExtraction(unittest.TestCase):
""" """
_CLEAN_TMP_OUTPUT = True _CLEAN_TMP_OUTPUT = True
_DEBUG_PLOT_VAD = False
def setUp(self): def setUp(self):
"""Create temporary folder.""" """Create temporary folder."""
@ -36,6 +38,7 @@ class TestAnnotationsExtraction(unittest.TestCase):
'pure_tone', [440, 1000]) 'pure_tone', [440, 1000])
signal_processing.SignalProcessingUtils.SaveWav( signal_processing.SignalProcessingUtils.SaveWav(
self._wav_file_path, pure_tone) self._wav_file_path, pure_tone)
self._sample_rate = pure_tone.frame_rate
def tearDown(self): def tearDown(self):
"""Recursively delete temporary folder.""" """Recursively delete temporary folder."""
@ -45,27 +48,49 @@ class TestAnnotationsExtraction(unittest.TestCase):
logging.warning(self.id() + ' did not clean the temporary path ' + ( logging.warning(self.id() + ' did not clean the temporary path ' + (
self._tmp_path)) self._tmp_path))
def testExtraction(self): def testFrameSizes(self):
e = annotations.AudioAnnotationsExtractor() for vad_type in annotations.AudioAnnotationsExtractor.VadType:
e.Extract(self._wav_file_path) e = annotations.AudioAnnotationsExtractor(vad_type=vad_type)
vad = e.GetVad() e.Extract(self._wav_file_path)
assert len(vad) > 0 samples_to_ms = lambda n, sr: 1000 * n // sr
self.assertGreaterEqual(float(np.sum(vad)) / len(vad), 0.95) self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
e.GetLevelFrameSizeMs())
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
e.GetVadFrameSizeMs())
def testVoiceActivityDetectors(self):
for vad_type in annotations.AudioAnnotationsExtractor.VadType:
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type)
e.Extract(self._wav_file_path)
vad_output = e.GetVadOutput()
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 0.95)
if self._DEBUG_PLOT_VAD:
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
num_frames).astype(np.float32) * frame_size_ms / 1000.0
level = e.GetLevel()
t_level = frame_times_s(
num_frames=len(level),
frame_size_ms=e.GetLevelFrameSizeMs())
t_vad = frame_times_s(
num_frames=len(vad_output),
frame_size_ms=e.GetVadFrameSizeMs())
import matplotlib.pyplot as plt
plt.figure()
plt.hold(True)
plt.plot(t_level, level)
plt.plot(t_vad, vad_output * np.max(level), '.')
plt.show()
def testSaveLoad(self): def testSaveLoad(self):
e = annotations.AudioAnnotationsExtractor() e = annotations.AudioAnnotationsExtractor(
vad_type=annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD)
e.Extract(self._wav_file_path) e.Extract(self._wav_file_path)
e.Save(self._tmp_path) e.Save(self._tmp_path)
level = np.load(os.path.join(self._tmp_path, e.GetLevelFileName())) data = np.load(os.path.join(self._tmp_path, e.GetOutputFileName()))
np.testing.assert_array_equal(e.GetLevel(), level) np.testing.assert_array_equal(e.GetLevel(), data['level'])
self.assertEqual(np.float32, level.dtype) self.assertEqual(np.float32, data['level'].dtype)
np.testing.assert_array_equal(e.GetVadOutput(), data['vad_output'])
vad = np.load(os.path.join(self._tmp_path, e.GetVadFileName())) self.assertEqual(np.uint8, data['vad_output'].dtype)
np.testing.assert_array_equal(e.GetVad(), vad)
self.assertEqual(np.uint8, vad.dtype)
speech_level = np.load(os.path.join(
self._tmp_path, e.GetSpeechLevelFileName()))
np.testing.assert_array_equal(e.GetSpeechLevel(), speech_level)
self.assertEqual(np.float32, speech_level.dtype)

View File

@ -46,7 +46,8 @@ class ApmModuleSimulator(object):
self._evaluation_score_factory = evaluation_score_factory self._evaluation_score_factory = evaluation_score_factory
self._audioproc_wrapper = ap_wrapper self._audioproc_wrapper = ap_wrapper
self._evaluator = evaluator self._evaluator = evaluator
self._annotator = annotations.AudioAnnotationsExtractor() self._annotator = annotations.AudioAnnotationsExtractor(
vad_type=annotations.AudioAnnotationsExtractor.VadType.WEBRTC)
# Init. # Init.
self._test_data_generator_factory.SetOutputDirectoryPrefix( self._test_data_generator_factory.SetOutputDirectoryPrefix(

View File

@ -0,0 +1,100 @@
// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
//
// Use of this source code is governed by a BSD-style license
// that can be found in the LICENSE file in the root of the source
// tree. An additional intellectual property rights grant can be found
// in the file PATENTS. All contributing project authors may
// be found in the AUTHORS file in the root of the source tree.
#include <array>
#include <fstream>
#include <memory>
#include "common_audio/vad/include/vad.h"
#include "common_audio/wav_file.h"
#include "rtc_base/flags.h"
#include "rtc_base/logging.h"
namespace webrtc {
namespace test {
namespace {
// The allowed values are 10, 20 or 30 ms.
constexpr uint8_t kAudioFrameLengthMilliseconds = 30;
constexpr int kMaxSampleRate = 48000;
constexpr size_t kMaxFrameLen =
kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
constexpr uint8_t kBitmaskBuffSize = 8;
DEFINE_string(i, "", "Input wav file");
DEFINE_string(o, "", "VAD output file");
int main(int argc, char* argv[]) {
if (rtc::FlagList::SetFlagsFromCommandLine(&argc, argv, true))
return 1;
// Open wav input file and check properties.
WavReader wav_reader(FLAG_i);
if (wav_reader.num_channels() != 1) {
LOG(LS_ERROR) << "Only mono wav files supported";
return 1;
}
if (wav_reader.sample_rate() > kMaxSampleRate) {
LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate << ")";
return 1;
}
const size_t kAudioFrameLen = rtc::CheckedDivExact(
kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
if (kAudioFrameLen > kMaxFrameLen) {
LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
return 2;
}
// Create output file and write header.
std::ofstream out_file(FLAG_o, std::ofstream::binary);
const char audio_frame_length_ms = kAudioFrameLengthMilliseconds;
out_file.write(&audio_frame_length_ms, 1); // Header.
// Run VAD and write decisions.
std::unique_ptr<Vad> vad = CreateVad(Vad::Aggressiveness::kVadNormal);
std::array<int16_t, kMaxFrameLen> samples;
char buff = 0; // Buffer to write one bit per frame.
uint8_t next = 0; // Points to the next bit to write in |buff|.
while (true) {
// Process frame.
const auto read_samples =
wav_reader.ReadSamples(kAudioFrameLen, samples.data());
if (read_samples < kAudioFrameLen)
break;
const auto is_speech = vad->VoiceActivity(samples.data(), kAudioFrameLen,
wav_reader.sample_rate());
// Write output.
buff = is_speech ? buff | (1 << next) : buff & ~(1 << next);
if (++next == kBitmaskBuffSize) {
out_file.write(&buff, 1); // Flush.
buff = 0; // Reset.
next = 0;
}
}
// Finalize.
char extra_bits = 0;
if (next > 0) {
extra_bits = kBitmaskBuffSize - next;
out_file.write(&buff, 1); // Flush.
}
out_file.write(&extra_bits, 1);
out_file.close();
return 0;
}
} // namespace
} // namespace test
} // namespace webrtc
int main(int argc, char* argv[]) {
return webrtc::test::main(argc, argv);
}