Remove ExceptionBox and remember_exception
TensorFlow already handles surfacing dataset exceptions internally.
This commit is contained in:
parent
497c828dd7
commit
33c2190015
@ -58,7 +58,7 @@ from .util.config import (
|
||||
log_warn,
|
||||
)
|
||||
from .util.feeding import create_dataset
|
||||
from .util.helpers import ExceptionBox, check_ctcdecoder_version
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
from .util.io import (
|
||||
is_remote_path,
|
||||
open_remote,
|
||||
@ -266,9 +266,11 @@ def early_training_checks():
|
||||
)
|
||||
|
||||
|
||||
def create_training_datasets(
|
||||
exception_box,
|
||||
) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],):
|
||||
def create_training_datasets() -> (
|
||||
tf.data.Dataset,
|
||||
[tf.data.Dataset],
|
||||
[tf.data.Dataset],
|
||||
):
|
||||
"""Creates training datasets from input flags.
|
||||
|
||||
Returns a single training dataset and two lists of datasets for validation
|
||||
@ -282,7 +284,6 @@ def create_training_datasets(
|
||||
augmentations=Config.augmentations,
|
||||
cache_path=Config.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * Config.train_batch_size * 2,
|
||||
reverse=Config.reverse_train,
|
||||
limit=Config.limit_train,
|
||||
@ -297,7 +298,6 @@ def create_training_datasets(
|
||||
batch_size=Config.dev_batch_size,
|
||||
train_phase=False,
|
||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||
reverse=Config.reverse_dev,
|
||||
limit=Config.limit_dev,
|
||||
@ -314,7 +314,6 @@ def create_training_datasets(
|
||||
batch_size=Config.dev_batch_size,
|
||||
train_phase=False,
|
||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||
reverse=Config.reverse_dev,
|
||||
limit=Config.limit_dev,
|
||||
@ -332,9 +331,7 @@ def train():
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(Config.random_seed)
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
train_set, dev_sets, metrics_sets = create_training_datasets(exception_box)
|
||||
train_set, dev_sets, metrics_sets = create_training_datasets()
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(
|
||||
tfv1.data.get_output_types(train_set),
|
||||
@ -512,9 +509,7 @@ def train():
|
||||
],
|
||||
feed_dict=feed_dict,
|
||||
)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
|
@ -12,7 +12,7 @@ import tensorflow as tf
|
||||
from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split
|
||||
from .augmentations import apply_graph_augmentations, apply_sample_augmentations
|
||||
from .config import Config
|
||||
from .helpers import MEGABYTE, remember_exception
|
||||
from .helpers import MEGABYTE
|
||||
from .sample_collections import samples_from_sources
|
||||
from .text import text_to_char_array
|
||||
|
||||
@ -138,7 +138,6 @@ def create_dataset(
|
||||
train_phase=False,
|
||||
reverse=False,
|
||||
limit=0,
|
||||
exception_box=None,
|
||||
process_ahead=None,
|
||||
buffering=1 * MEGABYTE,
|
||||
):
|
||||
@ -197,7 +196,7 @@ def create_dataset(
|
||||
)
|
||||
|
||||
dataset = tf.data.Dataset.from_generator(
|
||||
remember_exception(generate_values, exception_box),
|
||||
generate_values,
|
||||
output_types=(
|
||||
tf.string,
|
||||
tf.float32,
|
||||
@ -223,7 +222,6 @@ def split_audio_file(
|
||||
aggressiveness=3,
|
||||
outlier_duration_ms=10000,
|
||||
outlier_batch_size=1,
|
||||
exception_box=None,
|
||||
):
|
||||
def generate_values():
|
||||
frames = read_frames_from_file(audio_path)
|
||||
@ -240,7 +238,7 @@ def split_audio_file(
|
||||
def create_batch_set(bs, criteria):
|
||||
return (
|
||||
tf.data.Dataset.from_generator(
|
||||
remember_exception(generate_values, exception_box),
|
||||
generate_values,
|
||||
output_types=(tf.int32, tf.int32, tf.float32),
|
||||
)
|
||||
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
@ -163,35 +163,6 @@ class LimitingPool:
|
||||
self.pool.close()
|
||||
|
||||
|
||||
class ExceptionBox:
|
||||
"""Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator.
|
||||
Used in conjunction with `remember_exception`."""
|
||||
|
||||
def __init__(self):
|
||||
self.exception = None
|
||||
|
||||
def raise_if_set(self):
|
||||
if self.exception is not None:
|
||||
exception = self.exception
|
||||
self.exception = None
|
||||
raise exception # pylint: disable = raising-bad-type
|
||||
|
||||
|
||||
def remember_exception(iterable, exception_box=None):
|
||||
"""Wraps a TensorFlow dataset generator for catching its actual exceptions
|
||||
that would otherwise just interrupt iteration w/o bubbling up."""
|
||||
|
||||
def do_iterate():
|
||||
try:
|
||||
yield from iterable()
|
||||
except StopIteration:
|
||||
return
|
||||
except Exception as ex: # pylint: disable = broad-except
|
||||
exception_box.exception = ex
|
||||
|
||||
return iterable if exception_box is None else do_iterate
|
||||
|
||||
|
||||
def get_value_range(value, target_type):
|
||||
"""
|
||||
This function converts all possible supplied values for augmentation
|
||||
|
Loading…
Reference in New Issue
Block a user