Simplifies capturing code in graph_callable to use recent function improvements.
PiperOrigin-RevId: 172937003
This commit is contained in:
parent
d1e7382af7
commit
37fd951790
tensorflow/python
@ -45,28 +45,6 @@ def _default_initializer(name, shape, dtype):
|
||||
return initializer[0]
|
||||
|
||||
|
||||
class _VariableFromResource(resource_variable_ops.ResourceVariable):
|
||||
"""Variable object from a preexisting resource.
|
||||
|
||||
Required because the ResourceVariable constructor creates the resource handle,
|
||||
and here we want to use a preexisting one.
|
||||
"""
|
||||
|
||||
def __init__(self, resource, dtype, name, shape):
|
||||
self._handle = resource
|
||||
self._graph_shape = tensor_shape.as_shape(shape)
|
||||
self._handle_device = resource.device
|
||||
self._handle_name = name
|
||||
self._cached_value = None
|
||||
self._initializer_op = None
|
||||
self._caching_device = None
|
||||
self._dtype = dtype
|
||||
self._constraint = None
|
||||
self._in_graph_mode = context.in_graph_mode()
|
||||
if self._in_graph_mode:
|
||||
self._graph_element = self.read_value()
|
||||
|
||||
|
||||
class _CapturedVariable(object):
|
||||
"""Variable captured by graph_callable.
|
||||
|
||||
@ -137,17 +115,11 @@ class _VariableCapturingScope(object):
|
||||
trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name
|
||||
partitioner=None, validate_shape=True,
|
||||
use_resource=None):
|
||||
del getter, regularizer, partitioner, validate_shape, use_resource
|
||||
del collections, initializer, trainable, reuse, caching_device
|
||||
del getter, regularizer, partitioner, validate_shape, use_resource, dtype
|
||||
del collections, initializer, trainable, reuse, caching_device, shape,
|
||||
assert name in self.variables
|
||||
v = self.variables[name]
|
||||
v.placeholder = array_ops.placeholder(dtype=dtypes.resource, shape=shape)
|
||||
# TODO(apassos) remove the need for this by correctly dealing with shape
|
||||
# inference.
|
||||
v.placeholder._handle_data = v.variable.handle._handle_data # pylint: disable=protected-access
|
||||
return _VariableFromResource(
|
||||
v.placeholder, dtype=dtypes.as_dtype(dtype), name=name,
|
||||
shape=v.shape)
|
||||
return v.variable
|
||||
|
||||
scope = variable_scope.get_variable_scope()
|
||||
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
|
||||
@ -181,14 +153,12 @@ class _VariableCapturingScope(object):
|
||||
v = _CapturedVariable(name, initializer, shape, dtype, trainable)
|
||||
self.variables[name] = v
|
||||
|
||||
graph_mode_resource = resource_variable_ops.var_handle_op(
|
||||
shared_name=name, shape=shape, dtype=dtype)
|
||||
graph_mode_resource = v.variable.handle
|
||||
if initializer is None:
|
||||
initializer = _default_initializer(name, shape, dtype)
|
||||
resource_variable_ops.assign_variable_op(
|
||||
graph_mode_resource, initializer(shape, dtype))
|
||||
return _VariableFromResource(
|
||||
graph_mode_resource, dtype, name, shape=v.shape)
|
||||
return v.variable
|
||||
|
||||
scope = variable_scope.get_variable_scope()
|
||||
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
|
||||
@ -220,13 +190,6 @@ class _FunctionObject(function._GraphModeFunction): # pylint: disable=protected
|
||||
def variables(self):
|
||||
return [x.variable for x in self._variables]
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
kwds.pop("want_gradients", False)
|
||||
if kwds:
|
||||
raise ValueError("graph_callable functions do not take keyword args")
|
||||
values = [x.variable.handle for x in self._variables]
|
||||
return super(_FunctionObject, self).__call__(*(values + list(args)))
|
||||
|
||||
|
||||
class _InitializingFunctionObject(object):
|
||||
"""Responsible for deciding which version of func-to-object to call.
|
||||
@ -318,7 +281,8 @@ def _graph_callable_internal(func, shape_and_dtypes):
|
||||
# This graph will store both the initialization and the call version of the
|
||||
# wrapped function. It will later be used by the backprop code to build the
|
||||
# backprop graph, if necessary.
|
||||
tmp_graph = tf_ops.Graph()
|
||||
captures = {}
|
||||
tmp_graph = function.CapturingGraph(captures)
|
||||
# Inherit the container from the original graph to create resources at user
|
||||
# expected containers. Also inherits the container prefix, since this is
|
||||
# used for error checking when isolating Eager execution (the container
|
||||
@ -342,7 +306,6 @@ def _graph_callable_internal(func, shape_and_dtypes):
|
||||
# variables. As a side-effect this will populate the variable capturing
|
||||
# scope's view of which variables exist.
|
||||
variable_captures = _VariableCapturingScope()
|
||||
captures = {}
|
||||
with variable_captures.initializing_scope(), function.capture_tensors(
|
||||
captures):
|
||||
func_outputs = func(*func_inputs)
|
||||
@ -366,7 +329,6 @@ def _graph_callable_internal(func, shape_and_dtypes):
|
||||
|
||||
sorted_variables = sorted(variable_captures.variables.values(),
|
||||
key=lambda x: x.name)
|
||||
variable_placeholders = [x.placeholder for x in sorted_variables]
|
||||
ids = list(sorted(captures.keys()))
|
||||
if ids:
|
||||
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
|
||||
@ -377,7 +339,6 @@ def _graph_callable_internal(func, shape_and_dtypes):
|
||||
flat_inputs = [x for x in nest.flatten(func_inputs)
|
||||
if isinstance(x, tf_ops.Tensor)]
|
||||
placeholder_inputs = flat_inputs+ list(extra_placeholders)
|
||||
all_inputs = variable_placeholders + placeholder_inputs
|
||||
|
||||
func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
|
||||
initializer_function_def = function.make_function_def(
|
||||
@ -407,13 +368,13 @@ def _graph_callable_internal(func, shape_and_dtypes):
|
||||
captured_function_def = function.make_function_def(
|
||||
tmp_graph,
|
||||
capturing_operations,
|
||||
all_inputs,
|
||||
placeholder_inputs,
|
||||
capture_func_def_outputs)
|
||||
function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access
|
||||
captured_function_def)
|
||||
captured_function = _FunctionObject(
|
||||
sorted_variables,
|
||||
all_inputs,
|
||||
placeholder_inputs,
|
||||
extra_inputs,
|
||||
captured_function_def,
|
||||
tmp_graph,
|
||||
|
@ -26,7 +26,6 @@ from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
@ -315,7 +314,7 @@ class ResourceVariable(variables.Variable):
|
||||
self._handle_device = (
|
||||
self._handle.device if self._in_graph_mode else
|
||||
context.get_default_context().device_name)
|
||||
self._graph_shape = initial_value.get_shape()
|
||||
self._shape = initial_value.get_shape()
|
||||
else:
|
||||
initial_value = initial_value()
|
||||
with ops.name_scope("Initializer"):
|
||||
@ -330,7 +329,7 @@ class ResourceVariable(variables.Variable):
|
||||
self._handle_device = (
|
||||
self._handle.device if self._in_graph_mode else
|
||||
context.get_default_context().device_name)
|
||||
self._graph_shape = initial_value.get_shape()
|
||||
self._shape = initial_value.get_shape()
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Or get the initial value from a Tensor or Python object.
|
||||
@ -355,7 +354,7 @@ class ResourceVariable(variables.Variable):
|
||||
graph_mode=self._in_graph_mode)
|
||||
self._handle_device = (self._handle.device if self._in_graph_mode else
|
||||
context.get_default_context().device_name)
|
||||
self._graph_shape = initial_value.get_shape()
|
||||
self._shape = initial_value.get_shape()
|
||||
|
||||
self._initial_value = initial_value if self._in_graph_mode else None
|
||||
self._handle_name = handle_name + ":0"
|
||||
@ -422,7 +421,7 @@ class ResourceVariable(variables.Variable):
|
||||
self._handle = g.as_graph_element(
|
||||
ops.prepend_name_scope(
|
||||
variable_def.variable_name, import_scope=import_scope))
|
||||
self._graph_shape = tensor_shape.TensorShape(
|
||||
self._shape = tensor_shape.TensorShape(
|
||||
self._handle.op.get_attr("shape"))
|
||||
self._handle_device = self._handle.device
|
||||
self._handle_name = self._handle.name
|
||||
@ -502,11 +501,7 @@ class ResourceVariable(variables.Variable):
|
||||
@property
|
||||
def shape(self):
|
||||
"""The shape of this variable."""
|
||||
if self._in_graph_mode:
|
||||
return self._graph_shape
|
||||
return tensor_shape.TensorShape(
|
||||
tensor_util.constant_value(
|
||||
gen_resource_variable_ops.variable_shape(self._handle)))
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def create(self):
|
||||
|
Loading…
Reference in New Issue
Block a user