PR #32792: Add momentum optimizer option to speech_commands
Imported from GitHub PR #32792 - Add command line argument to choose momentum optimizer option to speech_commands. - Default optimizer for speech commands is unchanged to still use GradientDescentOptimizer. - Default values for momentum optimizer are to use a momentum of .9 and use_nesterov=True. Copybara import of the project: - 6a2c6007fff318ba9f6f8e4ff335c3089335c774 Add momentum optimizer option to speech_commands by Michael Wang <michael@companionlabs.ai> - ffdebcfa55fb9b82fba3add6884b6da0bf9cbde0 Use compat.v1.MomentumOptimizer by Mark Daoust <markdaoust@google.com> - 0168e4a6ab93596d2010caa458eb715d16a2a24a Fix formatting by Michael Wang <michael@companionlabs.ai> - 19b930f7d42a327bed9cb560b5577f50532a76d7 Merge 0168e4a6ab93596d2010caa458eb715d16a2a24a into b5c48... by mswang12 <40969844+mswang12@users.noreply.github.com> COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/32792 from Companion-Labs:momentum_speech_commands 0168e4a6ab93596d2010caa458eb715d16a2a24a PiperOrigin-RevId: 274079593
This commit is contained in:
parent
4a8171ecd7
commit
b4c72f0d5e
@ -169,8 +169,15 @@ def main(_):
|
||||
control_dependencies):
|
||||
learning_rate_input = tf.compat.v1.placeholder(
|
||||
tf.float32, [], name='learning_rate_input')
|
||||
train_step = tf.compat.v1.train.GradientDescentOptimizer(
|
||||
learning_rate_input).minimize(cross_entropy_mean)
|
||||
if FLAGS.optimizer == 'gradient_descent':
|
||||
train_step = tf.compat.v1.train.GradientDescentOptimizer(
|
||||
learning_rate_input).minimize(cross_entropy_mean)
|
||||
elif FLAGS.optimizer == 'momentum':
|
||||
train_step = tf.compat.v1.train.MomentumOptimizer(
|
||||
learning_rate_input, .9,
|
||||
use_nesterov=True).minimize(cross_entropy_mean)
|
||||
else:
|
||||
raise Exception('Invalid Optimizer')
|
||||
predicted_indices = tf.argmax(input=logits, axis=1)
|
||||
correct_prediction = tf.equal(predicted_indices, ground_truth_input)
|
||||
confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input,
|
||||
@ -491,6 +498,11 @@ if __name__ == '__main__':
|
||||
type=verbosity_arg,
|
||||
default=tf.compat.v1.logging.INFO,
|
||||
help='Log verbosity. Can be "INFO", "DEBUG", "ERROR", "FATAL", or "WARN"')
|
||||
parser.add_argument(
|
||||
'--optimizer',
|
||||
type=str,
|
||||
default='gradient_descent',
|
||||
help='Optimizer (gradient_descent or momentum)')
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -111,7 +111,8 @@ class TrainTest(test.TestCase):
|
||||
'background_frequency': 0.8,
|
||||
'eval_step_interval': 1,
|
||||
'save_step_interval': 1,
|
||||
'verbosity': tf.compat.v1.logging.INFO
|
||||
'verbosity': tf.compat.v1.logging.INFO,
|
||||
'optimizer': 'gradient_descent'
|
||||
}
|
||||
return DictStruct(**flags)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user