Removed caching of optimizer. Since optimizer may depend on a graph element (global_step).

Change: 126415442
This commit is contained in:
Mustafa Ispir 2016-07-01 08:09:22 -08:00 committed by TensorFlower Gardener
parent a00fa7b701
commit 242fe922e4

View File

@ -110,8 +110,7 @@ class _ComposableModel(object):
grads = gradients.gradients(loss, my_vars)
if self._gradient_clip_norm:
grads, _ = clip_ops.clip_by_global_norm(grads, self._gradient_clip_norm)
self._optimizer = self._get_optimizer()
return [self._optimizer.apply_gradients(zip(grads, my_vars))]
return [self._get_optimizer().apply_gradients(zip(grads, my_vars))]
def _get_feature_columns(self):
if not self._feature_columns:
@ -132,10 +131,12 @@ class _ComposableModel(object):
def _get_optimizer(self):
if (self._optimizer is None or isinstance(self._optimizer,
six.string_types)):
self._optimizer = self._get_default_optimizer(self._optimizer)
optimizer = self._get_default_optimizer(self._optimizer)
elif callable(self._optimizer):
self._optimizer = self._optimizer()
return self._optimizer
optimizer = self._optimizer()
else:
optimizer = self._optimizer
return optimizer
def _get_default_optimizer(self, optimizer_name=None):
raise NotImplementedError