Remove _id field in DistributedVariable
The _id field was added in DistributedVariable solely to support capturing of distributed values in FuncGraph when tensor equality is enabled. This is unnecessary since we can instead key the captures dictionary with id(), since the key is stored as part of the value tuple. We don't use experimental_ref() here since we want to easily return tuples of tensors & placeholders. PiperOrigin-RevId: 274077594
This commit is contained in:
parent
3bd66c35b0
commit
c3ec575f83
@ -652,7 +652,6 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
self.assertIsInstance(mirrored_var, values.MirroredVariable)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(1.0, self.evaluate(mirrored_var))
|
||||
self.assertIsNotNone(ops.tensor_id(mirrored_var))
|
||||
mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
|
||||
self.assertEqual(6.0, mirrored_var_result)
|
||||
|
||||
|
@ -617,7 +617,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
# We need to make _keras_initialized a member of DistributedVariable because
|
||||
# without this it will use `__getattr__` which will delegate to a component
|
||||
# variable.
|
||||
self._id = ops.uid()
|
||||
self._keras_initialized = False
|
||||
# Typically, a `DistributedVariable`'s initializer is composed of the
|
||||
# initializers of the components variables. However, in some cases, such as
|
||||
|
@ -844,7 +844,7 @@ class _TapeGradientFunctions(object):
|
||||
if backprop_util.IsTrainable(output):
|
||||
# Swap in the Variable object for resource handles if we can so
|
||||
# sparse gradients work.
|
||||
output = handles_to_variables.get(ops.tensor_id(output), output)
|
||||
output = handles_to_variables.get(id(output), output)
|
||||
trainable_outputs.append(output)
|
||||
trainable_indices.append(index)
|
||||
|
||||
|
@ -622,7 +622,7 @@ class FuncGraph(ops.Graph):
|
||||
return tensor
|
||||
|
||||
def _capture_helper(self, tensor, name, shape=None):
|
||||
capture = self._captures.get(ops.tensor_id(tensor))
|
||||
capture = self._captures.get(id(tensor))
|
||||
if capture is None:
|
||||
placeholder = _create_substitute_placeholder(
|
||||
tensor, name=name, dtype=tensor.dtype, shape=shape)
|
||||
@ -646,18 +646,22 @@ class FuncGraph(ops.Graph):
|
||||
tensor: Tensor to captures.
|
||||
placeholder: Provided placeholder for the tensor.
|
||||
"""
|
||||
self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
|
||||
self._captures[id(tensor)] = (tensor, placeholder)
|
||||
self.inputs.append(placeholder)
|
||||
|
||||
def replace_capture(self, tensor, placeholder):
|
||||
"""Replace already existing capture."""
|
||||
self._captures[id(tensor)] = (tensor, placeholder)
|
||||
|
||||
def reset_captures(self, capture_list):
|
||||
"""Set the captures with the provided list of captures & placeholder."""
|
||||
self._captures = py_collections.OrderedDict()
|
||||
for tensor, placeholder in capture_list:
|
||||
self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
|
||||
self._captures[id(tensor)] = (tensor, placeholder)
|
||||
|
||||
def pop_capture(self, tensor):
|
||||
"""Remove the capture and return the generated placeholder."""
|
||||
capture = self._captures.pop(ops.tensor_id(tensor), None)
|
||||
capture = self._captures.pop(id(tensor), None)
|
||||
if capture is None:
|
||||
return None
|
||||
|
||||
@ -675,13 +679,13 @@ class FuncGraph(ops.Graph):
|
||||
|
||||
def capture_distributed_variable(self, variable, placeholder):
|
||||
"""Add given distributed variable to captures with given placeholder."""
|
||||
self._captures[ops.tensor_id(variable)] = (variable, placeholder)
|
||||
self._captures[id(variable)] = (variable, placeholder)
|
||||
tape.record_operation("captured_value", [placeholder], [variable],
|
||||
backward_function=lambda x: [x],
|
||||
forward_function=lambda x: [x])
|
||||
|
||||
def capture_eager_tensor(self, tensor, name):
|
||||
capture = self._captures.get(ops.tensor_id(tensor))
|
||||
capture = self._captures.get(id(tensor))
|
||||
if capture is None:
|
||||
# We clear all control dependencies and place the Const op on the same
|
||||
# device as the source tensor. The device placement may be relaxed at
|
||||
@ -697,6 +701,10 @@ class FuncGraph(ops.Graph):
|
||||
forward_function=lambda x: [x])
|
||||
return graph_const
|
||||
|
||||
def captured(self, tensor):
|
||||
"""Check if the specified tensor has been captured."""
|
||||
return id(tensor) in self._captures
|
||||
|
||||
@property
|
||||
def external_captures(self):
|
||||
"""External tensors captured by this function."""
|
||||
@ -719,11 +727,11 @@ class FuncGraph(ops.Graph):
|
||||
|
||||
@property
|
||||
def variable_captures(self):
|
||||
"""Map of tensor ids of variable handles to variables which are captured."""
|
||||
"""Map of python object ids of variables to variables which are captured."""
|
||||
return {
|
||||
ops.tensor_id(self._captures[ops.tensor_id(v.handle)][1]): v
|
||||
id(self._captures[id(v)][1]): v
|
||||
for v in self.variables
|
||||
if ops.tensor_id(v.handle) in self._captures
|
||||
if id(v) in self._captures
|
||||
}
|
||||
|
||||
def mark_as_unsaveable(self, error_message):
|
||||
|
@ -914,8 +914,7 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
with self._forward_cond_graph.as_default():
|
||||
self._forward_cond_graph.capture(tensor)
|
||||
with self._forward_graph.as_default():
|
||||
already_captured = ops.tensor_id(
|
||||
tensor) in self._forward_graph._captures
|
||||
already_captured = self._forward_graph.captured(tensor)
|
||||
if not already_captured:
|
||||
self.extra_inputs.append(tensor)
|
||||
tensor = self._forward_graph.capture(tensor)
|
||||
|
@ -180,8 +180,8 @@ class Loader(object):
|
||||
concrete_function.graph.capture_distributed_variable(
|
||||
bound_input, internal_capture)
|
||||
else:
|
||||
concrete_function.graph._captures[ops.tensor_id(bound_input)] = ( # pylint: disable=protected-access
|
||||
bound_input, internal_capture)
|
||||
concrete_function.graph.replace_capture(bound_input,
|
||||
internal_capture)
|
||||
if internal_capture.dtype == dtypes.resource:
|
||||
if resource_variable_ops.is_resource_variable(bound_input):
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user