Use self._in_graph_mode consistently in ResourceVariable
instead of sometimes getting it from the context. Also: fix formatting of a comment and use a more precise test to detect if initial_value is set. PiperOrigin-RevId: 168047258
This commit is contained in:
parent
f331f528b8
commit
ff6dd474a6
@ -42,14 +42,14 @@ from tensorflow.python.ops.gen_resource_variable_ops import *
|
|||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
|
||||||
def _eager_safe_variable_handle(shape, dtype, shared_name, name,
|
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode,
|
||||||
container=None):
|
container=None):
|
||||||
"""Creates a variable handle with information to do shape inference."""
|
"""Creates a variable handle with information to do shape inference."""
|
||||||
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
||||||
shared_name=shared_name,
|
shared_name=shared_name,
|
||||||
name=name,
|
name=name,
|
||||||
container=container)
|
container=container)
|
||||||
if context.in_graph_mode():
|
if graph_mode:
|
||||||
return handle
|
return handle
|
||||||
with context.graph_mode(), ops.Graph().as_default():
|
with context.graph_mode(), ops.Graph().as_default():
|
||||||
h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
|
||||||
@ -172,7 +172,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
shape and `validate_shape` is `True`.
|
shape and `validate_shape` is `True`.
|
||||||
"""
|
"""
|
||||||
if variable_def:
|
if variable_def:
|
||||||
if initial_value:
|
if initial_value is not None:
|
||||||
raise ValueError("variable_def and initial_value are mutually "
|
raise ValueError("variable_def and initial_value are mutually "
|
||||||
"exclusive.")
|
"exclusive.")
|
||||||
if not context.in_graph_mode():
|
if not context.in_graph_mode():
|
||||||
@ -277,7 +277,8 @@ class ResourceVariable(variables.Variable):
|
|||||||
shape=initial_value.get_shape(),
|
shape=initial_value.get_shape(),
|
||||||
dtype=initial_value.dtype.base_dtype,
|
dtype=initial_value.dtype.base_dtype,
|
||||||
shared_name=handle_name,
|
shared_name=handle_name,
|
||||||
name=name)
|
name=name,
|
||||||
|
graph_mode=self._in_graph_mode)
|
||||||
self._handle_device = (
|
self._handle_device = (
|
||||||
self._handle.device if self._in_graph_mode else
|
self._handle.device if self._in_graph_mode else
|
||||||
context.get_default_context().device_name)
|
context.get_default_context().device_name)
|
||||||
@ -291,6 +292,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
dtype=initial_value.dtype.base_dtype,
|
dtype=initial_value.dtype.base_dtype,
|
||||||
shared_name=handle_name,
|
shared_name=handle_name,
|
||||||
name=name,
|
name=name,
|
||||||
|
graph_mode=False,
|
||||||
container="")
|
container="")
|
||||||
self._handle_device = (
|
self._handle_device = (
|
||||||
self._handle.device if self._in_graph_mode else
|
self._handle.device if self._in_graph_mode else
|
||||||
@ -316,6 +318,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
dtype=initial_value.dtype.base_dtype,
|
dtype=initial_value.dtype.base_dtype,
|
||||||
shared_name=handle_name,
|
shared_name=handle_name,
|
||||||
name=name,
|
name=name,
|
||||||
|
graph_mode=self._in_graph_mode,
|
||||||
container="")
|
container="")
|
||||||
self._handle_device = (self._handle.device if self._in_graph_mode else
|
self._handle_device = (self._handle.device if self._in_graph_mode else
|
||||||
context.get_default_context().device_name)
|
context.get_default_context().device_name)
|
||||||
@ -372,6 +375,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
"""Initializes from `VariableDef` proto."""
|
"""Initializes from `VariableDef` proto."""
|
||||||
# Note that init_from_proto is currently not supported in Eager mode.
|
# Note that init_from_proto is currently not supported in Eager mode.
|
||||||
assert context.in_graph_mode()
|
assert context.in_graph_mode()
|
||||||
|
self._in_graph_mode = True
|
||||||
assert isinstance(variable_def, variable_pb2.VariableDef)
|
assert isinstance(variable_def, variable_pb2.VariableDef)
|
||||||
if not variable_def.is_resource:
|
if not variable_def.is_resource:
|
||||||
raise ValueError("Trying to restore Variable as ResourceVariable.")
|
raise ValueError("Trying to restore Variable as ResourceVariable.")
|
||||||
@ -434,7 +438,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
@property
|
@property
|
||||||
def create(self):
|
def create(self):
|
||||||
"""The op responsible for initializing this variable."""
|
"""The op responsible for initializing this variable."""
|
||||||
if not context.in_graph_mode():
|
if not self._in_graph_mode:
|
||||||
raise RuntimeError("Calling create in EAGER mode not supported.")
|
raise RuntimeError("Calling create in EAGER mode not supported.")
|
||||||
return self._initializer_op
|
return self._initializer_op
|
||||||
|
|
||||||
@ -520,7 +524,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
# In graph mode, ensure we read the variable in the same device as the
|
# In graph mode, ensure we read the variable in the same device as the
|
||||||
# handle. In eager mode, however, this sometimes tries to read a GPU
|
# handle. In eager mode, however, this sometimes tries to read a GPU
|
||||||
# variable in the CPU because the handle is host memory. For now, then, we
|
# variable in the CPU because the handle is host memory. For now, then, we
|
||||||
# need to skip the device block in eager. TODO(apassos) eager should have
|
# need to skip the device block in eager. TODO(apassos): eager should have
|
||||||
# separate notions of device and memory, so handle.device can be GPU while
|
# separate notions of device and memory, so handle.device can be GPU while
|
||||||
# handle.memory_space is always CPU.
|
# handle.memory_space is always CPU.
|
||||||
if context.in_graph_mode():
|
if context.in_graph_mode():
|
||||||
|
Loading…
Reference in New Issue
Block a user