[rollback] Use self.handle inside ResourceVariable to allow tf.distribute to customize

handle behavior

PiperOrigin-RevId: 356541183
Change-Id: If4dbfc32a834c464bc94ce1c3ae71b3fb72e1e55
This commit is contained in:
Ran Chen 2021-02-09 10:55:20 -08:00 committed by TensorFlower Gardener
parent 417c2d448c
commit beab125d24
2 changed files with 12 additions and 13 deletions

View File

@ -252,8 +252,7 @@ class PackedVarAndDevice(object):
self._device = device self._device = device
def __getattr__(self, name): def __getattr__(self, name):
with ops.device(self._device): return getattr(self._var, name)
return getattr(self._var, name)
def var(self): def var(self):
return self._var return self._var

View File

@ -516,12 +516,12 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
@property @property
def device(self): def device(self):
"""The device this variable is on.""" """The device this variable is on."""
return self.handle.device return self._handle.device
@property @property
def graph(self): def graph(self):
"""The `Graph` of this variable.""" """The `Graph` of this variable."""
return self.handle.graph return self._handle.graph
@property @property
def name(self): def name(self):
@ -596,7 +596,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
@property @property
def op(self): def op(self):
"""The op for this variable.""" """The op for this variable."""
return self.handle.op return self._handle.op
@property @property
def trainable(self): def trainable(self):
@ -655,7 +655,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
else: else:
new_variable = copy_to_graph_uninitialized(self) new_variable = copy_to_graph_uninitialized(self)
obj_map = {self: new_variable} obj_map = {self: new_variable}
resource_map = {self.handle: new_variable.handle} resource_map = {self._handle: new_variable.handle}
return obj_map, resource_map return obj_map, resource_map
def _read_variable_op(self): def _read_variable_op(self):
@ -663,8 +663,8 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
def read_and_set_handle(): def read_and_set_handle():
result = gen_resource_variable_ops.read_variable_op( result = gen_resource_variable_ops.read_variable_op(
self.handle, self._dtype) self._handle, self._dtype)
_maybe_set_handle_data(self._dtype, self.handle, result) _maybe_set_handle_data(self._dtype, self._handle, result)
return result return result
if getattr(self, "_caching_device", None) is not None: if getattr(self, "_caching_device", None) is not None:
@ -678,7 +678,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
# Note that if a control flow context is active the input of the read op # Note that if a control flow context is active the input of the read op
# might not actually be the handle. This line bypasses it. # might not actually be the handle. This line bypasses it.
tape.record_operation( tape.record_operation(
"ReadVariableOp", [result], [self.handle], "ReadVariableOp", [result], [self._handle],
backward_function=lambda x: [x], backward_function=lambda x: [x],
forward_function=lambda x: [x]) forward_function=lambda x: [x])
return result return result
@ -703,12 +703,12 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
with ops.name_scope("Gather" if name is None else name) as name: with ops.name_scope("Gather" if name is None else name) as name:
variable_accessed(self) variable_accessed(self)
value = gen_resource_variable_ops.resource_gather( value = gen_resource_variable_ops.resource_gather(
self.handle, indices, dtype=self._dtype, name=name) self._handle, indices, dtype=self._dtype, name=name)
if self._dtype == dtypes.variant: if self._dtype == dtypes.variant:
# For DT_VARIANT types, the handle's shape_and_type[1:] stores the # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
# variant's handle data. Extract it. # variant's handle data. Extract it.
handle_data = get_eager_safe_handle_data(self.handle) handle_data = get_eager_safe_handle_data(self._handle)
if handle_data.is_set and len(handle_data.shape_and_type) > 1: if handle_data.is_set and len(handle_data.shape_and_type) > 1:
value._handle_data = ( # pylint: disable=protected-access value._handle_data = ( # pylint: disable=protected-access
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
@ -722,7 +722,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
if self.trainable: if self.trainable:
variable_accessed(self) variable_accessed(self)
value = gen_resource_variable_ops.resource_gather_nd( value = gen_resource_variable_ops.resource_gather_nd(
self.handle, indices, dtype=self._dtype, name=name) self._handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value) return array_ops.identity(value)
@ -855,7 +855,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
def _lazy_read(self, op): def _lazy_read(self, op):
variable_accessed(self) variable_accessed(self)
return _UnreadVariable( return _UnreadVariable(
handle=self.handle, handle=self._handle,
dtype=self.dtype, dtype=self.dtype,
shape=self._shape, shape=self._shape,
in_graph_mode=self._in_graph_mode, in_graph_mode=self._in_graph_mode,