[rollback] Use self.handle inside ResourceVariable to allow tf.distribute to customize
handle behavior PiperOrigin-RevId: 356541183 Change-Id: If4dbfc32a834c464bc94ce1c3ae71b3fb72e1e55
This commit is contained in:
parent
417c2d448c
commit
beab125d24
@ -252,7 +252,6 @@ class PackedVarAndDevice(object):
|
||||
self._device = device
|
||||
|
||||
def __getattr__(self, name):
|
||||
with ops.device(self._device):
|
||||
return getattr(self._var, name)
|
||||
|
||||
def var(self):
|
||||
|
@ -516,12 +516,12 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||
@property
|
||||
def device(self):
|
||||
"""The device this variable is on."""
|
||||
return self.handle.device
|
||||
return self._handle.device
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
"""The `Graph` of this variable."""
|
||||
return self.handle.graph
|
||||
return self._handle.graph
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -596,7 +596,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||
@property
|
||||
def op(self):
|
||||
"""The op for this variable."""
|
||||
return self.handle.op
|
||||
return self._handle.op
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
@ -655,7 +655,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||
else:
|
||||
new_variable = copy_to_graph_uninitialized(self)
|
||||
obj_map = {self: new_variable}
|
||||
resource_map = {self.handle: new_variable.handle}
|
||||
resource_map = {self._handle: new_variable.handle}
|
||||
return obj_map, resource_map
|
||||
|
||||
def _read_variable_op(self):
|
||||
@ -663,8 +663,8 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||
|
||||
def read_and_set_handle():
|
||||
result = gen_resource_variable_ops.read_variable_op(
|
||||
self.handle, self._dtype)
|
||||
_maybe_set_handle_data(self._dtype, self.handle, result)
|
||||
self._handle, self._dtype)
|
||||
_maybe_set_handle_data(self._dtype, self._handle, result)
|
||||
return result
|
||||
|
||||
if getattr(self, "_caching_device", None) is not None:
|
||||
@ -678,7 +678,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||
# Note that if a control flow context is active the input of the read op
|
||||
# might not actually be the handle. This line bypasses it.
|
||||
tape.record_operation(
|
||||
"ReadVariableOp", [result], [self.handle],
|
||||
"ReadVariableOp", [result], [self._handle],
|
||||
backward_function=lambda x: [x],
|
||||
forward_function=lambda x: [x])
|
||||
return result
|
||||
@ -703,12 +703,12 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||
with ops.name_scope("Gather" if name is None else name) as name:
|
||||
variable_accessed(self)
|
||||
value = gen_resource_variable_ops.resource_gather(
|
||||
self.handle, indices, dtype=self._dtype, name=name)
|
||||
self._handle, indices, dtype=self._dtype, name=name)
|
||||
|
||||
if self._dtype == dtypes.variant:
|
||||
# For DT_VARIANT types, the handle's shape_and_type[1:] stores the
|
||||
# variant's handle data. Extract it.
|
||||
handle_data = get_eager_safe_handle_data(self.handle)
|
||||
handle_data = get_eager_safe_handle_data(self._handle)
|
||||
if handle_data.is_set and len(handle_data.shape_and_type) > 1:
|
||||
value._handle_data = ( # pylint: disable=protected-access
|
||||
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
|
||||
@ -722,7 +722,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||
if self.trainable:
|
||||
variable_accessed(self)
|
||||
value = gen_resource_variable_ops.resource_gather_nd(
|
||||
self.handle, indices, dtype=self._dtype, name=name)
|
||||
self._handle, indices, dtype=self._dtype, name=name)
|
||||
|
||||
return array_ops.identity(value)
|
||||
|
||||
@ -855,7 +855,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||
def _lazy_read(self, op):
|
||||
variable_accessed(self)
|
||||
return _UnreadVariable(
|
||||
handle=self.handle,
|
||||
handle=self._handle,
|
||||
dtype=self.dtype,
|
||||
shape=self._shape,
|
||||
in_graph_mode=self._in_graph_mode,
|
||||
|
Loading…
x
Reference in New Issue
Block a user