Use self._in_graph_mode consistently in ResourceVariable

instead of sometimes getting it from the context.

Also: fix formatting of a comment and use a more precise test to detect
if initial_value is set.
PiperOrigin-RevId: 168047258
This commit is contained in:
A. Unique TensorFlower 2017-09-08 14:33:01 -07:00 committed by TensorFlower Gardener
parent f331f528b8
commit ff6dd474a6

View File

@ -42,14 +42,14 @@ from tensorflow.python.ops.gen_resource_variable_ops import *
from tensorflow.python.util import compat from tensorflow.python.util import compat
def _eager_safe_variable_handle(shape, dtype, shared_name, name, def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode,
container=None): container=None):
"""Creates a variable handle with information to do shape inference.""" """Creates a variable handle with information to do shape inference."""
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
shared_name=shared_name, shared_name=shared_name,
name=name, name=name,
container=container) container=container)
if context.in_graph_mode(): if graph_mode:
return handle return handle
with context.graph_mode(), ops.Graph().as_default(): with context.graph_mode(), ops.Graph().as_default():
h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
@ -152,8 +152,8 @@ class ResourceVariable(variables.Variable):
uniquified automatically. uniquified automatically.
dtype: If set, initial_value will be converted to the given type. dtype: If set, initial_value will be converted to the given type.
If None, either the datatype will be kept (if initial_value is If None, either the datatype will be kept (if initial_value is
a Tensor) or float32 will be used (if it is a Python object convertible a Tensor) or float32 will be used (if it is a Python object convertible
to a Tensor). to a Tensor).
variable_def: `VariableDef` protocol buffer. If not None, recreates the variable_def: `VariableDef` protocol buffer. If not None, recreates the
`ResourceVariable` object with its contents. `variable_def` and other `ResourceVariable` object with its contents. `variable_def` and other
arguments (except for import_scope) are mutually exclusive. arguments (except for import_scope) are mutually exclusive.
@ -172,7 +172,7 @@ class ResourceVariable(variables.Variable):
shape and `validate_shape` is `True`. shape and `validate_shape` is `True`.
""" """
if variable_def: if variable_def:
if initial_value: if initial_value is not None:
raise ValueError("variable_def and initial_value are mutually " raise ValueError("variable_def and initial_value are mutually "
"exclusive.") "exclusive.")
if not context.in_graph_mode(): if not context.in_graph_mode():
@ -277,7 +277,8 @@ class ResourceVariable(variables.Variable):
shape=initial_value.get_shape(), shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype, dtype=initial_value.dtype.base_dtype,
shared_name=handle_name, shared_name=handle_name,
name=name) name=name,
graph_mode=self._in_graph_mode)
self._handle_device = ( self._handle_device = (
self._handle.device if self._in_graph_mode else self._handle.device if self._in_graph_mode else
context.get_default_context().device_name) context.get_default_context().device_name)
@ -291,6 +292,7 @@ class ResourceVariable(variables.Variable):
dtype=initial_value.dtype.base_dtype, dtype=initial_value.dtype.base_dtype,
shared_name=handle_name, shared_name=handle_name,
name=name, name=name,
graph_mode=False,
container="") container="")
self._handle_device = ( self._handle_device = (
self._handle.device if self._in_graph_mode else self._handle.device if self._in_graph_mode else
@ -316,6 +318,7 @@ class ResourceVariable(variables.Variable):
dtype=initial_value.dtype.base_dtype, dtype=initial_value.dtype.base_dtype,
shared_name=handle_name, shared_name=handle_name,
name=name, name=name,
graph_mode=self._in_graph_mode,
container="") container="")
self._handle_device = (self._handle.device if self._in_graph_mode else self._handle_device = (self._handle.device if self._in_graph_mode else
context.get_default_context().device_name) context.get_default_context().device_name)
@ -372,6 +375,7 @@ class ResourceVariable(variables.Variable):
"""Initializes from `VariableDef` proto.""" """Initializes from `VariableDef` proto."""
# Note that init_from_proto is currently not supported in Eager mode. # Note that init_from_proto is currently not supported in Eager mode.
assert context.in_graph_mode() assert context.in_graph_mode()
self._in_graph_mode = True
assert isinstance(variable_def, variable_pb2.VariableDef) assert isinstance(variable_def, variable_pb2.VariableDef)
if not variable_def.is_resource: if not variable_def.is_resource:
raise ValueError("Trying to restore Variable as ResourceVariable.") raise ValueError("Trying to restore Variable as ResourceVariable.")
@ -434,7 +438,7 @@ class ResourceVariable(variables.Variable):
@property @property
def create(self): def create(self):
"""The op responsible for initializing this variable.""" """The op responsible for initializing this variable."""
if not context.in_graph_mode(): if not self._in_graph_mode:
raise RuntimeError("Calling create in EAGER mode not supported.") raise RuntimeError("Calling create in EAGER mode not supported.")
return self._initializer_op return self._initializer_op
@ -520,7 +524,7 @@ class ResourceVariable(variables.Variable):
# In graph mode, ensure we read the variable in the same device as the # In graph mode, ensure we read the variable in the same device as the
# handle. In eager mode, however, this sometimes tries to read a GPU # handle. In eager mode, however, this sometimes tries to read a GPU
# variable in the CPU because the handle is host memory. For now, then, we # variable in the CPU because the handle is host memory. For now, then, we
# need to skip the device block in eager. TODO(apassos) eager should have # need to skip the device block in eager. TODO(apassos): eager should have
# separate notions of device and memory, so handle.device can be GPU while # separate notions of device and memory, so handle.device can be GPU while
# handle.memory_space is always CPU. # handle.memory_space is always CPU.
if context.in_graph_mode(): if context.in_graph_mode():