Remove ExceptionBox and remember_exception

TensorFlow already handles surfacing dataset exceptions internally.
This commit is contained in:
Reuben Morais 2021-08-26 19:58:13 +02:00
parent 497c828dd7
commit 33c2190015
3 changed files with 10 additions and 46 deletions

View File

@ -58,7 +58,7 @@ from .util.config import (
log_warn,
)
from .util.feeding import create_dataset
from .util.helpers import ExceptionBox, check_ctcdecoder_version
from .util.helpers import check_ctcdecoder_version
from .util.io import (
is_remote_path,
open_remote,
@ -266,9 +266,11 @@ def early_training_checks():
)
def create_training_datasets(
exception_box,
) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],):
def create_training_datasets() -> (
tf.data.Dataset,
[tf.data.Dataset],
[tf.data.Dataset],
):
"""Creates training datasets from input flags.
Returns a single training dataset and two lists of datasets for validation
@ -282,7 +284,6 @@ def create_training_datasets(
augmentations=Config.augmentations,
cache_path=Config.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * Config.train_batch_size * 2,
reverse=Config.reverse_train,
limit=Config.limit_train,
@ -297,7 +298,6 @@ def create_training_datasets(
batch_size=Config.dev_batch_size,
train_phase=False,
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
exception_box=exception_box,
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
reverse=Config.reverse_dev,
limit=Config.limit_dev,
@ -314,7 +314,6 @@ def create_training_datasets(
batch_size=Config.dev_batch_size,
train_phase=False,
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
exception_box=exception_box,
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
reverse=Config.reverse_dev,
limit=Config.limit_dev,
@ -332,9 +331,7 @@ def train():
tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed)
exception_box = ExceptionBox()
train_set, dev_sets, metrics_sets = create_training_datasets(exception_box)
train_set, dev_sets, metrics_sets = create_training_datasets()
iterator = tfv1.data.Iterator.from_structure(
tfv1.data.get_output_types(train_set),
@ -512,9 +509,7 @@ def train():
],
feed_dict=feed_dict,
)
exception_box.raise_if_set()
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:

View File

@ -12,7 +12,7 @@ import tensorflow as tf
from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split
from .augmentations import apply_graph_augmentations, apply_sample_augmentations
from .config import Config
from .helpers import MEGABYTE, remember_exception
from .helpers import MEGABYTE
from .sample_collections import samples_from_sources
from .text import text_to_char_array
@ -138,7 +138,6 @@ def create_dataset(
train_phase=False,
reverse=False,
limit=0,
exception_box=None,
process_ahead=None,
buffering=1 * MEGABYTE,
):
@ -197,7 +196,7 @@ def create_dataset(
)
dataset = tf.data.Dataset.from_generator(
remember_exception(generate_values, exception_box),
generate_values,
output_types=(
tf.string,
tf.float32,
@ -223,7 +222,6 @@ def split_audio_file(
aggressiveness=3,
outlier_duration_ms=10000,
outlier_batch_size=1,
exception_box=None,
):
def generate_values():
frames = read_frames_from_file(audio_path)
@ -240,7 +238,7 @@ def split_audio_file(
def create_batch_set(bs, criteria):
return (
tf.data.Dataset.from_generator(
remember_exception(generate_values, exception_box),
generate_values,
output_types=(tf.int32, tf.int32, tf.float32),
)
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)

View File

@ -163,35 +163,6 @@ class LimitingPool:
self.pool.close()
class ExceptionBox:
"""Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator.
Used in conjunction with `remember_exception`."""
def __init__(self):
self.exception = None
def raise_if_set(self):
if self.exception is not None:
exception = self.exception
self.exception = None
raise exception # pylint: disable = raising-bad-type
def remember_exception(iterable, exception_box=None):
"""Wraps a TensorFlow dataset generator for catching its actual exceptions
that would otherwise just interrupt iteration w/o bubbling up."""
def do_iterate():
try:
yield from iterable()
except StopIteration:
return
except Exception as ex: # pylint: disable = broad-except
exception_box.exception = ex
return iterable if exception_box is None else do_iterate
def get_value_range(value, target_type):
"""
This function converts all possible supplied values for augmentation