Update transcribe.py to TF2

This commit is contained in:
Reuben Morais 2021-01-02 15:05:28 +00:00
parent 159697738c
commit 9b738dd70d

View File

@ -3,15 +3,19 @@
from __future__ import absolute_import, division, print_function
import os
import sys
import json
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import tensorflow.compat.v1.logging as tflogging
tflogging.set_verbosity(tflogging.ERROR)
import tensorflow.compat.v1 as tfv1
tfv1.logging.set_verbosity(tfv1.logging.ERROR)
import logging
logging.getLogger('sox').setLevel(logging.ERROR)
import absl.flags
import glob
import json
import sys
from deepspeech_training.util.audio import AudioFile
from deepspeech_training.util.config import Config, initialize_globals
@ -28,7 +32,7 @@ def fail(message, code=1):
def transcribe_file(audio_path, tlog_path):
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
from deepspeech_training.train import Model # pylint: disable=cyclic-import,import-outside-toplevel
from deepspeech_training.util.checkpoints import load_graph_for_evaluation
initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
@ -37,21 +41,29 @@ def transcribe_file(audio_path, tlog_path):
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(wav_path,
batch_size=FLAGS.batch_size,
aggressiveness=FLAGS.vad_aggressiveness,
outlier_duration_ms=FLAGS.outlier_duration_ms,
outlier_batch_size=FLAGS.outlier_batch_size)
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
output_classes=data_set.output_classes)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
with tfv1.Session(config=Config.session_config) as session:
data_set = split_audio_file(wav_path,
batch_size=FLAGS.batch_size,
aggressiveness=FLAGS.vad_aggressiveness,
outlier_duration_ms=FLAGS.outlier_duration_ms,
outlier_batch_size=FLAGS.outlier_batch_size)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(data_set),
tfv1.data.get_output_shapes(data_set),
output_classes=tfv1.data.get_output_classes(data_set))
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
model = Model()
logits, _ = model(batch_x)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
# Make sure global step variable is created in this graph so that
# checkpoint loading doesn't fail
tfv1.train.get_or_create_global_step()
load_graph_for_evaluation(session)
session.run(iterator.make_initializer(data_set))
transcripts = []
while True:
try:
@ -149,20 +161,20 @@ def main(_):
if __name__ == '__main__':
create_flags()
tf.app.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.'
absl.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.'
'Catalog files should be formatted from DSAlign. A directory will'
'be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be '
'written in-place using the source filenames with '
'suffix ".tlog" instead of ".wav".')
tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
absl.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
'If --src is a directory, this one also has to be a directory '
'and the required sub-dir tree of --src will get replicated.')
tf.app.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively')
tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
absl.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively')
absl.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
'transcription logs (.tlog)')
tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
absl.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
'split audio')
tf.app.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
tf.app.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
tf.app.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
tf.app.run(main)
absl.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
absl.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
absl.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
absl.app.run(main)