Use tf.compat.v1 to silence deprecation warnings and enable TF 2.0 testing

This commit is contained in:
Reuben Morais 2019-06-28 10:34:47 -03:00
parent dc78f8d1e6
commit 6f3e824ef7
3 changed files with 56 additions and 54 deletions

View File

@ -13,6 +13,7 @@ import numpy as np
import progressbar import progressbar
import shutil import shutil
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from datetime import datetime from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
@ -37,7 +38,7 @@ def variable_on_cpu(name, shape, initializer):
# Use the /cpu:0 device for scoped operations # Use the /cpu:0 device for scoped operations
with tf.device(Config.cpu_device): with tf.device(Config.cpu_device):
# Create or get apropos variable # Create or get apropos variable
var = tf.get_variable(name=name, shape=shape, initializer=initializer) var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
return var return var
@ -62,7 +63,7 @@ def create_overlapping_windows(batch_x):
def dense(name, x, units, dropout_rate=None, relu=True): def dense(name, x, units, dropout_rate=None, relu=True):
with tf.variable_scope(name): with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tf.zeros_initializer()) bias = variable_on_cpu('bias', [units], tf.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tf.contrib.layers.xavier_initializer()) weights = variable_on_cpu('weights', [x.shape[-1], units], tf.contrib.layers.xavier_initializer())
@ -186,7 +187,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse) logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse)
# Compute the CTC loss using TensorFlow's `ctc_loss` # Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tf.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len) total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Calculate the average loss across the batch # Calculate the average loss across the batch
avg_loss = tf.reduce_mean(total_loss) avg_loss = tf.reduce_mean(total_loss)
@ -205,10 +206,10 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980), # we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning. # because, generally, it requires less fine-tuning.
def create_optimizer(): def create_optimizer():
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, optimizer = tfv1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
beta1=FLAGS.beta1, beta1=FLAGS.beta1,
beta2=FLAGS.beta2, beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon) epsilon=FLAGS.epsilon)
return optimizer return optimizer
@ -240,7 +241,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):
# Tower gradients to return # Tower gradients to return
tower_gradients = [] tower_gradients = []
with tf.variable_scope(tf.get_variable_scope()): with tfv1.variable_scope(tfv1.get_variable_scope()):
# Loop over available_devices # Loop over available_devices
for i in range(len(Config.available_devices)): for i in range(len(Config.available_devices)):
# Execute operations of tower i on device i # Execute operations of tower i on device i
@ -253,7 +254,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0) avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower # Allow for variables to be re-used by the next tower
tf.get_variable_scope().reuse_variables() tfv1.get_variable_scope().reuse_variables()
# Retain tower's avg losses # Retain tower's avg losses
tower_avg_losses.append(avg_loss) tower_avg_losses.append(avg_loss)
@ -267,7 +268,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):
avg_loss_across_towers = tf.reduce_mean(tower_avg_losses, 0) avg_loss_across_towers = tf.reduce_mean(tower_avg_losses, 0)
tf.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries']) tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
# Return gradients and the average loss # Return gradients and the average loss
return tower_gradients, avg_loss_across_towers return tower_gradients, avg_loss_across_towers
@ -322,18 +323,18 @@ def log_variable(variable, gradient=None):
''' '''
name = variable.name.replace(':', '_') name = variable.name.replace(':', '_')
mean = tf.reduce_mean(variable) mean = tf.reduce_mean(variable)
tf.summary.scalar(name='%s/mean' % name, tensor=mean) tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tf.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(tf.square(variable - mean)))) tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(tf.square(variable - mean))))
tf.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(variable)) tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(variable))
tf.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(variable)) tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(variable))
tf.summary.histogram(name=name, values=variable) tfv1.summary.histogram(name=name, values=variable)
if gradient is not None: if gradient is not None:
if isinstance(gradient, tf.IndexedSlices): if isinstance(gradient, tf.IndexedSlices):
grad_values = gradient.values grad_values = gradient.values
else: else:
grad_values = gradient grad_values = gradient
if grad_values is not None: if grad_values is not None:
tf.summary.histogram(name='%s/gradients' % name, values=grad_values) tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
def log_grads_and_vars(grads_and_vars): def log_grads_and_vars(grads_and_vars):
@ -351,7 +352,7 @@ def try_loading(session, saver, checkpoint_filename, caption):
return False return False
checkpoint_path = checkpoint.model_checkpoint_path checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path) saver.restore(session, checkpoint_path)
restored_step = session.run(tf.train.get_global_step()) restored_step = session.run(tfv1.train.get_global_step())
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step)) log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
return True return True
except tf.errors.InvalidArgumentError as e: except tf.errors.InvalidArgumentError as e:
@ -369,9 +370,9 @@ def train():
batch_size=FLAGS.train_batch_size, batch_size=FLAGS.train_batch_size,
cache_path=FLAGS.feature_cache) cache_path=FLAGS.feature_cache)
iterator = tf.data.Iterator.from_structure(train_set.output_types, iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
train_set.output_shapes, tfv1.data.get_output_shapes(train_set),
output_classes=train_set.output_classes) output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets # Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set) train_init_op = iterator.make_initializer(train_set)
@ -382,7 +383,7 @@ def train():
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout # Dropout
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = { dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[1]: FLAGS.dropout_rate2,
@ -404,31 +405,31 @@ def train():
log_grads_and_vars(avg_tower_gradients) log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer # global_step is automagically incremented by the optimizer
global_step = tf.train.get_or_create_global_step() global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries # Summaries
step_summaries_op = tf.summary.merge_all('step_summaries') step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = { step_summary_writers = {
'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) 'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
} }
# Checkpointing # Checkpointing
checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train') checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
checkpoint_filename = 'checkpoint' checkpoint_filename = 'checkpoint'
best_dev_saver = tf.train.Saver(max_to_keep=1) best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev') best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
best_dev_filename = 'best_dev_checkpoint' best_dev_filename = 'best_dev_checkpoint'
initializer = tf.global_variables_initializer() initializer = tfv1.global_variables_initializer()
with tf.Session() as session: with tfv1.Session() as session:
log_debug('Session opened.') log_debug('Session opened.')
tf.get_default_graph().finalize() tfv1.get_default_graph().finalize()
# Loading or initializing # Loading or initializing
loaded = False loaded = False
@ -558,7 +559,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph # Create feature computation graph
input_samples = tf.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples') input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1) samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate) mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs') mfccs = tf.identity(mfccs, name='mfccs')
@ -567,15 +568,15 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# This shape is read by the native_client in DS_CreateModel to know the # This shape is read by the native_client in DS_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code # value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed. # there if this shape is changed.
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node') input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths') seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0: if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below) # no state management since n_step is expected to be dynamic too (see below)
previous_state = previous_state_c = previous_state_h = None previous_state = previous_state_c = previous_state_h = None
else: else:
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
@ -671,7 +672,7 @@ def export():
mapping = {fixup(v.op.name): v for v in tf.global_variables()} mapping = {fixup(v.op.name): v for v in tf.global_variables()}
saver = tf.train.Saver(mapping) saver = tfv1.train.Saver(mapping)
# Restore variables from training checkpoint # Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
@ -690,7 +691,7 @@ def export():
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''): def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
frozen = freeze_graph.freeze_graph_with_def_protos( frozen = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(), input_graph_def=tfv1.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(), input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path, input_checkpoint=checkpoint_path,
output_node_names=output_node_names, output_node_names=output_node_names,
@ -745,11 +746,11 @@ def export():
def do_single_file_inference(input_file_path): def do_single_file_inference(input_file_path):
with tf.Session() as session: with tfv1.Session() as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Create a saver using variables from the above newly created graph # Create a saver using variables from the above newly created graph
saver = tf.train.Saver() saver = tfv1.train.Saver()
# Restore variables from training checkpoint # Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counteract # TODO: This restores the most recent checkpoint, but if we use validation to counteract
@ -795,22 +796,22 @@ def main(_):
initialize_globals() initialize_globals()
if FLAGS.train_files: if FLAGS.train_files:
tf.reset_default_graph() tfv1.reset_default_graph()
tf.set_random_seed(FLAGS.random_seed) tfv1.set_random_seed(FLAGS.random_seed)
train() train()
if FLAGS.test_files: if FLAGS.test_files:
tf.reset_default_graph() tfv1.reset_default_graph()
test() test()
if FLAGS.export_dir: if FLAGS.export_dir:
tf.reset_default_graph() tfv1.reset_default_graph()
export() export()
if FLAGS.one_shot_infer: if FLAGS.one_shot_infer:
tf.reset_default_graph() tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer) do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__': if __name__ == '__main__':
create_flags() create_flags()
tf.app.run(main) tfv1.app.run(main)

