Simplify first-time conditional, allowing SVDF model to run successfully in
android demo app. PiperOrigin-RevId: 238072222
This commit is contained in:
parent
3be3aea56e
commit
d9ba90a5ee
@ -530,6 +530,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
|
||||
shape=[num_filters, batch, input_time_size],
|
||||
trainable=False,
|
||||
name='runtime-memory')
|
||||
first_time_flag = tf.get_variable(
|
||||
name="first_time_flag",
|
||||
dtype=tf.int32,
|
||||
initializer=1)
|
||||
# Determine the number of new frames in the input, such that we only operate
|
||||
# on those. For training we do not use the memory, and thus use all frames
|
||||
# provided in the input.
|
||||
@ -540,9 +544,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
|
||||
window_stride_ms = int(model_settings['window_stride_samples'] * 1000 /
|
||||
model_settings['sample_rate'])
|
||||
num_new_frames = tf.cond(
|
||||
tf.equal(tf.count_nonzero(memory), 0),
|
||||
tf.equal(first_time_flag, 1),
|
||||
lambda: input_time_size,
|
||||
lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms))
|
||||
first_time_flag = 0
|
||||
new_fingerprint_input = fingerprint_input[
|
||||
:, -num_new_frames*input_frequency_size:]
|
||||
# Expand to add input channels dimension.
|
||||
|
Loading…
Reference in New Issue
Block a user