Update transcribe.py to TF2
This commit is contained in:
parent
159697738c
commit
9b738dd70d
@ -3,15 +3,19 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.compat.v1.logging as tflogging
|
import tensorflow.compat.v1 as tfv1
|
||||||
tflogging.set_verbosity(tflogging.ERROR)
|
tfv1.logging.set_verbosity(tfv1.logging.ERROR)
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger('sox').setLevel(logging.ERROR)
|
logging.getLogger('sox').setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
import absl.flags
|
||||||
import glob
|
import glob
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
from deepspeech_training.util.audio import AudioFile
|
from deepspeech_training.util.audio import AudioFile
|
||||||
from deepspeech_training.util.config import Config, initialize_globals
|
from deepspeech_training.util.config import Config, initialize_globals
|
||||||
@ -28,7 +32,7 @@ def fail(message, code=1):
|
|||||||
|
|
||||||
|
|
||||||
def transcribe_file(audio_path, tlog_path):
|
def transcribe_file(audio_path, tlog_path):
|
||||||
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
from deepspeech_training.train import Model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||||
from deepspeech_training.util.checkpoints import load_graph_for_evaluation
|
from deepspeech_training.util.checkpoints import load_graph_for_evaluation
|
||||||
initialize_globals()
|
initialize_globals()
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||||
@ -37,21 +41,29 @@ def transcribe_file(audio_path, tlog_path):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
num_processes = 1
|
num_processes = 1
|
||||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||||
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
data_set = split_audio_file(wav_path,
|
data_set = split_audio_file(wav_path,
|
||||||
batch_size=FLAGS.batch_size,
|
batch_size=FLAGS.batch_size,
|
||||||
aggressiveness=FLAGS.vad_aggressiveness,
|
aggressiveness=FLAGS.vad_aggressiveness,
|
||||||
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
||||||
outlier_batch_size=FLAGS.outlier_batch_size)
|
outlier_batch_size=FLAGS.outlier_batch_size)
|
||||||
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
|
|
||||||
output_classes=data_set.output_classes)
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(data_set),
|
||||||
|
tfv1.data.get_output_shapes(data_set),
|
||||||
|
output_classes=tfv1.data.get_output_classes(data_set))
|
||||||
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||||
no_dropout = [None] * 6
|
|
||||||
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
|
model = Model()
|
||||||
|
logits, _ = model(batch_x)
|
||||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||||
tf.train.get_or_create_global_step()
|
|
||||||
with tf.Session(config=Config.session_config) as session:
|
# Make sure global step variable is created in this graph so that
|
||||||
|
# checkpoint loading doesn't fail
|
||||||
|
tfv1.train.get_or_create_global_step()
|
||||||
|
|
||||||
load_graph_for_evaluation(session)
|
load_graph_for_evaluation(session)
|
||||||
session.run(iterator.make_initializer(data_set))
|
session.run(iterator.make_initializer(data_set))
|
||||||
|
|
||||||
transcripts = []
|
transcripts = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -149,20 +161,20 @@ def main(_):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
create_flags()
|
create_flags()
|
||||||
tf.app.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.'
|
absl.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.'
|
||||||
'Catalog files should be formatted from DSAlign. A directory will'
|
'Catalog files should be formatted from DSAlign. A directory will'
|
||||||
'be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be '
|
'be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be '
|
||||||
'written in-place using the source filenames with '
|
'written in-place using the source filenames with '
|
||||||
'suffix ".tlog" instead of ".wav".')
|
'suffix ".tlog" instead of ".wav".')
|
||||||
tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
|
absl.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
|
||||||
'If --src is a directory, this one also has to be a directory '
|
'If --src is a directory, this one also has to be a directory '
|
||||||
'and the required sub-dir tree of --src will get replicated.')
|
'and the required sub-dir tree of --src will get replicated.')
|
||||||
tf.app.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively')
|
absl.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively')
|
||||||
tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
|
absl.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
|
||||||
'transcription logs (.tlog)')
|
'transcription logs (.tlog)')
|
||||||
tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
|
absl.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
|
||||||
'split audio')
|
'split audio')
|
||||||
tf.app.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
|
absl.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
|
||||||
tf.app.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
|
absl.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
|
||||||
tf.app.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
|
absl.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
|
||||||
tf.app.run(main)
|
absl.app.run(main)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user