Allow exporting as SavedModel
This commit is contained in:
parent
6a9bd1e6b6
commit
419b15b72a
|
@ -9,13 +9,15 @@ DESIRED_LOG_LEVEL = (
|
|||
)
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import shutil
|
||||
|
||||
from .deepspeech_model import create_inference_graph
|
||||
from .deepspeech_model import create_inference_graph, create_model
|
||||
from .util.checkpoints import load_graph_for_evaluation
|
||||
from .util.config import Config, initialize_globals_from_cli, log_error, log_info
|
||||
from .util.feeding import wavfile_bytes_to_features
|
||||
from .util.io import (
|
||||
open_remote,
|
||||
rmtree_remote,
|
||||
|
@ -35,6 +37,9 @@ def export():
|
|||
"""
|
||||
log_info("Exporting the model...")
|
||||
|
||||
if Config.export_savedmodel:
|
||||
return export_savedmodel()
|
||||
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(
|
||||
|
@ -172,6 +177,72 @@ def export():
|
|||
)
|
||||
|
||||
|
||||
def export_savedmodel():
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
input_wavfile_contents = tf.placeholder(tf.string)
|
||||
|
||||
features, features_len = wavfile_bytes_to_features(input_wavfile_contents)
|
||||
previous_state_c = tf.zeros([1, Config.n_cell_dim], tf.float32)
|
||||
previous_state_h = tf.zeros([1, Config.n_cell_dim], tf.float32)
|
||||
|
||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(
|
||||
previous_state_c, previous_state_h
|
||||
)
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
features_len = tf.expand_dims(features_len, 0)
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
|
||||
logits, layers = create_model(
|
||||
batch_x=features,
|
||||
batch_size=1,
|
||||
seq_length=features_len,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
probs = tf.nn.softmax(logits)
|
||||
|
||||
# Remove batch dimension
|
||||
squeezed = tf.squeeze(probs)
|
||||
|
||||
builder = tfv1.saved_model.builder.SavedModelBuilder(Config.export_dir)
|
||||
|
||||
input_file_tinfo = tfv1.saved_model.utils.build_tensor_info(
|
||||
input_wavfile_contents
|
||||
)
|
||||
output_probs_tinfo = tfv1.saved_model.utils.build_tensor_info(squeezed)
|
||||
|
||||
forward_sig = tfv1.saved_model.signature_def_utils.build_signature_def(
|
||||
inputs={
|
||||
"input_wavfile": input_file_tinfo,
|
||||
},
|
||||
outputs={
|
||||
"probs": output_probs_tinfo,
|
||||
},
|
||||
method_name="forward",
|
||||
)
|
||||
|
||||
builder.add_meta_graph_and_variables(
|
||||
session,
|
||||
[tfv1.saved_model.tag_constants.SERVING],
|
||||
signature_def_map={
|
||||
tfv1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: forward_sig
|
||||
},
|
||||
)
|
||||
|
||||
builder.save()
|
||||
log_info(f"Exported SavedModel to {Config.export_dir}")
|
||||
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(
|
||||
|
|
|
@ -587,6 +587,10 @@ class _SttConfig(Coqpit):
|
|||
default=True,
|
||||
metadata=dict(help="export a quantized model (optimized for size)"),
|
||||
)
|
||||
export_savedmodel: bool = field(
|
||||
default=False,
|
||||
metadata=dict(help="export model in TF SavedModel format"),
|
||||
)
|
||||
n_steps: int = field(
|
||||
default=16,
|
||||
metadata=dict(
|
||||
|
|
|
@ -84,6 +84,14 @@ def audiofile_to_features(
|
|||
wav_filename, clock=0.0, train_phase=False, augmentations=None
|
||||
):
|
||||
samples = tf.io.read_file(wav_filename)
|
||||
return wavfile_bytes_to_features(
|
||||
samples, clock, train_phase, augmentations, sample_id=wav_filename
|
||||
)
|
||||
|
||||
|
||||
def wavfile_bytes_to_features(
|
||||
samples, clock=0.0, train_phase=False, augmentations=None, sample_id=None
|
||||
):
|
||||
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
||||
return audio_to_features(
|
||||
decoded.audio,
|
||||
|
@ -91,7 +99,7 @@ def audiofile_to_features(
|
|||
clock=clock,
|
||||
train_phase=train_phase,
|
||||
augmentations=augmentations,
|
||||
sample_id=wav_filename,
|
||||
sample_id=sample_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue