Don't escape non-ASCII characters in test_output_file
This commit is contained in:
parent
65b7c41746
commit
b283aadae6
@ -17,7 +17,7 @@ from six.moves import zip
|
||||
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_graph_for_evaluation
|
||||
from .util.evaluate_tools import calculate_and_print_report
|
||||
from .util.evaluate_tools import calculate_and_print_report, save_samples_json
|
||||
from .util.feeding import create_dataset
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
@ -143,8 +143,7 @@ def main(_):
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
save_samples_json(samples, FLAGS.test_output_file)
|
||||
|
||||
|
||||
def run_script():
|
||||
|
@ -31,6 +31,7 @@ from .evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
|
||||
from .util.evaluate_tools import save_samples_json
|
||||
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
@ -641,8 +642,7 @@ def train():
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
save_samples_json(samples, FLAGS.test_output_file)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
|
@ -2,9 +2,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
from multiprocessing.dummy import Pool
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
from attrdict import AttrDict
|
||||
|
||||
from .flags import FLAGS
|
||||
@ -115,3 +116,13 @@ def print_report(samples, losses, wer, cer, dataset_name):
|
||||
print('Worst WER:', '\n' + '-' * 80)
|
||||
for s in worst_samples:
|
||||
print_single_sample(s)
|
||||
|
||||
|
||||
def save_samples_json(samples, output_path):
|
||||
''' Save decoded tuples as JSON, converting NumPy floats to Python floats.
|
||||
|
||||
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
|
||||
in the texts.
|
||||
'''
|
||||
with open(output_path, 'w') as fout:
|
||||
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user