[rollback] Use self.handle inside ResourceVariable to allow tf.distribute to customize
handle behavior PiperOrigin-RevId: 356541183 Change-Id: If4dbfc32a834c464bc94ce1c3ae71b3fb72e1e55
This commit is contained in:
parent
417c2d448c
commit
beab125d24
@ -252,8 +252,7 @@ class PackedVarAndDevice(object):
|
|||||||
self._device = device
|
self._device = device
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
with ops.device(self._device):
|
return getattr(self._var, name)
|
||||||
return getattr(self._var, name)
|
|
||||||
|
|
||||||
def var(self):
|
def var(self):
|
||||||
return self._var
|
return self._var
|
||||||
|
@ -516,12 +516,12 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
|||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
"""The device this variable is on."""
|
"""The device this variable is on."""
|
||||||
return self.handle.device
|
return self._handle.device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def graph(self):
|
def graph(self):
|
||||||
"""The `Graph` of this variable."""
|
"""The `Graph` of this variable."""
|
||||||
return self.handle.graph
|
return self._handle.graph
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -596,7 +596,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
|||||||
@property
|
@property
|
||||||
def op(self):
|
def op(self):
|
||||||
"""The op for this variable."""
|
"""The op for this variable."""
|
||||||
return self.handle.op
|
return self._handle.op
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def trainable(self):
|
def trainable(self):
|
||||||
@ -655,7 +655,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
|||||||
else:
|
else:
|
||||||
new_variable = copy_to_graph_uninitialized(self)
|
new_variable = copy_to_graph_uninitialized(self)
|
||||||
obj_map = {self: new_variable}
|
obj_map = {self: new_variable}
|
||||||
resource_map = {self.handle: new_variable.handle}
|
resource_map = {self._handle: new_variable.handle}
|
||||||
return obj_map, resource_map
|
return obj_map, resource_map
|
||||||
|
|
||||||
def _read_variable_op(self):
|
def _read_variable_op(self):
|
||||||
@ -663,8 +663,8 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
|||||||
|
|
||||||
def read_and_set_handle():
|
def read_and_set_handle():
|
||||||
result = gen_resource_variable_ops.read_variable_op(
|
result = gen_resource_variable_ops.read_variable_op(
|
||||||
self.handle, self._dtype)
|
self._handle, self._dtype)
|
||||||
_maybe_set_handle_data(self._dtype, self.handle, result)
|
_maybe_set_handle_data(self._dtype, self._handle, result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if getattr(self, "_caching_device", None) is not None:
|
if getattr(self, "_caching_device", None) is not None:
|
||||||
@ -678,7 +678,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
|||||||
# Note that if a control flow context is active the input of the read op
|
# Note that if a control flow context is active the input of the read op
|
||||||
# might not actually be the handle. This line bypasses it.
|
# might not actually be the handle. This line bypasses it.
|
||||||
tape.record_operation(
|
tape.record_operation(
|
||||||
"ReadVariableOp", [result], [self.handle],
|
"ReadVariableOp", [result], [self._handle],
|
||||||
backward_function=lambda x: [x],
|
backward_function=lambda x: [x],
|
||||||
forward_function=lambda x: [x])
|
forward_function=lambda x: [x])
|
||||||
return result
|
return result
|
||||||
@ -703,12 +703,12 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
|||||||
with ops.name_scope("Gather" if name is None else name) as name:
|
with ops.name_scope("Gather" if name is None else name) as name:
|
||||||
variable_accessed(self)
|
variable_accessed(self)
|
||||||
value = gen_resource_variable_ops.resource_gather(
|
value = gen_resource_variable_ops.resource_gather(
|
||||||
self.handle, indices, dtype=self._dtype, name=name)
|
self._handle, indices, dtype=self._dtype, name=name)
|
||||||
|
|
||||||
if self._dtype == dtypes.variant:
|
if self._dtype == dtypes.variant:
|
||||||
# For DT_VARIANT types, the handle's shape_and_type[1:] stores the
|
# For DT_VARIANT types, the handle's shape_and_type[1:] stores the
|
||||||
# variant's handle data. Extract it.
|
# variant's handle data. Extract it.
|
||||||
handle_data = get_eager_safe_handle_data(self.handle)
|
handle_data = get_eager_safe_handle_data(self._handle)
|
||||||
if handle_data.is_set and len(handle_data.shape_and_type) > 1:
|
if handle_data.is_set and len(handle_data.shape_and_type) > 1:
|
||||||
value._handle_data = ( # pylint: disable=protected-access
|
value._handle_data = ( # pylint: disable=protected-access
|
||||||
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
|
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
|
||||||
@ -722,7 +722,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
|||||||
if self.trainable:
|
if self.trainable:
|
||||||
variable_accessed(self)
|
variable_accessed(self)
|
||||||
value = gen_resource_variable_ops.resource_gather_nd(
|
value = gen_resource_variable_ops.resource_gather_nd(
|
||||||
self.handle, indices, dtype=self._dtype, name=name)
|
self._handle, indices, dtype=self._dtype, name=name)
|
||||||
|
|
||||||
return array_ops.identity(value)
|
return array_ops.identity(value)
|
||||||
|
|
||||||
@ -855,7 +855,7 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
|||||||
def _lazy_read(self, op):
|
def _lazy_read(self, op):
|
||||||
variable_accessed(self)
|
variable_accessed(self)
|
||||||
return _UnreadVariable(
|
return _UnreadVariable(
|
||||||
handle=self.handle,
|
handle=self._handle,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
shape=self._shape,
|
shape=self._shape,
|
||||||
in_graph_mode=self._in_graph_mode,
|
in_graph_mode=self._in_graph_mode,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user