From 5d5ef15ab701b6a190bd852b21ee9655ad73a226 Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Thu, 1 Aug 2019 21:00:55 -0300 Subject: [PATCH 1/7] -data-aug via additive and multiplicative noise in feature-space --- util/feeding.py | 10 +++++++++- util/flags.py | 7 +++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/util/feeding.py b/util/feeding.py index ba11ebb0..fd8c400d 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -15,7 +15,7 @@ from tensorflow.python.ops import gen_audio_ops as contrib_audio from util.config import Config from util.logging import log_error from util.text import text_to_char_array - +from util.flags import FLAGS def read_csvs(csv_files): source_data = None @@ -47,6 +47,14 @@ def audiofile_to_features(wav_filename): decoded = contrib_audio.decode_wav(samples, desired_channels=1) features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate) + + if FLAGS.data_aug_features_multiplicative > 0: + features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features)) + + if FLAGS.data_aug_features_additive > 0: + features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features)) + + return features, features_len diff --git a/util/flags.py b/util/flags.py index e1eb1788..3dc58fdd 100644 --- a/util/flags.py +++ b/util/flags.py @@ -21,6 +21,13 @@ def create_flags(): f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds') f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model') + # Data Augmentation + # ================ + + f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise') + f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise') + + # Global Constants # ================ From 0cc5ff230f572040c993dd8d061addb8f0df42a6 Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Thu, 1 Aug 2019 21:26:45 -0300 Subject: [PATCH 2/7] -spectrogram augmentations --- util/feeding.py | 27 +++++ util/flags.py | 19 ++++ util/sparse_image_warp.py | 177 ++++++++++++++++++++++++++++++ util/spectrogram_augmentations.py | 97 ++++++++++++++++ 4 files changed, 320 insertions(+) create mode 100644 util/sparse_image_warp.py create mode 100644 util/spectrogram_augmentations.py diff --git a/util/feeding.py b/util/feeding.py index fd8c400d..c65590ee 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -16,6 +16,7 @@ from util.config import Config from util.logging import log_error from util.text import text_to_char_array from util.flags import FLAGS +from util.spectrogram_augmentations import augment_sparse_deform, augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up def read_csvs(csv_files): source_data = None @@ -36,6 +37,32 @@ def samples_to_mfccs(samples, sample_rate): window_size=Config.audio_window_samples, stride=Config.audio_step_samples, magnitude_squared=True) + + if FLAGS.augmention_sparse_deform: + spectrogram = augment_sparse_deform(spectrogram, + time_warping_para=FLAGS.augmentation_time_warp_max_warping, + normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp) + + if FLAGS.augmentation_spec_dropout_keeprate < 1: + spectrogram = augment_dropout(spectrogram, + keep_prob=FLAGS.augmentation_spec_dropout_keeprate) + + if FLAGS.augmentation_freq_and_time_masking: + spectrogram = augment_freq_time_mask(spectrogram, + frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range, + time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range, + frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks, + time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks) + + if FLAGS.augmentation_pitch_and_tempo_scaling: + spectrogram = augment_pitch_and_tempo(spectrogram, + max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo, + max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch, + min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch) + + if FLAGS.augmentation_speed_up_std > 0: + spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std) + mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input) mfccs = tf.reshape(mfccs, [-1, Config.n_input]) diff --git a/util/flags.py b/util/flags.py index 3dc58fdd..7119e4a1 100644 --- a/util/flags.py +++ b/util/flags.py @@ -27,6 +27,25 @@ def create_flags(): f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise') f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise') + f.DEFINE_integer('augmention_sparse_deform', 0, 'whether to use time-warping augmentation') + f.DEFINE_integer('augmentation_time_warp_max_warping', 12, 'max value for warping') + f.DEFINE_float('augmentation_sparse_deform_std_warp', 0.5, 'std for warping different values to different frequencies') + + f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)') + + f.DEFINE_integer('augmentation_freq_and_time_masking', 0, 'whether to use frequency and time masking augmentation') + f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation') + f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation') + f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation') + f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation') + + f.DEFINE_float('augmentation_speed_up_std', 0.5, 'std for speeding-up tempo. If std is 0, this augmentation is not performed') + + f.DEFINE_integer('augmentation_pitch_and_tempo_scaling', 0, 'whether to use spectrogram speed and tempo scaling') + f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling') + f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling') + f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling') + # Global Constants # ================ diff --git a/util/sparse_image_warp.py b/util/sparse_image_warp.py new file mode 100644 index 00000000..2bd69d45 --- /dev/null +++ b/util/sparse_image_warp.py @@ -0,0 +1,177 @@ +## Implementation of sparse_image_warp that handles dynamic shapes +from tensorflow.contrib.image.python.ops import dense_image_warp +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _get_grid_locations(image_height, image_width): + """Wrapper for array_ops.meshgrid.""" + + y_range = math_ops.linspace(0.0, math_ops.to_float(image_height) - 1, + image_height) + x_range = math_ops.linspace(0.0, math_ops.to_float(image_width) - 1, + image_width) + y_grid, x_grid = array_ops.meshgrid(y_range, x_range, indexing='ij') + return array_ops.stack((y_grid, x_grid), -1) + + +def _expand_to_minibatch(array, batch_size): + """Tile arbitrarily-sized array to include new batch dimension.""" + batch_size = array_ops.expand_dims(batch_size, 0) + array_ones = array_ops.ones((array_ops.rank(array)), dtype=dtypes.int32) + tiles = array_ops.concat([batch_size, array_ones], axis=0) + return array_ops.tile(array_ops.expand_dims(array, 0), tiles) + + +def _get_boundary_locations(image_height, image_width, num_points_per_edge): + """Compute evenly-spaced indices along edge of image.""" + image_height = math_ops.to_float(image_height) + image_width = math_ops.to_float(image_width) + y_range = math_ops.linspace(0.0, image_height - 1, num_points_per_edge + 2) + x_range = math_ops.linspace(0.0, image_width - 1, num_points_per_edge + 2) + ys, xs = array_ops.meshgrid(y_range, x_range, indexing='ij') + is_boundary = math_ops.logical_or( + math_ops.logical_or(math_ops.equal(xs, 0), # pylint: disable=bad-continuation + math_ops.equal(xs, image_width - 1)), + math_ops.logical_or(math_ops.equal(ys, 0), # pylint: disable=bad-continuation + math_ops.equal(ys, image_height - 1))) + return array_ops.stack([array_ops.boolean_mask(ys, is_boundary), + array_ops.boolean_mask(xs, is_boundary)], axis=-1) + + +def _add_zero_flow_controls_at_boundary(control_point_locations, + control_point_flows, image_height, + image_width, boundary_points_per_edge): + """Add control points for zero-flow boundary conditions. + Augment the set of control points with extra points on the + boundary of the image that have zero flow. + Args: + control_point_locations: input control points + control_point_flows: their flows + image_height: image height + image_width: image width + boundary_points_per_edge: number of points to add in the middle of each + edge (not including the corners). + The total number of points added is + 4 + 4*(boundary_points_per_edge). + Returns: + merged_control_point_locations: augmented set of control point locations + merged_control_point_flows: augmented set of control point flows + """ + + batch_size = tensor_shape.dimension_value(control_point_locations.shape[0]) + + boundary_point_locations = _get_boundary_locations(image_height, image_width, + boundary_points_per_edge) + + boundary_point_flows = array_ops.zeros([array_ops.shape(boundary_point_locations)[0], 2]) + + boundary_point_locations = _expand_to_minibatch(boundary_point_locations, + batch_size) + + boundary_point_flows = _expand_to_minibatch(boundary_point_flows, batch_size) + + merged_control_point_locations = array_ops.concat([control_point_locations, boundary_point_locations], 1) + + merged_control_point_flows = array_ops.concat([control_point_flows, boundary_point_flows], 1) + + return merged_control_point_locations, merged_control_point_flows + + +def sparse_image_warp(image, + source_control_point_locations, + dest_control_point_locations, + interpolation_order=2, + regularization_weight=0.0, + num_boundary_points=0, + name='sparse_image_warp'): + """Image warping using correspondences between sparse control points. + Apply a non-linear warp to the image, where the warp is specified by + the source and destination locations of a (potentially small) number of + control points. First, we use a polyharmonic spline + (`tf.contrib.image.interpolate_spline`) to interpolate the displacements + between the corresponding control points to a dense flow field. + Then, we warp the image using this dense flow field + (`tf.contrib.image.dense_image_warp`). + Let t index our control points. For regularization_weight=0, we have: + warped_image[b, dest_control_point_locations[b, t, 0], + dest_control_point_locations[b, t, 1], :] = + image[b, source_control_point_locations[b, t, 0], + source_control_point_locations[b, t, 1], :]. + For regularization_weight > 0, this condition is met approximately, since + regularized interpolation trades off smoothness of the interpolant vs. + reconstruction of the interpolant at the control points. + See `tf.contrib.image.interpolate_spline` for further documentation of the + interpolation_order and regularization_weight arguments. + Args: + image: `[batch, height, width, channels]` float `Tensor` + source_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + dest_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + interpolation_order: polynomial order used by the spline interpolation + regularization_weight: weight on smoothness regularizer in interpolation + num_boundary_points: How many zero-flow boundary points to include at + each image edge.Usage: + num_boundary_points=0: don't add zero-flow points + num_boundary_points=1: 4 corners of the image + num_boundary_points=2: 4 corners and one in the middle of each edge + (8 points total) + num_boundary_points=n: 4 corners and n-1 along each edge + name: A name for the operation (optional). + Note that image and offsets can be of type tf.half, tf.float32, or + tf.float64, and do not necessarily have to be the same type. + Returns: + warped_image: `[batch, height, width, channels]` float `Tensor` with same + type as input image. + flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense + flow field produced by the interpolation. + """ + + image = ops.convert_to_tensor(image) + source_control_point_locations = ops.convert_to_tensor( + source_control_point_locations) + dest_control_point_locations = ops.convert_to_tensor( + dest_control_point_locations) + + control_point_flows = ( + dest_control_point_locations - source_control_point_locations) + + clamp_boundaries = num_boundary_points > 0 + boundary_points_per_edge = num_boundary_points - 1 + + with ops.name_scope(name): + batch_size, image_height, image_width = (array_ops.shape(image)[0], + array_ops.shape(image)[1], + array_ops.shape(image)[2]) + # This generates the dense locations where the interpolant + # will be evaluated. + grid_locations = _get_grid_locations(image_height, image_width) + + flattened_grid_locations = array_ops.reshape(grid_locations, + [image_height*image_width, 2]) + + flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, + batch_size) + + if clamp_boundaries: + (dest_control_point_locations, + control_point_flows) = _add_zero_flow_controls_at_boundary(dest_control_point_locations, + control_point_flows, image_height, + image_width, boundary_points_per_edge) + + flattened_flows = interpolate_spline.interpolate_spline(dest_control_point_locations, control_point_flows, + flattened_grid_locations, interpolation_order, + regularization_weight) + + dense_flows = array_ops.reshape(flattened_flows, + [batch_size, image_height, image_width, 2]) + + warped_image = dense_image_warp.dense_image_warp(image, dense_flows) + + return warped_image, dense_flows \ No newline at end of file diff --git a/util/spectrogram_augmentations.py b/util/spectrogram_augmentations.py new file mode 100644 index 00000000..012c1bd2 --- /dev/null +++ b/util/spectrogram_augmentations.py @@ -0,0 +1,97 @@ +import tensorflow as tf +from util.sparse_image_warp import sparse_image_warp + +def augment_sparse_deform(mel_spectrogram, + time_warping_para=12, + normal_around_warping_std=0.5): + mel_spectrogram = tf.expand_dims(mel_spectrogram, -1) + freq_max = tf.shape(mel_spectrogram)[1] + time_max = tf.shape(mel_spectrogram)[2] + center_freq = tf.cast(freq_max, tf.float32)/2.0 + random_time_point = tf.random.uniform(shape=(), minval=time_warping_para, maxval=tf.cast(time_max, tf.float32) - time_warping_para) + chosen_warping = tf.random.uniform(shape=(), minval=0, maxval=time_warping_para) + #add different warping values to different frequencies + normal_around_warping = tf.random.normal(mean=chosen_warping, stddev=normal_around_warping_std, shape=(3,)) + + control_point_freqs = tf.stack([0.0, center_freq, tf.cast(freq_max, tf.float32)], axis=0) + control_point_times_src = tf.stack([random_time_point, random_time_point, random_time_point], axis=0) + control_point_times_dst = control_point_times_src+normal_around_warping + + control_src = tf.expand_dims(tf.stack([control_point_freqs, control_point_times_src], axis=-1), 0) + control_dst = tf.expand_dims(tf.stack([control_point_freqs, control_point_times_dst], axis=1), 0) + warped_mel_spectrogram, _ = sparse_image_warp(mel_spectrogram, + source_control_point_locations=control_src, + dest_control_point_locations=control_dst, + interpolation_order=2, + regularization_weight=0, + num_boundary_points=1 + ) + warped_mel_spectrogram = warped_mel_spectrogram[:, :, :, 0] + return warped_mel_spectrogram + +def augment_freq_time_mask(mel_spectrogram, + frequency_masking_para=30, + time_masking_para=10, + frequency_mask_num=3, + time_mask_num=3): + freq_max = tf.shape(mel_spectrogram)[1] + time_max = tf.shape(mel_spectrogram)[2] + # Frequency masking + # Testing without loop + for _ in range(frequency_mask_num): + f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32) + f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32) + value_ones_freq_prev = tf.ones(shape=[1, f0, time_max]) + value_zeros_freq = tf.zeros(shape=[1, f, time_max]) + value_ones_freq_next = tf.ones(shape=[1, freq_max-(f0+f), time_max]) + freq_mask = tf.concat([value_ones_freq_prev, value_zeros_freq, value_ones_freq_next], axis=1) + #mel_spectrogram[:, f0:f0 + f, :] = 0 #can't assign to tensor + #mel_spectrogram[:, f0:f0 + f, :] = value_zeros_freq #can't assign to tensor + mel_spectrogram = mel_spectrogram*freq_mask + + # Time masking + # Testing without loop + for _ in range(time_mask_num): + t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32) + t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32) + value_zeros_time_prev = tf.ones(shape=[1, freq_max, t0]) + value_zeros_time = tf.zeros(shape=[1, freq_max, t]) + value_zeros_time_next = tf.ones(shape=[1, freq_max, time_max-(t0+t)]) + time_mask = tf.concat([value_zeros_time_prev, value_zeros_time, value_zeros_time_next], axis=2) + #mel_spectrogram[:, :, t0:t0 + t] = 0 #can't assign to tensor + #mel_spectrogram[:, :, t0:t0 + t] = value_zeros_time #can't assign to tensor + mel_spectrogram = mel_spectrogram*time_mask + + return mel_spectrogram + +def augment_pitch_and_tempo(spectrogram, + max_tempo=1.2, + max_pitch=1.1, + min_pitch=0.95): + original_shape = tf.shape(spectrogram) + choosen_pitch = tf.random.uniform(shape=(), minval=min_pitch, maxval=max_pitch) + choosen_tempo = tf.random.uniform(shape=(), minval=1, maxval=max_tempo) + new_height = tf.cast(tf.cast(original_shape[1], tf.float32)*choosen_pitch, tf.int32) + new_width = tf.cast(tf.cast(original_shape[2], tf.float32)/(choosen_tempo), tf.int32) + spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(spectrogram, -1), [new_height, new_width]) + spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0, target_height=tf.minimum(original_shape[1],new_height), target_width=tf.shape(spectrogram_aug)[2]) + spectrogram_aug = tf.cond(choosen_pitch < 1, + lambda: tf.image.pad_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0, + target_height=original_shape[1], target_width=tf.shape(spectrogram_aug)[2]), + lambda: spectrogram_aug) + return spectrogram_aug[:, :, :, 0] + + +def augment_speed_up(spectrogram, + speed_std=0.1): + original_shape = tf.shape(spectrogram) + choosen_speed = tf.math.abs(tf.random.normal(shape=(), stddev=speed_std)) # abs makes sure the augmention will only speed up + choosen_speed = 1 + choosen_speed + new_height = tf.cast(tf.cast(original_shape[1], tf.float32), tf.int32) + new_width = tf.cast(tf.cast(original_shape[2], tf.float32)/(choosen_speed), tf.int32) + spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(spectrogram, -1), [new_height, new_width]) + return spectrogram_aug[:, :, :, 0] + +def augment_dropout(spectrogram, + keep_prob=0.95): + return tf.nn.dropout(spectrogram, rate=1-keep_prob) From 49c6a9c9736f44d6fe27ef8e2bf15f76a95d4308 Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Thu, 1 Aug 2019 22:09:06 -0300 Subject: [PATCH 3/7] adding 'train_phase' to create_dataset. Now we can augment only the training-set. --- DeepSpeech.py | 5 ++- evaluate.py | 2 +- util/feeding.py | 71 +++++++++++++++++-------------- util/flags.py | 2 +- util/spectrogram_augmentations.py | 2 - 5 files changed, 44 insertions(+), 38 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 5183407e..c9344ea1 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -415,7 +415,8 @@ def train(): # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, - cache_path=FLAGS.feature_cache) + cache_path=FLAGS.feature_cache, + train_phase=True) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), @@ -426,7 +427,7 @@ def train(): if FLAGS.dev_files: dev_csvs = FLAGS.dev_files.split(',') - dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs] + dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] # Dropout diff --git a/evaluate.py b/evaluate.py index c86ebc1e..32c45367 100755 --- a/evaluate.py +++ b/evaluate.py @@ -47,7 +47,7 @@ def evaluate(test_csvs, create_model, try_loading): Config.alphabet) test_csvs = FLAGS.test_files.split(',') - test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs] + test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs] iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]), output_classes=tfv1.data.get_output_classes(test_sets[0])) diff --git a/util/feeding.py b/util/feeding.py index c65590ee..817338cf 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -32,36 +32,38 @@ def read_csvs(csv_files): return source_data -def samples_to_mfccs(samples, sample_rate): +def samples_to_mfccs(samples, sample_rate, train_phase=False): spectrogram = contrib_audio.audio_spectrogram(samples, window_size=Config.audio_window_samples, stride=Config.audio_step_samples, magnitude_squared=True) - if FLAGS.augmention_sparse_deform: - spectrogram = augment_sparse_deform(spectrogram, - time_warping_para=FLAGS.augmentation_time_warp_max_warping, - normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp) + # Data Augmentations + if train_phase: + if FLAGS.augmention_sparse_deform: + spectrogram = augment_sparse_deform(spectrogram, + time_warping_para=FLAGS.augmentation_time_warp_max_warping, + normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp) - if FLAGS.augmentation_spec_dropout_keeprate < 1: - spectrogram = augment_dropout(spectrogram, - keep_prob=FLAGS.augmentation_spec_dropout_keeprate) + if FLAGS.augmentation_spec_dropout_keeprate < 1: + spectrogram = augment_dropout(spectrogram, + keep_prob=FLAGS.augmentation_spec_dropout_keeprate) - if FLAGS.augmentation_freq_and_time_masking: - spectrogram = augment_freq_time_mask(spectrogram, - frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range, - time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range, - frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks, - time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks) + if FLAGS.augmentation_freq_and_time_masking: + spectrogram = augment_freq_time_mask(spectrogram, + frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range, + time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range, + frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks, + time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks) - if FLAGS.augmentation_pitch_and_tempo_scaling: - spectrogram = augment_pitch_and_tempo(spectrogram, - max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo, - max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch, - min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch) + if FLAGS.augmentation_pitch_and_tempo_scaling: + spectrogram = augment_pitch_and_tempo(spectrogram, + max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo, + max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch, + min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch) - if FLAGS.augmentation_speed_up_std > 0: - spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std) + if FLAGS.augmentation_speed_up_std > 0: + spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std) mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input) mfccs = tf.reshape(mfccs, [-1, Config.n_input]) @@ -69,25 +71,29 @@ def samples_to_mfccs(samples, sample_rate): return mfccs, tf.shape(input=mfccs)[0] -def audiofile_to_features(wav_filename): +def audiofile_to_features(wav_filename, train_phase=False): samples = tf.io.read_file(wav_filename) decoded = contrib_audio.decode_wav(samples, desired_channels=1) - features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate) + features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase) + if train_phase: + if FLAGS.data_aug_features_multiplicative > 0: + features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features)) - if FLAGS.data_aug_features_multiplicative > 0: - features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features)) - - if FLAGS.data_aug_features_additive > 0: - features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features)) - + if FLAGS.data_aug_features_additive > 0: + features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features)) return features, features_len -def entry_to_features(wav_filename, transcript): +def entry_to_features_not_augmented(wav_filename, transcript): # https://bugs.python.org/issue32117 - features, features_len = audiofile_to_features(wav_filename) + features, features_len = audiofile_to_features(wav_filename, train_phase=False) + return wav_filename, features, features_len, tf.SparseTensor(*transcript) + +def entry_to_features_augmented(wav_filename, transcript): + # https://bugs.python.org/issue32117 + features, features_len = audiofile_to_features(wav_filename, train_phase=True) return wav_filename, features, features_len, tf.SparseTensor(*transcript) @@ -100,7 +106,7 @@ def to_sparse_tuple(sequence): return indices, sequence, shape -def create_dataset(csvs, batch_size, cache_path=''): +def create_dataset(csvs, batch_size, cache_path='', train_phase=True): df = read_csvs(csvs) df.sort_values(by='wav_filesize', inplace=True) @@ -112,6 +118,7 @@ def create_dataset(csvs, batch_size, cache_path=''): log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message)) exit(1) + entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented def generate_values(): for _, row in df.iterrows(): yield row.wav_filename, to_sparse_tuple(row.transcript) diff --git a/util/flags.py b/util/flags.py index 7119e4a1..68c7ca7d 100644 --- a/util/flags.py +++ b/util/flags.py @@ -39,7 +39,7 @@ def create_flags(): f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation') f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation') - f.DEFINE_float('augmentation_speed_up_std', 0.5, 'std for speeding-up tempo. If std is 0, this augmentation is not performed') + f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed') f.DEFINE_integer('augmentation_pitch_and_tempo_scaling', 0, 'whether to use spectrogram speed and tempo scaling') f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling') diff --git a/util/spectrogram_augmentations.py b/util/spectrogram_augmentations.py index 012c1bd2..fb47f713 100644 --- a/util/spectrogram_augmentations.py +++ b/util/spectrogram_augmentations.py @@ -37,7 +37,6 @@ def augment_freq_time_mask(mel_spectrogram, freq_max = tf.shape(mel_spectrogram)[1] time_max = tf.shape(mel_spectrogram)[2] # Frequency masking - # Testing without loop for _ in range(frequency_mask_num): f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32) f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32) @@ -50,7 +49,6 @@ def augment_freq_time_mask(mel_spectrogram, mel_spectrogram = mel_spectrogram*freq_mask # Time masking - # Testing without loop for _ in range(time_mask_num): t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32) t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32) From b89fb04b97ee698f13a61160b87d832cb5b8be99 Mon Sep 17 00:00:00 2001 From: Bernardo Date: Fri, 2 Aug 2019 10:24:58 -0300 Subject: [PATCH 4/7] space after comma --- util/spectrogram_augmentations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/spectrogram_augmentations.py b/util/spectrogram_augmentations.py index fb47f713..98159826 100644 --- a/util/spectrogram_augmentations.py +++ b/util/spectrogram_augmentations.py @@ -72,7 +72,7 @@ def augment_pitch_and_tempo(spectrogram, new_height = tf.cast(tf.cast(original_shape[1], tf.float32)*choosen_pitch, tf.int32) new_width = tf.cast(tf.cast(original_shape[2], tf.float32)/(choosen_tempo), tf.int32) spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(spectrogram, -1), [new_height, new_width]) - spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0, target_height=tf.minimum(original_shape[1],new_height), target_width=tf.shape(spectrogram_aug)[2]) + spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0, target_height=tf.minimum(original_shape[1], new_height), target_width=tf.shape(spectrogram_aug)[2]) spectrogram_aug = tf.cond(choosen_pitch < 1, lambda: tf.image.pad_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0, target_height=original_shape[1], target_width=tf.shape(spectrogram_aug)[2]), From 0e4eed7be3389fce471e826c2231f1a75c5c534d Mon Sep 17 00:00:00 2001 From: Bernardo Date: Fri, 2 Aug 2019 20:27:38 -0300 Subject: [PATCH 5/7] removing trailing space --- util/flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/flags.py b/util/flags.py index 68c7ca7d..8ec0ce60 100644 --- a/util/flags.py +++ b/util/flags.py @@ -30,7 +30,7 @@ def create_flags(): f.DEFINE_integer('augmention_sparse_deform', 0, 'whether to use time-warping augmentation') f.DEFINE_integer('augmentation_time_warp_max_warping', 12, 'max value for warping') f.DEFINE_float('augmentation_sparse_deform_std_warp', 0.5, 'std for warping different values to different frequencies') - + f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)') f.DEFINE_integer('augmentation_freq_and_time_masking', 0, 'whether to use frequency and time masking augmentation') From d051d4fd0e5cf096f0973abf7f0f4417f5a1c0a0 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 9 Sep 2019 12:11:28 +0200 Subject: [PATCH 6/7] Remove sparse image warp, fix boolean flags type, rebase to master --- util/feeding.py | 7 +- util/flags.py | 8 +- util/sparse_image_warp.py | 177 ------------------------------ util/spectrogram_augmentations.py | 29 ----- 4 files changed, 3 insertions(+), 218 deletions(-) delete mode 100644 util/sparse_image_warp.py diff --git a/util/feeding.py b/util/feeding.py index 817338cf..829d2ffe 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -16,7 +16,7 @@ from util.config import Config from util.logging import log_error from util.text import text_to_char_array from util.flags import FLAGS -from util.spectrogram_augmentations import augment_sparse_deform, augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up +from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up def read_csvs(csv_files): source_data = None @@ -40,11 +40,6 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False): # Data Augmentations if train_phase: - if FLAGS.augmention_sparse_deform: - spectrogram = augment_sparse_deform(spectrogram, - time_warping_para=FLAGS.augmentation_time_warp_max_warping, - normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp) - if FLAGS.augmentation_spec_dropout_keeprate < 1: spectrogram = augment_dropout(spectrogram, keep_prob=FLAGS.augmentation_spec_dropout_keeprate) diff --git a/util/flags.py b/util/flags.py index 8ec0ce60..faa59f7a 100644 --- a/util/flags.py +++ b/util/flags.py @@ -27,13 +27,9 @@ def create_flags(): f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise') f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise') - f.DEFINE_integer('augmention_sparse_deform', 0, 'whether to use time-warping augmentation') - f.DEFINE_integer('augmentation_time_warp_max_warping', 12, 'max value for warping') - f.DEFINE_float('augmentation_sparse_deform_std_warp', 0.5, 'std for warping different values to different frequencies') - f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)') - f.DEFINE_integer('augmentation_freq_and_time_masking', 0, 'whether to use frequency and time masking augmentation') + f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation') f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation') f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation') f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation') @@ -41,7 +37,7 @@ def create_flags(): f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed') - f.DEFINE_integer('augmentation_pitch_and_tempo_scaling', 0, 'whether to use spectrogram speed and tempo scaling') + f.DEFINE_boolean('augmentation_pitch_and_tempo_scaling', False, 'whether to use spectrogram speed and tempo scaling') f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling') f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling') f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling') diff --git a/util/sparse_image_warp.py b/util/sparse_image_warp.py deleted file mode 100644 index 2bd69d45..00000000 --- a/util/sparse_image_warp.py +++ /dev/null @@ -1,177 +0,0 @@ -## Implementation of sparse_image_warp that handles dynamic shapes -from tensorflow.contrib.image.python.ops import dense_image_warp -from tensorflow.contrib.image.python.ops import interpolate_spline - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - - -def _get_grid_locations(image_height, image_width): - """Wrapper for array_ops.meshgrid.""" - - y_range = math_ops.linspace(0.0, math_ops.to_float(image_height) - 1, - image_height) - x_range = math_ops.linspace(0.0, math_ops.to_float(image_width) - 1, - image_width) - y_grid, x_grid = array_ops.meshgrid(y_range, x_range, indexing='ij') - return array_ops.stack((y_grid, x_grid), -1) - - -def _expand_to_minibatch(array, batch_size): - """Tile arbitrarily-sized array to include new batch dimension.""" - batch_size = array_ops.expand_dims(batch_size, 0) - array_ones = array_ops.ones((array_ops.rank(array)), dtype=dtypes.int32) - tiles = array_ops.concat([batch_size, array_ones], axis=0) - return array_ops.tile(array_ops.expand_dims(array, 0), tiles) - - -def _get_boundary_locations(image_height, image_width, num_points_per_edge): - """Compute evenly-spaced indices along edge of image.""" - image_height = math_ops.to_float(image_height) - image_width = math_ops.to_float(image_width) - y_range = math_ops.linspace(0.0, image_height - 1, num_points_per_edge + 2) - x_range = math_ops.linspace(0.0, image_width - 1, num_points_per_edge + 2) - ys, xs = array_ops.meshgrid(y_range, x_range, indexing='ij') - is_boundary = math_ops.logical_or( - math_ops.logical_or(math_ops.equal(xs, 0), # pylint: disable=bad-continuation - math_ops.equal(xs, image_width - 1)), - math_ops.logical_or(math_ops.equal(ys, 0), # pylint: disable=bad-continuation - math_ops.equal(ys, image_height - 1))) - return array_ops.stack([array_ops.boolean_mask(ys, is_boundary), - array_ops.boolean_mask(xs, is_boundary)], axis=-1) - - -def _add_zero_flow_controls_at_boundary(control_point_locations, - control_point_flows, image_height, - image_width, boundary_points_per_edge): - """Add control points for zero-flow boundary conditions. - Augment the set of control points with extra points on the - boundary of the image that have zero flow. - Args: - control_point_locations: input control points - control_point_flows: their flows - image_height: image height - image_width: image width - boundary_points_per_edge: number of points to add in the middle of each - edge (not including the corners). - The total number of points added is - 4 + 4*(boundary_points_per_edge). - Returns: - merged_control_point_locations: augmented set of control point locations - merged_control_point_flows: augmented set of control point flows - """ - - batch_size = tensor_shape.dimension_value(control_point_locations.shape[0]) - - boundary_point_locations = _get_boundary_locations(image_height, image_width, - boundary_points_per_edge) - - boundary_point_flows = array_ops.zeros([array_ops.shape(boundary_point_locations)[0], 2]) - - boundary_point_locations = _expand_to_minibatch(boundary_point_locations, - batch_size) - - boundary_point_flows = _expand_to_minibatch(boundary_point_flows, batch_size) - - merged_control_point_locations = array_ops.concat([control_point_locations, boundary_point_locations], 1) - - merged_control_point_flows = array_ops.concat([control_point_flows, boundary_point_flows], 1) - - return merged_control_point_locations, merged_control_point_flows - - -def sparse_image_warp(image, - source_control_point_locations, - dest_control_point_locations, - interpolation_order=2, - regularization_weight=0.0, - num_boundary_points=0, - name='sparse_image_warp'): - """Image warping using correspondences between sparse control points. - Apply a non-linear warp to the image, where the warp is specified by - the source and destination locations of a (potentially small) number of - control points. First, we use a polyharmonic spline - (`tf.contrib.image.interpolate_spline`) to interpolate the displacements - between the corresponding control points to a dense flow field. - Then, we warp the image using this dense flow field - (`tf.contrib.image.dense_image_warp`). - Let t index our control points. For regularization_weight=0, we have: - warped_image[b, dest_control_point_locations[b, t, 0], - dest_control_point_locations[b, t, 1], :] = - image[b, source_control_point_locations[b, t, 0], - source_control_point_locations[b, t, 1], :]. - For regularization_weight > 0, this condition is met approximately, since - regularized interpolation trades off smoothness of the interpolant vs. - reconstruction of the interpolant at the control points. - See `tf.contrib.image.interpolate_spline` for further documentation of the - interpolation_order and regularization_weight arguments. - Args: - image: `[batch, height, width, channels]` float `Tensor` - source_control_point_locations: `[batch, num_control_points, 2]` float - `Tensor` - dest_control_point_locations: `[batch, num_control_points, 2]` float - `Tensor` - interpolation_order: polynomial order used by the spline interpolation - regularization_weight: weight on smoothness regularizer in interpolation - num_boundary_points: How many zero-flow boundary points to include at - each image edge.Usage: - num_boundary_points=0: don't add zero-flow points - num_boundary_points=1: 4 corners of the image - num_boundary_points=2: 4 corners and one in the middle of each edge - (8 points total) - num_boundary_points=n: 4 corners and n-1 along each edge - name: A name for the operation (optional). - Note that image and offsets can be of type tf.half, tf.float32, or - tf.float64, and do not necessarily have to be the same type. - Returns: - warped_image: `[batch, height, width, channels]` float `Tensor` with same - type as input image. - flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense - flow field produced by the interpolation. - """ - - image = ops.convert_to_tensor(image) - source_control_point_locations = ops.convert_to_tensor( - source_control_point_locations) - dest_control_point_locations = ops.convert_to_tensor( - dest_control_point_locations) - - control_point_flows = ( - dest_control_point_locations - source_control_point_locations) - - clamp_boundaries = num_boundary_points > 0 - boundary_points_per_edge = num_boundary_points - 1 - - with ops.name_scope(name): - batch_size, image_height, image_width = (array_ops.shape(image)[0], - array_ops.shape(image)[1], - array_ops.shape(image)[2]) - # This generates the dense locations where the interpolant - # will be evaluated. - grid_locations = _get_grid_locations(image_height, image_width) - - flattened_grid_locations = array_ops.reshape(grid_locations, - [image_height*image_width, 2]) - - flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, - batch_size) - - if clamp_boundaries: - (dest_control_point_locations, - control_point_flows) = _add_zero_flow_controls_at_boundary(dest_control_point_locations, - control_point_flows, image_height, - image_width, boundary_points_per_edge) - - flattened_flows = interpolate_spline.interpolate_spline(dest_control_point_locations, control_point_flows, - flattened_grid_locations, interpolation_order, - regularization_weight) - - dense_flows = array_ops.reshape(flattened_flows, - [batch_size, image_height, image_width, 2]) - - warped_image = dense_image_warp.dense_image_warp(image, dense_flows) - - return warped_image, dense_flows \ No newline at end of file diff --git a/util/spectrogram_augmentations.py b/util/spectrogram_augmentations.py index 98159826..9cf36a24 100644 --- a/util/spectrogram_augmentations.py +++ b/util/spectrogram_augmentations.py @@ -1,33 +1,4 @@ import tensorflow as tf -from util.sparse_image_warp import sparse_image_warp - -def augment_sparse_deform(mel_spectrogram, - time_warping_para=12, - normal_around_warping_std=0.5): - mel_spectrogram = tf.expand_dims(mel_spectrogram, -1) - freq_max = tf.shape(mel_spectrogram)[1] - time_max = tf.shape(mel_spectrogram)[2] - center_freq = tf.cast(freq_max, tf.float32)/2.0 - random_time_point = tf.random.uniform(shape=(), minval=time_warping_para, maxval=tf.cast(time_max, tf.float32) - time_warping_para) - chosen_warping = tf.random.uniform(shape=(), minval=0, maxval=time_warping_para) - #add different warping values to different frequencies - normal_around_warping = tf.random.normal(mean=chosen_warping, stddev=normal_around_warping_std, shape=(3,)) - - control_point_freqs = tf.stack([0.0, center_freq, tf.cast(freq_max, tf.float32)], axis=0) - control_point_times_src = tf.stack([random_time_point, random_time_point, random_time_point], axis=0) - control_point_times_dst = control_point_times_src+normal_around_warping - - control_src = tf.expand_dims(tf.stack([control_point_freqs, control_point_times_src], axis=-1), 0) - control_dst = tf.expand_dims(tf.stack([control_point_freqs, control_point_times_dst], axis=1), 0) - warped_mel_spectrogram, _ = sparse_image_warp(mel_spectrogram, - source_control_point_locations=control_src, - dest_control_point_locations=control_dst, - interpolation_order=2, - regularization_weight=0, - num_boundary_points=1 - ) - warped_mel_spectrogram = warped_mel_spectrogram[:, :, :, 0] - return warped_mel_spectrogram def augment_freq_time_mask(mel_spectrogram, frequency_masking_para=30, From b6af8c5dc7d4326bc0c022d251bdd8c67e758af1 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 9 Sep 2019 12:20:16 +0200 Subject: [PATCH 7/7] Remove some duplicated code --- util/feeding.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/util/feeding.py b/util/feeding.py index 829d2ffe..a041503a 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -81,14 +81,9 @@ def audiofile_to_features(wav_filename, train_phase=False): return features, features_len -def entry_to_features_not_augmented(wav_filename, transcript): +def entry_to_features(wav_filename, transcript, train_phase): # https://bugs.python.org/issue32117 - features, features_len = audiofile_to_features(wav_filename, train_phase=False) - return wav_filename, features, features_len, tf.SparseTensor(*transcript) - -def entry_to_features_augmented(wav_filename, transcript): - # https://bugs.python.org/issue32117 - features, features_len = audiofile_to_features(wav_filename, train_phase=True) + features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase) return wav_filename, features, features_len, tf.SparseTensor(*transcript) @@ -101,7 +96,7 @@ def to_sparse_tuple(sequence): return indices, sequence, shape -def create_dataset(csvs, batch_size, cache_path='', train_phase=True): +def create_dataset(csvs, batch_size, cache_path='', train_phase=False): df = read_csvs(csvs) df.sort_values(by='wav_filesize', inplace=True) @@ -113,7 +108,6 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True): log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message)) exit(1) - entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented def generate_values(): for _, row in df.iterrows(): yield row.wav_filename, to_sparse_tuple(row.transcript) @@ -134,10 +128,11 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True): return tf.data.Dataset.zip((wav_filenames, features, transcripts)) num_gpus = len(Config.available_devices) + process_fn = partial(entry_to_features, train_phase=train_phase) dataset = (tf.data.Dataset.from_generator(generate_values, output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) - .map(entry_to_features, num_parallel_calls=tf.data.experimental.AUTOTUNE) + .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) .cache(cache_path) .window(batch_size, drop_remainder=True).flat_map(batch_fn) .prefetch(num_gpus))