Adding ValueError checking for summaries argument of optimize_loss (similar to optimizer name checking)
Change: 144727719
This commit is contained in:
parent
18bbff2725
commit
3c44578744
tensorflow/contrib/layers/python/layers
@ -176,6 +176,11 @@ def optimize_loss(loss,
|
||||
str(type(learning_rate))))
|
||||
if summaries is None:
|
||||
summaries = ["loss", "learning_rate"]
|
||||
else:
|
||||
for summ in summaries:
|
||||
if summ not in OPTIMIZER_SUMMARIES:
|
||||
raise ValueError("Summaries should be one of [%s], you provided %s." %
|
||||
(", ".join(OPTIMIZER_SUMMARIES), summ))
|
||||
if learning_rate is not None and learning_rate_decay_fn is not None:
|
||||
if global_step is None:
|
||||
raise ValueError("global_step is required for learning_rate_decay_fn.")
|
||||
|
@ -108,6 +108,14 @@ class OptimizersTest(test.TestCase):
|
||||
optimizers_lib.optimize_loss(
|
||||
loss, global_step, learning_rate=0.1, optimizer=optimizer)
|
||||
|
||||
def testBadSummaries(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(graph=g):
|
||||
_, _, loss, global_step = _setup_model()
|
||||
with self.assertRaises(ValueError):
|
||||
optimizers_lib.optimize_loss(
|
||||
loss, global_step, learning_rate=0.1, optimizer="SGD",
|
||||
summaries=["loss", "bad_summary"])
|
||||
|
||||
def testInvalidLoss(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(graph=g):
|
||||
_, _, _, global_step = _setup_model()
|
||||
|
Loading…
Reference in New Issue
Block a user