Internal clean-up: move signature handling code that's specific to WrapFunction into WrapFunction subclass.
PiperOrigin-RevId: 306041412 Change-Id: I672c54772689e7bd67dc25d8059d6cf2de8fd3e6
This commit is contained in:
parent
20a26f65d0
commit
859c4c21c0
tensorflow/python/eager
@ -1495,8 +1495,7 @@ class ConcreteFunction(object):
|
||||
is differentiable under `tf.GradientTape` objects.
|
||||
"""
|
||||
|
||||
def __init__(self, func_graph, attrs=None, signature=None,
|
||||
shared_func_graph=True):
|
||||
def __init__(self, func_graph, attrs=None, shared_func_graph=True):
|
||||
"""Initialize a `ConcreteFunction`.
|
||||
|
||||
Args:
|
||||
@ -1504,8 +1503,6 @@ class ConcreteFunction(object):
|
||||
attrs: (optional) dict mapping names of attributes to their AttrValue
|
||||
values. Attributes in `attrs` will be included in this function's
|
||||
definition.
|
||||
signature: a nested sequence of `TensorSpec` objects specifying the input
|
||||
signature of this function.
|
||||
shared_func_graph: If False, the ConcreteFunction takes ownership of
|
||||
`func_graph` and will break reference cycles when it is deleted. This
|
||||
makes the FuncGraph inoperable.
|
||||
@ -1550,7 +1547,6 @@ class ConcreteFunction(object):
|
||||
self._output_shapes = tuple(
|
||||
output.shape for output in self._func_graph.outputs)
|
||||
self._attrs = _parse_func_attrs(attrs or {})
|
||||
self._signature = signature
|
||||
|
||||
if shared_func_graph:
|
||||
self._garbage_collector = None
|
||||
@ -1607,12 +1603,6 @@ class ConcreteFunction(object):
|
||||
def _call_impl(self, args, kwargs, cancellation_manager=None):
|
||||
"""See `__call__` for details."""
|
||||
if self._arg_keywords is None or self._num_positional_args is None:
|
||||
if self._signature is not None:
|
||||
if kwargs:
|
||||
raise NotImplementedError(
|
||||
"Keyword arguments not supported when calling a "
|
||||
"wrap_function-decorated function.")
|
||||
return self._call_flat(args, self.captured_inputs)
|
||||
raise AssertionError(
|
||||
"Tried to call a concrete function obtained from an internal API "
|
||||
"through the public interface. Use get_concrete_function instead.")
|
||||
@ -1728,10 +1718,6 @@ class ConcreteFunction(object):
|
||||
arg_name, arg,
|
||||
self._func_graph.inputs[i].shape,
|
||||
arg.shape))
|
||||
elif (self._signature is not None and
|
||||
isinstance(self._signature[i], tensor_spec.DenseSpec)):
|
||||
tensor_inputs.append(
|
||||
ops.convert_to_tensor(arg, self._signature[i].dtype))
|
||||
else:
|
||||
raise ValueError("All inputs to `ConcreteFunction`s must be Tensors; "
|
||||
"on invocation of %s, the %d-th input (%s) was not a "
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -226,8 +227,24 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
# properly reflects the new captured inputs.
|
||||
for f in fn_graph.as_graph_def().library.function:
|
||||
context.context().add_function_def(f)
|
||||
super(WrappedFunction, self).__init__(
|
||||
fn_graph, attrs=attrs, signature=signature)
|
||||
self._signature = signature
|
||||
super(WrappedFunction, self).__init__(fn_graph, attrs=attrs)
|
||||
|
||||
def _call_impl(self, args, kwargs, cancellation_manager=None):
|
||||
if self._arg_keywords is None:
|
||||
if kwargs:
|
||||
raise NotImplementedError(
|
||||
"Keyword arguments not supported when calling a "
|
||||
"wrap_function-decorated function.")
|
||||
if self._signature is not None:
|
||||
args = list(args)
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(self._signature[i], tensor_spec.DenseSpec):
|
||||
args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype)
|
||||
return self._call_flat(args, self.captured_inputs)
|
||||
else:
|
||||
return super(WrappedFunction, self)._call_impl(
|
||||
args, kwargs, cancellation_manager)
|
||||
|
||||
def prune(self, feeds, fetches, name=None, input_signature=None):
|
||||
"""Extract a subgraph of this function's underlying graph.
|
||||
|
Loading…
Reference in New Issue
Block a user