Work remote I/O into train script

This commit is contained in:
CatalinVoss 2020-11-12 10:45:35 -08:00
parent 53e3f5374f
commit 579921cc92

View File

@ -35,6 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
from .util.io import open_remote, remove_remote, listdir_remote
check_ctcdecoder_version()
@ -514,7 +515,7 @@ def train():
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
with open_remote(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
@ -541,7 +542,7 @@ def train():
feature_cache_index = FLAGS.feature_cache + '.index'
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
log_info('Invalidating feature cache')
os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files
remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
@ -773,7 +774,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
return open_remote(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
@ -809,14 +810,14 @@ def export():
load_graph_for_evaluation(session)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
if FLAGS.remove_remote_export:
if isdir_remote(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
if not isdir_remote(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
@ -829,7 +830,7 @@ def export():
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
with open_remote(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
@ -840,7 +841,7 @@ def export():
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
with open_remote(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
@ -851,7 +852,7 @@ def export():
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
with open_remote(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
@ -959,7 +960,7 @@ def main(_):
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
if listdir_remote(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)