Update speech_commands example: Code cleanup by resolving warnings and errors.
PiperOrigin-RevId: 296559221 Change-Id: I07442f9296dd02f2a6ef46969a3efc475c2235e5
This commit is contained in:
parent
900764b474
commit
3fd71fca2c
@ -48,7 +48,6 @@ py_binary(
|
||||
":recognize_commands_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -172,8 +171,6 @@ py_library(
|
||||
":input_data",
|
||||
":models",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -216,8 +213,6 @@ py_library(
|
||||
":input_data",
|
||||
":models",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -261,7 +256,6 @@ py_library(
|
||||
":models",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -44,10 +44,10 @@ import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.ops import gen_audio_ops as audio_ops
|
||||
import input_data
|
||||
import models
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.ops import gen_audio_ops as audio_ops
|
||||
|
||||
# If it's available, load the specialized feature generator. If this doesn't
|
||||
# work, try building with bazel instead of running the Python script directly.
|
||||
|
@ -233,15 +233,15 @@ class AudioProcessor(object):
|
||||
filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
|
||||
except:
|
||||
tf.compat.v1.logging.error(
|
||||
'Failed to download URL: %s to folder: %s', data_url, filepath)
|
||||
tf.compat.v1.logging.error(
|
||||
'Please make sure you have enough free space and'
|
||||
' an internet connection')
|
||||
'Failed to download URL: {0} to folder: {1}. Please make sure you '
|
||||
'have enough free space and an internet connection'.format(
|
||||
data_url, filepath))
|
||||
raise
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
tf.compat.v1.logging.info('Successfully downloaded %s (%d bytes)',
|
||||
filename, statinfo.st_size)
|
||||
tf.compat.v1.logging.info(
|
||||
'Successfully downloaded {0} ({1} bytes)'.format(
|
||||
filename, statinfo.st_size))
|
||||
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
|
||||
|
||||
def prepare_data_index(self, silence_percentage, unknown_percentage,
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.platform import test
|
||||
class InputDataTest(test.TestCase):
|
||||
|
||||
def _getWavData(self):
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
sample_data = tf.zeros([32000, 2])
|
||||
wav_encoder = tf.audio.encode_wav(sample_data, 16000)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
@ -105,11 +105,11 @@ class InputDataTest(test.TestCase):
|
||||
["a", "b"], 10, 10,
|
||||
self._model_settings(), tmp_dir)
|
||||
self.assertLess(0, audio_processor.set_size("training"))
|
||||
self.assertTrue("training" in audio_processor.data_index)
|
||||
self.assertTrue("validation" in audio_processor.data_index)
|
||||
self.assertTrue("testing" in audio_processor.data_index)
|
||||
self.assertEquals(input_data.UNKNOWN_WORD_INDEX,
|
||||
audio_processor.word_to_index["c"])
|
||||
self.assertIn("training", audio_processor.data_index)
|
||||
self.assertIn("validation", audio_processor.data_index)
|
||||
self.assertIn("testing", audio_processor.data_index)
|
||||
self.assertEqual(input_data.UNKNOWN_WORD_INDEX,
|
||||
audio_processor.word_to_index["c"])
|
||||
|
||||
def testPrepareDataIndexEmpty(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
@ -117,7 +117,7 @@ class InputDataTest(test.TestCase):
|
||||
with self.assertRaises(Exception) as e:
|
||||
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
|
||||
self._model_settings(), tmp_dir)
|
||||
self.assertTrue("No .wavs found" in str(e.exception))
|
||||
self.assertIn("No .wavs found", str(e.exception))
|
||||
|
||||
def testPrepareDataIndexMissing(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
@ -125,7 +125,7 @@ class InputDataTest(test.TestCase):
|
||||
with self.assertRaises(Exception) as e:
|
||||
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
|
||||
10, self._model_settings(), tmp_dir)
|
||||
self.assertTrue("Expected to find" in str(e.exception))
|
||||
self.assertIn("Expected to find", str(e.exception))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testPrepareBackgroundData(self):
|
||||
|
@ -77,13 +77,12 @@ def run_graph(wav_data, labels, input_layer_name, output_layer_name,
|
||||
def label_wav(wav, labels, graph, input_name, output_name, how_many_labels):
|
||||
"""Loads the model and labels, and runs the inference to print predictions."""
|
||||
if not wav or not tf.io.gfile.exists(wav):
|
||||
tf.compat.v1.logging.fatal('Audio file does not exist %s', wav)
|
||||
|
||||
raise ValueError('Audio file does not exist at {0}'.format(wav))
|
||||
if not labels or not tf.io.gfile.exists(labels):
|
||||
tf.compat.v1.logging.fatal('Labels file does not exist %s', labels)
|
||||
raise ValueError('Labels file does not exist at {0}'.format(labels))
|
||||
|
||||
if not graph or not tf.io.gfile.exists(graph):
|
||||
tf.compat.v1.logging.fatal('Graph file does not exist %s', graph)
|
||||
raise ValueError('Graph file does not exist at {0}'.format(graph))
|
||||
|
||||
labels_list = load_labels(labels)
|
||||
|
||||
|
@ -64,8 +64,7 @@ def run_graph(wav_dir, labels, input_layer_name, output_layer_name,
|
||||
# predictions per class
|
||||
for wav_path in glob.glob(wav_dir + '/*.wav'):
|
||||
if not wav_path or not tf.io.gfile.exists(wav_path):
|
||||
tf.compat.v1.logging.fatal('Audio file does not exist %s', wav_path)
|
||||
|
||||
raise ValueError('Audio file does not exist at {0}'.format(wav_path))
|
||||
with open(wav_path, 'rb') as wav_file:
|
||||
wav_data = wav_file.read()
|
||||
|
||||
@ -86,10 +85,10 @@ def run_graph(wav_dir, labels, input_layer_name, output_layer_name,
|
||||
def label_wav(wav_dir, labels, graph, input_name, output_name, how_many_labels):
|
||||
"""Loads the model and labels, and runs the inference to print predictions."""
|
||||
if not labels or not tf.io.gfile.exists(labels):
|
||||
tf.compat.v1.logging.fatal('Labels file does not exist %s', labels)
|
||||
raise ValueError('Labels file does not exist at {0}'.format(labels))
|
||||
|
||||
if not graph or not tf.io.gfile.exists(graph):
|
||||
tf.compat.v1.logging.fatal('Graph file does not exist %s', graph)
|
||||
raise ValueError('Graph file does not exist at {0}'.format(graph))
|
||||
|
||||
labels_list = load_labels(labels)
|
||||
|
||||
|
@ -29,7 +29,7 @@ from tensorflow.python.platform import test
|
||||
class LabelWavTest(test.TestCase):
|
||||
|
||||
def _getWavData(self):
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
sample_data = tf.zeros([1000, 2])
|
||||
wav_encoder = tf.audio.encode_wav(sample_data, 16000)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
|
@ -251,12 +251,16 @@ def main(_):
|
||||
dropout_rate: 0.5
|
||||
})
|
||||
train_writer.add_summary(train_summary, training_step)
|
||||
tf.compat.v1.logging.info(
|
||||
tf.compat.v1.logging.debug(
|
||||
'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
|
||||
(training_step, learning_rate_value, train_accuracy * 100,
|
||||
cross_entropy_value))
|
||||
is_last_step = (training_step == training_steps_max)
|
||||
if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
|
||||
tf.compat.v1.logging.info(
|
||||
'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
|
||||
(training_step, learning_rate_value, train_accuracy * 100,
|
||||
cross_entropy_value))
|
||||
set_size = audio_processor.set_size('validation')
|
||||
total_accuracy = 0
|
||||
total_conf_matrix = None
|
||||
|
@ -30,7 +30,7 @@ from tensorflow.python.platform import test
|
||||
class WavToFeaturesTest(test.TestCase):
|
||||
|
||||
def _getWavData(self):
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
sample_data = tf.zeros([32000, 2])
|
||||
wav_encoder = tf.audio.encode_wav(sample_data, 16000)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
@ -63,7 +63,7 @@ class WavToFeaturesTest(test.TestCase):
|
||||
input_file_path, output_file_path)
|
||||
with open(output_file_path, "rb") as f:
|
||||
content = f.read()
|
||||
self.assertTrue(b"const unsigned char g_input_data" in content)
|
||||
self.assertIn(b"const unsigned char g_input_data", content)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWavToFeaturesMicro(self):
|
||||
|
Loading…
Reference in New Issue
Block a user