Removed caching of optimizer. Since optimizer may depend on a graph element (global_step).
Change: 126415442
This commit is contained in:
parent
a00fa7b701
commit
242fe922e4
@ -110,8 +110,7 @@ class _ComposableModel(object):
|
||||
grads = gradients.gradients(loss, my_vars)
|
||||
if self._gradient_clip_norm:
|
||||
grads, _ = clip_ops.clip_by_global_norm(grads, self._gradient_clip_norm)
|
||||
self._optimizer = self._get_optimizer()
|
||||
return [self._optimizer.apply_gradients(zip(grads, my_vars))]
|
||||
return [self._get_optimizer().apply_gradients(zip(grads, my_vars))]
|
||||
|
||||
def _get_feature_columns(self):
|
||||
if not self._feature_columns:
|
||||
@ -132,10 +131,12 @@ class _ComposableModel(object):
|
||||
def _get_optimizer(self):
|
||||
if (self._optimizer is None or isinstance(self._optimizer,
|
||||
six.string_types)):
|
||||
self._optimizer = self._get_default_optimizer(self._optimizer)
|
||||
optimizer = self._get_default_optimizer(self._optimizer)
|
||||
elif callable(self._optimizer):
|
||||
self._optimizer = self._optimizer()
|
||||
return self._optimizer
|
||||
optimizer = self._optimizer()
|
||||
else:
|
||||
optimizer = self._optimizer
|
||||
return optimizer
|
||||
|
||||
def _get_default_optimizer(self, optimizer_name=None):
|
||||
raise NotImplementedError
|
||||
|
Loading…
Reference in New Issue
Block a user