Merge pull request from jaingaurav/cherry-2.0

[r2.0 CherryPick]: Use experimental_ref() in moving_averages
This commit is contained in:
Goldie Gadde 2019-09-15 11:35:15 -07:00 committed by GitHub
commit 4096702326
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,7 +28,6 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
@ -369,7 +368,7 @@ class ExponentialMovingAverage(object):
self._num_updates = num_updates
self._zero_debias = zero_debias
self._name = name
self._averages = object_identity.ObjectIdentityDictionary()
self._averages = {}
@property
def name(self):
@ -423,7 +422,7 @@ class ExponentialMovingAverage(object):
raise TypeError("The variables must be half, float, or double: %s" %
var.name)
if var not in self._averages:
if var.experimental_ref() not in self._averages:
# For variables: to lower communication bandwidth across devices we keep
# the moving averages on the same device as the variables. For other
# tensors, we rely on the existing device allocation mechanism.
@ -445,8 +444,8 @@ class ExponentialMovingAverage(object):
"Variable", "VariableV2", "VarHandleOp"
]))
if self._zero_debias:
zero_debias_true.add(avg)
self._averages[var] = avg
zero_debias_true.add(avg.experimental_ref())
self._averages[var.experimental_ref()] = avg
with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
@ -457,10 +456,9 @@ class ExponentialMovingAverage(object):
(1.0 + num_updates) / (10.0 + num_updates))
updates = []
for var in var_list:
zero_debias = any(self._averages[var] is v for v in zero_debias_true)
updates.append(
assign_moving_average(
self._averages[var], var, decay, zero_debias=zero_debias))
avg = self._averages[var.experimental_ref()]
zero_debias = avg.experimental_ref() in zero_debias_true
updates.append(assign_moving_average(avg, var, decay, zero_debias))
return control_flow_ops.group(*updates, name=scope)
def average(self, var):
@ -473,7 +471,7 @@ class ExponentialMovingAverage(object):
A `Variable` object or `None` if the moving average of `var`
is not maintained.
"""
return self._averages.get(var, None)
return self._averages.get(var.experimental_ref(), None)
def average_name(self, var):
"""Returns the name of the `Variable` holding the average for `var`.
@ -497,8 +495,8 @@ class ExponentialMovingAverage(object):
by the `ExponentialMovingAverage class` to hold the moving average of
`var`.
"""
if var in self._averages:
return self._averages[var].op.name
if var.experimental_ref() in self._averages:
return self._averages[var.experimental_ref()].op.name
return ops.get_default_graph().unique_name(
var.op.name + "/" + self.name, mark_as_used=False)