From b283aadae6d1907d93a4bc5cb072ed97de4e6c06 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Wed, 29 Apr 2020 15:42:48 +0200 Subject: [PATCH] Don't escape non-ASCII characters in test_output_file --- training/deepspeech_training/evaluate.py | 5 ++--- training/deepspeech_training/train.py | 4 ++-- training/deepspeech_training/util/evaluate_tools.py | 13 ++++++++++++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/training/deepspeech_training/evaluate.py b/training/deepspeech_training/evaluate.py index 5877b618..3ae8b9b8 100755 --- a/training/deepspeech_training/evaluate.py +++ b/training/deepspeech_training/evaluate.py @@ -17,7 +17,7 @@ from six.moves import zip from .util.config import Config, initialize_globals from .util.checkpoints import load_graph_for_evaluation -from .util.evaluate_tools import calculate_and_print_report +from .util.evaluate_tools import calculate_and_print_report, save_samples_json from .util.feeding import create_dataset from .util.flags import create_flags, FLAGS from .util.helpers import check_ctcdecoder_version @@ -143,8 +143,7 @@ def main(_): samples = evaluate(FLAGS.test_files.split(','), create_model) if FLAGS.test_output_file: - # Save decoded tuples as JSON, converting NumPy floats to Python floats - json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float) + save_samples_json(samples, FLAGS.test_output_file) def run_script(): diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index eed8cd9e..143ac377 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -31,6 +31,7 @@ from .evaluate import evaluate from six.moves import zip, range from .util.config import Config, initialize_globals from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation +from .util.evaluate_tools import save_samples_json from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from .util.flags import create_flags, FLAGS from .util.helpers import check_ctcdecoder_version, ExceptionBox @@ -641,8 +642,7 @@ def train(): def test(): samples = evaluate(FLAGS.test_files.split(','), create_model) if FLAGS.test_output_file: - # Save decoded tuples as JSON, converting NumPy floats to Python floats - json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float) + save_samples_json(samples, FLAGS.test_output_file) def create_inference_graph(batch_size=1, n_steps=16, tflite=False): diff --git a/training/deepspeech_training/util/evaluate_tools.py b/training/deepspeech_training/util/evaluate_tools.py index 19f06961..e482211e 100644 --- a/training/deepspeech_training/util/evaluate_tools.py +++ b/training/deepspeech_training/util/evaluate_tools.py @@ -2,9 +2,10 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function +import json from multiprocessing.dummy import Pool -import numpy as np +import numpy as np from attrdict import AttrDict from .flags import FLAGS @@ -115,3 +116,13 @@ def print_report(samples, losses, wer, cer, dataset_name): print('Worst WER:', '\n' + '-' * 80) for s in worst_samples: print_single_sample(s) + + +def save_samples_json(samples, output_path): + ''' Save decoded tuples as JSON, converting NumPy floats to Python floats. + + We set ensure_ascii=True to prevent json from escaping non-ASCII chars + in the texts. + ''' + with open(output_path, 'w') as fout: + json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)