Update transcribe.py to TF2
This commit is contained in:
parent
159697738c
commit
9b738dd70d
@ -3,15 +3,19 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1.logging as tflogging
|
||||
tflogging.set_verbosity(tflogging.ERROR)
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
tfv1.logging.set_verbosity(tfv1.logging.ERROR)
|
||||
|
||||
import logging
|
||||
logging.getLogger('sox').setLevel(logging.ERROR)
|
||||
|
||||
import absl.flags
|
||||
import glob
|
||||
import json
|
||||
import sys
|
||||
|
||||
from deepspeech_training.util.audio import AudioFile
|
||||
from deepspeech_training.util.config import Config, initialize_globals
|
||||
@ -28,7 +32,7 @@ def fail(message, code=1):
|
||||
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from deepspeech_training.train import Model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from deepspeech_training.util.checkpoints import load_graph_for_evaluation
|
||||
initialize_globals()
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||
@ -37,21 +41,29 @@ def transcribe_file(audio_path, tlog_path):
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||
data_set = split_audio_file(wav_path,
|
||||
batch_size=FLAGS.batch_size,
|
||||
aggressiveness=FLAGS.vad_aggressiveness,
|
||||
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
||||
outlier_batch_size=FLAGS.outlier_batch_size)
|
||||
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
|
||||
output_classes=data_set.output_classes)
|
||||
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
data_set = split_audio_file(wav_path,
|
||||
batch_size=FLAGS.batch_size,
|
||||
aggressiveness=FLAGS.vad_aggressiveness,
|
||||
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
||||
outlier_batch_size=FLAGS.outlier_batch_size)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(data_set),
|
||||
tfv1.data.get_output_shapes(data_set),
|
||||
output_classes=tfv1.data.get_output_classes(data_set))
|
||||
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||
|
||||
model = Model()
|
||||
logits, _ = model(batch_x)
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
|
||||
# Make sure global step variable is created in this graph so that
|
||||
# checkpoint loading doesn't fail
|
||||
tfv1.train.get_or_create_global_step()
|
||||
|
||||
load_graph_for_evaluation(session)
|
||||
session.run(iterator.make_initializer(data_set))
|
||||
|
||||
transcripts = []
|
||||
while True:
|
||||
try:
|
||||
@ -149,20 +161,20 @@ def main(_):
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
tf.app.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.'
|
||||
absl.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.'
|
||||
'Catalog files should be formatted from DSAlign. A directory will'
|
||||
'be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be '
|
||||
'written in-place using the source filenames with '
|
||||
'suffix ".tlog" instead of ".wav".')
|
||||
tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
|
||||
absl.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
|
||||
'If --src is a directory, this one also has to be a directory '
|
||||
'and the required sub-dir tree of --src will get replicated.')
|
||||
tf.app.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively')
|
||||
tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
|
||||
absl.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively')
|
||||
absl.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
|
||||
'transcription logs (.tlog)')
|
||||
tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
|
||||
absl.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
|
||||
'split audio')
|
||||
tf.app.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
|
||||
tf.app.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
|
||||
tf.app.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
|
||||
tf.app.run(main)
|
||||
absl.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
|
||||
absl.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
|
||||
absl.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
|
||||
absl.app.run(main)
|
||||
|
Loading…
Reference in New Issue
Block a user