Merge pull request #1948 from coqui-ai/remove-exception-box
Remove ExceptionBox and remember_exception
This commit is contained in:
commit
f94d16bcc3
|
@ -58,7 +58,7 @@ from .util.config import (
|
||||||
log_warn,
|
log_warn,
|
||||||
)
|
)
|
||||||
from .util.feeding import create_dataset
|
from .util.feeding import create_dataset
|
||||||
from .util.helpers import ExceptionBox, check_ctcdecoder_version
|
from .util.helpers import check_ctcdecoder_version
|
||||||
from .util.io import (
|
from .util.io import (
|
||||||
is_remote_path,
|
is_remote_path,
|
||||||
open_remote,
|
open_remote,
|
||||||
|
@ -266,9 +266,11 @@ def early_training_checks():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_training_datasets(
|
def create_training_datasets() -> (
|
||||||
exception_box,
|
tf.data.Dataset,
|
||||||
) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],):
|
[tf.data.Dataset],
|
||||||
|
[tf.data.Dataset],
|
||||||
|
):
|
||||||
"""Creates training datasets from input flags.
|
"""Creates training datasets from input flags.
|
||||||
|
|
||||||
Returns a single training dataset and two lists of datasets for validation
|
Returns a single training dataset and two lists of datasets for validation
|
||||||
|
@ -282,7 +284,6 @@ def create_training_datasets(
|
||||||
augmentations=Config.augmentations,
|
augmentations=Config.augmentations,
|
||||||
cache_path=Config.feature_cache,
|
cache_path=Config.feature_cache,
|
||||||
train_phase=True,
|
train_phase=True,
|
||||||
exception_box=exception_box,
|
|
||||||
process_ahead=len(Config.available_devices) * Config.train_batch_size * 2,
|
process_ahead=len(Config.available_devices) * Config.train_batch_size * 2,
|
||||||
reverse=Config.reverse_train,
|
reverse=Config.reverse_train,
|
||||||
limit=Config.limit_train,
|
limit=Config.limit_train,
|
||||||
|
@ -297,7 +298,6 @@ def create_training_datasets(
|
||||||
batch_size=Config.dev_batch_size,
|
batch_size=Config.dev_batch_size,
|
||||||
train_phase=False,
|
train_phase=False,
|
||||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||||
exception_box=exception_box,
|
|
||||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||||
reverse=Config.reverse_dev,
|
reverse=Config.reverse_dev,
|
||||||
limit=Config.limit_dev,
|
limit=Config.limit_dev,
|
||||||
|
@ -314,7 +314,6 @@ def create_training_datasets(
|
||||||
batch_size=Config.dev_batch_size,
|
batch_size=Config.dev_batch_size,
|
||||||
train_phase=False,
|
train_phase=False,
|
||||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||||
exception_box=exception_box,
|
|
||||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||||
reverse=Config.reverse_dev,
|
reverse=Config.reverse_dev,
|
||||||
limit=Config.limit_dev,
|
limit=Config.limit_dev,
|
||||||
|
@ -332,9 +331,7 @@ def train():
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
tfv1.set_random_seed(Config.random_seed)
|
tfv1.set_random_seed(Config.random_seed)
|
||||||
|
|
||||||
exception_box = ExceptionBox()
|
train_set, dev_sets, metrics_sets = create_training_datasets()
|
||||||
|
|
||||||
train_set, dev_sets, metrics_sets = create_training_datasets(exception_box)
|
|
||||||
|
|
||||||
iterator = tfv1.data.Iterator.from_structure(
|
iterator = tfv1.data.Iterator.from_structure(
|
||||||
tfv1.data.get_output_types(train_set),
|
tfv1.data.get_output_types(train_set),
|
||||||
|
@ -512,9 +509,7 @@ def train():
|
||||||
],
|
],
|
||||||
feed_dict=feed_dict,
|
feed_dict=feed_dict,
|
||||||
)
|
)
|
||||||
exception_box.raise_if_set()
|
|
||||||
except tf.errors.OutOfRangeError:
|
except tf.errors.OutOfRangeError:
|
||||||
exception_box.raise_if_set()
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if problem_files.size > 0:
|
if problem_files.size > 0:
|
||||||
|
|
|
@ -12,7 +12,7 @@ import tensorflow as tf
|
||||||
from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split
|
from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split
|
||||||
from .augmentations import apply_graph_augmentations, apply_sample_augmentations
|
from .augmentations import apply_graph_augmentations, apply_sample_augmentations
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .helpers import MEGABYTE, remember_exception
|
from .helpers import MEGABYTE
|
||||||
from .sample_collections import samples_from_sources
|
from .sample_collections import samples_from_sources
|
||||||
from .text import text_to_char_array
|
from .text import text_to_char_array
|
||||||
|
|
||||||
|
@ -138,7 +138,6 @@ def create_dataset(
|
||||||
train_phase=False,
|
train_phase=False,
|
||||||
reverse=False,
|
reverse=False,
|
||||||
limit=0,
|
limit=0,
|
||||||
exception_box=None,
|
|
||||||
process_ahead=None,
|
process_ahead=None,
|
||||||
buffering=1 * MEGABYTE,
|
buffering=1 * MEGABYTE,
|
||||||
):
|
):
|
||||||
|
@ -197,7 +196,7 @@ def create_dataset(
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = tf.data.Dataset.from_generator(
|
dataset = tf.data.Dataset.from_generator(
|
||||||
remember_exception(generate_values, exception_box),
|
generate_values,
|
||||||
output_types=(
|
output_types=(
|
||||||
tf.string,
|
tf.string,
|
||||||
tf.float32,
|
tf.float32,
|
||||||
|
@ -223,7 +222,6 @@ def split_audio_file(
|
||||||
aggressiveness=3,
|
aggressiveness=3,
|
||||||
outlier_duration_ms=10000,
|
outlier_duration_ms=10000,
|
||||||
outlier_batch_size=1,
|
outlier_batch_size=1,
|
||||||
exception_box=None,
|
|
||||||
):
|
):
|
||||||
def generate_values():
|
def generate_values():
|
||||||
frames = read_frames_from_file(audio_path)
|
frames = read_frames_from_file(audio_path)
|
||||||
|
@ -240,7 +238,7 @@ def split_audio_file(
|
||||||
def create_batch_set(bs, criteria):
|
def create_batch_set(bs, criteria):
|
||||||
return (
|
return (
|
||||||
tf.data.Dataset.from_generator(
|
tf.data.Dataset.from_generator(
|
||||||
remember_exception(generate_values, exception_box),
|
generate_values,
|
||||||
output_types=(tf.int32, tf.int32, tf.float32),
|
output_types=(tf.int32, tf.int32, tf.float32),
|
||||||
)
|
)
|
||||||
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||||
|
|
|
@ -163,35 +163,6 @@ class LimitingPool:
|
||||||
self.pool.close()
|
self.pool.close()
|
||||||
|
|
||||||
|
|
||||||
class ExceptionBox:
|
|
||||||
"""Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator.
|
|
||||||
Used in conjunction with `remember_exception`."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.exception = None
|
|
||||||
|
|
||||||
def raise_if_set(self):
|
|
||||||
if self.exception is not None:
|
|
||||||
exception = self.exception
|
|
||||||
self.exception = None
|
|
||||||
raise exception # pylint: disable = raising-bad-type
|
|
||||||
|
|
||||||
|
|
||||||
def remember_exception(iterable, exception_box=None):
|
|
||||||
"""Wraps a TensorFlow dataset generator for catching its actual exceptions
|
|
||||||
that would otherwise just interrupt iteration w/o bubbling up."""
|
|
||||||
|
|
||||||
def do_iterate():
|
|
||||||
try:
|
|
||||||
yield from iterable()
|
|
||||||
except StopIteration:
|
|
||||||
return
|
|
||||||
except Exception as ex: # pylint: disable = broad-except
|
|
||||||
exception_box.exception = ex
|
|
||||||
|
|
||||||
return iterable if exception_box is None else do_iterate
|
|
||||||
|
|
||||||
|
|
||||||
def get_value_range(value, target_type):
|
def get_value_range(value, target_type):
|
||||||
"""
|
"""
|
||||||
This function converts all possible supplied values for augmentation
|
This function converts all possible supplied values for augmentation
|
||||||
|
|
Loading…
Reference in New Issue