Merge pull request #1948 from coqui-ai/remove-exception-box

Remove ExceptionBox and remember_exception
This commit is contained in:
Reuben Morais 2021-08-26 21:13:18 +02:00 committed by GitHub
commit f94d16bcc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 46 deletions

View File

@ -58,7 +58,7 @@ from .util.config import (
log_warn, log_warn,
) )
from .util.feeding import create_dataset from .util.feeding import create_dataset
from .util.helpers import ExceptionBox, check_ctcdecoder_version from .util.helpers import check_ctcdecoder_version
from .util.io import ( from .util.io import (
is_remote_path, is_remote_path,
open_remote, open_remote,
@ -266,9 +266,11 @@ def early_training_checks():
) )
def create_training_datasets( def create_training_datasets() -> (
exception_box, tf.data.Dataset,
) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],): [tf.data.Dataset],
[tf.data.Dataset],
):
"""Creates training datasets from input flags. """Creates training datasets from input flags.
Returns a single training dataset and two lists of datasets for validation Returns a single training dataset and two lists of datasets for validation
@ -282,7 +284,6 @@ def create_training_datasets(
augmentations=Config.augmentations, augmentations=Config.augmentations,
cache_path=Config.feature_cache, cache_path=Config.feature_cache,
train_phase=True, train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * Config.train_batch_size * 2, process_ahead=len(Config.available_devices) * Config.train_batch_size * 2,
reverse=Config.reverse_train, reverse=Config.reverse_train,
limit=Config.limit_train, limit=Config.limit_train,
@ -297,7 +298,6 @@ def create_training_datasets(
batch_size=Config.dev_batch_size, batch_size=Config.dev_batch_size,
train_phase=False, train_phase=False,
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)], augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
exception_box=exception_box,
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2, process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
reverse=Config.reverse_dev, reverse=Config.reverse_dev,
limit=Config.limit_dev, limit=Config.limit_dev,
@ -314,7 +314,6 @@ def create_training_datasets(
batch_size=Config.dev_batch_size, batch_size=Config.dev_batch_size,
train_phase=False, train_phase=False,
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)], augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
exception_box=exception_box,
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2, process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
reverse=Config.reverse_dev, reverse=Config.reverse_dev,
limit=Config.limit_dev, limit=Config.limit_dev,
@ -332,9 +331,7 @@ def train():
tfv1.reset_default_graph() tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed) tfv1.set_random_seed(Config.random_seed)
exception_box = ExceptionBox() train_set, dev_sets, metrics_sets = create_training_datasets()
train_set, dev_sets, metrics_sets = create_training_datasets(exception_box)
iterator = tfv1.data.Iterator.from_structure( iterator = tfv1.data.Iterator.from_structure(
tfv1.data.get_output_types(train_set), tfv1.data.get_output_types(train_set),
@ -512,9 +509,7 @@ def train():
], ],
feed_dict=feed_dict, feed_dict=feed_dict,
) )
exception_box.raise_if_set()
except tf.errors.OutOfRangeError: except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break break
if problem_files.size > 0: if problem_files.size > 0:

View File

@ -12,7 +12,7 @@ import tensorflow as tf
from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split
from .augmentations import apply_graph_augmentations, apply_sample_augmentations from .augmentations import apply_graph_augmentations, apply_sample_augmentations
from .config import Config from .config import Config
from .helpers import MEGABYTE, remember_exception from .helpers import MEGABYTE
from .sample_collections import samples_from_sources from .sample_collections import samples_from_sources
from .text import text_to_char_array from .text import text_to_char_array
@ -138,7 +138,6 @@ def create_dataset(
train_phase=False, train_phase=False,
reverse=False, reverse=False,
limit=0, limit=0,
exception_box=None,
process_ahead=None, process_ahead=None,
buffering=1 * MEGABYTE, buffering=1 * MEGABYTE,
): ):
@ -197,7 +196,7 @@ def create_dataset(
) )
dataset = tf.data.Dataset.from_generator( dataset = tf.data.Dataset.from_generator(
remember_exception(generate_values, exception_box), generate_values,
output_types=( output_types=(
tf.string, tf.string,
tf.float32, tf.float32,
@ -223,7 +222,6 @@ def split_audio_file(
aggressiveness=3, aggressiveness=3,
outlier_duration_ms=10000, outlier_duration_ms=10000,
outlier_batch_size=1, outlier_batch_size=1,
exception_box=None,
): ):
def generate_values(): def generate_values():
frames = read_frames_from_file(audio_path) frames = read_frames_from_file(audio_path)
@ -240,7 +238,7 @@ def split_audio_file(
def create_batch_set(bs, criteria): def create_batch_set(bs, criteria):
return ( return (
tf.data.Dataset.from_generator( tf.data.Dataset.from_generator(
remember_exception(generate_values, exception_box), generate_values,
output_types=(tf.int32, tf.int32, tf.float32), output_types=(tf.int32, tf.int32, tf.float32),
) )
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE) .map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)

View File

@ -163,35 +163,6 @@ class LimitingPool:
self.pool.close() self.pool.close()
class ExceptionBox:
"""Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator.
Used in conjunction with `remember_exception`."""
def __init__(self):
self.exception = None
def raise_if_set(self):
if self.exception is not None:
exception = self.exception
self.exception = None
raise exception # pylint: disable = raising-bad-type
def remember_exception(iterable, exception_box=None):
"""Wraps a TensorFlow dataset generator for catching its actual exceptions
that would otherwise just interrupt iteration w/o bubbling up."""
def do_iterate():
try:
yield from iterable()
except StopIteration:
return
except Exception as ex: # pylint: disable = broad-except
exception_box.exception = ex
return iterable if exception_box is None else do_iterate
def get_value_range(value, target_type): def get_value_range(value, target_type):
""" """
This function converts all possible supplied values for augmentation This function converts all possible supplied values for augmentation