Add comments marking nodes with names/shapes known by native client

This commit is contained in:
Reuben Morais 2021-08-19 18:33:48 +02:00
parent f90408d3ab
commit f9556d2236
1 changed files with 27 additions and 0 deletions

View File

@ -912,17 +912,28 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
input_samples = tfv1.placeholder(
tf.float32, [Config.audio_window_samples], "input_samples"
)
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = audio_to_features(samples, Config.audio_sample_rate)
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
mfccs = tf.identity(mfccs, name="mfccs")
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in STT_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
#
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
input_tensor = tfv1.placeholder(
tf.float32,
[
@ -933,15 +944,24 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
],
name="input_node",
)
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths")
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = None
else:
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
previous_state_c = tfv1.placeholder(
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c"
)
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
previous_state_h = tfv1.placeholder(
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h"
)
@ -971,6 +991,10 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
#
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
if tflite:
logits = tf.squeeze(logits, [1])
@ -1045,6 +1069,9 @@ def export():
graph_version = int(file_relative_read("GRAPH_VERSION").strip())
assert graph_version > 0
# native_client: these nodes's names and shapes are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
outputs["metadata_sample_rate"] = tf.constant(
[Config.audio_sample_rate], name="metadata_sample_rate"