Remote training I/O once more (#3437)

* Redo remote I/O changes once more; this time without messing with taskcluster

* Add bin changes

* Fix merge-induced issue?

* For the interleaved case with multiple collections, unpack audio on the fly

To reproduce the previous failure

rm data/smoke_test/ldc93s1.csv
rm data/smoke_test/ldc93s1.sdb
rm -rf /tmp/ldc93s1_cache_sdb_csv
rm -rf /tmp/ckpt_sdb_csv
rm -rf /tmp/train_sdb_csv

./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 16000
python -u DeepSpeech.py --noshow_progressbar --noearly_stop --train_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --train_batch_size 1 --feature_cache /tmp/ldc93s1_cache_sdb_csv --dev_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --dev_batch_size 1 --test_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --test_batch_size 1 --n_hidden 100 --epochs 109 --max_to_keep 1 --checkpoint_dir /tmp/ckpt_sdb_csv --learning_rate 0.001 --dropout_rate 0.05 --export_dir /tmp/train_sdb_csv --scorer_path data/smoke_test/pruned_lm.scorer --audio_sample_rate 16000

* Attempt to preserve length information with a wrapper around `map()`… this gets pretty python-y

* Call the right `__next__()`

* Properly implement the rest of the map wrappers here……

* Fix trailing whitespace situation and other linter complaints

* Remove data accidentally checked in

* Fix overlay augmentations

* Wavs must be open in rb mode if we're passing in an external file pointer -- this confused me

* Lint whitespace

* Revert "Fix trailing whitespace situation and other linter complaints"

This reverts commit c3c45397a2f98e9b00d00c18c4ced4fc52475032.

* Fix linter issue but without such an aggressive diff

* Move unpack_maybe into sample_collections

* Use unpack_maybe in place of duplicate lambda

* Fix confusing comment

* Add clarifying comment for on-the-fly unpacking
This commit is contained in:
Catalin Voss 2020-12-07 04:07:34 -08:00 committed by GitHub
parent 18b66adf46
commit 6640cf2341
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 249 additions and 55 deletions

View File

@ -15,8 +15,8 @@ def fail(message):
def compare_samples(): def compare_samples():
sample1 = load_sample(CLI_ARGS.sample1) sample1 = load_sample(CLI_ARGS.sample1).unpack()
sample2 = load_sample(CLI_ARGS.sample2) sample2 = load_sample(CLI_ARGS.sample2).unpack()
if sample1.audio_format != sample2.audio_format: if sample1.audio_format != sample2.audio_format:
fail('Samples differ on: audio-format ({} and {})'.format(sample1.audio_format, sample2.audio_format)) fail('Samples differ on: audio-format ({} and {})'.format(sample1.audio_format, sample2.audio_format))
if sample1.duration != sample2.duration: if sample1.duration != sample2.duration:

View File

@ -35,6 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur
from .util.flags import create_flags, FLAGS from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote
check_ctcdecoder_version() check_ctcdecoder_version()
@ -512,9 +513,10 @@ def train():
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints # Save flags next to checkpoints
if not is_remote_path(FLAGS.save_checkpoint_dir):
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout: with open_remote(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string()) fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session: with tfv1.Session(config=Config.session_config) as session:
@ -541,7 +543,7 @@ def train():
feature_cache_index = FLAGS.feature_cache + '.index' feature_cache_index = FLAGS.feature_cache + '.index'
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index): if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
log_info('Invalidating feature cache') log_info('Invalidating feature cache')
os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files
# Setup progress bar # Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel): class LossWidget(progressbar.widgets.FormatLabel):
@ -810,13 +812,13 @@ def export():
output_filename = FLAGS.export_file_name + '.pb' output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export: if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir): if isdir_remote(FLAGS.export_dir):
log_info('Removing old export') log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir) remove_remote(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename) output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir): if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir) os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants( frozen_graph = tfv1.graph_util.convert_variables_to_constants(
@ -829,7 +831,7 @@ def export():
dest_nodes=output_names) dest_nodes=output_names)
if not FLAGS.export_tflite: if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout: with open_remote(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString()) fout.write(frozen_graph.SerializeToString())
else: else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
@ -840,7 +842,7 @@ def export():
converter.allow_custom_ops = True converter.allow_custom_ops = True
tflite_model = converter.convert() tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout: with open_remote(output_tflite_path, 'wb') as fout:
fout.write(tflite_model) fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir)) log_info('Models exported at %s' % (FLAGS.export_dir))
@ -851,7 +853,7 @@ def export():
FLAGS.export_model_version)) FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow' model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f: with open_remote(metadata_fname, 'w') as f:
f.write('---\n') f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id)) f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name)) f.write('model_name: {}\n'.format(FLAGS.export_model_name))
@ -873,6 +875,10 @@ def export():
def package_zip(): def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/' export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
if is_remote_path(export_dir):
log_error("Cannot package remote path zip %s. Please do this manually." % export_dir)
return
zip_filename = os.path.dirname(export_dir) zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir) shutil.copy(FLAGS.scorer_path, export_dir)
@ -959,7 +965,7 @@ def main(_):
tfv1.reset_default_graph() tfv1.reset_default_graph()
FLAGS.export_tflite = True FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir): if listdir_remote(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1) sys.exit(1)

