Merge pull request #2959 from mozilla/test-output-utf8

Don't escape non-ASCII chars in test_output_file JSON & other small fixes
This commit is contained in:
Reuben Morais 2020-04-29 18:05:57 +02:00 committed by GitHub
commit 26e2f88bfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 10 deletions

0
DeepSpeech.py Normal file → Executable file
View File

View File

@ -112,7 +112,7 @@ If, for example, Common Voice language ``en`` was extracted to ``../data/CV/en/`
.. code-block:: bash
./DeepSpeech.py --train_files ../data/CV/en/clips/train.csv --dev_files ../data/CV/en/clips/dev.csv --test_files ../data/CV/en/clips/test.csv
python3 DeepSpeech.py --train_files ../data/CV/en/clips/train.csv --dev_files ../data/CV/en/clips/dev.csv --test_files ../data/CV/en/clips/test.csv
Training a model
^^^^^^^^^^^^^^^^
@ -121,7 +121,7 @@ The central (Python) script is ``DeepSpeech.py`` in the project's root directory
.. code-block:: bash
./DeepSpeech.py --helpfull
python3 DeepSpeech.py --helpfull
To get the output of this in a slightly better-formatted way, you can also look at the flag definitions in :ref:`training-flags`.
@ -163,7 +163,7 @@ Automatic Mixed Precision (AMP) training on GPU for TensorFlow has been recently
Mixed precision training makes use of both FP32 and FP16 precisions where appropriate. FP16 operations can leverage the Tensor cores on NVIDIA GPUs (Volta, Turing or newer architectures) for improved throughput. Mixed precision training also often allows larger batch sizes. DeepSpeech GPU automatic mixed precision training can be enabled via the flag value `--auto_mixed_precision=True`.
```
DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_files ./test.csv --automatic_mixed_precision=True
python3 DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_files ./test.csv --automatic_mixed_precision=True
```
On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech training and evaluation by ~30%-40%.

View File

@ -17,7 +17,7 @@ from six.moves import zip
from .util.config import Config, initialize_globals
from .util.checkpoints import load_graph_for_evaluation
from .util.evaluate_tools import calculate_and_print_report
from .util.evaluate_tools import calculate_and_print_report, save_samples_json
from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version
@ -143,8 +143,7 @@ def main(_):
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
save_samples_json(samples, FLAGS.test_output_file)
def run_script():

View File

@ -31,10 +31,11 @@ from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
from .util.evaluate_tools import save_samples_json
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
check_ctcdecoder_version()
@ -641,8 +642,7 @@ def train():
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
save_samples_json(samples, FLAGS.test_output_file)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):

View File

@ -2,9 +2,10 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
from multiprocessing.dummy import Pool
import numpy as np
import numpy as np
from attrdict import AttrDict
from .flags import FLAGS
@ -115,3 +116,13 @@ def print_report(samples, losses, wer, cer, dataset_name):
print('Worst WER:', '\n' + '-' * 80)
for s in worst_samples:
print_single_sample(s)
def save_samples_json(samples, output_path):
''' Save decoded tuples as JSON, converting NumPy floats to Python floats.
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
in the texts.
'''
with open(output_path, 'w') as fout:
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)