Merge pull request #32399 from jaingaurav/cherry-2.0

[r2.0 CherryPick]: Use experimental_ref() in moving_averages
This commit is contained in:
Goldie Gadde 2019-09-15 11:35:15 -07:00 committed by GitHub
commit 4096702326
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,7 +28,6 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator from tensorflow.python.training import slot_creator
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -369,7 +368,7 @@ class ExponentialMovingAverage(object):
self._num_updates = num_updates self._num_updates = num_updates
self._zero_debias = zero_debias self._zero_debias = zero_debias
self._name = name self._name = name
self._averages = object_identity.ObjectIdentityDictionary() self._averages = {}
@property @property
def name(self): def name(self):
@ -423,7 +422,7 @@ class ExponentialMovingAverage(object):
raise TypeError("The variables must be half, float, or double: %s" % raise TypeError("The variables must be half, float, or double: %s" %
var.name) var.name)
if var not in self._averages: if var.experimental_ref() not in self._averages:
# For variables: to lower communication bandwidth across devices we keep # For variables: to lower communication bandwidth across devices we keep
# the moving averages on the same device as the variables. For other # the moving averages on the same device as the variables. For other
# tensors, we rely on the existing device allocation mechanism. # tensors, we rely on the existing device allocation mechanism.
@ -445,8 +444,8 @@ class ExponentialMovingAverage(object):
"Variable", "VariableV2", "VarHandleOp" "Variable", "VariableV2", "VarHandleOp"
])) ]))
if self._zero_debias: if self._zero_debias:
zero_debias_true.add(avg) zero_debias_true.add(avg.experimental_ref())
self._averages[var] = avg self._averages[var.experimental_ref()] = avg
with ops.name_scope(self.name) as scope: with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay") decay = ops.convert_to_tensor(self._decay, name="decay")
@ -457,10 +456,9 @@ class ExponentialMovingAverage(object):
(1.0 + num_updates) / (10.0 + num_updates)) (1.0 + num_updates) / (10.0 + num_updates))
updates = [] updates = []
for var in var_list: for var in var_list:
zero_debias = any(self._averages[var] is v for v in zero_debias_true) avg = self._averages[var.experimental_ref()]
updates.append( zero_debias = avg.experimental_ref() in zero_debias_true
assign_moving_average( updates.append(assign_moving_average(avg, var, decay, zero_debias))
self._averages[var], var, decay, zero_debias=zero_debias))
return control_flow_ops.group(*updates, name=scope) return control_flow_ops.group(*updates, name=scope)
def average(self, var): def average(self, var):
@ -473,7 +471,7 @@ class ExponentialMovingAverage(object):
A `Variable` object or `None` if the moving average of `var` A `Variable` object or `None` if the moving average of `var`
is not maintained. is not maintained.
""" """
return self._averages.get(var, None) return self._averages.get(var.experimental_ref(), None)
def average_name(self, var): def average_name(self, var):
"""Returns the name of the `Variable` holding the average for `var`. """Returns the name of the `Variable` holding the average for `var`.
@ -497,8 +495,8 @@ class ExponentialMovingAverage(object):
by the `ExponentialMovingAverage class` to hold the moving average of by the `ExponentialMovingAverage class` to hold the moving average of
`var`. `var`.
""" """
if var in self._averages: if var.experimental_ref() in self._averages:
return self._averages[var].op.name return self._averages[var.experimental_ref()].op.name
return ops.get_default_graph().unique_name( return ops.get_default_graph().unique_name(
var.op.name + "/" + self.name, mark_as_used=False) var.op.name + "/" + self.name, mark_as_used=False)