Add comments marking nodes with names/shapes known by native client
This commit is contained in:
parent
f90408d3ab
commit
f9556d2236
|
@ -912,17 +912,28 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
input_samples = tfv1.placeholder(
|
||||
tf.float32, [Config.audio_window_samples], "input_samples"
|
||||
)
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = audio_to_features(samples, Config.audio_sample_rate)
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
mfccs = tf.identity(mfccs, name="mfccs")
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in STT_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
#
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
input_tensor = tfv1.placeholder(
|
||||
tf.float32,
|
||||
[
|
||||
|
@ -933,15 +944,24 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
],
|
||||
name="input_node",
|
||||
)
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths")
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
previous_state_c = tfv1.placeholder(
|
||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c"
|
||||
)
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
previous_state_h = tfv1.placeholder(
|
||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h"
|
||||
)
|
||||
|
@ -971,6 +991,10 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
#
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
|
@ -1045,6 +1069,9 @@ def export():
|
|||
graph_version = int(file_relative_read("GRAPH_VERSION").strip())
|
||||
assert graph_version > 0
|
||||
|
||||
# native_client: these nodes's names and shapes are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
|
||||
outputs["metadata_sample_rate"] = tf.constant(
|
||||
[Config.audio_sample_rate], name="metadata_sample_rate"
|
||||
|
|
Loading…
Reference in New Issue