From 37fd951790d7ad27c679c925c28b01ca73875738 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Fri, 20 Oct 2017 14:49:56 -0700 Subject: [PATCH] Simplifies capturing code in graph_callable to use recent function improvements. PiperOrigin-RevId: 172937003 --- tensorflow/python/eager/graph_callable.py | 57 +++---------------- .../python/ops/resource_variable_ops.py | 15 ++--- 2 files changed, 14 insertions(+), 58 deletions(-) diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 0ec83636a0f..7f7a8c4a88c 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -45,28 +45,6 @@ def _default_initializer(name, shape, dtype): return initializer[0] -class _VariableFromResource(resource_variable_ops.ResourceVariable): - """Variable object from a preexisting resource. - - Required because the ResourceVariable constructor creates the resource handle, - and here we want to use a preexisting one. - """ - - def __init__(self, resource, dtype, name, shape): - self._handle = resource - self._graph_shape = tensor_shape.as_shape(shape) - self._handle_device = resource.device - self._handle_name = name - self._cached_value = None - self._initializer_op = None - self._caching_device = None - self._dtype = dtype - self._constraint = None - self._in_graph_mode = context.in_graph_mode() - if self._in_graph_mode: - self._graph_element = self.read_value() - - class _CapturedVariable(object): """Variable captured by graph_callable. @@ -137,17 +115,11 @@ class _VariableCapturingScope(object): trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name partitioner=None, validate_shape=True, use_resource=None): - del getter, regularizer, partitioner, validate_shape, use_resource - del collections, initializer, trainable, reuse, caching_device + del getter, regularizer, partitioner, validate_shape, use_resource, dtype + del collections, initializer, trainable, reuse, caching_device, shape, assert name in self.variables v = self.variables[name] - v.placeholder = array_ops.placeholder(dtype=dtypes.resource, shape=shape) - # TODO(apassos) remove the need for this by correctly dealing with shape - # inference. - v.placeholder._handle_data = v.variable.handle._handle_data # pylint: disable=protected-access - return _VariableFromResource( - v.placeholder, dtype=dtypes.as_dtype(dtype), name=name, - shape=v.shape) + return v.variable scope = variable_scope.get_variable_scope() with variable_scope.variable_scope(scope, custom_getter=_custom_getter): @@ -181,14 +153,12 @@ class _VariableCapturingScope(object): v = _CapturedVariable(name, initializer, shape, dtype, trainable) self.variables[name] = v - graph_mode_resource = resource_variable_ops.var_handle_op( - shared_name=name, shape=shape, dtype=dtype) + graph_mode_resource = v.variable.handle if initializer is None: initializer = _default_initializer(name, shape, dtype) resource_variable_ops.assign_variable_op( graph_mode_resource, initializer(shape, dtype)) - return _VariableFromResource( - graph_mode_resource, dtype, name, shape=v.shape) + return v.variable scope = variable_scope.get_variable_scope() with variable_scope.variable_scope(scope, custom_getter=_custom_getter): @@ -220,13 +190,6 @@ class _FunctionObject(function._GraphModeFunction): # pylint: disable=protected def variables(self): return [x.variable for x in self._variables] - def __call__(self, *args, **kwds): - kwds.pop("want_gradients", False) - if kwds: - raise ValueError("graph_callable functions do not take keyword args") - values = [x.variable.handle for x in self._variables] - return super(_FunctionObject, self).__call__(*(values + list(args))) - class _InitializingFunctionObject(object): """Responsible for deciding which version of func-to-object to call. @@ -318,7 +281,8 @@ def _graph_callable_internal(func, shape_and_dtypes): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. - tmp_graph = tf_ops.Graph() + captures = {} + tmp_graph = function.CapturingGraph(captures) # Inherit the container from the original graph to create resources at user # expected containers. Also inherits the container prefix, since this is # used for error checking when isolating Eager execution (the container @@ -342,7 +306,6 @@ def _graph_callable_internal(func, shape_and_dtypes): # variables. As a side-effect this will populate the variable capturing # scope's view of which variables exist. variable_captures = _VariableCapturingScope() - captures = {} with variable_captures.initializing_scope(), function.capture_tensors( captures): func_outputs = func(*func_inputs) @@ -366,7 +329,6 @@ def _graph_callable_internal(func, shape_and_dtypes): sorted_variables = sorted(variable_captures.variables.values(), key=lambda x: x.name) - variable_placeholders = [x.placeholder for x in sorted_variables] ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) @@ -377,7 +339,6 @@ def _graph_callable_internal(func, shape_and_dtypes): flat_inputs = [x for x in nest.flatten(func_inputs) if isinstance(x, tf_ops.Tensor)] placeholder_inputs = flat_inputs+ list(extra_placeholders) - all_inputs = variable_placeholders + placeholder_inputs func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)] initializer_function_def = function.make_function_def( @@ -407,13 +368,13 @@ def _graph_callable_internal(func, shape_and_dtypes): captured_function_def = function.make_function_def( tmp_graph, capturing_operations, - all_inputs, + placeholder_inputs, capture_func_def_outputs) function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access captured_function_def) captured_function = _FunctionObject( sorted_variables, - all_inputs, + placeholder_inputs, extra_inputs, captured_function_def, tmp_graph, diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index c94ddb06275..71e1fb0297e 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -26,7 +26,6 @@ from tensorflow.python.eager import tape from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -315,7 +314,7 @@ class ResourceVariable(variables.Variable): self._handle_device = ( self._handle.device if self._in_graph_mode else context.get_default_context().device_name) - self._graph_shape = initial_value.get_shape() + self._shape = initial_value.get_shape() else: initial_value = initial_value() with ops.name_scope("Initializer"): @@ -330,7 +329,7 @@ class ResourceVariable(variables.Variable): self._handle_device = ( self._handle.device if self._in_graph_mode else context.get_default_context().device_name) - self._graph_shape = initial_value.get_shape() + self._shape = initial_value.get_shape() # pylint: enable=protected-access # Or get the initial value from a Tensor or Python object. @@ -355,7 +354,7 @@ class ResourceVariable(variables.Variable): graph_mode=self._in_graph_mode) self._handle_device = (self._handle.device if self._in_graph_mode else context.get_default_context().device_name) - self._graph_shape = initial_value.get_shape() + self._shape = initial_value.get_shape() self._initial_value = initial_value if self._in_graph_mode else None self._handle_name = handle_name + ":0" @@ -422,7 +421,7 @@ class ResourceVariable(variables.Variable): self._handle = g.as_graph_element( ops.prepend_name_scope( variable_def.variable_name, import_scope=import_scope)) - self._graph_shape = tensor_shape.TensorShape( + self._shape = tensor_shape.TensorShape( self._handle.op.get_attr("shape")) self._handle_device = self._handle.device self._handle_name = self._handle.name @@ -502,11 +501,7 @@ class ResourceVariable(variables.Variable): @property def shape(self): """The shape of this variable.""" - if self._in_graph_mode: - return self._graph_shape - return tensor_shape.TensorShape( - tensor_util.constant_value( - gen_resource_variable_ops.variable_shape(self._handle))) + return self._shape @property def create(self):