Simplify first-time conditional, allowing SVDF model to run successfully in

android demo app.

PiperOrigin-RevId: 238072222
This commit is contained in:
A. Unique TensorFlower 2019-03-12 12:31:40 -07:00 committed by TensorFlower Gardener
parent 3be3aea56e
commit d9ba90a5ee

View File

@ -530,6 +530,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
shape=[num_filters, batch, input_time_size],
trainable=False,
name='runtime-memory')
first_time_flag = tf.get_variable(
name="first_time_flag",
dtype=tf.int32,
initializer=1)
# Determine the number of new frames in the input, such that we only operate
# on those. For training we do not use the memory, and thus use all frames
# provided in the input.
@ -540,9 +544,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
window_stride_ms = int(model_settings['window_stride_samples'] * 1000 /
model_settings['sample_rate'])
num_new_frames = tf.cond(
tf.equal(tf.count_nonzero(memory), 0),
tf.equal(first_time_flag, 1),
lambda: input_time_size,
lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms))
first_time_flag = 0
new_fingerprint_input = fingerprint_input[
:, -num_new_frames*input_frequency_size:]
# Expand to add input channels dimension.