diff --git a/DeepSpeech.py b/DeepSpeech.py index daa80a8c..f3162140 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -595,23 +595,13 @@ def train(): # Batch loop while True: try: - try: - _, current_step, batch_loss, problem_files, step_summary = \ - session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], - feed_dict=feed_dict) - except errors_impl.InvalidArgumentError as err: - if FLAGS.augmentation_sparse_warp: - # recover twice for sparse warp, if still error, abort it!!! - try: - print('recovering the invertible error: {}'.format(err)) - _, current_step, batch_loss, problem_files, step_summary = \ - session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], - feed_dict=feed_dict) - except errors_impl.InvalidArgumentError as err: - print('recovering the invertible error `AGAIN`: {}'.format(err)) - _, current_step, batch_loss, problem_files, step_summary = \ - session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], - feed_dict=feed_dict) + _, current_step, batch_loss, problem_files, step_summary = \ + session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], + feed_dict=feed_dict) + except tf.errors.InvalidArgumentError as err: + if FLAGS.augmentation_sparse_warp: + log_info("skip sparse warp error: {}".format(err)) + continue except tf.errors.OutOfRangeError: break diff --git a/util/flags.py b/util/flags.py index d2a45766..624217e0 100644 --- a/util/flags.py +++ b/util/flags.py @@ -29,7 +29,7 @@ def create_flags(): f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)') - f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp') + f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp. USE OF THIS FLAG IS UNSUPPORTED, enable sparse warp will increase training time drastically, and the paper also mentioned that this is not a major factor to improve accuracy.') f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points') f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 20, 'time_warping_para') f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order') diff --git a/util/spectrogram_augmentations.py b/util/spectrogram_augmentations.py index 8ba903be..58980e8c 100644 --- a/util/spectrogram_augmentations.py +++ b/util/spectrogram_augmentations.py @@ -117,11 +117,6 @@ def augment_sparse_warp(spectrogram, time_warping_para=20, interpolation_order=2 source_control_point_locations = tf.cast([sources], tf.float32) dest_control_point_locations = tf.cast([dests], tf.float32) - # debug - # spectrogram = tf.Print(spectrogram, [tf.shape(spectrogram)], message='spectrogram', first_n=1000) - # spectrogram = tf.Print(spectrogram, sources, message='sources', first_n=1000) - # spectrogram = tf.Print(spectrogram, dests, message='dests', first_n=1000) - warped_spectrogram, _ = sparse_image_warp(spectrogram, source_control_point_locations=source_control_point_locations, dest_control_point_locations=dest_control_point_locations,