PR #32792: Add momentum optimizer option to speech_commands

Imported from GitHub PR #32792

- Add command line argument to choose momentum optimizer option to speech_commands.
- Default optimizer for speech commands is unchanged to still use GradientDescentOptimizer.
- Default values for momentum optimizer are to use a momentum of .9 and use_nesterov=True.

Copybara import of the project:

  - 6a2c6007fff318ba9f6f8e4ff335c3089335c774 Add momentum optimizer option to speech_commands by Michael Wang <michael@companionlabs.ai>
  - ffdebcfa55fb9b82fba3add6884b6da0bf9cbde0 Use compat.v1.MomentumOptimizer by Mark Daoust <markdaoust@google.com>
  - 0168e4a6ab93596d2010caa458eb715d16a2a24a Fix formatting by Michael Wang <michael@companionlabs.ai>
  - 19b930f7d42a327bed9cb560b5577f50532a76d7 Merge 0168e4a6ab93596d2010caa458eb715d16a2a24a into b5c48... by mswang12 <40969844+mswang12@users.noreply.github.com>

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/32792 from Companion-Labs:momentum_speech_commands 0168e4a6ab93596d2010caa458eb715d16a2a24a
PiperOrigin-RevId: 274079593
This commit is contained in:
A. Unique TensorFlower 2019-10-10 18:05:20 -07:00 committed by TensorFlower Gardener
parent 4a8171ecd7
commit b4c72f0d5e
2 changed files with 16 additions and 3 deletions

View File

@ -169,8 +169,15 @@ def main(_):
control_dependencies):
learning_rate_input = tf.compat.v1.placeholder(
tf.float32, [], name='learning_rate_input')
train_step = tf.compat.v1.train.GradientDescentOptimizer(
learning_rate_input).minimize(cross_entropy_mean)
if FLAGS.optimizer == 'gradient_descent':
train_step = tf.compat.v1.train.GradientDescentOptimizer(
learning_rate_input).minimize(cross_entropy_mean)
elif FLAGS.optimizer == 'momentum':
train_step = tf.compat.v1.train.MomentumOptimizer(
learning_rate_input, .9,
use_nesterov=True).minimize(cross_entropy_mean)
else:
raise Exception('Invalid Optimizer')
predicted_indices = tf.argmax(input=logits, axis=1)
correct_prediction = tf.equal(predicted_indices, ground_truth_input)
confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input,
@ -491,6 +498,11 @@ if __name__ == '__main__':
type=verbosity_arg,
default=tf.compat.v1.logging.INFO,
help='Log verbosity. Can be "INFO", "DEBUG", "ERROR", "FATAL", or "WARN"')
parser.add_argument(
'--optimizer',
type=str,
default='gradient_descent',
help='Optimizer (gradient_descent or momentum)')
FLAGS, unparsed = parser.parse_known_args()
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -111,7 +111,8 @@ class TrainTest(test.TestCase):
'background_frequency': 0.8,
'eval_step_interval': 1,
'save_step_interval': 1,
'verbosity': tf.compat.v1.logging.INFO
'verbosity': tf.compat.v1.logging.INFO,
'optimizer': 'gradient_descent'
}
return DictStruct(**flags)