WebRTC VAD wrapper for APM-QA
Alternative VAD based on the existing one in WebRTC. It is used to extract VAD annotations in APM-QA. TBR= Bug: webrtc:7494 Change-Id: I6af412742f804631ad4f3ba3ccf71a30d74de984 Reviewed-on: https://webrtc-review.googlesource.com/14553 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/master@{#20404}
This commit is contained in:

committed by
Commit Bot

parent
ef48df9aeb
commit
330bf4076e
@ -102,6 +102,7 @@ group("unit_tests") {
|
||||
":fake_polqa",
|
||||
":lib_unit_tests",
|
||||
":scripts_unit_tests",
|
||||
":vad",
|
||||
]
|
||||
}
|
||||
|
||||
@ -118,6 +119,17 @@ rtc_executable("fake_polqa") {
|
||||
]
|
||||
}
|
||||
|
||||
rtc_executable("vad") {
|
||||
sources = [
|
||||
"quality_assessment/vad.cc",
|
||||
]
|
||||
deps = [
|
||||
"../../../..:webrtc_common",
|
||||
"../../../../common_audio",
|
||||
"../../../../rtc_base:rtc_base_approved",
|
||||
]
|
||||
}
|
||||
|
||||
copy("lib_unit_tests") {
|
||||
testonly = true
|
||||
sources = [
|
||||
|
@ -10,9 +10,14 @@
|
||||
"""
|
||||
|
||||
from __future__ import division
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
@ -20,6 +25,7 @@ except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import exceptions
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
@ -27,9 +33,12 @@ class AudioAnnotationsExtractor(object):
|
||||
"""Extracts annotations from audio files.
|
||||
"""
|
||||
|
||||
_LEVEL_FILENAME = 'level.npy'
|
||||
_VAD_FILENAME = 'vad.npy'
|
||||
_SPEECH_LEVEL_FILENAME = 'speech_level.npy'
|
||||
@enum.unique
|
||||
class VadType(enum.Enum):
|
||||
ENERGY_THRESHOLD = 0 # TODO(alessiob): Consider switching to P56 standard.
|
||||
WEBRTC = 1
|
||||
|
||||
_OUTPUT_FILENAME = 'annotations.npz'
|
||||
|
||||
# Level estimation params.
|
||||
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
|
||||
@ -41,36 +50,50 @@ class AudioAnnotationsExtractor(object):
|
||||
|
||||
# VAD params.
|
||||
_VAD_THRESHOLD = 1
|
||||
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(
|
||||
os.path.abspath(__file__)), os.pardir, os.pardir)
|
||||
_VAD_WEBRTC_BIN_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, vad_type):
|
||||
self._signal = None
|
||||
self._level = None
|
||||
self._vad = None
|
||||
self._speech_level = None
|
||||
self._level_frame_size = None
|
||||
self._vad_output = None
|
||||
self._vad_frame_size = None
|
||||
self._vad_frame_size_ms = None
|
||||
self._c_attack = None
|
||||
self._c_decay = None
|
||||
|
||||
@classmethod
|
||||
def GetLevelFileName(cls):
|
||||
return cls._LEVEL_FILENAME
|
||||
self._vad_type = vad_type
|
||||
if self._vad_type not in self.VadType:
|
||||
raise exceptions.InitializationException(
|
||||
'Invalid vad type: ' + self._vad_type)
|
||||
logging.info('VAD used for annotations: ' + str(self._vad_type))
|
||||
|
||||
assert os.path.exists(self._VAD_WEBRTC_BIN_PATH), self._VAD_WEBRTC_BIN_PATH
|
||||
|
||||
@classmethod
|
||||
def GetVadFileName(cls):
|
||||
return cls._VAD_FILENAME
|
||||
|
||||
@classmethod
|
||||
def GetSpeechLevelFileName(cls):
|
||||
return cls._SPEECH_LEVEL_FILENAME
|
||||
def GetOutputFileName(cls):
|
||||
return cls._OUTPUT_FILENAME
|
||||
|
||||
def GetLevel(self):
|
||||
return self._level
|
||||
|
||||
def GetVad(self):
|
||||
return self._vad
|
||||
def GetLevelFrameSize(self):
|
||||
return self._level_frame_size
|
||||
|
||||
def GetSpeechLevel(self):
|
||||
return self._speech_level
|
||||
@classmethod
|
||||
def GetLevelFrameSizeMs(cls):
|
||||
return cls._LEVEL_FRAME_SIZE_MS
|
||||
|
||||
def GetVadOutput(self):
|
||||
return self._vad_output
|
||||
|
||||
def GetVadFrameSize(self):
|
||||
return self._vad_frame_size
|
||||
|
||||
def GetVadFrameSizeMs(self):
|
||||
return self._vad_frame_size_ms
|
||||
|
||||
def Extract(self, filepath):
|
||||
# Load signal.
|
||||
@ -78,7 +101,7 @@ class AudioAnnotationsExtractor(object):
|
||||
if self._signal.channels != 1:
|
||||
raise NotImplementedError('multiple-channel annotations not implemented')
|
||||
|
||||
# level estimation params.
|
||||
# Level estimation params.
|
||||
self._level_frame_size = int(self._signal.frame_rate / 1000 * (
|
||||
self._LEVEL_FRAME_SIZE_MS))
|
||||
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
|
||||
@ -91,26 +114,26 @@ class AudioAnnotationsExtractor(object):
|
||||
# Compute level.
|
||||
self._LevelEstimation()
|
||||
|
||||
# Naive VAD based on level thresholding. It assumes ideal clean speech
|
||||
# with high SNR.
|
||||
# TODO(alessiob): Maybe replace with a VAD based on stationary-noise
|
||||
# detection.
|
||||
# Ideal VAD output, it requires clean speech with high SNR as input.
|
||||
if self._vad_type == self.VadType.ENERGY_THRESHOLD:
|
||||
# Naive VAD based on level thresholding.
|
||||
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
|
||||
self._vad = np.uint8(self._level > vad_threshold)
|
||||
|
||||
# Speech level based on VAD output.
|
||||
self._speech_level = self._level * self._vad
|
||||
|
||||
# Expand to one value per sample.
|
||||
self._level = np.repeat(self._level, self._level_frame_size)
|
||||
self._vad = np.repeat(self._vad, self._level_frame_size)
|
||||
self._speech_level = np.repeat(self._speech_level, self._level_frame_size)
|
||||
self._vad_output = np.uint8(self._level > vad_threshold)
|
||||
self._vad_frame_size = self._level_frame_size
|
||||
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
|
||||
elif self._vad_type == self.VadType.WEBRTC:
|
||||
# WebRTC VAD.
|
||||
self._RunWebRtcVad(filepath, self._signal.frame_rate)
|
||||
|
||||
def Save(self, output_path):
|
||||
np.save(os.path.join(output_path, self._LEVEL_FILENAME), self._level)
|
||||
np.save(os.path.join(output_path, self._VAD_FILENAME), self._vad)
|
||||
np.save(os.path.join(output_path, self._SPEECH_LEVEL_FILENAME),
|
||||
self._speech_level)
|
||||
np.savez_compressed(
|
||||
file=os.path.join(output_path, self._OUTPUT_FILENAME),
|
||||
level=self._level,
|
||||
level_frame_size=self._level_frame_size,
|
||||
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
|
||||
vad_output=self._vad_output,
|
||||
vad_frame_size=self._vad_frame_size,
|
||||
vad_frame_size_ms=self._vad_frame_size_ms)
|
||||
|
||||
def _LevelEstimation(self):
|
||||
# Read samples.
|
||||
@ -132,4 +155,47 @@ class AudioAnnotationsExtractor(object):
|
||||
self._level[i], self._level[i - 1], self._c_attack if (
|
||||
self._level[i] > self._level[i - 1]) else self._c_decay)
|
||||
|
||||
return self._level
|
||||
def _RunWebRtcVad(self, wav_file_path, sample_rate):
|
||||
self._vad_output = None
|
||||
self._vad_frame_size = None
|
||||
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path = os.path.join(
|
||||
tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp')
|
||||
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_BIN_PATH,
|
||||
'-i', wav_file_path,
|
||||
'-o', output_file_path
|
||||
], cwd=self._VAD_WEBRTC_PATH)
|
||||
|
||||
# Read bytes.
|
||||
with open(output_file_path, 'rb') as f:
|
||||
raw_data = f.read()
|
||||
|
||||
# Parse side information.
|
||||
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
|
||||
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
|
||||
assert self._vad_frame_size_ms in [10, 20, 30]
|
||||
extra_bits = struct.unpack('B', raw_data[-1])[0]
|
||||
assert 0 <= extra_bits <= 8
|
||||
|
||||
# Init VAD vector.
|
||||
num_bytes = len(raw_data)
|
||||
num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte.
|
||||
self._vad_output = np.zeros(num_frames, np.uint8)
|
||||
|
||||
# Read VAD decisions.
|
||||
for i, byte in enumerate(raw_data[1:-1]):
|
||||
byte = struct.unpack('B', byte)[0]
|
||||
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
|
||||
self._vad_output[i * 8 + j] = int(byte & 1)
|
||||
byte = byte >> 1
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC VAD (' + e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
|
@ -9,6 +9,7 @@
|
||||
"""Unit tests for the annotations module.
|
||||
"""
|
||||
|
||||
from __future__ import division
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@ -27,6 +28,7 @@ class TestAnnotationsExtraction(unittest.TestCase):
|
||||
"""
|
||||
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
_DEBUG_PLOT_VAD = False
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folder."""
|
||||
@ -36,6 +38,7 @@ class TestAnnotationsExtraction(unittest.TestCase):
|
||||
'pure_tone', [440, 1000])
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._wav_file_path, pure_tone)
|
||||
self._sample_rate = pure_tone.frame_rate
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
@ -45,27 +48,49 @@ class TestAnnotationsExtraction(unittest.TestCase):
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' + (
|
||||
self._tmp_path))
|
||||
|
||||
def testExtraction(self):
|
||||
e = annotations.AudioAnnotationsExtractor()
|
||||
def testFrameSizes(self):
|
||||
for vad_type in annotations.AudioAnnotationsExtractor.VadType:
|
||||
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type)
|
||||
e.Extract(self._wav_file_path)
|
||||
vad = e.GetVad()
|
||||
assert len(vad) > 0
|
||||
self.assertGreaterEqual(float(np.sum(vad)) / len(vad), 0.95)
|
||||
samples_to_ms = lambda n, sr: 1000 * n // sr
|
||||
self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
|
||||
e.GetLevelFrameSizeMs())
|
||||
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
|
||||
e.GetVadFrameSizeMs())
|
||||
|
||||
def testVoiceActivityDetectors(self):
|
||||
for vad_type in annotations.AudioAnnotationsExtractor.VadType:
|
||||
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type)
|
||||
e.Extract(self._wav_file_path)
|
||||
vad_output = e.GetVadOutput()
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 0.95)
|
||||
|
||||
if self._DEBUG_PLOT_VAD:
|
||||
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
|
||||
num_frames).astype(np.float32) * frame_size_ms / 1000.0
|
||||
level = e.GetLevel()
|
||||
t_level = frame_times_s(
|
||||
num_frames=len(level),
|
||||
frame_size_ms=e.GetLevelFrameSizeMs())
|
||||
t_vad = frame_times_s(
|
||||
num_frames=len(vad_output),
|
||||
frame_size_ms=e.GetVadFrameSizeMs())
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure()
|
||||
plt.hold(True)
|
||||
plt.plot(t_level, level)
|
||||
plt.plot(t_vad, vad_output * np.max(level), '.')
|
||||
plt.show()
|
||||
|
||||
def testSaveLoad(self):
|
||||
e = annotations.AudioAnnotationsExtractor()
|
||||
e = annotations.AudioAnnotationsExtractor(
|
||||
vad_type=annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD)
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path)
|
||||
|
||||
level = np.load(os.path.join(self._tmp_path, e.GetLevelFileName()))
|
||||
np.testing.assert_array_equal(e.GetLevel(), level)
|
||||
self.assertEqual(np.float32, level.dtype)
|
||||
|
||||
vad = np.load(os.path.join(self._tmp_path, e.GetVadFileName()))
|
||||
np.testing.assert_array_equal(e.GetVad(), vad)
|
||||
self.assertEqual(np.uint8, vad.dtype)
|
||||
|
||||
speech_level = np.load(os.path.join(
|
||||
self._tmp_path, e.GetSpeechLevelFileName()))
|
||||
np.testing.assert_array_equal(e.GetSpeechLevel(), speech_level)
|
||||
self.assertEqual(np.float32, speech_level.dtype)
|
||||
data = np.load(os.path.join(self._tmp_path, e.GetOutputFileName()))
|
||||
np.testing.assert_array_equal(e.GetLevel(), data['level'])
|
||||
self.assertEqual(np.float32, data['level'].dtype)
|
||||
np.testing.assert_array_equal(e.GetVadOutput(), data['vad_output'])
|
||||
self.assertEqual(np.uint8, data['vad_output'].dtype)
|
||||
|
@ -46,7 +46,8 @@ class ApmModuleSimulator(object):
|
||||
self._evaluation_score_factory = evaluation_score_factory
|
||||
self._audioproc_wrapper = ap_wrapper
|
||||
self._evaluator = evaluator
|
||||
self._annotator = annotations.AudioAnnotationsExtractor()
|
||||
self._annotator = annotations.AudioAnnotationsExtractor(
|
||||
vad_type=annotations.AudioAnnotationsExtractor.VadType.WEBRTC)
|
||||
|
||||
# Init.
|
||||
self._test_data_generator_factory.SetOutputDirectoryPrefix(
|
||||
|
@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
//
|
||||
// Use of this source code is governed by a BSD-style license
|
||||
// that can be found in the LICENSE file in the root of the source
|
||||
// tree. An additional intellectual property rights grant can be found
|
||||
// in the file PATENTS. All contributing project authors may
|
||||
// be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
#include <array>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
|
||||
#include "common_audio/vad/include/vad.h"
|
||||
#include "common_audio/wav_file.h"
|
||||
#include "rtc_base/flags.h"
|
||||
#include "rtc_base/logging.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
// The allowed values are 10, 20 or 30 ms.
|
||||
constexpr uint8_t kAudioFrameLengthMilliseconds = 30;
|
||||
constexpr int kMaxSampleRate = 48000;
|
||||
constexpr size_t kMaxFrameLen =
|
||||
kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
|
||||
|
||||
constexpr uint8_t kBitmaskBuffSize = 8;
|
||||
|
||||
DEFINE_string(i, "", "Input wav file");
|
||||
DEFINE_string(o, "", "VAD output file");
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (rtc::FlagList::SetFlagsFromCommandLine(&argc, argv, true))
|
||||
return 1;
|
||||
|
||||
// Open wav input file and check properties.
|
||||
WavReader wav_reader(FLAG_i);
|
||||
if (wav_reader.num_channels() != 1) {
|
||||
LOG(LS_ERROR) << "Only mono wav files supported";
|
||||
return 1;
|
||||
}
|
||||
if (wav_reader.sample_rate() > kMaxSampleRate) {
|
||||
LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate << ")";
|
||||
return 1;
|
||||
}
|
||||
const size_t kAudioFrameLen = rtc::CheckedDivExact(
|
||||
kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
|
||||
if (kAudioFrameLen > kMaxFrameLen) {
|
||||
LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Create output file and write header.
|
||||
std::ofstream out_file(FLAG_o, std::ofstream::binary);
|
||||
const char audio_frame_length_ms = kAudioFrameLengthMilliseconds;
|
||||
out_file.write(&audio_frame_length_ms, 1); // Header.
|
||||
|
||||
// Run VAD and write decisions.
|
||||
std::unique_ptr<Vad> vad = CreateVad(Vad::Aggressiveness::kVadNormal);
|
||||
std::array<int16_t, kMaxFrameLen> samples;
|
||||
char buff = 0; // Buffer to write one bit per frame.
|
||||
uint8_t next = 0; // Points to the next bit to write in |buff|.
|
||||
while (true) {
|
||||
// Process frame.
|
||||
const auto read_samples =
|
||||
wav_reader.ReadSamples(kAudioFrameLen, samples.data());
|
||||
if (read_samples < kAudioFrameLen)
|
||||
break;
|
||||
const auto is_speech = vad->VoiceActivity(samples.data(), kAudioFrameLen,
|
||||
wav_reader.sample_rate());
|
||||
|
||||
// Write output.
|
||||
buff = is_speech ? buff | (1 << next) : buff & ~(1 << next);
|
||||
if (++next == kBitmaskBuffSize) {
|
||||
out_file.write(&buff, 1); // Flush.
|
||||
buff = 0; // Reset.
|
||||
next = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize.
|
||||
char extra_bits = 0;
|
||||
if (next > 0) {
|
||||
extra_bits = kBitmaskBuffSize - next;
|
||||
out_file.write(&buff, 1); // Flush.
|
||||
}
|
||||
out_file.write(&extra_bits, 1);
|
||||
out_file.close();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return webrtc::test::main(argc, argv);
|
||||
}
|
Reference in New Issue
Block a user