Don't escape non-ASCII characters in test_output_file

This commit is contained in:
Reuben Morais 2020-04-29 15:42:48 +02:00
parent 65b7c41746
commit b283aadae6
3 changed files with 16 additions and 6 deletions

View File

@ -17,7 +17,7 @@ from six.moves import zip
from .util.config import Config, initialize_globals
from .util.checkpoints import load_graph_for_evaluation
from .util.evaluate_tools import calculate_and_print_report
from .util.evaluate_tools import calculate_and_print_report, save_samples_json
from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version
@ -143,8 +143,7 @@ def main(_):
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
save_samples_json(samples, FLAGS.test_output_file)
def run_script():

View File

@ -31,6 +31,7 @@ from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
from .util.evaluate_tools import save_samples_json
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
@ -641,8 +642,7 @@ def train():
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
save_samples_json(samples, FLAGS.test_output_file)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):

View File

@ -2,9 +2,10 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
from multiprocessing.dummy import Pool
import numpy as np
import numpy as np
from attrdict import AttrDict
from .flags import FLAGS
@ -115,3 +116,13 @@ def print_report(samples, losses, wer, cer, dataset_name):
print('Worst WER:', '\n' + '-' * 80)
for s in worst_samples:
print_single_sample(s)
def save_samples_json(samples, output_path):
''' Save decoded tuples as JSON, converting NumPy floats to Python floats.
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
in the texts.
'''
with open(output_path, 'w') as fout:
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)