202 lines
6.2 KiB
Python
202 lines
6.2 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import json
|
|
import sys
|
|
from multiprocessing import cpu_count
|
|
|
|
import progressbar
|
|
import tensorflow.compat.v1 as tfv1
|
|
from coqui_stt_ctcdecoder import (
|
|
Scorer,
|
|
flashlight_beam_search_decoder_batch,
|
|
FlashlightDecoderState,
|
|
)
|
|
from six.moves import zip
|
|
|
|
import tensorflow as tf
|
|
|
|
from .deepspeech_model import create_model
|
|
from .util.augmentations import NormalizeSampleRate
|
|
from .util.checkpoints import load_graph_for_evaluation
|
|
from .util.config import (
|
|
Config,
|
|
create_progressbar,
|
|
initialize_globals_from_cli,
|
|
log_error,
|
|
log_progress,
|
|
)
|
|
from .util.evaluate_tools import calculate_and_print_report, save_samples_json
|
|
from .util.feeding import create_dataset
|
|
from .util.helpers import check_ctcdecoder_version
|
|
|
|
|
|
def sparse_tensor_value_to_texts(value, alphabet):
|
|
r"""
|
|
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
|
representing its values, converting tokens to strings using ``alphabet``.
|
|
"""
|
|
return sparse_tuple_to_texts(
|
|
(value.indices, value.values, value.dense_shape), alphabet
|
|
)
|
|
|
|
|
|
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
|
indices = sp_tuple[0]
|
|
values = sp_tuple[1]
|
|
results = [[] for _ in range(sp_tuple[2][0])]
|
|
for i, index in enumerate(indices):
|
|
results[index[0]].append(values[i])
|
|
# List of strings
|
|
return [alphabet.Decode(res) for res in results]
|
|
|
|
|
|
def evaluate(test_csvs, create_model):
|
|
if Config.scorer_path:
|
|
scorer = Scorer(
|
|
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
|
)
|
|
else:
|
|
scorer = None
|
|
|
|
test_sets = [
|
|
create_dataset(
|
|
[csv],
|
|
batch_size=Config.test_batch_size,
|
|
train_phase=False,
|
|
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
|
reverse=Config.reverse_test,
|
|
limit=Config.limit_test,
|
|
)
|
|
for csv in test_csvs
|
|
]
|
|
iterator = tfv1.data.Iterator.from_structure(
|
|
tfv1.data.get_output_types(test_sets[0]),
|
|
tfv1.data.get_output_shapes(test_sets[0]),
|
|
output_classes=tfv1.data.get_output_classes(test_sets[0]),
|
|
)
|
|
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
|
|
|
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
|
|
|
# One rate per layer
|
|
no_dropout = [None] * 6
|
|
logits, _ = create_model(
|
|
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
|
|
)
|
|
|
|
# Transpose to batch major and apply softmax for decoder
|
|
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
|
|
|
|
loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len)
|
|
|
|
tfv1.train.get_or_create_global_step()
|
|
|
|
# Get number of accessible CPU cores for this process
|
|
try:
|
|
num_processes = cpu_count()
|
|
except NotImplementedError:
|
|
num_processes = 1
|
|
|
|
with open(Config.vocab_file) as fin:
|
|
vocab = [l.strip().encode("utf-8") for l in fin]
|
|
|
|
with tfv1.Session(config=Config.session_config) as session:
|
|
load_graph_for_evaluation(session)
|
|
|
|
def run_test(init_op, dataset):
|
|
wav_filenames = []
|
|
losses = []
|
|
predictions = []
|
|
ground_truths = []
|
|
|
|
bar = create_progressbar(
|
|
prefix="Test epoch | ",
|
|
widgets=["Steps: ", progressbar.Counter(), " | ", progressbar.Timer()],
|
|
).start()
|
|
log_progress("Test epoch...")
|
|
|
|
step_count = 0
|
|
|
|
# Initialize iterator to the appropriate dataset
|
|
session.run(init_op)
|
|
|
|
# First pass, compute losses and transposed logits for decoding
|
|
while True:
|
|
try:
|
|
(
|
|
batch_wav_filenames,
|
|
batch_logits,
|
|
batch_loss,
|
|
batch_lengths,
|
|
batch_transcripts,
|
|
) = session.run(
|
|
[batch_wav_filename, transposed, loss, batch_x_len, batch_y]
|
|
)
|
|
except tf.errors.OutOfRangeError:
|
|
break
|
|
|
|
decoded = flashlight_beam_search_decoder_batch(
|
|
batch_logits,
|
|
batch_lengths,
|
|
Config.alphabet,
|
|
beam_size=Config.export_beam_width,
|
|
decoder_type=FlashlightDecoderState.DecoderType.LexiconBased,
|
|
token_type=FlashlightDecoderState.TokenType.Aggregate,
|
|
lm_tokens=vocab,
|
|
num_processes=num_processes,
|
|
scorer=scorer,
|
|
cutoff_top_n=Config.cutoff_top_n,
|
|
)
|
|
predictions.extend(" ".join(d[0].words) for d in decoded)
|
|
ground_truths.extend(
|
|
sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)
|
|
)
|
|
wav_filenames.extend(
|
|
wav_filename.decode("UTF-8") for wav_filename in batch_wav_filenames
|
|
)
|
|
losses.extend(batch_loss)
|
|
|
|
step_count += 1
|
|
bar.update(step_count)
|
|
|
|
bar.finish()
|
|
|
|
# Print test summary
|
|
test_samples = calculate_and_print_report(
|
|
wav_filenames, ground_truths, predictions, losses, dataset
|
|
)
|
|
return test_samples
|
|
|
|
samples = []
|
|
for csv, init_op in zip(test_csvs, test_init_ops):
|
|
print("Testing model on {}".format(csv))
|
|
samples.extend(run_test(init_op, dataset=csv))
|
|
return samples
|
|
|
|
|
|
def test():
|
|
tfv1.reset_default_graph()
|
|
|
|
samples = evaluate(Config.test_files, create_model)
|
|
if Config.test_output_file:
|
|
save_samples_json(samples, Config.test_output_file)
|
|
|
|
|
|
def main():
|
|
initialize_globals_from_cli()
|
|
check_ctcdecoder_version()
|
|
|
|
if not Config.test_files:
|
|
raise RuntimeError(
|
|
"You need to specify what files to use for evaluation via "
|
|
"the --test_files flag."
|
|
)
|
|
|
|
test()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|