Internal clean-up: move signature handling code that's specific to WrapFunction into WrapFunction subclass.

PiperOrigin-RevId: 306041412
Change-Id: I672c54772689e7bd67dc25d8059d6cf2de8fd3e6
This commit is contained in:
Edward Loper 2020-04-11 09:39:39 -07:00 committed by TensorFlower Gardener
parent 20a26f65d0
commit 859c4c21c0
2 changed files with 20 additions and 17 deletions
tensorflow/python/eager

View File

@ -1495,8 +1495,7 @@ class ConcreteFunction(object):
is differentiable under `tf.GradientTape` objects.
"""
def __init__(self, func_graph, attrs=None, signature=None,
shared_func_graph=True):
def __init__(self, func_graph, attrs=None, shared_func_graph=True):
"""Initialize a `ConcreteFunction`.
Args:
@ -1504,8 +1503,6 @@ class ConcreteFunction(object):
attrs: (optional) dict mapping names of attributes to their AttrValue
values. Attributes in `attrs` will be included in this function's
definition.
signature: a nested sequence of `TensorSpec` objects specifying the input
signature of this function.
shared_func_graph: If False, the ConcreteFunction takes ownership of
`func_graph` and will break reference cycles when it is deleted. This
makes the FuncGraph inoperable.
@ -1550,7 +1547,6 @@ class ConcreteFunction(object):
self._output_shapes = tuple(
output.shape for output in self._func_graph.outputs)
self._attrs = _parse_func_attrs(attrs or {})
self._signature = signature
if shared_func_graph:
self._garbage_collector = None
@ -1607,12 +1603,6 @@ class ConcreteFunction(object):
def _call_impl(self, args, kwargs, cancellation_manager=None):
"""See `__call__` for details."""
if self._arg_keywords is None or self._num_positional_args is None:
if self._signature is not None:
if kwargs:
raise NotImplementedError(
"Keyword arguments not supported when calling a "
"wrap_function-decorated function.")
return self._call_flat(args, self.captured_inputs)
raise AssertionError(
"Tried to call a concrete function obtained from an internal API "
"through the public interface. Use get_concrete_function instead.")
@ -1728,10 +1718,6 @@ class ConcreteFunction(object):
arg_name, arg,
self._func_graph.inputs[i].shape,
arg.shape))
elif (self._signature is not None and
isinstance(self._signature[i], tensor_spec.DenseSpec)):
tensor_inputs.append(
ops.convert_to_tensor(arg, self._signature[i].dtype))
else:
raise ValueError("All inputs to `ConcreteFunction`s must be Tensors; "
"on invocation of %s, the %d-th input (%s) was not a "

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@ -226,8 +227,24 @@ class WrappedFunction(function.ConcreteFunction):
# properly reflects the new captured inputs.
for f in fn_graph.as_graph_def().library.function:
context.context().add_function_def(f)
super(WrappedFunction, self).__init__(
fn_graph, attrs=attrs, signature=signature)
self._signature = signature
super(WrappedFunction, self).__init__(fn_graph, attrs=attrs)
def _call_impl(self, args, kwargs, cancellation_manager=None):
if self._arg_keywords is None:
if kwargs:
raise NotImplementedError(
"Keyword arguments not supported when calling a "
"wrap_function-decorated function.")
if self._signature is not None:
args = list(args)
for i, arg in enumerate(args):
if isinstance(self._signature[i], tensor_spec.DenseSpec):
args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype)
return self._call_flat(args, self.captured_inputs)
else:
return super(WrappedFunction, self)._call_impl(
args, kwargs, cancellation_manager)
def prune(self, feeds, fetches, name=None, input_signature=None):
"""Extract a subgraph of this function's underlying graph.