[rollback] Use self.handle inside ResourceVariable to allow tf.distribute to customize

handle behavior

PiperOrigin-RevId: 356541183
Change-Id: If4dbfc32a834c464bc94ce1c3ae71b3fb72e1e55
This commit is contained in:
Ran Chen 2021-02-09 10:55:20 -08:00 committed by TensorFlower Gardener
parent 417c2d448c
commit beab125d24
2 changed files with 12 additions and 13 deletions

View File

@ -252,7 +252,6 @@ class PackedVarAndDevice(object):
self._device = device
def __getattr__(self, name):
with ops.device(self._device):
return getattr(self._var, name)
def var(self):

View File

@ -516,12 +516,12 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
@property
def device(self):
"""The device this variable is on."""
return self.handle.device
return self._handle.device
@property
def graph(self):
"""The `Graph` of this variable."""
return self.handle.graph
return self._handle.graph
@property
def name(self):
@ -596,7 +596,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
@property
def op(self):
"""The op for this variable."""
return self.handle.op
return self._handle.op
@property
def trainable(self):
@ -655,7 +655,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
else:
new_variable = copy_to_graph_uninitialized(self)
obj_map = {self: new_variable}
resource_map = {self.handle: new_variable.handle}
resource_map = {self._handle: new_variable.handle}
return obj_map, resource_map
def _read_variable_op(self):
@ -663,8 +663,8 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
def read_and_set_handle():
result = gen_resource_variable_ops.read_variable_op(
self.handle, self._dtype)
_maybe_set_handle_data(self._dtype, self.handle, result)
self._handle, self._dtype)
_maybe_set_handle_data(self._dtype, self._handle, result)
return result
if getattr(self, "_caching_device", None) is not None:
@ -678,7 +678,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
# Note that if a control flow context is active the input of the read op
# might not actually be the handle. This line bypasses it.
tape.record_operation(
"ReadVariableOp", [result], [self.handle],
"ReadVariableOp", [result], [self._handle],
backward_function=lambda x: [x],
forward_function=lambda x: [x])
return result
@ -703,12 +703,12 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
with ops.name_scope("Gather" if name is None else name) as name:
variable_accessed(self)
value = gen_resource_variable_ops.resource_gather(
self.handle, indices, dtype=self._dtype, name=name)
self._handle, indices, dtype=self._dtype, name=name)
if self._dtype == dtypes.variant:
# For DT_VARIANT types, the handle's shape_and_type[1:] stores the
# variant's handle data. Extract it.
handle_data = get_eager_safe_handle_data(self.handle)
handle_data = get_eager_safe_handle_data(self._handle)
if handle_data.is_set and len(handle_data.shape_and_type) > 1:
value._handle_data = ( # pylint: disable=protected-access
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
@ -722,7 +722,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
if self.trainable:
variable_accessed(self)
value = gen_resource_variable_ops.resource_gather_nd(
self.handle, indices, dtype=self._dtype, name=name)
self._handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value)
@ -855,7 +855,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
def _lazy_read(self, op):
variable_accessed(self)
return _UnreadVariable(
handle=self.handle,
handle=self._handle,
dtype=self.dtype,
shape=self._shape,
in_graph_mode=self._in_graph_mode,