Use self._in_graph_mode consistently in ResourceVariable
instead of sometimes getting it from the context. Also: fix formatting of a comment and use a more precise test to detect if initial_value is set. PiperOrigin-RevId: 168047258
This commit is contained in:
parent
f331f528b8
commit
ff6dd474a6
@ -42,14 +42,14 @@ from tensorflow.python.ops.gen_resource_variable_ops import *
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
def _eager_safe_variable_handle(shape, dtype, shared_name, name,
|
||||
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode,
|
||||
container=None):
|
||||
"""Creates a variable handle with information to do shape inference."""
|
||||
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
||||
shared_name=shared_name,
|
||||
name=name,
|
||||
container=container)
|
||||
if context.in_graph_mode():
|
||||
if graph_mode:
|
||||
return handle
|
||||
with context.graph_mode(), ops.Graph().as_default():
|
||||
h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
||||
@ -152,8 +152,8 @@ class ResourceVariable(variables.Variable):
|
||||
uniquified automatically.
|
||||
dtype: If set, initial_value will be converted to the given type.
|
||||
If None, either the datatype will be kept (if initial_value is
|
||||
a Tensor) or float32 will be used (if it is a Python object convertible
|
||||
to a Tensor).
|
||||
a Tensor) or float32 will be used (if it is a Python object convertible
|
||||
to a Tensor).
|
||||
variable_def: `VariableDef` protocol buffer. If not None, recreates the
|
||||
`ResourceVariable` object with its contents. `variable_def` and other
|
||||
arguments (except for import_scope) are mutually exclusive.
|
||||
@ -172,7 +172,7 @@ class ResourceVariable(variables.Variable):
|
||||
shape and `validate_shape` is `True`.
|
||||
"""
|
||||
if variable_def:
|
||||
if initial_value:
|
||||
if initial_value is not None:
|
||||
raise ValueError("variable_def and initial_value are mutually "
|
||||
"exclusive.")
|
||||
if not context.in_graph_mode():
|
||||
@ -277,7 +277,8 @@ class ResourceVariable(variables.Variable):
|
||||
shape=initial_value.get_shape(),
|
||||
dtype=initial_value.dtype.base_dtype,
|
||||
shared_name=handle_name,
|
||||
name=name)
|
||||
name=name,
|
||||
graph_mode=self._in_graph_mode)
|
||||
self._handle_device = (
|
||||
self._handle.device if self._in_graph_mode else
|
||||
context.get_default_context().device_name)
|
||||
@ -291,6 +292,7 @@ class ResourceVariable(variables.Variable):
|
||||
dtype=initial_value.dtype.base_dtype,
|
||||
shared_name=handle_name,
|
||||
name=name,
|
||||
graph_mode=False,
|
||||
container="")
|
||||
self._handle_device = (
|
||||
self._handle.device if self._in_graph_mode else
|
||||
@ -316,6 +318,7 @@ class ResourceVariable(variables.Variable):
|
||||
dtype=initial_value.dtype.base_dtype,
|
||||
shared_name=handle_name,
|
||||
name=name,
|
||||
graph_mode=self._in_graph_mode,
|
||||
container="")
|
||||
self._handle_device = (self._handle.device if self._in_graph_mode else
|
||||
context.get_default_context().device_name)
|
||||
@ -372,6 +375,7 @@ class ResourceVariable(variables.Variable):
|
||||
"""Initializes from `VariableDef` proto."""
|
||||
# Note that init_from_proto is currently not supported in Eager mode.
|
||||
assert context.in_graph_mode()
|
||||
self._in_graph_mode = True
|
||||
assert isinstance(variable_def, variable_pb2.VariableDef)
|
||||
if not variable_def.is_resource:
|
||||
raise ValueError("Trying to restore Variable as ResourceVariable.")
|
||||
@ -434,7 +438,7 @@ class ResourceVariable(variables.Variable):
|
||||
@property
|
||||
def create(self):
|
||||
"""The op responsible for initializing this variable."""
|
||||
if not context.in_graph_mode():
|
||||
if not self._in_graph_mode:
|
||||
raise RuntimeError("Calling create in EAGER mode not supported.")
|
||||
return self._initializer_op
|
||||
|
||||
@ -520,7 +524,7 @@ class ResourceVariable(variables.Variable):
|
||||
# In graph mode, ensure we read the variable in the same device as the
|
||||
# handle. In eager mode, however, this sometimes tries to read a GPU
|
||||
# variable in the CPU because the handle is host memory. For now, then, we
|
||||
# need to skip the device block in eager. TODO(apassos) eager should have
|
||||
# need to skip the device block in eager. TODO(apassos): eager should have
|
||||
# separate notions of device and memory, so handle.device can be GPU while
|
||||
# handle.memory_space is always CPU.
|
||||
if context.in_graph_mode():
|
||||
|
Loading…
Reference in New Issue
Block a user