Merge pull request #32399 from jaingaurav/cherry-2.0
[r2.0 CherryPick]: Use experimental_ref() in moving_averages
This commit is contained in:
commit
4096702326
@ -28,7 +28,6 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import slot_creator
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -369,7 +368,7 @@ class ExponentialMovingAverage(object):
|
||||
self._num_updates = num_updates
|
||||
self._zero_debias = zero_debias
|
||||
self._name = name
|
||||
self._averages = object_identity.ObjectIdentityDictionary()
|
||||
self._averages = {}
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -423,7 +422,7 @@ class ExponentialMovingAverage(object):
|
||||
raise TypeError("The variables must be half, float, or double: %s" %
|
||||
var.name)
|
||||
|
||||
if var not in self._averages:
|
||||
if var.experimental_ref() not in self._averages:
|
||||
# For variables: to lower communication bandwidth across devices we keep
|
||||
# the moving averages on the same device as the variables. For other
|
||||
# tensors, we rely on the existing device allocation mechanism.
|
||||
@ -445,8 +444,8 @@ class ExponentialMovingAverage(object):
|
||||
"Variable", "VariableV2", "VarHandleOp"
|
||||
]))
|
||||
if self._zero_debias:
|
||||
zero_debias_true.add(avg)
|
||||
self._averages[var] = avg
|
||||
zero_debias_true.add(avg.experimental_ref())
|
||||
self._averages[var.experimental_ref()] = avg
|
||||
|
||||
with ops.name_scope(self.name) as scope:
|
||||
decay = ops.convert_to_tensor(self._decay, name="decay")
|
||||
@ -457,10 +456,9 @@ class ExponentialMovingAverage(object):
|
||||
(1.0 + num_updates) / (10.0 + num_updates))
|
||||
updates = []
|
||||
for var in var_list:
|
||||
zero_debias = any(self._averages[var] is v for v in zero_debias_true)
|
||||
updates.append(
|
||||
assign_moving_average(
|
||||
self._averages[var], var, decay, zero_debias=zero_debias))
|
||||
avg = self._averages[var.experimental_ref()]
|
||||
zero_debias = avg.experimental_ref() in zero_debias_true
|
||||
updates.append(assign_moving_average(avg, var, decay, zero_debias))
|
||||
return control_flow_ops.group(*updates, name=scope)
|
||||
|
||||
def average(self, var):
|
||||
@ -473,7 +471,7 @@ class ExponentialMovingAverage(object):
|
||||
A `Variable` object or `None` if the moving average of `var`
|
||||
is not maintained.
|
||||
"""
|
||||
return self._averages.get(var, None)
|
||||
return self._averages.get(var.experimental_ref(), None)
|
||||
|
||||
def average_name(self, var):
|
||||
"""Returns the name of the `Variable` holding the average for `var`.
|
||||
@ -497,8 +495,8 @@ class ExponentialMovingAverage(object):
|
||||
by the `ExponentialMovingAverage class` to hold the moving average of
|
||||
`var`.
|
||||
"""
|
||||
if var in self._averages:
|
||||
return self._averages[var].op.name
|
||||
if var.experimental_ref() in self._averages:
|
||||
return self._averages[var.experimental_ref()].op.name
|
||||
return ops.get_default_graph().unique_name(
|
||||
var.op.name + "/" + self.name, mark_as_used=False)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user