Merge pull request #32399 from jaingaurav/cherry-2.0
[r2.0 CherryPick]: Use experimental_ref() in moving_averages
This commit is contained in:
commit
4096702326
@ -28,7 +28,6 @@ from tensorflow.python.ops import state_ops
|
|||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import slot_creator
|
from tensorflow.python.training import slot_creator
|
||||||
from tensorflow.python.util import object_identity
|
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -369,7 +368,7 @@ class ExponentialMovingAverage(object):
|
|||||||
self._num_updates = num_updates
|
self._num_updates = num_updates
|
||||||
self._zero_debias = zero_debias
|
self._zero_debias = zero_debias
|
||||||
self._name = name
|
self._name = name
|
||||||
self._averages = object_identity.ObjectIdentityDictionary()
|
self._averages = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -423,7 +422,7 @@ class ExponentialMovingAverage(object):
|
|||||||
raise TypeError("The variables must be half, float, or double: %s" %
|
raise TypeError("The variables must be half, float, or double: %s" %
|
||||||
var.name)
|
var.name)
|
||||||
|
|
||||||
if var not in self._averages:
|
if var.experimental_ref() not in self._averages:
|
||||||
# For variables: to lower communication bandwidth across devices we keep
|
# For variables: to lower communication bandwidth across devices we keep
|
||||||
# the moving averages on the same device as the variables. For other
|
# the moving averages on the same device as the variables. For other
|
||||||
# tensors, we rely on the existing device allocation mechanism.
|
# tensors, we rely on the existing device allocation mechanism.
|
||||||
@ -445,8 +444,8 @@ class ExponentialMovingAverage(object):
|
|||||||
"Variable", "VariableV2", "VarHandleOp"
|
"Variable", "VariableV2", "VarHandleOp"
|
||||||
]))
|
]))
|
||||||
if self._zero_debias:
|
if self._zero_debias:
|
||||||
zero_debias_true.add(avg)
|
zero_debias_true.add(avg.experimental_ref())
|
||||||
self._averages[var] = avg
|
self._averages[var.experimental_ref()] = avg
|
||||||
|
|
||||||
with ops.name_scope(self.name) as scope:
|
with ops.name_scope(self.name) as scope:
|
||||||
decay = ops.convert_to_tensor(self._decay, name="decay")
|
decay = ops.convert_to_tensor(self._decay, name="decay")
|
||||||
@ -457,10 +456,9 @@ class ExponentialMovingAverage(object):
|
|||||||
(1.0 + num_updates) / (10.0 + num_updates))
|
(1.0 + num_updates) / (10.0 + num_updates))
|
||||||
updates = []
|
updates = []
|
||||||
for var in var_list:
|
for var in var_list:
|
||||||
zero_debias = any(self._averages[var] is v for v in zero_debias_true)
|
avg = self._averages[var.experimental_ref()]
|
||||||
updates.append(
|
zero_debias = avg.experimental_ref() in zero_debias_true
|
||||||
assign_moving_average(
|
updates.append(assign_moving_average(avg, var, decay, zero_debias))
|
||||||
self._averages[var], var, decay, zero_debias=zero_debias))
|
|
||||||
return control_flow_ops.group(*updates, name=scope)
|
return control_flow_ops.group(*updates, name=scope)
|
||||||
|
|
||||||
def average(self, var):
|
def average(self, var):
|
||||||
@ -473,7 +471,7 @@ class ExponentialMovingAverage(object):
|
|||||||
A `Variable` object or `None` if the moving average of `var`
|
A `Variable` object or `None` if the moving average of `var`
|
||||||
is not maintained.
|
is not maintained.
|
||||||
"""
|
"""
|
||||||
return self._averages.get(var, None)
|
return self._averages.get(var.experimental_ref(), None)
|
||||||
|
|
||||||
def average_name(self, var):
|
def average_name(self, var):
|
||||||
"""Returns the name of the `Variable` holding the average for `var`.
|
"""Returns the name of the `Variable` holding the average for `var`.
|
||||||
@ -497,8 +495,8 @@ class ExponentialMovingAverage(object):
|
|||||||
by the `ExponentialMovingAverage class` to hold the moving average of
|
by the `ExponentialMovingAverage class` to hold the moving average of
|
||||||
`var`.
|
`var`.
|
||||||
"""
|
"""
|
||||||
if var in self._averages:
|
if var.experimental_ref() in self._averages:
|
||||||
return self._averages[var].op.name
|
return self._averages[var.experimental_ref()].op.name
|
||||||
return ops.get_default_graph().unique_name(
|
return ops.get_default_graph().unique_name(
|
||||||
var.op.name + "/" + self.name, mark_as_used=False)
|
var.op.name + "/" + self.name, mark_as_used=False)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user