From 33c21900159f6733590ceb5663b3f6ba885e0c87 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 26 Aug 2021 19:58:13 +0200 Subject: [PATCH] Remove ExceptionBox and remember_exception TensorFlow already handles surfacing dataset exceptions internally. --- training/coqui_stt_training/train.py | 19 +++++--------- training/coqui_stt_training/util/feeding.py | 8 +++--- training/coqui_stt_training/util/helpers.py | 29 --------------------- 3 files changed, 10 insertions(+), 46 deletions(-) diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index c2e61c25..38417d1d 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -58,7 +58,7 @@ from .util.config import ( log_warn, ) from .util.feeding import create_dataset -from .util.helpers import ExceptionBox, check_ctcdecoder_version +from .util.helpers import check_ctcdecoder_version from .util.io import ( is_remote_path, open_remote, @@ -266,9 +266,11 @@ def early_training_checks(): ) -def create_training_datasets( - exception_box, -) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],): +def create_training_datasets() -> ( + tf.data.Dataset, + [tf.data.Dataset], + [tf.data.Dataset], +): """Creates training datasets from input flags. Returns a single training dataset and two lists of datasets for validation @@ -282,7 +284,6 @@ def create_training_datasets( augmentations=Config.augmentations, cache_path=Config.feature_cache, train_phase=True, - exception_box=exception_box, process_ahead=len(Config.available_devices) * Config.train_batch_size * 2, reverse=Config.reverse_train, limit=Config.limit_train, @@ -297,7 +298,6 @@ def create_training_datasets( batch_size=Config.dev_batch_size, train_phase=False, augmentations=[NormalizeSampleRate(Config.audio_sample_rate)], - exception_box=exception_box, process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2, reverse=Config.reverse_dev, limit=Config.limit_dev, @@ -314,7 +314,6 @@ def create_training_datasets( batch_size=Config.dev_batch_size, train_phase=False, augmentations=[NormalizeSampleRate(Config.audio_sample_rate)], - exception_box=exception_box, process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2, reverse=Config.reverse_dev, limit=Config.limit_dev, @@ -332,9 +331,7 @@ def train(): tfv1.reset_default_graph() tfv1.set_random_seed(Config.random_seed) - exception_box = ExceptionBox() - - train_set, dev_sets, metrics_sets = create_training_datasets(exception_box) + train_set, dev_sets, metrics_sets = create_training_datasets() iterator = tfv1.data.Iterator.from_structure( tfv1.data.get_output_types(train_set), @@ -512,9 +509,7 @@ def train(): ], feed_dict=feed_dict, ) - exception_box.raise_if_set() except tf.errors.OutOfRangeError: - exception_box.raise_if_set() break if problem_files.size > 0: diff --git a/training/coqui_stt_training/util/feeding.py b/training/coqui_stt_training/util/feeding.py index 333c78ce..80ff0c20 100644 --- a/training/coqui_stt_training/util/feeding.py +++ b/training/coqui_stt_training/util/feeding.py @@ -12,7 +12,7 @@ import tensorflow as tf from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split from .augmentations import apply_graph_augmentations, apply_sample_augmentations from .config import Config -from .helpers import MEGABYTE, remember_exception +from .helpers import MEGABYTE from .sample_collections import samples_from_sources from .text import text_to_char_array @@ -138,7 +138,6 @@ def create_dataset( train_phase=False, reverse=False, limit=0, - exception_box=None, process_ahead=None, buffering=1 * MEGABYTE, ): @@ -197,7 +196,7 @@ def create_dataset( ) dataset = tf.data.Dataset.from_generator( - remember_exception(generate_values, exception_box), + generate_values, output_types=( tf.string, tf.float32, @@ -223,7 +222,6 @@ def split_audio_file( aggressiveness=3, outlier_duration_ms=10000, outlier_batch_size=1, - exception_box=None, ): def generate_values(): frames = read_frames_from_file(audio_path) @@ -240,7 +238,7 @@ def split_audio_file( def create_batch_set(bs, criteria): return ( tf.data.Dataset.from_generator( - remember_exception(generate_values, exception_box), + generate_values, output_types=(tf.int32, tf.int32, tf.float32), ) .map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE) diff --git a/training/coqui_stt_training/util/helpers.py b/training/coqui_stt_training/util/helpers.py index b897e4a9..81e60bb2 100644 --- a/training/coqui_stt_training/util/helpers.py +++ b/training/coqui_stt_training/util/helpers.py @@ -163,35 +163,6 @@ class LimitingPool: self.pool.close() -class ExceptionBox: - """Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator. - Used in conjunction with `remember_exception`.""" - - def __init__(self): - self.exception = None - - def raise_if_set(self): - if self.exception is not None: - exception = self.exception - self.exception = None - raise exception # pylint: disable = raising-bad-type - - -def remember_exception(iterable, exception_box=None): - """Wraps a TensorFlow dataset generator for catching its actual exceptions - that would otherwise just interrupt iteration w/o bubbling up.""" - - def do_iterate(): - try: - yield from iterable() - except StopIteration: - return - except Exception as ex: # pylint: disable = broad-except - exception_box.exception = ex - - return iterable if exception_box is None else do_iterate - - def get_value_range(value, target_type): """ This function converts all possible supplied values for augmentation