Merge pull request #2959 from mozilla/test-output-utf8
Don't escape non-ASCII chars in test_output_file JSON & other small fixes
This commit is contained in:
commit
26e2f88bfe
0
DeepSpeech.py
Normal file → Executable file
0
DeepSpeech.py
Normal file → Executable file
@ -112,7 +112,7 @@ If, for example, Common Voice language ``en`` was extracted to ``../data/CV/en/`
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./DeepSpeech.py --train_files ../data/CV/en/clips/train.csv --dev_files ../data/CV/en/clips/dev.csv --test_files ../data/CV/en/clips/test.csv
|
||||
python3 DeepSpeech.py --train_files ../data/CV/en/clips/train.csv --dev_files ../data/CV/en/clips/dev.csv --test_files ../data/CV/en/clips/test.csv
|
||||
|
||||
Training a model
|
||||
^^^^^^^^^^^^^^^^
|
||||
@ -121,7 +121,7 @@ The central (Python) script is ``DeepSpeech.py`` in the project's root directory
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./DeepSpeech.py --helpfull
|
||||
python3 DeepSpeech.py --helpfull
|
||||
|
||||
To get the output of this in a slightly better-formatted way, you can also look at the flag definitions in :ref:`training-flags`.
|
||||
|
||||
@ -163,7 +163,7 @@ Automatic Mixed Precision (AMP) training on GPU for TensorFlow has been recently
|
||||
Mixed precision training makes use of both FP32 and FP16 precisions where appropriate. FP16 operations can leverage the Tensor cores on NVIDIA GPUs (Volta, Turing or newer architectures) for improved throughput. Mixed precision training also often allows larger batch sizes. DeepSpeech GPU automatic mixed precision training can be enabled via the flag value `--auto_mixed_precision=True`.
|
||||
|
||||
```
|
||||
DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_files ./test.csv --automatic_mixed_precision=True
|
||||
python3 DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_files ./test.csv --automatic_mixed_precision=True
|
||||
```
|
||||
|
||||
On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech training and evaluation by ~30%-40%.
|
||||
|
@ -17,7 +17,7 @@ from six.moves import zip
|
||||
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_graph_for_evaluation
|
||||
from .util.evaluate_tools import calculate_and_print_report
|
||||
from .util.evaluate_tools import calculate_and_print_report, save_samples_json
|
||||
from .util.feeding import create_dataset
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
@ -143,8 +143,7 @@ def main(_):
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
save_samples_json(samples, FLAGS.test_output_file)
|
||||
|
||||
|
||||
def run_script():
|
||||
|
@ -31,10 +31,11 @@ from .evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
|
||||
from .util.evaluate_tools import save_samples_json
|
||||
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
@ -641,8 +642,7 @@ def train():
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
save_samples_json(samples, FLAGS.test_output_file)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
|
@ -2,9 +2,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
from multiprocessing.dummy import Pool
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
from attrdict import AttrDict
|
||||
|
||||
from .flags import FLAGS
|
||||
@ -115,3 +116,13 @@ def print_report(samples, losses, wer, cer, dataset_name):
|
||||
print('Worst WER:', '\n' + '-' * 80)
|
||||
for s in worst_samples:
|
||||
print_single_sample(s)
|
||||
|
||||
|
||||
def save_samples_json(samples, output_path):
|
||||
''' Save decoded tuples as JSON, converting NumPy floats to Python floats.
|
||||
|
||||
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
|
||||
in the texts.
|
||||
'''
|
||||
with open(output_path, 'w') as fout:
|
||||
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user