Don't escape non-ASCII characters in test_output_file
This commit is contained in:
parent
65b7c41746
commit
b283aadae6
@ -17,7 +17,7 @@ from six.moves import zip
|
|||||||
|
|
||||||
from .util.config import Config, initialize_globals
|
from .util.config import Config, initialize_globals
|
||||||
from .util.checkpoints import load_graph_for_evaluation
|
from .util.checkpoints import load_graph_for_evaluation
|
||||||
from .util.evaluate_tools import calculate_and_print_report
|
from .util.evaluate_tools import calculate_and_print_report, save_samples_json
|
||||||
from .util.feeding import create_dataset
|
from .util.feeding import create_dataset
|
||||||
from .util.flags import create_flags, FLAGS
|
from .util.flags import create_flags, FLAGS
|
||||||
from .util.helpers import check_ctcdecoder_version
|
from .util.helpers import check_ctcdecoder_version
|
||||||
@ -143,8 +143,7 @@ def main(_):
|
|||||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
|
|
||||||
if FLAGS.test_output_file:
|
if FLAGS.test_output_file:
|
||||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
save_samples_json(samples, FLAGS.test_output_file)
|
||||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
|
||||||
|
|
||||||
|
|
||||||
def run_script():
|
def run_script():
|
||||||
|
@ -31,6 +31,7 @@ from .evaluate import evaluate
|
|||||||
from six.moves import zip, range
|
from six.moves import zip, range
|
||||||
from .util.config import Config, initialize_globals
|
from .util.config import Config, initialize_globals
|
||||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
|
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
|
||||||
|
from .util.evaluate_tools import save_samples_json
|
||||||
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||||
from .util.flags import create_flags, FLAGS
|
from .util.flags import create_flags, FLAGS
|
||||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||||
@ -641,8 +642,7 @@ def train():
|
|||||||
def test():
|
def test():
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
if FLAGS.test_output_file:
|
if FLAGS.test_output_file:
|
||||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
save_samples_json(samples, FLAGS.test_output_file)
|
||||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
|
||||||
|
|
||||||
|
|
||||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import json
|
||||||
from multiprocessing.dummy import Pool
|
from multiprocessing.dummy import Pool
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from attrdict import AttrDict
|
from attrdict import AttrDict
|
||||||
|
|
||||||
from .flags import FLAGS
|
from .flags import FLAGS
|
||||||
@ -115,3 +116,13 @@ def print_report(samples, losses, wer, cer, dataset_name):
|
|||||||
print('Worst WER:', '\n' + '-' * 80)
|
print('Worst WER:', '\n' + '-' * 80)
|
||||||
for s in worst_samples:
|
for s in worst_samples:
|
||||||
print_single_sample(s)
|
print_single_sample(s)
|
||||||
|
|
||||||
|
|
||||||
|
def save_samples_json(samples, output_path):
|
||||||
|
''' Save decoded tuples as JSON, converting NumPy floats to Python floats.
|
||||||
|
|
||||||
|
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
|
||||||
|
in the texts.
|
||||||
|
'''
|
||||||
|
with open(output_path, 'w') as fout:
|
||||||
|
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user