Adding ValueError checking for summaries argument of optimize_loss (similar to optimizer name checking)
Change: 144727719
This commit is contained in:
parent
18bbff2725
commit
3c44578744
@ -176,6 +176,11 @@ def optimize_loss(loss,
|
|||||||
str(type(learning_rate))))
|
str(type(learning_rate))))
|
||||||
if summaries is None:
|
if summaries is None:
|
||||||
summaries = ["loss", "learning_rate"]
|
summaries = ["loss", "learning_rate"]
|
||||||
|
else:
|
||||||
|
for summ in summaries:
|
||||||
|
if summ not in OPTIMIZER_SUMMARIES:
|
||||||
|
raise ValueError("Summaries should be one of [%s], you provided %s." %
|
||||||
|
(", ".join(OPTIMIZER_SUMMARIES), summ))
|
||||||
if learning_rate is not None and learning_rate_decay_fn is not None:
|
if learning_rate is not None and learning_rate_decay_fn is not None:
|
||||||
if global_step is None:
|
if global_step is None:
|
||||||
raise ValueError("global_step is required for learning_rate_decay_fn.")
|
raise ValueError("global_step is required for learning_rate_decay_fn.")
|
||||||
|
@ -108,6 +108,14 @@ class OptimizersTest(test.TestCase):
|
|||||||
optimizers_lib.optimize_loss(
|
optimizers_lib.optimize_loss(
|
||||||
loss, global_step, learning_rate=0.1, optimizer=optimizer)
|
loss, global_step, learning_rate=0.1, optimizer=optimizer)
|
||||||
|
|
||||||
|
def testBadSummaries(self):
|
||||||
|
with ops.Graph().as_default() as g, self.test_session(graph=g):
|
||||||
|
_, _, loss, global_step = _setup_model()
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
optimizers_lib.optimize_loss(
|
||||||
|
loss, global_step, learning_rate=0.1, optimizer="SGD",
|
||||||
|
summaries=["loss", "bad_summary"])
|
||||||
|
|
||||||
def testInvalidLoss(self):
|
def testInvalidLoss(self):
|
||||||
with ops.Graph().as_default() as g, self.test_session(graph=g):
|
with ops.Graph().as_default() as g, self.test_session(graph=g):
|
||||||
_, _, _, global_step = _setup_model()
|
_, _, _, global_step = _setup_model()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user