Adding ValueError checking for summaries argument of optimize_loss (similar to optimizer name checking)

Change: 144727719
This commit is contained in:
A. Unique TensorFlower 2017-01-17 10:44:18 -08:00 committed by TensorFlower Gardener
parent 18bbff2725
commit 3c44578744
2 changed files with 13 additions and 0 deletions
tensorflow/contrib/layers/python/layers

View File

@ -176,6 +176,11 @@ def optimize_loss(loss,
str(type(learning_rate))))
if summaries is None:
summaries = ["loss", "learning_rate"]
else:
for summ in summaries:
if summ not in OPTIMIZER_SUMMARIES:
raise ValueError("Summaries should be one of [%s], you provided %s." %
(", ".join(OPTIMIZER_SUMMARIES), summ))
if learning_rate is not None and learning_rate_decay_fn is not None:
if global_step is None:
raise ValueError("global_step is required for learning_rate_decay_fn.")

View File

@ -108,6 +108,14 @@ class OptimizersTest(test.TestCase):
optimizers_lib.optimize_loss(
loss, global_step, learning_rate=0.1, optimizer=optimizer)
def testBadSummaries(self):
with ops.Graph().as_default() as g, self.test_session(graph=g):
_, _, loss, global_step = _setup_model()
with self.assertRaises(ValueError):
optimizers_lib.optimize_loss(
loss, global_step, learning_rate=0.1, optimizer="SGD",
summaries=["loss", "bad_summary"])
def testInvalidLoss(self):
with ops.Graph().as_default() as g, self.test_session(graph=g):
_, _, _, global_step = _setup_model()