Don't escape non-ASCII characters in test_output_file

This commit is contained in:
Reuben Morais 2020-04-29 15:42:48 +02:00
parent 65b7c41746
commit b283aadae6
3 changed files with 16 additions and 6 deletions

View File

@ -17,7 +17,7 @@ from six.moves import zip
from .util.config import Config, initialize_globals from .util.config import Config, initialize_globals
from .util.checkpoints import load_graph_for_evaluation from .util.checkpoints import load_graph_for_evaluation
from .util.evaluate_tools import calculate_and_print_report from .util.evaluate_tools import calculate_and_print_report, save_samples_json
from .util.feeding import create_dataset from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version from .util.helpers import check_ctcdecoder_version
@ -143,8 +143,7 @@ def main(_):
samples = evaluate(FLAGS.test_files.split(','), create_model) samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file: if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats save_samples_json(samples, FLAGS.test_output_file)
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def run_script(): def run_script():

View File

@ -31,6 +31,7 @@ from .evaluate import evaluate
from six.moves import zip, range from six.moves import zip, range
from .util.config import Config, initialize_globals from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
from .util.evaluate_tools import save_samples_json
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox from .util.helpers import check_ctcdecoder_version, ExceptionBox
@ -641,8 +642,7 @@ def train():
def test(): def test():
samples = evaluate(FLAGS.test_files.split(','), create_model) samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file: if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats save_samples_json(samples, FLAGS.test_output_file)
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False): def create_inference_graph(batch_size=1, n_steps=16, tflite=False):

View File

@ -2,9 +2,10 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import json
from multiprocessing.dummy import Pool from multiprocessing.dummy import Pool
import numpy as np
import numpy as np
from attrdict import AttrDict from attrdict import AttrDict
from .flags import FLAGS from .flags import FLAGS
@ -115,3 +116,13 @@ def print_report(samples, losses, wer, cer, dataset_name):
print('Worst WER:', '\n' + '-' * 80) print('Worst WER:', '\n' + '-' * 80)
for s in worst_samples: for s in worst_samples:
print_single_sample(s) print_single_sample(s)
def save_samples_json(samples, output_path):
''' Save decoded tuples as JSON, converting NumPy floats to Python floats.
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
in the texts.
'''
with open(output_path, 'w') as fout:
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)