View File

@ -8,6 +8,7 @@ import numpy as np
from .helpers import LimitingPool from .helpers import LimitingPool
from collections import namedtuple from collections import namedtuple
from .io import open_remote, remove_remote, copy_remote, is_remote_path
AudioFormat = namedtuple('AudioFormat', 'rate channels width') AudioFormat = namedtuple('AudioFormat', 'rate channels width')
@ -117,15 +118,19 @@ class Sample:
self.audio_type = new_audio_type self.audio_type = new_audio_type
def _change_audio_type(sample_and_audio_type): def _unpack_and_change_audio_type(sample_and_audio_type):
sample, audio_type, bitrate = sample_and_audio_type packed_sample, audio_type, bitrate = sample_and_audio_type
if hasattr(packed_sample, 'unpack'):
sample = packed_sample.unpack()
else:
sample = packed_sample
sample.change_audio_type(audio_type, bitrate=bitrate) sample.change_audio_type(audio_type, bitrate=bitrate)
return sample return sample
def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None): def change_audio_types(packed_samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None):
with LimitingPool(processes=processes, process_ahead=process_ahead) as pool: with LimitingPool(processes=processes, process_ahead=process_ahead) as pool:
yield from pool.imap(_change_audio_type, map(lambda s: (s, audio_type, bitrate), samples)) yield from pool.imap(_unpack_and_change_audio_type, map(lambda s: (s, audio_type, bitrate), packed_samples))
def get_audio_type_from_extension(ext): def get_audio_type_from_extension(ext):
@ -168,29 +173,45 @@ class AudioFile:
self.audio_format = audio_format self.audio_format = audio_format
self.as_path = as_path self.as_path = as_path
self.open_file = None self.open_file = None
self.open_wav = None
self.tmp_file_path = None self.tmp_file_path = None
self.tmp_src_file_path = None
def __enter__(self): def __enter__(self):
if self.audio_path.endswith('.wav'): if self.audio_path.endswith('.wav'):
self.open_file = wave.open(self.audio_path, 'r') self.open_file = open_remote(self.audio_path, 'rb')
if read_audio_format_from_wav_file(self.open_file) == self.audio_format: self.open_wav = wave.open(self.open_file)
if read_audio_format_from_wav_file(self.open_wav) == self.audio_format:
if self.as_path: if self.as_path:
self.open_wav.close()
self.open_file.close() self.open_file.close()
return self.audio_path return self.audio_path
return self.open_file return self.open_wav
self.open_wav.close()
self.open_file.close() self.open_file.close()
# If the format isn't right, copy the file to local tmp dir and do the conversion on disk
if is_remote_path(self.audio_path):
_, self.tmp_src_file_path = tempfile.mkstemp(suffix='.wav')
copy_remote(self.audio_path, self.tmp_src_file_path)
self.audio_path = self.tmp_file_path
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav') _, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format) convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
if self.as_path: if self.as_path:
return self.tmp_file_path return self.tmp_file_path
self.open_file = wave.open(self.tmp_file_path, 'r') self.open_wav = wave.open(self.tmp_file_path, 'rb')
return self.open_file return self.open_wav
def __exit__(self, *args): def __exit__(self, *args):
if not self.as_path: if not self.as_path:
self.open_wav.close()
if self.open_file:
self.open_file.close() self.open_file.close()
if self.tmp_file_path is not None: if self.tmp_file_path is not None:
os.remove(self.tmp_file_path) os.remove(self.tmp_file_path)
if self.tmp_src_file_path is not None:
os.remove(self.tmp_src_file_path)
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False): def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
@ -320,6 +341,7 @@ def read_opus(opus_file):
def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
# wav_file is already a file-pointer here
with wave.open(wav_file, 'wb') as wav_file_writer: with wave.open(wav_file, 'wb') as wav_file_writer:
wav_file_writer.setframerate(audio_format.rate) wav_file_writer.setframerate(audio_format.rate)
wav_file_writer.setnchannels(audio_format.channels) wav_file_writer.setnchannels(audio_format.channels)

