Merge pull request #36743 from foldl:make-speech-command-v2-compat
PiperOrigin-RevId: 296056103 Change-Id: Ibee2e8731ff7e8ebf19e5fa4e53f62d95634e4b8
This commit is contained in:
commit
23b4405ace
@ -137,14 +137,16 @@ class StreamingAccuracyStats(object):
|
|||||||
def print_accuracy_stats(self):
|
def print_accuracy_stats(self):
|
||||||
"""Write a human-readable description of the statistics to stdout."""
|
"""Write a human-readable description of the statistics to stdout."""
|
||||||
if self._how_many_gt == 0:
|
if self._how_many_gt == 0:
|
||||||
tf.logging.info('No ground truth yet, {}false positives'.format(
|
tf.compat.v1.logging.info('No ground truth yet, {}false positives'.format(
|
||||||
self._how_many_fp))
|
self._how_many_fp))
|
||||||
else:
|
else:
|
||||||
any_match_percentage = self._how_many_gt_matched / self._how_many_gt * 100
|
any_match_percentage = self._how_many_gt_matched / self._how_many_gt * 100
|
||||||
correct_match_percentage = self._how_many_c / self._how_many_gt * 100
|
correct_match_percentage = self._how_many_c / self._how_many_gt * 100
|
||||||
wrong_match_percentage = self._how_many_w / self._how_many_gt * 100
|
wrong_match_percentage = self._how_many_w / self._how_many_gt * 100
|
||||||
false_positive_percentage = self._how_many_fp / self._how_many_gt * 100
|
false_positive_percentage = self._how_many_fp / self._how_many_gt * 100
|
||||||
tf.logging.info('{:.1f}% matched, {:.1f}% correct, {:.1f}% wrong, '
|
tf.compat.v1.logging.info(
|
||||||
'{:.1f}% false positive'.format(
|
'{:.1f}% matched, {:.1f}% correct, {:.1f}% wrong, '
|
||||||
any_match_percentage, correct_match_percentage,
|
'{:.1f}% false positive'.format(any_match_percentage,
|
||||||
wrong_match_percentage, false_positive_percentage))
|
correct_match_percentage,
|
||||||
|
wrong_match_percentage,
|
||||||
|
false_positive_percentage))
|
||||||
|
|||||||
@ -69,10 +69,9 @@ import sys
|
|||||||
import numpy
|
import numpy
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
from accuracy_utils import StreamingAccuracyStats
|
||||||
from tensorflow.examples.speech_commands.accuracy_utils import StreamingAccuracyStats
|
from recognize_commands import RecognizeCommands
|
||||||
from tensorflow.examples.speech_commands.recognize_commands import RecognizeCommands
|
from recognize_commands import RecognizeResult
|
||||||
from tensorflow.examples.speech_commands.recognize_commands import RecognizeResult
|
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
|
|
||||||
FLAGS = None
|
FLAGS = None
|
||||||
@ -82,8 +81,8 @@ def load_graph(mode_file):
|
|||||||
"""Read a tensorflow model, and creates a default graph object."""
|
"""Read a tensorflow model, and creates a default graph object."""
|
||||||
graph = tf.Graph()
|
graph = tf.Graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
od_graph_def = tf.GraphDef()
|
od_graph_def = tf.compat.v1.GraphDef()
|
||||||
with tf.gfile.GFile(mode_file, 'rb') as fid:
|
with tf.io.gfile.GFile(mode_file, 'rb') as fid:
|
||||||
serialized_graph = fid.read()
|
serialized_graph = fid.read()
|
||||||
od_graph_def.ParseFromString(serialized_graph)
|
od_graph_def.ParseFromString(serialized_graph)
|
||||||
tf.import_graph_def(od_graph_def, name='')
|
tf.import_graph_def(od_graph_def, name='')
|
||||||
@ -101,10 +100,10 @@ def read_label_file(file_name):
|
|||||||
|
|
||||||
def read_wav_file(filename):
|
def read_wav_file(filename):
|
||||||
"""Load a wav file and return sample_rate and numpy data of float64 type."""
|
"""Load a wav file and return sample_rate and numpy data of float64 type."""
|
||||||
with tf.Session(graph=tf.Graph()) as sess:
|
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
|
||||||
wav_filename_placeholder = tf.placeholder(tf.string, [])
|
wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
|
||||||
wav_loader = io_ops.read_file(wav_filename_placeholder)
|
wav_loader = io_ops.read_file(wav_filename_placeholder)
|
||||||
wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
|
wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
|
||||||
res = sess.run(wav_decoder, feed_dict={wav_filename_placeholder: filename})
|
res = sess.run(wav_decoder, feed_dict={wav_filename_placeholder: filename})
|
||||||
return res.sample_rate, res.audio.flatten()
|
return res.sample_rate, res.audio.flatten()
|
||||||
|
|
||||||
@ -133,15 +132,12 @@ def main(_):
|
|||||||
# Load model and create a tf session to process audio pieces
|
# Load model and create a tf session to process audio pieces
|
||||||
recognize_graph = load_graph(FLAGS.model)
|
recognize_graph = load_graph(FLAGS.model)
|
||||||
with recognize_graph.as_default():
|
with recognize_graph.as_default():
|
||||||
with tf.Session() as sess:
|
with tf.compat.v1.Session() as sess:
|
||||||
|
|
||||||
# Get input and output tensor
|
# Get input and output tensor
|
||||||
data_tensor = tf.get_default_graph().get_tensor_by_name(
|
data_tensor = sess.graph.get_tensor_by_name(FLAGS.input_names[0])
|
||||||
FLAGS.input_names[0])
|
sample_rate_tensor = sess.graph.get_tensor_by_name(FLAGS.input_names[1])
|
||||||
sample_rate_tensor = tf.get_default_graph().get_tensor_by_name(
|
output_softmax_tensor = sess.graph.get_tensor_by_name(FLAGS.output_name)
|
||||||
FLAGS.input_names[1])
|
|
||||||
output_softmax_tensor = tf.get_default_graph().get_tensor_by_name(
|
|
||||||
FLAGS.output_name)
|
|
||||||
|
|
||||||
# Inference along audio stream.
|
# Inference along audio stream.
|
||||||
for audio_data_offset in range(0, audio_data_end, clip_stride_samples):
|
for audio_data_offset in range(0, audio_data_end, clip_stride_samples):
|
||||||
@ -161,7 +157,7 @@ def main(_):
|
|||||||
recognize_commands.process_latest_result(outputs, current_time_ms,
|
recognize_commands.process_latest_result(outputs, current_time_ms,
|
||||||
recognize_element)
|
recognize_element)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
tf.logging.error('Recognition processing failed: {}' % e)
|
tf.compat.v1.logging.error('Recognition processing failed: {}' % e)
|
||||||
return
|
return
|
||||||
if (recognize_element.is_new_command and
|
if (recognize_element.is_new_command and
|
||||||
recognize_element.founded_command != '_silence_'):
|
recognize_element.founded_command != '_silence_'):
|
||||||
@ -173,10 +169,10 @@ def main(_):
|
|||||||
try:
|
try:
|
||||||
recognition_state = stats.delta()
|
recognition_state = stats.delta()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
tf.logging.error(
|
tf.compat.v1.logging.error(
|
||||||
'Statistics delta computing failed: {}'.format(e))
|
'Statistics delta computing failed: {}'.format(e))
|
||||||
else:
|
else:
|
||||||
tf.logging.info('{}ms {}:{}{}'.format(
|
tf.compat.v1.logging.info('{}ms {}:{}{}'.format(
|
||||||
current_time_ms, recognize_element.founded_command,
|
current_time_ms, recognize_element.founded_command,
|
||||||
recognize_element.score, recognition_state))
|
recognize_element.score, recognition_state))
|
||||||
stats.print_accuracy_stats()
|
stats.print_accuracy_stats()
|
||||||
@ -249,5 +245,5 @@ if __name__ == '__main__':
|
|||||||
help='Whether to print streaming accuracy on stdout.')
|
help='Whether to print streaming accuracy on stdout.')
|
||||||
|
|
||||||
FLAGS, unparsed = parser.parse_known_args()
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
|
||||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user