Allow exporting as SavedModel

This commit is contained in:
Reuben Morais 2021-11-10 18:29:38 +01:00
parent 6a9bd1e6b6
commit 419b15b72a
3 changed files with 85 additions and 2 deletions

View File

@ -9,13 +9,15 @@ DESIRED_LOG_LEVEL = (
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import shutil
from .deepspeech_model import create_inference_graph
from .deepspeech_model import create_inference_graph, create_model
from .util.checkpoints import load_graph_for_evaluation
from .util.config import Config, initialize_globals_from_cli, log_error, log_info
from .util.feeding import wavfile_bytes_to_features
from .util.io import (
open_remote,
rmtree_remote,
@ -35,6 +37,9 @@ def export():
"""
log_info("Exporting the model...")
if Config.export_savedmodel:
return export_savedmodel()
tfv1.reset_default_graph()
inputs, outputs, _ = create_inference_graph(
@ -172,6 +177,72 @@ def export():
)
def export_savedmodel():
tfv1.reset_default_graph()
with tfv1.Session(config=Config.session_config) as session:
input_wavfile_contents = tf.placeholder(tf.string)
features, features_len = wavfile_bytes_to_features(input_wavfile_contents)
previous_state_c = tf.zeros([1, Config.n_cell_dim], tf.float32)
previous_state_h = tf.zeros([1, Config.n_cell_dim], tf.float32)
previous_state = tf.nn.rnn_cell.LSTMStateTuple(
previous_state_c, previous_state_h
)
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# One rate per layer
no_dropout = [None] * 6
logits, layers = create_model(
batch_x=features,
batch_size=1,
seq_length=features_len,
dropout=no_dropout,
previous_state=previous_state,
)
# Restore variables from training checkpoint
load_graph_for_evaluation(session)
probs = tf.nn.softmax(logits)
# Remove batch dimension
squeezed = tf.squeeze(probs)
builder = tfv1.saved_model.builder.SavedModelBuilder(Config.export_dir)
input_file_tinfo = tfv1.saved_model.utils.build_tensor_info(
input_wavfile_contents
)
output_probs_tinfo = tfv1.saved_model.utils.build_tensor_info(squeezed)
forward_sig = tfv1.saved_model.signature_def_utils.build_signature_def(
inputs={
"input_wavfile": input_file_tinfo,
},
outputs={
"probs": output_probs_tinfo,
},
method_name="forward",
)
builder.add_meta_graph_and_variables(
session,
[tfv1.saved_model.tag_constants.SERVING],
signature_def_map={
tfv1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: forward_sig
},
)
builder.save()
log_info(f"Exported SavedModel to {Config.export_dir}")
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(

View File

@ -587,6 +587,10 @@ class _SttConfig(Coqpit):
default=True,
metadata=dict(help="export a quantized model (optimized for size)"),
)
export_savedmodel: bool = field(
default=False,
metadata=dict(help="export model in TF SavedModel format"),
)
n_steps: int = field(
default=16,
metadata=dict(

View File

@ -84,6 +84,14 @@ def audiofile_to_features(
wav_filename, clock=0.0, train_phase=False, augmentations=None
):
samples = tf.io.read_file(wav_filename)
return wavfile_bytes_to_features(
samples, clock, train_phase, augmentations, sample_id=wav_filename
)
def wavfile_bytes_to_features(
samples, clock=0.0, train_phase=False, augmentations=None, sample_id=None
):
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
return audio_to_features(
decoded.audio,
@ -91,7 +99,7 @@ def audiofile_to_features(
clock=clock,
train_phase=train_phase,
augmentations=augmentations,
sample_id=wav_filename,
sample_id=sample_id,
)