View File

@ -8,7 +8,7 @@ import numpy as np
from multiprocessing import Queue, Process from multiprocessing import Queue, Process
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE
from .sample_collections import samples_from_source from .sample_collections import samples_from_source, unpack_maybe
BUFFER_SIZE = 1 * MEGABYTE BUFFER_SIZE = 1 * MEGABYTE
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$') SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$')
@ -150,6 +150,12 @@ def _init_augmentation_worker(preparation_context):
AUGMENTATION_CONTEXT = preparation_context AUGMENTATION_CONTEXT = preparation_context
def _load_and_augment_sample(timed_sample, context=None):
sample, clock = timed_sample
realized_sample = unpack_maybe(sample)
return _augment_sample((realized_sample, clock), context)
def _augment_sample(timed_sample, context=None): def _augment_sample(timed_sample, context=None):
context = AUGMENTATION_CONTEXT if context is None else context context = AUGMENTATION_CONTEXT if context is None else context
sample, clock = timed_sample sample, clock = timed_sample
@ -213,12 +219,12 @@ def apply_sample_augmentations(samples,
context = AugmentationContext(audio_type, augmentations) context = AugmentationContext(audio_type, augmentations)
if process_ahead == 0: if process_ahead == 0:
for timed_sample in timed_samples(): for timed_sample in timed_samples():
yield _augment_sample(timed_sample, context=context) yield _load_and_augment_sample(timed_sample, context=context)
else: else:
with LimitingPool(process_ahead=process_ahead, with LimitingPool(process_ahead=process_ahead,
initializer=_init_augmentation_worker, initializer=_init_augmentation_worker,
initargs=(context,)) as pool: initargs=(context,)) as pool:
yield from pool.imap(_augment_sample, timed_samples()) yield from pool.imap(_load_and_augment_sample, timed_samples())
finally: finally:
for augmentation in augmentations: for augmentation in augmentations:
augmentation.stop() augmentation.stop()
@ -256,6 +262,7 @@ class Overlay(SampleAugmentation):
self.enqueue_process.start() self.enqueue_process.start()
def apply(self, sample, clock=0.0): def apply(self, sample, clock=0.0):
sample = unpack_maybe(sample)
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
n_layers = pick_value_from_range(self.layers, clock=clock) n_layers = pick_value_from_range(self.layers, clock=clock)
audio = sample.audio audio = sample.audio
@ -265,6 +272,7 @@ class Overlay(SampleAugmentation):
while overlay_offset < len(audio): while overlay_offset < len(audio):
if self.current_sample is None: if self.current_sample is None:
next_overlay_sample = self.queue.get() next_overlay_sample = self.queue.get()
next_overlay_sample = unpack_maybe(next_overlay_sample)
next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
self.current_sample = next_overlay_sample.audio self.current_sample = next_overlay_sample.audio
n_required = len(audio) - overlay_offset n_required = len(audio) - overlay_offset

View File

@ -19,6 +19,7 @@ import csv
import os import os
import sys import sys
import unicodedata import unicodedata
from .io import open_remote
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -27,14 +28,14 @@ def main():
parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true") parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true")
parser.add_argument("-unicode", "--disable-unicode-variants", help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", action="store_true") parser.add_argument("-unicode", "--disable-unicode-variants", help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", action="store_true")
args = parser.parse_args() args = parser.parse_args()
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")] in_files = args.csv_files.split(",")
print("### Reading in the following transcript files: ###") print("### Reading in the following transcript files: ###")
print("### {} ###".format(in_files)) print("### {} ###".format(in_files))
all_text = set() all_text = set()
for in_file in in_files: for in_file in in_files:
with open(in_file, "r") as csv_file: with open_remote(in_file, "r") as csv_file:
reader = csv.reader(csv_file) reader = csv.reader(csv_file)
try: try:
next(reader, None) # skip the file header (i.e. "transcript") next(reader, None) # skip the file header (i.e. "transcript")

