Remove _id field in DistributedVariable

The _id field was added in DistributedVariable solely to support
capturing of distributed values in FuncGraph when tensor equality is
enabled. This is unnecessary since we can instead key the captures
dictionary with id(), since the key is stored as part of the value
tuple. We don't use experimental_ref() here since we want to easily
return tuples of tensors & placeholders.

PiperOrigin-RevId: 274077594
This commit is contained in:
Gaurav Jain 2019-10-10 17:51:32 -07:00 committed by TensorFlower Gardener
parent 3bd66c35b0
commit c3ec575f83
6 changed files with 21 additions and 16 deletions

View File

@ -652,7 +652,6 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
self.assertIsNotNone(ops.tensor_id(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
self.assertEqual(6.0, mirrored_var_result)

View File

@ -617,7 +617,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
# We need to make _keras_initialized a member of DistributedVariable because
# without this it will use `__getattr__` which will delegate to a component
# variable.
self._id = ops.uid()
self._keras_initialized = False
# Typically, a `DistributedVariable`'s initializer is composed of the
# initializers of the components variables. However, in some cases, such as

View File

@ -844,7 +844,7 @@ class _TapeGradientFunctions(object):
if backprop_util.IsTrainable(output):
# Swap in the Variable object for resource handles if we can so
# sparse gradients work.
output = handles_to_variables.get(ops.tensor_id(output), output)
output = handles_to_variables.get(id(output), output)
trainable_outputs.append(output)
trainable_indices.append(index)

View File

@ -622,7 +622,7 @@ class FuncGraph(ops.Graph):
return tensor
def _capture_helper(self, tensor, name, shape=None):
capture = self._captures.get(ops.tensor_id(tensor))
capture = self._captures.get(id(tensor))
if capture is None:
placeholder = _create_substitute_placeholder(
tensor, name=name, dtype=tensor.dtype, shape=shape)
@ -646,18 +646,22 @@ class FuncGraph(ops.Graph):
tensor: Tensor to captures.
placeholder: Provided placeholder for the tensor.
"""
self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
self._captures[id(tensor)] = (tensor, placeholder)
self.inputs.append(placeholder)
def replace_capture(self, tensor, placeholder):
"""Replace already existing capture."""
self._captures[id(tensor)] = (tensor, placeholder)
def reset_captures(self, capture_list):
"""Set the captures with the provided list of captures & placeholder."""
self._captures = py_collections.OrderedDict()
for tensor, placeholder in capture_list:
self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
self._captures[id(tensor)] = (tensor, placeholder)
def pop_capture(self, tensor):
"""Remove the capture and return the generated placeholder."""
capture = self._captures.pop(ops.tensor_id(tensor), None)
capture = self._captures.pop(id(tensor), None)
if capture is None:
return None
@ -675,13 +679,13 @@ class FuncGraph(ops.Graph):
def capture_distributed_variable(self, variable, placeholder):
"""Add given distributed variable to captures with given placeholder."""
self._captures[ops.tensor_id(variable)] = (variable, placeholder)
self._captures[id(variable)] = (variable, placeholder)
tape.record_operation("captured_value", [placeholder], [variable],
backward_function=lambda x: [x],
forward_function=lambda x: [x])
def capture_eager_tensor(self, tensor, name):
capture = self._captures.get(ops.tensor_id(tensor))
capture = self._captures.get(id(tensor))
if capture is None:
# We clear all control dependencies and place the Const op on the same
# device as the source tensor. The device placement may be relaxed at
@ -697,6 +701,10 @@ class FuncGraph(ops.Graph):
forward_function=lambda x: [x])
return graph_const
def captured(self, tensor):
"""Check if the specified tensor has been captured."""
return id(tensor) in self._captures
@property
def external_captures(self):
"""External tensors captured by this function."""
@ -719,11 +727,11 @@ class FuncGraph(ops.Graph):
@property
def variable_captures(self):
"""Map of tensor ids of variable handles to variables which are captured."""
"""Map of python object ids of variables to variables which are captured."""
return {
ops.tensor_id(self._captures[ops.tensor_id(v.handle)][1]): v
id(self._captures[id(v)][1]): v
for v in self.variables
if ops.tensor_id(v.handle) in self._captures
if id(v) in self._captures
}
def mark_as_unsaveable(self, error_message):

View File

@ -914,8 +914,7 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
with self._forward_cond_graph.as_default():
self._forward_cond_graph.capture(tensor)
with self._forward_graph.as_default():
already_captured = ops.tensor_id(
tensor) in self._forward_graph._captures
already_captured = self._forward_graph.captured(tensor)
if not already_captured:
self.extra_inputs.append(tensor)
tensor = self._forward_graph.capture(tensor)

View File

@ -180,8 +180,8 @@ class Loader(object):
concrete_function.graph.capture_distributed_variable(
bound_input, internal_capture)
else:
concrete_function.graph._captures[ops.tensor_id(bound_input)] = ( # pylint: disable=protected-access
bound_input, internal_capture)
concrete_function.graph.replace_capture(bound_input,
internal_capture)
if internal_capture.dtype == dtypes.resource:
if resource_variable_ops.is_resource_variable(bound_input):
try: