From 53e3f5374fad861585d5877823533af7d024dd22 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 10:44:19 -0800 Subject: [PATCH 01/30] Add I/O helpers for remote file access --- training/deepspeech_training/util/io.py | 76 +++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 training/deepspeech_training/util/io.py diff --git a/training/deepspeech_training/util/io.py b/training/deepspeech_training/util/io.py new file mode 100644 index 00000000..4801c075 --- /dev/null +++ b/training/deepspeech_training/util/io.py @@ -0,0 +1,76 @@ +""" +A set of I/O utils that allow us to open files on remote storage as if they were present locally. +Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets. +""" +import inspect +import os +import sys + +def path_exists_remote(path): + """ + Wrapper that allows existance check of local and remote paths like + `gs://...` + """ + # Conditional import + if path.startswith("gs://"): + from tensorflow.io import gfile + return gfile.exists(path) + return path_exists_remotes(path) + + +def open_remote(path, mode): + """ + Wrapper around open_remote() method that can handle remote paths like `gs://...` + off Google Cloud using Tensorflow's IO helpers. + + This enables us to do: + with open_remote('gs://.....', mode='w+') as f: + do something with the file f, whether or not we have local access to it + """ + # Conditional import + if path.startswith("gs://"): + from tensorflow.io import gfile + return gfile.GFile(path, mode=mode) + return open_remote(path, mode) + + +def isdir_remote(path): + """ + Wrapper to check if remote and local paths are directories + """ + # Conditional import + if path.startswith("gs://"): + from tensorflow.io import gfile + return gfile.isdir(path) + return isdir_remote(path) + + +def listdir_remote(path): + """ + Wrapper to list paths in local dirs (alternative to using a glob, I suppose) + """ + # Conditional import + if path.startswith("gs://"): + from tensorflow.io import gfile + return gfile.listdir(path) + return os.listdir(path) + + +def glob_remote(filename): + """ + Wrapper that provides globs on local and remote paths like `gs://...` + """ + # Conditional import + from tensorflow.io import gfile + + return gfile.glob(filename) + + +def remove_remote(filename): + """ + Wrapper that can remove_remote local and remote files like `gs://...` + """ + # Conditional import + from tensorflow.io import gfile + + return gfile.remove_remote(filename) \ No newline at end of file From 579921cc9250e86e4aee566df8898e10fffad67b Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 10:45:35 -0800 Subject: [PATCH 02/30] Work remote I/O into train script --- training/deepspeech_training/train.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index 8bf7a354..d94e8a45 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -35,6 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur from .util.flags import create_flags, FLAGS from .util.helpers import check_ctcdecoder_version, ExceptionBox from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn +from .util.io import open_remote, remove_remote, listdir_remote check_ctcdecoder_version() @@ -514,7 +515,7 @@ def train(): # Save flags next to checkpoints os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') - with open(flags_file, 'w') as fout: + with open_remote(flags_file, 'w') as fout: fout.write(FLAGS.flags_into_string()) with tfv1.Session(config=Config.session_config) as session: @@ -541,7 +542,7 @@ def train(): feature_cache_index = FLAGS.feature_cache + '.index' if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index): log_info('Invalidating feature cache') - os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files + remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files # Setup progress bar class LossWidget(progressbar.widgets.FormatLabel): @@ -773,7 +774,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): def file_relative_read(fname): - return open(os.path.join(os.path.dirname(__file__), fname)).read() + return open_remote(os.path.join(os.path.dirname(__file__), fname)).read() def export(): @@ -809,14 +810,14 @@ def export(): load_graph_for_evaluation(session) output_filename = FLAGS.export_file_name + '.pb' - if FLAGS.remove_export: - if os.path.isdir(FLAGS.export_dir): + if FLAGS.remove_remote_export: + if isdir_remote(FLAGS.export_dir): log_info('Removing old export') shutil.rmtree(FLAGS.export_dir) output_graph_path = os.path.join(FLAGS.export_dir, output_filename) - if not os.path.isdir(FLAGS.export_dir): + if not isdir_remote(FLAGS.export_dir): os.makedirs(FLAGS.export_dir) frozen_graph = tfv1.graph_util.convert_variables_to_constants( @@ -829,7 +830,7 @@ def export(): dest_nodes=output_names) if not FLAGS.export_tflite: - with open(output_graph_path, 'wb') as fout: + with open_remote(output_graph_path, 'wb') as fout: fout.write(frozen_graph.SerializeToString()) else: output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) @@ -840,7 +841,7 @@ def export(): converter.allow_custom_ops = True tflite_model = converter.convert() - with open(output_tflite_path, 'wb') as fout: + with open_remote(output_tflite_path, 'wb') as fout: fout.write(tflite_model) log_info('Models exported at %s' % (FLAGS.export_dir)) @@ -851,7 +852,7 @@ def export(): FLAGS.export_model_version)) model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow' - with open(metadata_fname, 'w') as f: + with open_remote(metadata_fname, 'w') as f: f.write('---\n') f.write('author: {}\n'.format(FLAGS.export_author_id)) f.write('model_name: {}\n'.format(FLAGS.export_model_name)) @@ -959,7 +960,7 @@ def main(_): tfv1.reset_default_graph() FLAGS.export_tflite = True - if os.listdir(FLAGS.export_dir): + if listdir_remote(FLAGS.export_dir): log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) sys.exit(1) From 83e5cf0416fdbd51b90ea3aec52042a9374d3c1a Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 10:46:15 -0800 Subject: [PATCH 03/30] Remote I/O fro check_characters --- training/deepspeech_training/util/check_characters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/training/deepspeech_training/util/check_characters.py b/training/deepspeech_training/util/check_characters.py index f155b4ac..b40c5b3a 100644 --- a/training/deepspeech_training/util/check_characters.py +++ b/training/deepspeech_training/util/check_characters.py @@ -19,6 +19,7 @@ import csv import os import sys import unicodedata +from .util.io import open_remote def main(): parser = argparse.ArgumentParser() @@ -34,7 +35,7 @@ def main(): all_text = set() for in_file in in_files: - with open(in_file, "r") as csv_file: + with open_remote(in_file, "r") as csv_file: reader = csv.reader(csv_file) try: next(reader, None) # skip the file header (i.e. "transcript") From 42170a57eb4d14120b847cde95998b3c91d9b7d7 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 10:46:49 -0800 Subject: [PATCH 04/30] Remote I/O for config --- training/deepspeech_training/util/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training/deepspeech_training/util/config.py b/training/deepspeech_training/util/config.py index 0b9929e5..17d8a5a0 100755 --- a/training/deepspeech_training/util/config.py +++ b/training/deepspeech_training/util/config.py @@ -13,7 +13,7 @@ from .gpu import get_available_gpus from .logging import log_error, log_warn from .helpers import parse_file_size from .augmentations import parse_augmentations - +from .util.io import path_exists_remote class ConfigSingleton: _config = None @@ -139,7 +139,7 @@ def initialize_globals(): c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000) if FLAGS.one_shot_infer: - if not os.path.exists(FLAGS.one_shot_infer): + if not path_exists_remote(FLAGS.one_shot_infer): log_error('Path specified in --one_shot_infer is not a valid file.') sys.exit(1) From 933d96dc7435074a3861627c29ee88fecfef773a Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 10:47:26 -0800 Subject: [PATCH 05/30] Fix relative imports --- training/deepspeech_training/util/check_characters.py | 2 +- training/deepspeech_training/util/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/training/deepspeech_training/util/check_characters.py b/training/deepspeech_training/util/check_characters.py index b40c5b3a..bde69d74 100644 --- a/training/deepspeech_training/util/check_characters.py +++ b/training/deepspeech_training/util/check_characters.py @@ -19,7 +19,7 @@ import csv import os import sys import unicodedata -from .util.io import open_remote +from .io import open_remote def main(): parser = argparse.ArgumentParser() diff --git a/training/deepspeech_training/util/config.py b/training/deepspeech_training/util/config.py index 17d8a5a0..18da6eed 100755 --- a/training/deepspeech_training/util/config.py +++ b/training/deepspeech_training/util/config.py @@ -13,7 +13,7 @@ from .gpu import get_available_gpus from .logging import log_error, log_warn from .helpers import parse_file_size from .augmentations import parse_augmentations -from .util.io import path_exists_remote +from .io import path_exists_remote class ConfigSingleton: _config = None From 396ac7fe4685c9eeeed5e1dd8a9c9d69e56019c7 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 10:48:49 -0800 Subject: [PATCH 06/30] Remote I/O for downloader --- training/deepspeech_training/util/downloader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/training/deepspeech_training/util/downloader.py b/training/deepspeech_training/util/downloader.py index 9fcbf674..0a40c481 100644 --- a/training/deepspeech_training/util/downloader.py +++ b/training/deepspeech_training/util/downloader.py @@ -2,6 +2,7 @@ import requests import progressbar from os import path, makedirs +from .io import open_remote, path_exists_remote SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed'] @@ -9,16 +10,16 @@ def maybe_download(archive_name, target_dir, archive_url): # If archive file does not exist, download it... archive_path = path.join(target_dir, archive_name) - if not path.exists(target_dir): + if not path_exists_remote(target_dir): print('No path "%s" - creating ...' % target_dir) makedirs(target_dir) - if not path.exists(archive_path): + if not path_exists_remote(archive_path): print('No archive "%s" - downloading...' % archive_path) req = requests.get(archive_url, stream=True) total_size = int(req.headers.get('content-length', 0)) done = 0 - with open(archive_path, 'wb') as f: + with open_remote(archive_path, 'wb') as f: bar = progressbar.ProgressBar(max_value=total_size, widgets=SIMPLE_BAR) for data in req.iter_content(1024*1024): done += len(data) From 7de317cf59289ece0d3cce92f7171d9d68554aa5 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 10:49:33 -0800 Subject: [PATCH 07/30] Remote I/O for evaluate_tools --- training/deepspeech_training/util/evaluate_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training/deepspeech_training/util/evaluate_tools.py b/training/deepspeech_training/util/evaluate_tools.py index 66fc8293..68d29f3e 100644 --- a/training/deepspeech_training/util/evaluate_tools.py +++ b/training/deepspeech_training/util/evaluate_tools.py @@ -10,7 +10,7 @@ from attrdict import AttrDict from .flags import FLAGS from .text import levenshtein - +from .io import open_remote def pmap(fun, iterable): pool = Pool() @@ -124,5 +124,5 @@ def save_samples_json(samples, output_path): We set ensure_ascii=True to prevent json from escaping non-ASCII chars in the texts. ''' - with open(output_path, 'w') as fout: + with open_remote(output_path, 'w') as fout: json.dump(samples, fout, default=float, ensure_ascii=False, indent=2) From 296b74e01a9409beb593a69ae885b30875031bb2 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 10:54:44 -0800 Subject: [PATCH 08/30] Remote I/O for sample_collections --- .../deepspeech_training/util/sample_collections.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index 3f1b55ea..2467854d 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -18,6 +18,7 @@ from .audio import ( get_audio_type_from_extension, write_wav ) +from .io import open_remote BIG_ENDIAN = 'big' INT_SIZE = 4 @@ -80,7 +81,7 @@ def load_sample(filename, label=None): audio_type = get_audio_type_from_extension(ext) if audio_type is None: raise ValueError('Unknown audio type extension "{}"'.format(ext)) - with open(filename, 'rb') as audio_file: + with open_remote(filename, 'rb') as audio_file: if label is None: return Sample(audio_type, audio_file.read(), sample_id=filename) return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename) @@ -119,7 +120,7 @@ class DirectSDBWriter: raise ValueError('Audio type "{}" not supported'.format(audio_type)) self.audio_type = audio_type self.bitrate = bitrate - self.sdb_file = open(sdb_filename, 'wb', buffering=buffering) + self.sdb_file = open_remote(sdb_filename, 'wb', buffering=buffering) self.offsets = [] self.num_samples = 0 @@ -215,7 +216,7 @@ class SDB: # pylint: disable=too-many-instance-attributes """ self.sdb_filename = sdb_filename self.id_prefix = sdb_filename if id_prefix is None else id_prefix - self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering) + self.sdb_file = open_remote(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering) self.offsets = [] if self.sdb_file.read(len(MAGIC)) != MAGIC: raise RuntimeError('No Sample Database') @@ -345,7 +346,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes self.labeled = labeled if labeled: fieldnames.append('transcript') - self.csv_file = open(csv_filename, 'w', encoding='utf-8', newline='') + self.csv_file = open_remote(csv_filename, 'w', encoding='utf-8', newline='') self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) self.csv_writer.writeheader() self.counter = 0 @@ -399,7 +400,7 @@ class TarWriter: # pylint: disable=too-many-instance-attributes include : str[] List of files to include into tar root. """ - self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w') + self.tar = tarfile.open_remote(tar_filename, 'w:gz' if gz else 'w') samples_dir = tarfile.TarInfo('samples') samples_dir.type = tarfile.DIRTYPE self.tar.addfile(samples_dir) @@ -499,7 +500,7 @@ class CSV(SampleList): """ rows = [] csv_dir = Path(csv_filename).parent - with open(csv_filename, 'r', encoding='utf8') as csv_file: + with open_remote(csv_filename, 'r', encoding='utf8') as csv_file: reader = csv.DictReader(csv_file) if 'transcript' in reader.fieldnames: if labeled is None: From abe5dd2eb4bbab122c96138f841d9f6a572f0ca9 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 12:49:44 -0800 Subject: [PATCH 09/30] Remote I/O for taskcluster --- training/deepspeech_training/util/taskcluster.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/training/deepspeech_training/util/taskcluster.py b/training/deepspeech_training/util/taskcluster.py index d0053c7d..1a5200ab 100644 --- a/training/deepspeech_training/util/taskcluster.py +++ b/training/deepspeech_training/util/taskcluster.py @@ -14,6 +14,7 @@ import sys from pkg_resources import parse_version +from .io import isdir_remote, open_remote DEFAULT_SCHEMES = { 'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s', @@ -48,7 +49,7 @@ def maybe_download_tc(target_dir, tc_url, progress=True): except OSError as e: if e.errno != errno.EEXIST: raise e - assert os.path.isdir(os.path.dirname(target_dir)) + assert isdir_remote(os.path.dirname(target_dir)) tc_filename = os.path.basename(tc_url) target_file = os.path.join(target_dir, tc_filename) @@ -61,7 +62,7 @@ def maybe_download_tc(target_dir, tc_url, progress=True): print('File already exists: %s' % target_file) if is_gzip: - with open(target_file, "r+b") as frw: + with open_remote(target_file, "r+b") as frw: decompressed = gzip.decompress(frw.read()) frw.seek(0) frw.write(decompressed) @@ -75,7 +76,7 @@ def maybe_download_tc_bin(**kwargs): os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC) def read(fname): - return open(os.path.join(os.path.dirname(__file__), fname)).read() + return open_remote(os.path.join(os.path.dirname(__file__), fname)).read() def main(): parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.') From c3dc4c0d5c1a301c661d957bfaf6e2aae36dc20c Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 14:06:22 -0800 Subject: [PATCH 10/30] Fix bad I/O helper fn replace errors --- training/deepspeech_training/util/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training/deepspeech_training/util/io.py b/training/deepspeech_training/util/io.py index 4801c075..eb177b76 100644 --- a/training/deepspeech_training/util/io.py +++ b/training/deepspeech_training/util/io.py @@ -15,7 +15,7 @@ def path_exists_remote(path): if path.startswith("gs://"): from tensorflow.io import gfile return gfile.exists(path) - return path_exists_remotes(path) + return os.path.exists(path) def open_remote(path, mode): @@ -42,7 +42,7 @@ def isdir_remote(path): if path.startswith("gs://"): from tensorflow.io import gfile return gfile.isdir(path) - return isdir_remote(path) + return os.path.isdir(path) def listdir_remote(path): From 3d503bd69ec11e455cad2b39275e033c5642eb32 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 14:16:37 -0800 Subject: [PATCH 11/30] Add universal is_remote_path to I/O helper --- training/deepspeech_training/util/io.py | 28 ++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/training/deepspeech_training/util/io.py b/training/deepspeech_training/util/io.py index eb177b76..3d9b3dc0 100644 --- a/training/deepspeech_training/util/io.py +++ b/training/deepspeech_training/util/io.py @@ -1,23 +1,41 @@ """ -A set of I/O utils that allow us to open files on remote storage as if they were present locally. +A set of I/O utils that allow us to open files on remote storage as if they were present locally and access +into HDFS storage using Tensorflow's C++ FileStream API. Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets. """ import inspect import os import sys + +def is_remote_path(path): + """ + Returns True iff the path is one of the remote formats that this + module supports + """ + return path.startswith('gs://') or path.starts_with('hdfs://') + + def path_exists_remote(path): """ Wrapper that allows existance check of local and remote paths like `gs://...` """ # Conditional import - if path.startswith("gs://"): + if is_remote_path(path): from tensorflow.io import gfile return gfile.exists(path) return os.path.exists(path) +def copy_remote(src, dst, overwrite=False): + """ + Allows us to copy a file from local to remote or vice versa + """ + from tensorflow.io import gfile + return gfile.copy(src, dst, overwrite) + + def open_remote(path, mode): """ Wrapper around open_remote() method that can handle remote paths like `gs://...` @@ -28,7 +46,7 @@ def open_remote(path, mode): do something with the file f, whether or not we have local access to it """ # Conditional import - if path.startswith("gs://"): + if is_remote_path(path): from tensorflow.io import gfile return gfile.GFile(path, mode=mode) return open_remote(path, mode) @@ -39,7 +57,7 @@ def isdir_remote(path): Wrapper to check if remote and local paths are directories """ # Conditional import - if path.startswith("gs://"): + if is_remote_path(path): from tensorflow.io import gfile return gfile.isdir(path) return os.path.isdir(path) @@ -50,7 +68,7 @@ def listdir_remote(path): Wrapper to list paths in local dirs (alternative to using a glob, I suppose) """ # Conditional import - if path.startswith("gs://"): + if is_remote_path(path): from tensorflow.io import gfile return gfile.listdir(path) return os.listdir(path) From ad0883042126f265bcc6c7180313746beb257535 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 14:17:03 -0800 Subject: [PATCH 12/30] Work remote I/O into audio utils -- a bit more involved --- training/deepspeech_training/util/audio.py | 34 ++++++++++++++++------ 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/training/deepspeech_training/util/audio.py b/training/deepspeech_training/util/audio.py index 031f13ed..da1a9acb 100644 --- a/training/deepspeech_training/util/audio.py +++ b/training/deepspeech_training/util/audio.py @@ -8,6 +8,7 @@ import numpy as np from .helpers import LimitingPool from collections import namedtuple +from .io import open_remote, remove_remote, copy_remote, is_remote_path AudioFormat = namedtuple('AudioFormat', 'rate channels width') @@ -168,29 +169,44 @@ class AudioFile: self.audio_format = audio_format self.as_path = as_path self.open_file = None + self.open_wav = None self.tmp_file_path = None def __enter__(self): if self.audio_path.endswith('.wav'): - self.open_file = wave.open(self.audio_path, 'r') - if read_audio_format_from_wav_file(self.open_file) == self.audio_format: + self.open_file = open_remote(self.audio_path, 'r') + self.open_wav = wave.open(self.open_file) + if read_audio_format_from_wav_file(self.open_wav) == self.audio_format: if self.as_path: + self.open_wav.close() self.open_file.close() return self.audio_path - return self.open_file + return self.open_wav + self.open_wav.close() self.open_file.close() + + # If the format isn't right, copy the file to local tmp dir and do the conversion on disk + if is_remote_path(self.audio_path): + _, self.tmp_src_file_path = tempfile.mkstemp(suffix='.wav') + copy_remote(self.audio_path, self.tmp_src_file_path) + self.audio_path = self.tmp_file_path + _, self.tmp_file_path = tempfile.mkstemp(suffix='.wav') convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format) if self.as_path: return self.tmp_file_path - self.open_file = wave.open(self.tmp_file_path, 'r') - return self.open_file + self.open_wav = wave.open(self.tmp_file_path, 'r') + return self.open_wav def __exit__(self, *args): if not self.as_path: - self.open_file.close() + self.open_wav.close() + if self.open_file: + self.open_file.close() if self.tmp_file_path is not None: os.remove(self.tmp_file_path) + if self.tmp_src_file_path is not None: + os.remove(self.tmp_src_file_path) def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False): @@ -320,7 +336,7 @@ def read_opus(opus_file): def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): - with wave.open(wav_file, 'wb') as wav_file_writer: + with wave.open_remote(wav_file, 'wb') as wav_file_writer: wav_file_writer.setframerate(audio_format.rate) wav_file_writer.setnchannels(audio_format.channels) wav_file_writer.setsampwidth(audio_format.width) @@ -329,7 +345,7 @@ def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): def read_wav(wav_file): wav_file.seek(0) - with wave.open(wav_file, 'rb') as wav_file_reader: + with wave.open_remote(wav_file, 'rb') as wav_file_reader: audio_format = read_audio_format_from_wav_file(wav_file_reader) pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes()) return audio_format, pcm_data @@ -353,7 +369,7 @@ def write_audio(audio_type, audio_file, pcm_data, audio_format=DEFAULT_FORMAT, b def read_wav_duration(wav_file): wav_file.seek(0) - with wave.open(wav_file, 'rb') as wav_file_reader: + with wave.open_remote(wav_file, 'rb') as wav_file_reader: return wav_file_reader.getnframes() / wav_file_reader.getframerate() From 90e2e1f7d26cffe603dc47d75b1fda6f330a4799 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 14:45:05 -0800 Subject: [PATCH 13/30] Respect buffering, encoding, newline, closefd, and opener if we're looking at a local file --- training/deepspeech_training/util/io.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/training/deepspeech_training/util/io.py b/training/deepspeech_training/util/io.py index 3d9b3dc0..7d72f910 100644 --- a/training/deepspeech_training/util/io.py +++ b/training/deepspeech_training/util/io.py @@ -36,11 +36,13 @@ def copy_remote(src, dst, overwrite=False): return gfile.copy(src, dst, overwrite) -def open_remote(path, mode): +def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, closefd=True, opener=None): """ Wrapper around open_remote() method that can handle remote paths like `gs://...` off Google Cloud using Tensorflow's IO helpers. + buffering, encoding, newline, closefd, and opener are ignored for remote files + This enables us to do: with open_remote('gs://.....', mode='w+') as f: do something with the file f, whether or not we have local access to it @@ -49,7 +51,7 @@ def open_remote(path, mode): if is_remote_path(path): from tensorflow.io import gfile return gfile.GFile(path, mode=mode) - return open_remote(path, mode) + return open(path, mode, buffering=buffering, encoding=encoding, newline=newline, closefd=closefd, opener=opener) def isdir_remote(path): From 8f310729989db0fd1b9368d258fe89bad352b06b Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 15:09:42 -0800 Subject: [PATCH 14/30] Fix startswith check --- training/deepspeech_training/util/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/deepspeech_training/util/io.py b/training/deepspeech_training/util/io.py index 7d72f910..885a276d 100644 --- a/training/deepspeech_training/util/io.py +++ b/training/deepspeech_training/util/io.py @@ -13,7 +13,7 @@ def is_remote_path(path): Returns True iff the path is one of the remote formats that this module supports """ - return path.startswith('gs://') or path.starts_with('hdfs://') + return path.startswith('gs://') or path.startswith('hdfs://') def path_exists_remote(path): From a6322b384e9c0c55f72151799ce78e2728117626 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 16:29:16 -0800 Subject: [PATCH 15/30] Fix remote I/O handling in train --- training/deepspeech_training/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index d94e8a45..2e7263b1 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -35,7 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur from .util.flags import create_flags, FLAGS from .util.helpers import check_ctcdecoder_version, ExceptionBox from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn -from .util.io import open_remote, remove_remote, listdir_remote +from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote check_ctcdecoder_version() @@ -513,7 +513,8 @@ def train(): best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') # Save flags next to checkpoints - os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) + if not is_remote_path(FLAGS.save_checkpoint_dir): + os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') with open_remote(flags_file, 'w') as fout: fout.write(FLAGS.flags_into_string()) @@ -813,11 +814,11 @@ def export(): if FLAGS.remove_remote_export: if isdir_remote(FLAGS.export_dir): log_info('Removing old export') - shutil.rmtree(FLAGS.export_dir) + remove_remote(FLAGS.export_dir) output_graph_path = os.path.join(FLAGS.export_dir, output_filename) - if not isdir_remote(FLAGS.export_dir): + if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir): os.makedirs(FLAGS.export_dir) frozen_graph = tfv1.graph_util.convert_variables_to_constants( From 0030cab22078592134ce9950c8e6fb603af0679b Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 16:29:23 -0800 Subject: [PATCH 16/30] Skip remote zipping for now --- training/deepspeech_training/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index 2e7263b1..6ebe29a6 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -875,8 +875,12 @@ def export(): def package_zip(): # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/' - zip_filename = os.path.dirname(export_dir) + if is_remote_path(export_dir): + log_error("Cannot package remote path zip %s. Please do this manually." % export_dir) + return + zip_filename = os.path.dirname(export_dir) + shutil.copy(FLAGS.scorer_path, export_dir) archive = shutil.make_archive(zip_filename, 'zip', export_dir) From 64d278560dc20b6f623bc3770e527e6d9f551829 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 16:29:43 -0800 Subject: [PATCH 17/30] Why do we need absolute paths everywhere here? --- training/deepspeech_training/util/check_characters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/deepspeech_training/util/check_characters.py b/training/deepspeech_training/util/check_characters.py index bde69d74..7e6cdd0b 100644 --- a/training/deepspeech_training/util/check_characters.py +++ b/training/deepspeech_training/util/check_characters.py @@ -28,7 +28,7 @@ def main(): parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true") parser.add_argument("-unicode", "--disable-unicode-variants", help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", action="store_true") args = parser.parse_args() - in_files = [os.path.abspath(i) for i in args.csv_files.split(",")] + in_files = args.csv_files.split(",") print("### Reading in the following transcript files: ###") print("### {} ###".format(in_files)) From 783cdad8db471cc33c0f9d9fa79b0b1c8d4c198b Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 16:30:11 -0800 Subject: [PATCH 18/30] Fix downloader and taskcluster directory mgmt with remote I/O --- training/deepspeech_training/util/downloader.py | 4 ++-- training/deepspeech_training/util/taskcluster.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/training/deepspeech_training/util/downloader.py b/training/deepspeech_training/util/downloader.py index 0a40c481..b8fcdb8d 100644 --- a/training/deepspeech_training/util/downloader.py +++ b/training/deepspeech_training/util/downloader.py @@ -2,7 +2,7 @@ import requests import progressbar from os import path, makedirs -from .io import open_remote, path_exists_remote +from .io import open_remote, path_exists_remote, is_remote_path SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed'] @@ -10,7 +10,7 @@ def maybe_download(archive_name, target_dir, archive_url): # If archive file does not exist, download it... archive_path = path.join(target_dir, archive_name) - if not path_exists_remote(target_dir): + if not is_remote_path(target_dir) and not path.exists(target_dir): print('No path "%s" - creating ...' % target_dir) makedirs(target_dir) diff --git a/training/deepspeech_training/util/taskcluster.py b/training/deepspeech_training/util/taskcluster.py index 1a5200ab..ba4f2019 100644 --- a/training/deepspeech_training/util/taskcluster.py +++ b/training/deepspeech_training/util/taskcluster.py @@ -14,7 +14,7 @@ import sys from pkg_resources import parse_version -from .io import isdir_remote, open_remote +from .io import isdir_remote, open_remote, is_remote_path DEFAULT_SCHEMES = { 'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s', @@ -43,13 +43,13 @@ def maybe_download_tc(target_dir, tc_url, progress=True): assert target_dir is not None - target_dir = os.path.abspath(target_dir) - try: - os.makedirs(target_dir) - except OSError as e: - if e.errno != errno.EEXIST: - raise e - assert isdir_remote(os.path.dirname(target_dir)) + if not is_remote_path(target_dir): + try: + os.makedirs(target_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise e + assert os.path.isdir(os.path.dirname(target_dir)) tc_filename = os.path.basename(tc_url) target_file = os.path.join(target_dir, tc_filename) From 8fe972eb6f296f0bb1bbb4b8f51657c7660656b6 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 16:40:40 -0800 Subject: [PATCH 19/30] Fix wave file reading helpers --- training/deepspeech_training/util/audio.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/training/deepspeech_training/util/audio.py b/training/deepspeech_training/util/audio.py index da1a9acb..5e2ed5d9 100644 --- a/training/deepspeech_training/util/audio.py +++ b/training/deepspeech_training/util/audio.py @@ -336,7 +336,8 @@ def read_opus(opus_file): def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): - with wave.open_remote(wav_file, 'wb') as wav_file_writer: + # wav_file is already a file-pointer here + with wave.open(wav_file, 'wb') as wav_file_writer: wav_file_writer.setframerate(audio_format.rate) wav_file_writer.setnchannels(audio_format.channels) wav_file_writer.setsampwidth(audio_format.width) @@ -345,7 +346,7 @@ def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): def read_wav(wav_file): wav_file.seek(0) - with wave.open_remote(wav_file, 'rb') as wav_file_reader: + with wave.open(wav_file, 'rb') as wav_file_reader: audio_format = read_audio_format_from_wav_file(wav_file_reader) pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes()) return audio_format, pcm_data @@ -369,7 +370,7 @@ def write_audio(audio_type, audio_file, pcm_data, audio_format=DEFAULT_FORMAT, b def read_wav_duration(wav_file): wav_file.seek(0) - with wave.open_remote(wav_file, 'rb') as wav_file_reader: + with wave.open(wav_file, 'rb') as wav_file_reader: return wav_file_reader.getnframes() / wav_file_reader.getframerate() From 86cba458c556227e5c685fce81a7112cb76af7dd Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 16:40:59 -0800 Subject: [PATCH 20/30] Fix remote path handling for CSV sample reading --- .../deepspeech_training/util/sample_collections.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index 2467854d..e0e6b12b 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -18,7 +18,7 @@ from .audio import ( get_audio_type_from_extension, write_wav ) -from .io import open_remote +from .io import open_remote, is_remote_path BIG_ENDIAN = 'big' INT_SIZE = 4 @@ -499,7 +499,6 @@ class CSV(SampleList): If the order of the samples should be reversed """ rows = [] - csv_dir = Path(csv_filename).parent with open_remote(csv_filename, 'r', encoding='utf8') as csv_file: reader = csv.DictReader(csv_file) if 'transcript' in reader.fieldnames: @@ -509,9 +508,12 @@ class CSV(SampleList): raise RuntimeError('No transcript data (missing CSV column)') for row in reader: wav_filename = Path(row['wav_filename']) - if not wav_filename.is_absolute(): - wav_filename = csv_dir / wav_filename - wav_filename = str(wav_filename) + if not wav_filename.is_absolute() and not is_remote_path(row['wav_filename']): + wav_filename = Path(csv_filename).parent / wav_filename + wav_filename = str(wav_filename) + else: + # Pathlib otherwise removes a / from filenames like hdfs:// + wav_filename = row['wav_filename'] wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0 if labeled: rows.append((wav_filename, wav_filesize, row['transcript'])) From fc0b4956431271f0b7caa834492aaf71fd2768d2 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 16:46:59 -0800 Subject: [PATCH 21/30] TODO: CSVWriter still totally breaks with remote paths --- training/deepspeech_training/util/sample_collections.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index e0e6b12b..d9856484 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -334,6 +334,8 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes If True: Writes labeled samples (util.sample_collections.LabeledSample) only. If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. """ + + # TODO: This all breaks with remote paths self.csv_filename = Path(csv_filename) self.csv_base_dir = self.csv_filename.parent.resolve().absolute() self.set_name = self.csv_filename.stem From be39d3354dc71499b5fa461c8ce2983779b9f262 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Thu, 12 Nov 2020 21:46:39 -0800 Subject: [PATCH 22/30] Perform data loading I/O within worker process rather than main process by wrapping Sample --- .../deepspeech_training/util/augmentations.py | 10 ++++-- .../util/sample_collections.py | 34 +++++++++++++++---- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/training/deepspeech_training/util/augmentations.py b/training/deepspeech_training/util/augmentations.py index 941c17f2..0934fbd5 100644 --- a/training/deepspeech_training/util/augmentations.py +++ b/training/deepspeech_training/util/augmentations.py @@ -150,6 +150,12 @@ def _init_augmentation_worker(preparation_context): AUGMENTATION_CONTEXT = preparation_context +def _load_and_augment_sample(timed_sample, context=None): + sample, clock = timed_sample + realized_sample = sample.unpack() + return _augment_sample((realized_sample, clock), context) + + def _augment_sample(timed_sample, context=None): context = AUGMENTATION_CONTEXT if context is None else context sample, clock = timed_sample @@ -213,12 +219,12 @@ def apply_sample_augmentations(samples, context = AugmentationContext(audio_type, augmentations) if process_ahead == 0: for timed_sample in timed_samples(): - yield _augment_sample(timed_sample, context=context) + yield _load_and_augment_sample(timed_sample, context=context) else: with LimitingPool(process_ahead=process_ahead, initializer=_init_augmentation_worker, initargs=(context,)) as pool: - yield from pool.imap(_augment_sample, timed_samples()) + yield from pool.imap(_load_and_augment_sample, timed_samples()) finally: for augmentation in augmentations: augmentation.stop() diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index d9856484..23b0422b 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -60,6 +60,27 @@ class LabeledSample(Sample): self.transcript = transcript +class PackedSample: + """ + A wrapper that we can carry around in an iterator and pass to a child process in order to + have the child process do the loading/unpacking of the sample, allowing for parallel file + I/O. + """ + def __init__(self, filename, audio_type, label): + self.filename = filename + self.audio_type = audio_type + self.label = label + + def unpack(self): + print("Unpacking sample: %s" % self.filename) + with open_remote(self.filename, 'rb') as audio_file: + data = audio_file.read() + if self.label is None: + s = Sample(self.audio_type, data, sample_id=self.filename) + s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename) + print("unpacked!") + return s + def load_sample(filename, label=None): """ Loads audio-file as a (labeled or unlabeled) sample @@ -70,21 +91,20 @@ def load_sample(filename, label=None): Filename of the audio-file to load as sample label : str Label (transcript) of the sample. - If None: return util.audio.Sample instance - Otherwise: return util.sample_collections.LabeledSample instance + If None: returned result.unpack() will return util.audio.Sample instance + Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance Returns ------- - util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance + util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return + util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance """ + print("loading sample!") ext = os.path.splitext(filename)[1].lower() audio_type = get_audio_type_from_extension(ext) if audio_type is None: raise ValueError('Unknown audio type extension "{}"'.format(ext)) - with open_remote(filename, 'rb') as audio_file: - if label is None: - return Sample(audio_type, audio_file.read(), sample_id=filename) - return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename) + return PackedSample(filename, audio_type, label) class DirectSDBWriter: From 2332e7fb76c72dc9d7bc2ca73823ebfa83ec85b9 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Fri, 13 Nov 2020 10:45:53 -0800 Subject: [PATCH 23/30] Linter fix: define self.tmp_src_file_path in init --- training/deepspeech_training/util/audio.py | 1 + 1 file changed, 1 insertion(+) diff --git a/training/deepspeech_training/util/audio.py b/training/deepspeech_training/util/audio.py index 5e2ed5d9..05ceba38 100644 --- a/training/deepspeech_training/util/audio.py +++ b/training/deepspeech_training/util/audio.py @@ -171,6 +171,7 @@ class AudioFile: self.open_file = None self.open_wav = None self.tmp_file_path = None + self.tmp_src_file_path = None def __enter__(self): if self.audio_path.endswith('.wav'): From 3d2b09b951241885d773a42d3a5c20188216a2bb Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Fri, 13 Nov 2020 10:47:06 -0800 Subject: [PATCH 24/30] Linter seems unhappy with conditional imports. Make gfile a module-level import. I usually do this as a conditional because tf takes a while to load and it's nice to skip it when you want to run a script that just preps data or something like that, but it doesn't seem like a big deal. --- training/deepspeech_training/util/io.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/training/deepspeech_training/util/io.py b/training/deepspeech_training/util/io.py index 885a276d..5f1de483 100644 --- a/training/deepspeech_training/util/io.py +++ b/training/deepspeech_training/util/io.py @@ -3,9 +3,8 @@ A set of I/O utils that allow us to open files on remote storage as if they were into HDFS storage using Tensorflow's C++ FileStream API. Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets. """ -import inspect import os -import sys +from tensorflow.io import gfile def is_remote_path(path): @@ -21,9 +20,7 @@ def path_exists_remote(path): Wrapper that allows existance check of local and remote paths like `gs://...` """ - # Conditional import if is_remote_path(path): - from tensorflow.io import gfile return gfile.exists(path) return os.path.exists(path) @@ -32,7 +29,6 @@ def copy_remote(src, dst, overwrite=False): """ Allows us to copy a file from local to remote or vice versa """ - from tensorflow.io import gfile return gfile.copy(src, dst, overwrite) @@ -47,9 +43,7 @@ def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, close with open_remote('gs://.....', mode='w+') as f: do something with the file f, whether or not we have local access to it """ - # Conditional import if is_remote_path(path): - from tensorflow.io import gfile return gfile.GFile(path, mode=mode) return open(path, mode, buffering=buffering, encoding=encoding, newline=newline, closefd=closefd, opener=opener) @@ -58,9 +52,7 @@ def isdir_remote(path): """ Wrapper to check if remote and local paths are directories """ - # Conditional import if is_remote_path(path): - from tensorflow.io import gfile return gfile.isdir(path) return os.path.isdir(path) @@ -69,9 +61,7 @@ def listdir_remote(path): """ Wrapper to list paths in local dirs (alternative to using a glob, I suppose) """ - # Conditional import if is_remote_path(path): - from tensorflow.io import gfile return gfile.listdir(path) return os.listdir(path) @@ -80,9 +70,6 @@ def glob_remote(filename): """ Wrapper that provides globs on local and remote paths like `gs://...` """ - # Conditional import - from tensorflow.io import gfile - return gfile.glob(filename) @@ -91,6 +78,4 @@ def remove_remote(filename): Wrapper that can remove_remote local and remote files like `gs://...` """ # Conditional import - from tensorflow.io import gfile - return gfile.remove_remote(filename) \ No newline at end of file From 47020e4ecbcb30976104e3ad9dbf7af5b9945cd7 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Fri, 13 Nov 2020 19:20:02 -0800 Subject: [PATCH 25/30] Add an imap_unordered helper to LimitPool -- I might experiment with this --- training/deepspeech_training/util/helpers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/training/deepspeech_training/util/helpers.py b/training/deepspeech_training/util/helpers.py index 195c117e..8d35e149 100644 --- a/training/deepspeech_training/util/helpers.py +++ b/training/deepspeech_training/util/helpers.py @@ -103,6 +103,11 @@ class LimitingPool: self.processed -= 1 yield obj + def imap_unordered(self, fun, it): + for obj in self.pool.imap_unordered(fun, self._limit(it)): + self.processed -= 1 + yield obj + def terminate(self): self.pool.terminate() From 8c1a183c671063ae113c2f9d1ac710dc4b8efc76 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Fri, 13 Nov 2020 19:24:09 -0800 Subject: [PATCH 26/30] Clean up print debugging statements --- training/deepspeech_training/util/sample_collections.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index 23b0422b..d075b440 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -72,13 +72,11 @@ class PackedSample: self.label = label def unpack(self): - print("Unpacking sample: %s" % self.filename) with open_remote(self.filename, 'rb') as audio_file: data = audio_file.read() if self.label is None: s = Sample(self.audio_type, data, sample_id=self.filename) s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename) - print("unpacked!") return s def load_sample(filename, label=None): @@ -99,7 +97,6 @@ def load_sample(filename, label=None): util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance """ - print("loading sample!") ext = os.path.splitext(filename)[1].lower() audio_type = get_audio_type_from_extension(ext) if audio_type is None: From fb6d4ca361da6283f75ca5e57edea4f55d08bf68 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Fri, 13 Nov 2020 19:36:07 -0800 Subject: [PATCH 27/30] Add disclaimers to CSV and Tar writers --- .../deepspeech_training/util/sample_collections.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index d075b440..227d152c 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -350,9 +350,9 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes labeled : bool or None If True: Writes labeled samples (util.sample_collections.LabeledSample) only. If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. + + Currently only works with local files (not gs:// or hdfs://...) """ - - # TODO: This all breaks with remote paths self.csv_filename = Path(csv_filename) self.csv_base_dir = self.csv_filename.parent.resolve().absolute() self.set_name = self.csv_filename.stem @@ -400,7 +400,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes class TarWriter: # pylint: disable=too-many-instance-attributes - """Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file""" + """Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file.""" def __init__(self, tar_filename, gz=False, @@ -418,8 +418,10 @@ class TarWriter: # pylint: disable=too-many-instance-attributes If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. include : str[] List of files to include into tar root. + + Currently only works with local files (not gs:// or hdfs://...) """ - self.tar = tarfile.open_remote(tar_filename, 'w:gz' if gz else 'w') + self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w') samples_dir = tarfile.TarInfo('samples') samples_dir.type = tarfile.DIRTYPE self.tar.addfile(samples_dir) From b5b3b2546ca5ba9581b833194921fe9f23daaf3e Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Mon, 16 Nov 2020 13:46:34 -0800 Subject: [PATCH 28/30] Clean up remote I/O docs --- training/deepspeech_training/util/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training/deepspeech_training/util/io.py b/training/deepspeech_training/util/io.py index 5f1de483..947b43af 100644 --- a/training/deepspeech_training/util/io.py +++ b/training/deepspeech_training/util/io.py @@ -34,7 +34,7 @@ def copy_remote(src, dst, overwrite=False): def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, closefd=True, opener=None): """ - Wrapper around open_remote() method that can handle remote paths like `gs://...` + Wrapper around open() method that can handle remote paths like `gs://...` off Google Cloud using Tensorflow's IO helpers. buffering, encoding, newline, closefd, and opener are ignored for remote files @@ -75,7 +75,7 @@ def glob_remote(filename): def remove_remote(filename): """ - Wrapper that can remove_remote local and remote files like `gs://...` + Wrapper that can remove local and remote files like `gs://...` """ # Conditional import return gfile.remove_remote(filename) \ No newline at end of file From 611633fcf64faaa168dbf4c50f799499902a2b2a Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Mon, 16 Nov 2020 13:47:06 -0800 Subject: [PATCH 29/30] Remove unnecessary uses of `open_remote()` where we know `__file__` will always be local --- training/deepspeech_training/train.py | 2 +- training/deepspeech_training/util/taskcluster.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index 6ebe29a6..3428598d 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -775,7 +775,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): def file_relative_read(fname): - return open_remote(os.path.join(os.path.dirname(__file__), fname)).read() + return open(os.path.join(os.path.dirname(__file__), fname)).read() def export(): diff --git a/training/deepspeech_training/util/taskcluster.py b/training/deepspeech_training/util/taskcluster.py index ba4f2019..4471659d 100644 --- a/training/deepspeech_training/util/taskcluster.py +++ b/training/deepspeech_training/util/taskcluster.py @@ -76,7 +76,7 @@ def maybe_download_tc_bin(**kwargs): os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC) def read(fname): - return open_remote(os.path.join(os.path.dirname(__file__), fname)).read() + return open(os.path.join(os.path.dirname(__file__), fname)).read() def main(): parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.') From d0678cd1b70d2207dfb29c29863a31eb255971a7 Mon Sep 17 00:00:00 2001 From: CatalinVoss Date: Mon, 16 Nov 2020 13:47:21 -0800 Subject: [PATCH 30/30] Remove unused unordered imap from LimitPool --- training/deepspeech_training/util/helpers.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/training/deepspeech_training/util/helpers.py b/training/deepspeech_training/util/helpers.py index 8d35e149..195c117e 100644 --- a/training/deepspeech_training/util/helpers.py +++ b/training/deepspeech_training/util/helpers.py @@ -103,11 +103,6 @@ class LimitingPool: self.processed -= 1 yield obj - def imap_unordered(self, fun, it): - for obj in self.pool.imap_unordered(fun, self._limit(it)): - self.processed -= 1 - yield obj - def terminate(self): self.pool.terminate()