View File

@ -13,7 +13,7 @@ from .gpu import get_available_gpus
from .logging import log_error, log_warn from .logging import log_error, log_warn
from .helpers import parse_file_size from .helpers import parse_file_size
from .augmentations import parse_augmentations from .augmentations import parse_augmentations
from .io import path_exists_remote
class ConfigSingleton: class ConfigSingleton:
_config = None _config = None
@ -139,7 +139,7 @@ def initialize_globals():
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000) c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)
if FLAGS.one_shot_infer: if FLAGS.one_shot_infer:
if not os.path.exists(FLAGS.one_shot_infer): if not path_exists_remote(FLAGS.one_shot_infer):
log_error('Path specified in --one_shot_infer is not a valid file.') log_error('Path specified in --one_shot_infer is not a valid file.')
sys.exit(1) sys.exit(1)

View File

@ -2,6 +2,7 @@ import requests
import progressbar import progressbar
from os import path, makedirs from os import path, makedirs
from .io import open_remote, path_exists_remote, is_remote_path
SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed'] SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed']
@ -9,17 +10,18 @@ def maybe_download(archive_name, target_dir, archive_url):
# If archive file does not exist, download it... # If archive file does not exist, download it...
archive_path = path.join(target_dir, archive_name) archive_path = path.join(target_dir, archive_name)
if not path.exists(target_dir): if not is_remote_path(target_dir) and not path.exists(target_dir):
print('No path "%s" - creating ...' % target_dir) print('No path "%s" - creating ...' % target_dir)
makedirs(target_dir) makedirs(target_dir)
if not path.exists(archive_path): if not path_exists_remote(archive_path):
print('No archive "%s" - downloading...' % archive_path) print('No archive "%s" - downloading...' % archive_path)
req = requests.get(archive_url, stream=True) req = requests.get(archive_url, stream=True)
total_size = int(req.headers.get('content-length', 0)) total_size = int(req.headers.get('content-length', 0))
done = 0 done = 0
with open(archive_path, 'wb') as f: with open_remote(archive_path, 'wb') as f:
bar = progressbar.ProgressBar(max_value=total_size if total_size > 0 else progressbar.UnknownLength, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=total_size if total_size > 0 else progressbar.UnknownLength, widgets=SIMPLE_BAR)
for data in req.iter_content(1024*1024): for data in req.iter_content(1024*1024):
done += len(data) done += len(data)
f.write(data) f.write(data)

