Work remote I/O into train script
This commit is contained in:
parent
53e3f5374f
commit
579921cc92
@ -35,6 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
|
||||
from .util.io import open_remote, remove_remote, listdir_remote
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
@ -514,7 +515,7 @@ def train():
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open(flags_file, 'w') as fout:
|
||||
with open_remote(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
@ -541,7 +542,7 @@ def train():
|
||||
feature_cache_index = FLAGS.feature_cache + '.index'
|
||||
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
|
||||
log_info('Invalidating feature cache')
|
||||
os.remove(feature_cache_index) # this will let TF also overwrite the related cache data files
|
||||
remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
@ -773,7 +774,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
return open_remote(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
def export():
|
||||
@ -809,14 +810,14 @@ def export():
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
if FLAGS.remove_remote_export:
|
||||
if isdir_remote(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
if not isdir_remote(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
@ -829,7 +830,7 @@ def export():
|
||||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
with open_remote(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
@ -840,7 +841,7 @@ def export():
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open(output_tflite_path, 'wb') as fout:
|
||||
with open_remote(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
@ -851,7 +852,7 @@ def export():
|
||||
FLAGS.export_model_version))
|
||||
|
||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||
with open(metadata_fname, 'w') as f:
|
||||
with open_remote(metadata_fname, 'w') as f:
|
||||
f.write('---\n')
|
||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||
@ -959,7 +960,7 @@ def main(_):
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
if os.listdir(FLAGS.export_dir):
|
||||
if listdir_remote(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user