Update speech_commands example: Code cleanup by resolving warnings and errors.
PiperOrigin-RevId: 296559221 Change-Id: I07442f9296dd02f2a6ef46969a3efc475c2235e5
This commit is contained in:
parent
900764b474
commit
3fd71fca2c
@ -48,7 +48,6 @@ py_binary(
|
|||||||
":recognize_commands_py",
|
":recognize_commands_py",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@six_archive//:six",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -172,8 +171,6 @@ py_library(
|
|||||||
":input_data",
|
":input_data",
|
||||||
":models",
|
":models",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
"//third_party/py/numpy",
|
|
||||||
"@six_archive//:six",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -216,8 +213,6 @@ py_library(
|
|||||||
":input_data",
|
":input_data",
|
||||||
":models",
|
":models",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
"//third_party/py/numpy",
|
|
||||||
"@six_archive//:six",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -261,7 +256,6 @@ py_library(
|
|||||||
":models",
|
":models",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@six_archive//:six",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,10 +44,10 @@ import sys
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.python.ops import gen_audio_ops as audio_ops
|
|
||||||
import input_data
|
import input_data
|
||||||
import models
|
import models
|
||||||
from tensorflow.python.framework import graph_util
|
from tensorflow.python.framework import graph_util
|
||||||
|
from tensorflow.python.ops import gen_audio_ops as audio_ops
|
||||||
|
|
||||||
# If it's available, load the specialized feature generator. If this doesn't
|
# If it's available, load the specialized feature generator. If this doesn't
|
||||||
# work, try building with bazel instead of running the Python script directly.
|
# work, try building with bazel instead of running the Python script directly.
|
||||||
|
@ -233,15 +233,15 @@ class AudioProcessor(object):
|
|||||||
filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
|
filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
|
||||||
except:
|
except:
|
||||||
tf.compat.v1.logging.error(
|
tf.compat.v1.logging.error(
|
||||||
'Failed to download URL: %s to folder: %s', data_url, filepath)
|
'Failed to download URL: {0} to folder: {1}. Please make sure you '
|
||||||
tf.compat.v1.logging.error(
|
'have enough free space and an internet connection'.format(
|
||||||
'Please make sure you have enough free space and'
|
data_url, filepath))
|
||||||
' an internet connection')
|
|
||||||
raise
|
raise
|
||||||
print()
|
print()
|
||||||
statinfo = os.stat(filepath)
|
statinfo = os.stat(filepath)
|
||||||
tf.compat.v1.logging.info('Successfully downloaded %s (%d bytes)',
|
tf.compat.v1.logging.info(
|
||||||
filename, statinfo.st_size)
|
'Successfully downloaded {0} ({1} bytes)'.format(
|
||||||
|
filename, statinfo.st_size))
|
||||||
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
|
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
|
||||||
|
|
||||||
def prepare_data_index(self, silence_percentage, unknown_percentage,
|
def prepare_data_index(self, silence_percentage, unknown_percentage,
|
||||||
|
@ -33,7 +33,7 @@ from tensorflow.python.platform import test
|
|||||||
class InputDataTest(test.TestCase):
|
class InputDataTest(test.TestCase):
|
||||||
|
|
||||||
def _getWavData(self):
|
def _getWavData(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
sample_data = tf.zeros([32000, 2])
|
sample_data = tf.zeros([32000, 2])
|
||||||
wav_encoder = tf.audio.encode_wav(sample_data, 16000)
|
wav_encoder = tf.audio.encode_wav(sample_data, 16000)
|
||||||
wav_data = self.evaluate(wav_encoder)
|
wav_data = self.evaluate(wav_encoder)
|
||||||
@ -105,10 +105,10 @@ class InputDataTest(test.TestCase):
|
|||||||
["a", "b"], 10, 10,
|
["a", "b"], 10, 10,
|
||||||
self._model_settings(), tmp_dir)
|
self._model_settings(), tmp_dir)
|
||||||
self.assertLess(0, audio_processor.set_size("training"))
|
self.assertLess(0, audio_processor.set_size("training"))
|
||||||
self.assertTrue("training" in audio_processor.data_index)
|
self.assertIn("training", audio_processor.data_index)
|
||||||
self.assertTrue("validation" in audio_processor.data_index)
|
self.assertIn("validation", audio_processor.data_index)
|
||||||
self.assertTrue("testing" in audio_processor.data_index)
|
self.assertIn("testing", audio_processor.data_index)
|
||||||
self.assertEquals(input_data.UNKNOWN_WORD_INDEX,
|
self.assertEqual(input_data.UNKNOWN_WORD_INDEX,
|
||||||
audio_processor.word_to_index["c"])
|
audio_processor.word_to_index["c"])
|
||||||
|
|
||||||
def testPrepareDataIndexEmpty(self):
|
def testPrepareDataIndexEmpty(self):
|
||||||
@ -117,7 +117,7 @@ class InputDataTest(test.TestCase):
|
|||||||
with self.assertRaises(Exception) as e:
|
with self.assertRaises(Exception) as e:
|
||||||
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
|
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
|
||||||
self._model_settings(), tmp_dir)
|
self._model_settings(), tmp_dir)
|
||||||
self.assertTrue("No .wavs found" in str(e.exception))
|
self.assertIn("No .wavs found", str(e.exception))
|
||||||
|
|
||||||
def testPrepareDataIndexMissing(self):
|
def testPrepareDataIndexMissing(self):
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
@ -125,7 +125,7 @@ class InputDataTest(test.TestCase):
|
|||||||
with self.assertRaises(Exception) as e:
|
with self.assertRaises(Exception) as e:
|
||||||
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
|
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
|
||||||
10, self._model_settings(), tmp_dir)
|
10, self._model_settings(), tmp_dir)
|
||||||
self.assertTrue("Expected to find" in str(e.exception))
|
self.assertIn("Expected to find", str(e.exception))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testPrepareBackgroundData(self):
|
def testPrepareBackgroundData(self):
|
||||||
|
@ -77,13 +77,12 @@ def run_graph(wav_data, labels, input_layer_name, output_layer_name,
|
|||||||
def label_wav(wav, labels, graph, input_name, output_name, how_many_labels):
|
def label_wav(wav, labels, graph, input_name, output_name, how_many_labels):
|
||||||
"""Loads the model and labels, and runs the inference to print predictions."""
|
"""Loads the model and labels, and runs the inference to print predictions."""
|
||||||
if not wav or not tf.io.gfile.exists(wav):
|
if not wav or not tf.io.gfile.exists(wav):
|
||||||
tf.compat.v1.logging.fatal('Audio file does not exist %s', wav)
|
raise ValueError('Audio file does not exist at {0}'.format(wav))
|
||||||
|
|
||||||
if not labels or not tf.io.gfile.exists(labels):
|
if not labels or not tf.io.gfile.exists(labels):
|
||||||
tf.compat.v1.logging.fatal('Labels file does not exist %s', labels)
|
raise ValueError('Labels file does not exist at {0}'.format(labels))
|
||||||
|
|
||||||
if not graph or not tf.io.gfile.exists(graph):
|
if not graph or not tf.io.gfile.exists(graph):
|
||||||
tf.compat.v1.logging.fatal('Graph file does not exist %s', graph)
|
raise ValueError('Graph file does not exist at {0}'.format(graph))
|
||||||
|
|
||||||
labels_list = load_labels(labels)
|
labels_list = load_labels(labels)
|
||||||
|
|
||||||
|
@ -64,8 +64,7 @@ def run_graph(wav_dir, labels, input_layer_name, output_layer_name,
|
|||||||
# predictions per class
|
# predictions per class
|
||||||
for wav_path in glob.glob(wav_dir + '/*.wav'):
|
for wav_path in glob.glob(wav_dir + '/*.wav'):
|
||||||
if not wav_path or not tf.io.gfile.exists(wav_path):
|
if not wav_path or not tf.io.gfile.exists(wav_path):
|
||||||
tf.compat.v1.logging.fatal('Audio file does not exist %s', wav_path)
|
raise ValueError('Audio file does not exist at {0}'.format(wav_path))
|
||||||
|
|
||||||
with open(wav_path, 'rb') as wav_file:
|
with open(wav_path, 'rb') as wav_file:
|
||||||
wav_data = wav_file.read()
|
wav_data = wav_file.read()
|
||||||
|
|
||||||
@ -86,10 +85,10 @@ def run_graph(wav_dir, labels, input_layer_name, output_layer_name,
|
|||||||
def label_wav(wav_dir, labels, graph, input_name, output_name, how_many_labels):
|
def label_wav(wav_dir, labels, graph, input_name, output_name, how_many_labels):
|
||||||
"""Loads the model and labels, and runs the inference to print predictions."""
|
"""Loads the model and labels, and runs the inference to print predictions."""
|
||||||
if not labels or not tf.io.gfile.exists(labels):
|
if not labels or not tf.io.gfile.exists(labels):
|
||||||
tf.compat.v1.logging.fatal('Labels file does not exist %s', labels)
|
raise ValueError('Labels file does not exist at {0}'.format(labels))
|
||||||
|
|
||||||
if not graph or not tf.io.gfile.exists(graph):
|
if not graph or not tf.io.gfile.exists(graph):
|
||||||
tf.compat.v1.logging.fatal('Graph file does not exist %s', graph)
|
raise ValueError('Graph file does not exist at {0}'.format(graph))
|
||||||
|
|
||||||
labels_list = load_labels(labels)
|
labels_list = load_labels(labels)
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ from tensorflow.python.platform import test
|
|||||||
class LabelWavTest(test.TestCase):
|
class LabelWavTest(test.TestCase):
|
||||||
|
|
||||||
def _getWavData(self):
|
def _getWavData(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
sample_data = tf.zeros([1000, 2])
|
sample_data = tf.zeros([1000, 2])
|
||||||
wav_encoder = tf.audio.encode_wav(sample_data, 16000)
|
wav_encoder = tf.audio.encode_wav(sample_data, 16000)
|
||||||
wav_data = self.evaluate(wav_encoder)
|
wav_data = self.evaluate(wav_encoder)
|
||||||
|
@ -251,12 +251,16 @@ def main(_):
|
|||||||
dropout_rate: 0.5
|
dropout_rate: 0.5
|
||||||
})
|
})
|
||||||
train_writer.add_summary(train_summary, training_step)
|
train_writer.add_summary(train_summary, training_step)
|
||||||
tf.compat.v1.logging.info(
|
tf.compat.v1.logging.debug(
|
||||||
'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
|
'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
|
||||||
(training_step, learning_rate_value, train_accuracy * 100,
|
(training_step, learning_rate_value, train_accuracy * 100,
|
||||||
cross_entropy_value))
|
cross_entropy_value))
|
||||||
is_last_step = (training_step == training_steps_max)
|
is_last_step = (training_step == training_steps_max)
|
||||||
if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
|
if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
|
||||||
|
tf.compat.v1.logging.info(
|
||||||
|
'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
|
||||||
|
(training_step, learning_rate_value, train_accuracy * 100,
|
||||||
|
cross_entropy_value))
|
||||||
set_size = audio_processor.set_size('validation')
|
set_size = audio_processor.set_size('validation')
|
||||||
total_accuracy = 0
|
total_accuracy = 0
|
||||||
total_conf_matrix = None
|
total_conf_matrix = None
|
||||||
|
@ -30,7 +30,7 @@ from tensorflow.python.platform import test
|
|||||||
class WavToFeaturesTest(test.TestCase):
|
class WavToFeaturesTest(test.TestCase):
|
||||||
|
|
||||||
def _getWavData(self):
|
def _getWavData(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
sample_data = tf.zeros([32000, 2])
|
sample_data = tf.zeros([32000, 2])
|
||||||
wav_encoder = tf.audio.encode_wav(sample_data, 16000)
|
wav_encoder = tf.audio.encode_wav(sample_data, 16000)
|
||||||
wav_data = self.evaluate(wav_encoder)
|
wav_data = self.evaluate(wav_encoder)
|
||||||
@ -63,7 +63,7 @@ class WavToFeaturesTest(test.TestCase):
|
|||||||
input_file_path, output_file_path)
|
input_file_path, output_file_path)
|
||||||
with open(output_file_path, "rb") as f:
|
with open(output_file_path, "rb") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
self.assertTrue(b"const unsigned char g_input_data" in content)
|
self.assertIn(b"const unsigned char g_input_data", content)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testWavToFeaturesMicro(self):
|
def testWavToFeaturesMicro(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user