View File

@ -10,7 +10,7 @@ from attrdict import AttrDict
from .flags import FLAGS from .flags import FLAGS
from .text import levenshtein from .text import levenshtein
from .io import open_remote
def pmap(fun, iterable): def pmap(fun, iterable):
pool = Pool() pool = Pool()
@ -124,5 +124,5 @@ def save_samples_json(samples, output_path):
We set ensure_ascii=True to prevent json from escaping non-ASCII chars We set ensure_ascii=True to prevent json from escaping non-ASCII chars
in the texts. in the texts.
''' '''
with open(output_path, 'w') as fout: with open_remote(output_path, 'w') as fout:
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2) json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)

View File

@ -78,6 +78,32 @@ class Interleaved:
return self.len return self.len
class LenMap:
"""
Wrapper around python map() output object that preserves the original collection length
by implementing __len__.
"""
def __init__(self, fn, iterable):
try:
self.length = len(iterable)
except TypeError:
self.length = None
self.mapobj = map(fn, iterable)
def __iter__(self):
self.mapobj = self.mapobj.__iter__()
return self
def __next__(self):
return self.mapobj.__next__()
def __getitem__(self, key):
return self.mapobj.__getitem__(key)
def __len__(self):
return self.length
class LimitingPool: class LimitingPool:
"""Limits unbound ahead-processing of multiprocessing.Pool's imap method """Limits unbound ahead-processing of multiprocessing.Pool's imap method
before items get consumed by the iteration caller. before items get consumed by the iteration caller.

View File

@ -0,0 +1,81 @@
"""
A set of I/O utils that allow us to open files on remote storage as if they were present locally and access
into HDFS storage using Tensorflow's C++ FileStream API.
Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets.
"""
import os
from tensorflow.io import gfile
def is_remote_path(path):
"""
Returns True iff the path is one of the remote formats that this
module supports
"""
return path.startswith('gs://') or path.startswith('hdfs://')
def path_exists_remote(path):
"""
Wrapper that allows existance check of local and remote paths like
`gs://...`
"""
if is_remote_path(path):
return gfile.exists(path)
return os.path.exists(path)
def copy_remote(src, dst, overwrite=False):
"""
Allows us to copy a file from local to remote or vice versa
"""
return gfile.copy(src, dst, overwrite)
def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, closefd=True, opener=None):
"""
Wrapper around open() method that can handle remote paths like `gs://...`
off Google Cloud using Tensorflow's IO helpers.
buffering, encoding, newline, closefd, and opener are ignored for remote files
This enables us to do:
with open_remote('gs://.....', mode='w+') as f:
do something with the file f, whether or not we have local access to it
"""
if is_remote_path(path):
return gfile.GFile(path, mode=mode)
return open(path, mode, buffering=buffering, encoding=encoding, newline=newline, closefd=closefd, opener=opener)
def isdir_remote(path):
"""
Wrapper to check if remote and local paths are directories
"""
if is_remote_path(path):
return gfile.isdir(path)
return os.path.isdir(path)
def listdir_remote(path):
"""
Wrapper to list paths in local dirs (alternative to using a glob, I suppose)
"""
if is_remote_path(path):
return gfile.listdir(path)
return os.listdir(path)
def glob_remote(filename):
"""
Wrapper that provides globs on local and remote paths like `gs://...`
"""
return gfile.glob(filename)
def remove_remote(filename):
"""
Wrapper that can remove local and remote files like `gs://...`
"""
# Conditional import
return gfile.remove_remote(filename)

View File