View File

@ -10,6 +10,7 @@ from multiprocessing import cpu_count
import numpy as np import numpy as np
import progressbar import progressbar
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip from six.moves import zip
@ -46,9 +47,9 @@ def evaluate(test_csvs, create_model, try_loading):
test_csvs = FLAGS.test_files.split(',') test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs] test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
iterator = tf.data.Iterator.from_structure(test_sets[0].output_types, iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
test_sets[0].output_shapes, tfv1.data.get_output_shapes(test_sets[0]),
output_classes=test_sets[0].output_classes) output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets] test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
(batch_x, batch_x_len), batch_y = iterator.get_next() (batch_x, batch_x_len), batch_y = iterator.get_next()
@ -62,11 +63,11 @@ def evaluate(test_csvs, create_model, try_loading):
# Transpose to batch major and apply softmax for decoder # Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
loss = tf.nn.ctc_loss(labels=batch_y, loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits, inputs=logits,
sequence_length=batch_x_len) sequence_length=batch_x_len)
tf.train.get_or_create_global_step() tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process # Get number of accessible CPU cores for this process
try: try:
@ -75,9 +76,9 @@ def evaluate(test_csvs, create_model, try_loading):
num_processes = 1 num_processes = 1
# Create a saver using variables from the above newly created graph # Create a saver using variables from the above newly created graph
saver = tf.train.Saver() saver = tfv1.train.Saver()
with tf.Session() as session: with tfv1.Session() as session:
# Restore variables from training checkpoint # Restore variables from training checkpoint
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
if not loaded: if not loaded:
@ -163,4 +164,4 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
create_flags() create_flags()
tf.app.flags.DEFINE_string('test_output_file', '', 'path to a file to save all src/decoded/distance/loss tuples') tf.app.flags.DEFINE_string('test_output_file', '', 'path to a file to save all src/decoded/distance/loss tuples')
tf.app.run(main) tfv1.app.run(main)

View File

@ -42,7 +42,7 @@ def samples_to_mfccs(samples, sample_rate):
def audiofile_to_features(wav_filename): def audiofile_to_features(wav_filename):
samples = tf.read_file(wav_filename) samples = tf.io.read_file(wav_filename)
decoded = contrib_audio.decode_wav(samples, desired_channels=1) decoded = contrib_audio.decode_wav(samples, desired_channels=1)
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate) features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate)