299 lines
9.7 KiB
Python
299 lines
9.7 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
import os
|
|
import sys
|
|
|
|
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
|
DESIRED_LOG_LEVEL = (
|
|
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
|
|
)
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import tensorflow.compat.v1 as tfv1
|
|
import shutil
|
|
|
|
from .deepspeech_model import create_inference_graph, create_model
|
|
from .util.checkpoints import load_graph_for_evaluation
|
|
from .util.config import Config, initialize_globals_from_cli, log_error, log_info
|
|
from .util.feeding import wavfile_bytes_to_features
|
|
from .util.io import (
|
|
open_remote,
|
|
rmtree_remote,
|
|
listdir_remote,
|
|
is_remote_path,
|
|
isdir_remote,
|
|
)
|
|
|
|
|
|
def file_relative_read(fname):
|
|
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
|
|
|
|
|
def export():
|
|
r"""
|
|
Restores the trained variables into a simpler graph that will be exported for serving.
|
|
"""
|
|
log_info("Exporting the model...")
|
|
|
|
if Config.export_savedmodel:
|
|
return export_savedmodel()
|
|
|
|
tfv1.reset_default_graph()
|
|
|
|
inputs, outputs, _ = create_inference_graph(
|
|
batch_size=Config.export_batch_size,
|
|
n_steps=Config.n_steps,
|
|
tflite=Config.export_tflite,
|
|
)
|
|
|
|
graph_version = int(file_relative_read("GRAPH_VERSION").strip())
|
|
assert graph_version > 0
|
|
|
|
# native_client: these nodes's names and shapes are part of the API boundary
|
|
# with the native client, if you change them you should sync changes with
|
|
# the C++ code.
|
|
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
|
|
outputs["metadata_sample_rate"] = tf.constant(
|
|
[Config.audio_sample_rate], name="metadata_sample_rate"
|
|
)
|
|
outputs["metadata_feature_win_len"] = tf.constant(
|
|
[Config.feature_win_len], name="metadata_feature_win_len"
|
|
)
|
|
outputs["metadata_feature_win_step"] = tf.constant(
|
|
[Config.feature_win_step], name="metadata_feature_win_step"
|
|
)
|
|
outputs["metadata_beam_width"] = tf.constant(
|
|
[Config.export_beam_width], name="metadata_beam_width"
|
|
)
|
|
outputs["metadata_alphabet"] = tf.constant(
|
|
[Config.alphabet.Serialize()], name="metadata_alphabet"
|
|
)
|
|
|
|
if Config.export_language:
|
|
outputs["metadata_language"] = tf.constant(
|
|
[Config.export_language.encode("utf-8")], name="metadata_language"
|
|
)
|
|
|
|
# Prevent further graph changes
|
|
tfv1.get_default_graph().finalize()
|
|
|
|
output_names_tensors = [
|
|
tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)
|
|
]
|
|
output_names_ops = [
|
|
op.name for op in outputs.values() if isinstance(op, tf.Operation)
|
|
]
|
|
output_names = output_names_tensors + output_names_ops
|
|
|
|
with tf.Session() as session:
|
|
# Restore variables from checkpoint
|
|
load_graph_for_evaluation(session)
|
|
|
|
output_filename = Config.export_file_name + ".pb"
|
|
if Config.remove_export:
|
|
if isdir_remote(Config.export_dir):
|
|
log_info("Removing old export")
|
|
rmtree_remote(Config.export_dir)
|
|
|
|
output_graph_path = os.path.join(Config.export_dir, output_filename)
|
|
|
|
if not is_remote_path(Config.export_dir) and not os.path.isdir(
|
|
Config.export_dir
|
|
):
|
|
os.makedirs(Config.export_dir)
|
|
|
|
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
|
sess=session,
|
|
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
|
output_node_names=output_names,
|
|
)
|
|
|
|
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
|
graph_def=frozen_graph, dest_nodes=output_names
|
|
)
|
|
|
|
if not Config.export_tflite:
|
|
with open_remote(output_graph_path, "wb") as fout:
|
|
fout.write(frozen_graph.SerializeToString())
|
|
else:
|
|
output_tflite_path = os.path.join(
|
|
Config.export_dir, output_filename.replace(".pb", ".tflite")
|
|
)
|
|
|
|
converter = tf.lite.TFLiteConverter(
|
|
frozen_graph,
|
|
input_tensors=inputs.values(),
|
|
output_tensors=outputs.values(),
|
|
)
|
|
|
|
if Config.export_quantize:
|
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
|
|
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
|
converter.allow_custom_ops = True
|
|
tflite_model = converter.convert()
|
|
|
|
with open_remote(output_tflite_path, "wb") as fout:
|
|
fout.write(tflite_model)
|
|
|
|
log_info("Models exported at %s" % (Config.export_dir))
|
|
|
|
metadata_fname = os.path.join(
|
|
Config.export_dir,
|
|
"{}_{}_{}.md".format(
|
|
Config.export_author_id,
|
|
Config.export_model_name,
|
|
Config.export_model_version,
|
|
),
|
|
)
|
|
|
|
model_runtime = "tflite" if Config.export_tflite else "tensorflow"
|
|
with open_remote(metadata_fname, "w") as f:
|
|
f.write("---\n")
|
|
f.write("author: {}\n".format(Config.export_author_id))
|
|
f.write("model_name: {}\n".format(Config.export_model_name))
|
|
f.write("model_version: {}\n".format(Config.export_model_version))
|
|
f.write("contact_info: {}\n".format(Config.export_contact_info))
|
|
f.write("license: {}\n".format(Config.export_license))
|
|
f.write("language: {}\n".format(Config.export_language))
|
|
f.write("runtime: {}\n".format(model_runtime))
|
|
f.write("min_stt_version: {}\n".format(Config.export_min_stt_version))
|
|
f.write("max_stt_version: {}\n".format(Config.export_max_stt_version))
|
|
f.write(
|
|
"acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n"
|
|
)
|
|
f.write(
|
|
"scorer_url: <replace this with a publicly available URL of the scorer, if present>\n"
|
|
)
|
|
f.write("---\n")
|
|
f.write("{}\n".format(Config.export_description))
|
|
|
|
log_info(
|
|
"Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.".format(
|
|
metadata_fname
|
|
)
|
|
)
|
|
|
|
|
|
def export_savedmodel():
|
|
tfv1.reset_default_graph()
|
|
|
|
with tfv1.Session(config=Config.session_config) as session:
|
|
input_wavfile_contents = tf.placeholder(tf.string)
|
|
|
|
features, features_len = wavfile_bytes_to_features(input_wavfile_contents)
|
|
previous_state_c = tf.zeros([1, Config.n_cell_dim], tf.float32)
|
|
previous_state_h = tf.zeros([1, Config.n_cell_dim], tf.float32)
|
|
|
|
previous_state = tf.nn.rnn_cell.LSTMStateTuple(
|
|
previous_state_c, previous_state_h
|
|
)
|
|
|
|
# Add batch dimension
|
|
features = tf.expand_dims(features, 0)
|
|
features_len = tf.expand_dims(features_len, 0)
|
|
|
|
# One rate per layer
|
|
no_dropout = [None] * 6
|
|
|
|
logits, layers = create_model(
|
|
batch_x=features,
|
|
batch_size=1,
|
|
seq_length=features_len,
|
|
dropout=no_dropout,
|
|
previous_state=previous_state,
|
|
)
|
|
|
|
# Restore variables from training checkpoint
|
|
load_graph_for_evaluation(session)
|
|
|
|
probs = tf.nn.softmax(logits)
|
|
|
|
# Remove batch dimension
|
|
squeezed = tf.squeeze(probs)
|
|
|
|
builder = tfv1.saved_model.builder.SavedModelBuilder(Config.export_dir)
|
|
|
|
input_file_tinfo = tfv1.saved_model.utils.build_tensor_info(
|
|
input_wavfile_contents
|
|
)
|
|
output_probs_tinfo = tfv1.saved_model.utils.build_tensor_info(squeezed)
|
|
|
|
forward_sig = tfv1.saved_model.signature_def_utils.build_signature_def(
|
|
inputs={
|
|
"input_wavfile": input_file_tinfo,
|
|
},
|
|
outputs={
|
|
"probs": output_probs_tinfo,
|
|
},
|
|
method_name="forward",
|
|
)
|
|
|
|
builder.add_meta_graph_and_variables(
|
|
session,
|
|
[tfv1.saved_model.tag_constants.SERVING],
|
|
signature_def_map={
|
|
tfv1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: forward_sig
|
|
},
|
|
)
|
|
|
|
builder.save()
|
|
|
|
# Copy scorer and alphabet alongside SavedModel
|
|
if Config.scorer_path:
|
|
shutil.copy(
|
|
Config.scorer_path, os.path.join(Config.export_dir, "exported.scorer")
|
|
)
|
|
shutil.copy(
|
|
Config.effective_alphabet_path,
|
|
os.path.join(Config.export_dir, "alphabet.txt"),
|
|
)
|
|
|
|
log_info(f"Exported SavedModel to {Config.export_dir}")
|
|
|
|
|
|
def package_zip():
|
|
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
|
export_dir = os.path.join(
|
|
os.path.abspath(Config.export_dir), ""
|
|
) # Force ending '/'
|
|
if is_remote_path(export_dir):
|
|
log_error(
|
|
"Cannot package remote path zip %s. Please do this manually." % export_dir
|
|
)
|
|
return
|
|
|
|
zip_filename = os.path.dirname(export_dir)
|
|
|
|
shutil.copy(Config.scorer_path, export_dir)
|
|
|
|
archive = shutil.make_archive(zip_filename, "zip", export_dir)
|
|
log_info("Exported packaged model {}".format(archive))
|
|
|
|
|
|
def main():
|
|
initialize_globals_from_cli()
|
|
|
|
if not Config.export_dir:
|
|
raise RuntimeError(
|
|
"Calling export script directly but no --export_dir specified"
|
|
)
|
|
|
|
if not Config.export_zip:
|
|
# Export to folder
|
|
export()
|
|
else:
|
|
if listdir_remote(Config.export_dir):
|
|
raise RuntimeError(
|
|
"Directory {} is not empty, please fix this.".format(Config.export_dir)
|
|
)
|
|
|
|
export()
|
|
package_zip()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|