@ -8,7 +8,7 @@ import tarfile
from pathlib import Path from pathlib import Path
from functools import partial from functools import partial
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap
from .audio import ( from .audio import (
Sample, Sample,
DEFAULT_FORMAT, DEFAULT_FORMAT,
@ -18,6 +18,7 @@ from .audio import (
get_audio_type_from_extension, get_audio_type_from_extension,
write_wav write_wav
) )
from .io import open_remote, is_remote_path
BIG_ENDIAN = 'big' BIG_ENDIAN = 'big'
INT_SIZE = 4 INT_SIZE = 4
@ -59,6 +60,37 @@ class LabeledSample(Sample):
self.transcript = transcript self.transcript = transcript
class PackedSample:
"""
A wrapper that we can carry around in an iterator and pass to a child process in order to
have the child process do the loading/unpacking of the sample, allowing for parallel file
I/O.
"""
def __init__(self, filename, audio_type, label):
self.filename = filename
self.audio_type = audio_type
self.label = label
def unpack(self):
with open_remote(self.filename, 'rb') as audio_file:
data = audio_file.read()
if self.label is None:
s = Sample(self.audio_type, data, sample_id=self.filename)
s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename)
return s
def unpack_maybe(sample):
"""
Loads the supplied sample from disk (or the network) if the audio isn't loaded in to memory already.
"""
if hasattr(sample, 'unpack'):
realized_sample = sample.unpack()
else:
realized_sample = sample
return realized_sample
def load_sample(filename, label=None): def load_sample(filename, label=None):
""" """
Loads audio-file as a (labeled or unlabeled) sample Loads audio-file as a (labeled or unlabeled) sample
@ -69,21 +101,19 @@ def load_sample(filename, label=None):
Filename of the audio-file to load as sample Filename of the audio-file to load as sample
label : str label : str
Label (transcript) of the sample. Label (transcript) of the sample.
If None: return util.audio.Sample instance If None: returned result.unpack() will return util.audio.Sample instance
Otherwise: return util.sample_collections.LabeledSample instance Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance
Returns Returns
------- -------
util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
""" """
ext = os.path.splitext(filename)[1].lower() ext = os.path.splitext(filename)[1].lower()
audio_type = get_audio_type_from_extension(ext) audio_type = get_audio_type_from_extension(ext)
if audio_type is None: if audio_type is None:
raise ValueError('Unknown audio type extension "{}"'.format(ext)) raise ValueError('Unknown audio type extension "{}"'.format(ext))
with open(filename, 'rb') as audio_file: return PackedSample(filename, audio_type, label)
if label is None:
return Sample(audio_type, audio_file.read(), sample_id=filename)
return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename)
class DirectSDBWriter: class DirectSDBWriter:
@ -119,7 +149,7 @@ class DirectSDBWriter:
raise ValueError('Audio type "{}" not supported'.format(audio_type)) raise ValueError('Audio type "{}" not supported'.format(audio_type))
self.audio_type = audio_type self.audio_type = audio_type
self.bitrate = bitrate self.bitrate = bitrate
self.sdb_file = open(sdb_filename, 'wb', buffering=buffering) self.sdb_file = open_remote(sdb_filename, 'wb', buffering=buffering)
self.offsets = [] self.offsets = []
self.num_samples = 0 self.num_samples = 0
@ -215,7 +245,7 @@ class SDB: # pylint: disable=too-many-instance-attributes
""" """
self.sdb_filename = sdb_filename self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering) self.sdb_file = open_remote(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
self.offsets = [] self.offsets = []
if self.sdb_file.read(len(MAGIC)) != MAGIC: if self.sdb_file.read(len(MAGIC)) != MAGIC:
raise RuntimeError('No Sample Database') raise RuntimeError('No Sample Database')
@ -332,6 +362,8 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
labeled : bool or None labeled : bool or None
If True: Writes labeled samples (util.sample_collections.LabeledSample) only. If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
Currently only works with local files (not gs:// or hdfs://...)
""" """
self.csv_filename = Path(csv_filename) self.csv_filename = Path(csv_filename)
self.csv_base_dir = self.csv_filename.parent.resolve().absolute() self.csv_base_dir = self.csv_filename.parent.resolve().absolute()
@ -345,7 +377,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
self.labeled = labeled self.labeled = labeled
if labeled: if labeled:
fieldnames.append('transcript') fieldnames.append('transcript')
self.csv_file = open(csv_filename, 'w', encoding='utf-8', newline='') self.csv_file = open_remote(csv_filename, 'w', encoding='utf-8', newline='')
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.csv_writer.writeheader() self.csv_writer.writeheader()
self.counter = 0 self.counter = 0
@ -380,7 +412,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
class TarWriter: # pylint: disable=too-many-instance-attributes class TarWriter: # pylint: disable=too-many-instance-attributes
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file""" """Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file."""
def __init__(self, def __init__(self,
tar_filename, tar_filename,
gz=False, gz=False,
@ -398,6 +430,8 @@ class TarWriter: # pylint: disable=too-many-instance-attributes
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
include : str[] include : str[]
List of files to include into tar root. List of files to include into tar root.
Currently only works with local files (not gs:// or hdfs://...)
""" """
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w') self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w')
samples_dir = tarfile.TarInfo('samples') samples_dir = tarfile.TarInfo('samples')
@ -498,8 +532,7 @@ class CSV(SampleList):
If the order of the samples should be reversed If the order of the samples should be reversed
""" """
rows = [] rows = []
csv_dir = Path(csv_filename).parent with open_remote(csv_filename, 'r', encoding='utf8') as csv_file:
with open(csv_filename, 'r', encoding='utf8') as csv_file:
reader = csv.DictReader(csv_file) reader = csv.DictReader(csv_file)
if 'transcript' in reader.fieldnames: if 'transcript' in reader.fieldnames:
if labeled is None: if labeled is None:
@ -508,9 +541,12 @@ class CSV(SampleList):
raise RuntimeError('No transcript data (missing CSV column)') raise RuntimeError('No transcript data (missing CSV column)')
for row in reader: for row in reader:
wav_filename = Path(row['wav_filename']) wav_filename = Path(row['wav_filename'])
if not wav_filename.is_absolute(): if not wav_filename.is_absolute() and not is_remote_path(row['wav_filename']):
wav_filename = csv_dir / wav_filename wav_filename = Path(csv_filename).parent / wav_filename
wav_filename = str(wav_filename) wav_filename = str(wav_filename)
else:
# Pathlib otherwise removes a / from filenames like hdfs://
wav_filename = row['wav_filename']
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0 wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
if labeled: if labeled:
rows.append((wav_filename, wav_filesize, row['transcript'])) rows.append((wav_filename, wav_filesize, row['transcript']))
@ -554,6 +590,11 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
keep default sample order from shortest to longest. keep default sample order from shortest to longest.
Note that when using distributed training, it is much faster to call this function with single pre-
sorted sample source, because this allows for parallelization of the file I/O. (If this function is
called with multiple sources, the samples have to be unpacked on a single parent process to allow
for reading their durations.)
Parameters Parameters
---------- ----------
sample_sources : list of str sample_sources : list of str
@ -570,13 +611,20 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re
Returns Returns
------- -------
iterable of util.sample_collections.LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len iterable of util.sample_collections.PackedSample if a single collection is provided, wrapping
LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len
or LabeledSample / util.audio.Sample directly, if multiple collections are provided
""" """
sample_sources = list(sample_sources) sample_sources = list(sample_sources)
if len(sample_sources) == 0: if len(sample_sources) == 0:
raise ValueError('No files') raise ValueError('No files')
if len(sample_sources) == 1: if len(sample_sources) == 1:
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse) return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse)
cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)
# If we wish to interleave based on duration, we have to unpack the audio. Note that this unpacking should
# be done lazily onn the fly so that it respects the LimitingPool logic used in the feeding code.
cols = [LenMap(
unpack_maybe, samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse))
for source in sample_sources] for source in sample_sources]
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse) return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)