Update test_streaming_accuracy.py

fix v2 compatibility.
This commit is contained in:
Judd 2020-02-14 14:36:43 +08:00 committed by GitHub
parent 19ecdb017a
commit 8e3fc97982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -69,10 +69,9 @@ import sys
import numpy
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from tensorflow.examples.speech_commands.accuracy_utils import StreamingAccuracyStats
from tensorflow.examples.speech_commands.recognize_commands import RecognizeCommands
from tensorflow.examples.speech_commands.recognize_commands import RecognizeResult
from accuracy_utils import StreamingAccuracyStats
from recognize_commands import RecognizeCommands
from recognize_commands import RecognizeResult
from tensorflow.python.ops import io_ops
FLAGS = None
@ -82,8 +81,8 @@ def load_graph(mode_file):
"""Read a tensorflow model, and creates a default graph object."""
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(mode_file, 'rb') as fid:
od_graph_def = tf.compat.v1.GraphDef()
with tf.io.gfile.GFile(mode_file, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
@ -101,10 +100,10 @@ def read_label_file(file_name):
def read_wav_file(filename):
"""Load a wav file and return sample_rate and numpy data of float64 type."""
with tf.Session(graph=tf.Graph()) as sess:
wav_filename_placeholder = tf.placeholder(tf.string, [])
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
wav_loader = io_ops.read_file(wav_filename_placeholder)
wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
res = sess.run(wav_decoder, feed_dict={wav_filename_placeholder: filename})
return res.sample_rate, res.audio.flatten()
@ -133,14 +132,14 @@ def main(_):
# Load model and create a tf session to process audio pieces
recognize_graph = load_graph(FLAGS.model)
with recognize_graph.as_default():
with tf.Session() as sess:
with tf.compat.v1.Session() as sess:
# Get input and output tensor
data_tensor = tf.get_default_graph().get_tensor_by_name(
data_tensor = sess.graph.get_tensor_by_name(
FLAGS.input_names[0])
sample_rate_tensor = tf.get_default_graph().get_tensor_by_name(
sample_rate_tensor = sess.graph.get_tensor_by_name(
FLAGS.input_names[1])
output_softmax_tensor = tf.get_default_graph().get_tensor_by_name(
output_softmax_tensor = sess.graph.get_tensor_by_name(
FLAGS.output_name)
# Inference along audio stream.
@ -161,7 +160,7 @@ def main(_):
recognize_commands.process_latest_result(outputs, current_time_ms,
recognize_element)
except ValueError as e:
tf.logging.error('Recognition processing failed: {}' % e)
tf.compat.v1.logging.error('Recognition processing failed: {}' % e)
return
if (recognize_element.is_new_command and
recognize_element.founded_command != '_silence_'):
@ -173,10 +172,10 @@ def main(_):
try:
recognition_state = stats.delta()
except ValueError as e:
tf.logging.error(
tf.compat.v1.logging.error(
'Statistics delta computing failed: {}'.format(e))
else:
tf.logging.info('{}ms {}:{}{}'.format(
tf.compat.v1.logging.info('{}ms {}:{}{}'.format(
current_time_ms, recognize_element.founded_command,
recognize_element.score, recognition_state))
stats.print_accuracy_stats()
@ -249,5 +248,5 @@ if __name__ == '__main__':
help='Whether to print streaming accuracy on stdout.')
FLAGS, unparsed = parser.parse_known_args()
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)