Simplifies capturing code in graph_callable to use recent function improvements.

PiperOrigin-RevId: 172937003
This commit is contained in:
Alexandre Passos 2017-10-20 14:49:56 -07:00 committed by TensorFlower Gardener
parent d1e7382af7
commit 37fd951790
2 changed files with 14 additions and 58 deletions

View File

@ -45,28 +45,6 @@ def _default_initializer(name, shape, dtype):
return initializer[0]
class _VariableFromResource(resource_variable_ops.ResourceVariable):
"""Variable object from a preexisting resource.
Required because the ResourceVariable constructor creates the resource handle,
and here we want to use a preexisting one.
"""
def __init__(self, resource, dtype, name, shape):
self._handle = resource
self._graph_shape = tensor_shape.as_shape(shape)
self._handle_device = resource.device
self._handle_name = name
self._cached_value = None
self._initializer_op = None
self._caching_device = None
self._dtype = dtype
self._constraint = None
self._in_graph_mode = context.in_graph_mode()
if self._in_graph_mode:
self._graph_element = self.read_value()
class _CapturedVariable(object):
"""Variable captured by graph_callable.
@ -137,17 +115,11 @@ class _VariableCapturingScope(object):
trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name
partitioner=None, validate_shape=True,
use_resource=None):
del getter, regularizer, partitioner, validate_shape, use_resource
del collections, initializer, trainable, reuse, caching_device
del getter, regularizer, partitioner, validate_shape, use_resource, dtype
del collections, initializer, trainable, reuse, caching_device, shape,
assert name in self.variables
v = self.variables[name]
v.placeholder = array_ops.placeholder(dtype=dtypes.resource, shape=shape)
# TODO(apassos) remove the need for this by correctly dealing with shape
# inference.
v.placeholder._handle_data = v.variable.handle._handle_data # pylint: disable=protected-access
return _VariableFromResource(
v.placeholder, dtype=dtypes.as_dtype(dtype), name=name,
shape=v.shape)
return v.variable
scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
@ -181,14 +153,12 @@ class _VariableCapturingScope(object):
v = _CapturedVariable(name, initializer, shape, dtype, trainable)
self.variables[name] = v
graph_mode_resource = resource_variable_ops.var_handle_op(
shared_name=name, shape=shape, dtype=dtype)
graph_mode_resource = v.variable.handle
if initializer is None:
initializer = _default_initializer(name, shape, dtype)
resource_variable_ops.assign_variable_op(
graph_mode_resource, initializer(shape, dtype))
return _VariableFromResource(
graph_mode_resource, dtype, name, shape=v.shape)
return v.variable
scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
@ -220,13 +190,6 @@ class _FunctionObject(function._GraphModeFunction): # pylint: disable=protected
def variables(self):
return [x.variable for x in self._variables]
def __call__(self, *args, **kwds):
kwds.pop("want_gradients", False)
if kwds:
raise ValueError("graph_callable functions do not take keyword args")
values = [x.variable.handle for x in self._variables]
return super(_FunctionObject, self).__call__(*(values + list(args)))
class _InitializingFunctionObject(object):
"""Responsible for deciding which version of func-to-object to call.
@ -318,7 +281,8 @@ def _graph_callable_internal(func, shape_and_dtypes):
# This graph will store both the initialization and the call version of the
# wrapped function. It will later be used by the backprop code to build the
# backprop graph, if necessary.
tmp_graph = tf_ops.Graph()
captures = {}
tmp_graph = function.CapturingGraph(captures)
# Inherit the container from the original graph to create resources at user
# expected containers. Also inherits the container prefix, since this is
# used for error checking when isolating Eager execution (the container
@ -342,7 +306,6 @@ def _graph_callable_internal(func, shape_and_dtypes):
# variables. As a side-effect this will populate the variable capturing
# scope's view of which variables exist.
variable_captures = _VariableCapturingScope()
captures = {}
with variable_captures.initializing_scope(), function.capture_tensors(
captures):
func_outputs = func(*func_inputs)
@ -366,7 +329,6 @@ def _graph_callable_internal(func, shape_and_dtypes):
sorted_variables = sorted(variable_captures.variables.values(),
key=lambda x: x.name)
variable_placeholders = [x.placeholder for x in sorted_variables]
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
@ -377,7 +339,6 @@ def _graph_callable_internal(func, shape_and_dtypes):
flat_inputs = [x for x in nest.flatten(func_inputs)
if isinstance(x, tf_ops.Tensor)]
placeholder_inputs = flat_inputs+ list(extra_placeholders)
all_inputs = variable_placeholders + placeholder_inputs
func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
initializer_function_def = function.make_function_def(
@ -407,13 +368,13 @@ def _graph_callable_internal(func, shape_and_dtypes):
captured_function_def = function.make_function_def(
tmp_graph,
capturing_operations,
all_inputs,
placeholder_inputs,
capture_func_def_outputs)
function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access
captured_function_def)
captured_function = _FunctionObject(
sorted_variables,
all_inputs,
placeholder_inputs,
extra_inputs,
captured_function_def,
tmp_graph,

View File

@ -26,7 +26,6 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_resource_variable_ops
@ -315,7 +314,7 @@ class ResourceVariable(variables.Variable):
self._handle_device = (
self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
self._graph_shape = initial_value.get_shape()
self._shape = initial_value.get_shape()
else:
initial_value = initial_value()
with ops.name_scope("Initializer"):
@ -330,7 +329,7 @@ class ResourceVariable(variables.Variable):
self._handle_device = (
self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
self._graph_shape = initial_value.get_shape()
self._shape = initial_value.get_shape()
# pylint: enable=protected-access
# Or get the initial value from a Tensor or Python object.
@ -355,7 +354,7 @@ class ResourceVariable(variables.Variable):
graph_mode=self._in_graph_mode)
self._handle_device = (self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
self._graph_shape = initial_value.get_shape()
self._shape = initial_value.get_shape()
self._initial_value = initial_value if self._in_graph_mode else None
self._handle_name = handle_name + ":0"
@ -422,7 +421,7 @@ class ResourceVariable(variables.Variable):
self._handle = g.as_graph_element(
ops.prepend_name_scope(
variable_def.variable_name, import_scope=import_scope))
self._graph_shape = tensor_shape.TensorShape(
self._shape = tensor_shape.TensorShape(
self._handle.op.get_attr("shape"))
self._handle_device = self._handle.device
self._handle_name = self._handle.name
@ -502,11 +501,7 @@ class ResourceVariable(variables.Variable):
@property
def shape(self):
"""The shape of this variable."""
if self._in_graph_mode:
return self._graph_shape
return tensor_shape.TensorShape(
tensor_util.constant_value(
gen_resource_variable_ops.variable_shape(self._handle)))
return self._shape
@property
def create(self):