STT/training/coqui_stt_training/training_graph_inference.py

88 lines
2.7 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
DESIRED_LOG_LEVEL = (
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import ctc_beam_search_decoder, Scorer
from .deepspeech_model import create_inference_graph, create_overlapping_windows
from .util.checkpoints import load_graph_for_evaluation
from .util.config import Config, initialize_globals_from_cli, log_error
from .util.feeding import audiofile_to_features
def do_single_file_inference(input_file_path):
tfv1.reset_default_graph()
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
load_graph_for_evaluation(session)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
probs = outputs["outputs"].eval(
feed_dict={
inputs["input"]: features,
inputs["input_lengths"]: features_len,
inputs["previous_state_c"]: previous_state_c,
inputs["previous_state_h"]: previous_state_h,
},
session=session,
)
probs = np.squeeze(probs)
if Config.scorer_path:
scorer = Scorer(
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
)
else:
scorer = None
decoded = ctc_beam_search_decoder(
probs,
Config.alphabet,
Config.beam_width,
scorer=scorer,
cutoff_prob=Config.cutoff_prob,
cutoff_top_n=Config.cutoff_top_n,
)
# Print highest probability result
print(decoded[0][1])
def main():
initialize_globals_from_cli()
if Config.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(Config.one_shot_infer)
else:
raise RuntimeError(
"Calling training_graph_inference script directly but no --one_shot_infer input audio file specified"
)
if __name__ == "__main__":
main()