diff --git a/tensorflow/core/protobuf/struct.proto b/tensorflow/core/protobuf/struct.proto index ecf48776c56..8983db02eeb 100644 --- a/tensorflow/core/protobuf/struct.proto +++ b/tensorflow/core/protobuf/struct.proto @@ -1,10 +1,10 @@ syntax = "proto3"; +package tensorflow; + import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/types.proto"; -package tensorflow; - // `StructuredValue` represents a dynamically typed value representing various // data structures that are inspired by Python data structures typically used in // TensorFlow functions as inputs and outputs. @@ -120,6 +120,7 @@ message TypeSpecProto { DATA_ITERATOR_SPEC = 6; // IteratorSpec from data/ops/iterator_ops.py OPTIONAL_SPEC = 7; // tf.OptionalSpec PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py + VARIABLE_SPEC = 9; // tf.VariableSpec } TypeSpecClass type_spec_class = 1; diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 86c7528c34a..b1cd03f35b9 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4672,6 +4672,18 @@ cuda_py_test( tags = ["no_windows_gpu"], ) +tf_py_test( + name = "variable_spec_test", + size = "small", + srcs = ["ops/variable_spec_test.py"], + additional_deps = [ + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":platform_test", + "//third_party/py/numpy", + ], +) + py_library( name = "training_lib", srcs = glob( diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index f804747d64e..c87260e1c23 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -143,9 +143,11 @@ def _flat_shape_list(*params): Returns: A list of entries containing either `None` or `TensorShape`. """ - return [tensor_shape.TensorShape(x.shape) - if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else None - for x in nest.flatten(params, expand_composites=True)] + return [ + tensor_shape.TensorShape(x.shape) + if isinstance(x, (ops.Tensor, tensor_spec.DenseSpec)) else None + for x in nest.flatten(params, expand_composites=True) + ] def _shape_less_specific_than(relaxed, to_check): @@ -1651,7 +1653,7 @@ class ConcreteFunction(object): self._func_graph.inputs[i].shape, arg.shape)) elif (self._signature is not None and - isinstance(self._signature[i], tensor_spec.TensorSpec)): + isinstance(self._signature[i], tensor_spec.DenseSpec)): tensor_inputs.append( ops.convert_to_tensor(arg, self._signature[i].dtype)) else: @@ -2208,7 +2210,8 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): need_packing = False for index, (value, spec) in enumerate(zip(flatten_inputs, flat_input_signature)): - if not _pywrap_utils.IsTensor(value): + if (isinstance(spec, tensor_spec.TensorSpec) and + not _pywrap_utils.IsTensor(value)): try: flatten_inputs[index] = ops.convert_to_tensor( value, dtype_hint=spec.dtype) @@ -2392,11 +2395,12 @@ class Function(object): raise ValueError("Structure of Python function inputs does not match " "input_signature.") flat_inputs = nest.flatten(args, expand_composites=True) - if any(not isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)) + if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec, + resource_variable_ops.BaseResourceVariable)) for arg in flat_inputs): raise ValueError("When input_signature is provided, all inputs to " - "the Python function must be Tensors or " - "tf.TensorSpec objects.") + "the Python function must be Tensors, Variables, " + "tf.TensorSpec or tf.VariableSpec objects.") if any(not spec.is_compatible_with(other) for spec, other in zip(self.flat_input_signature, flat_inputs)): raise ValueError("Python inputs incompatible with input_signature: " @@ -2701,7 +2705,7 @@ def register(func, *args, **kwargs): def validate_signature(signature): - if any(not isinstance(arg, tensor_spec.TensorSpec) + if any(not isinstance(arg, tensor_spec.DenseSpec) for arg in nest.flatten(signature, expand_composites=True)): raise TypeError("Invalid input_signature {}; input_signature must be " "a possibly nested sequence of TensorSpec objects." diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 93160e8cdf9..ca9fb7b68da 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -2055,6 +2055,27 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp(ValueError, 'does not match'): defined(rt5) + def testInputSignatureWithVariableArgs(self): + + def f(v): + v.assign_add(1) + + signature = [ + resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) + ] + defined = function.defun(f, input_signature=signature) + + v1 = variables.Variable(0) + v2 = variables.Variable(0) + + defined(v1) + self.assertEqual(v1.numpy(), 1) + self.assertEqual(v2.numpy(), 0) + + defined(v=v2) + self.assertEqual(v1.numpy(), 1) + self.assertEqual(v2.numpy(), 1) + def testTensorKeywordArguments(self): def foo(a, b): diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 2c27e96c4ad..f288867f657 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -99,6 +99,9 @@ def convert_structure_to_signature(structure, arg_names=None): if isinstance(arg, composite_tensor.CompositeTensor): # TODO(b/133606651) Do we need to inject arg_name? return arg._type_spec # pylint: disable=protected-access + if isinstance(arg, resource_variable_ops.BaseResourceVariable): + name = "/".join([str(p) for p in path]) + return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name) if isinstance(arg, ( int, float, @@ -292,7 +295,7 @@ class FuncGraph(ops.Graph): if key not in self._deferred_captures: def convert_to_placeholder(s): - if not isinstance(s, tensor_spec.TensorSpec): + if not isinstance(s, tensor_spec.DenseSpec): raise TypeError( "Expected a nest of `TypeSpec` objects, found %s of type %s." % (s, type(s))) @@ -1177,7 +1180,7 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None): flattened = nest.flatten(arg_value, expand_composites=True) tensor_specs = [ - arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec) + arg for arg in flattened if isinstance(arg, tensor_spec.DenseSpec) ] specified_names = [arg.name for arg in tensor_specs if arg.name] if specified_names and len(specified_names) < len(tensor_specs): @@ -1209,7 +1212,20 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None): "_user_specified_name", attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) function_inputs.append(placeholder) - elif isinstance(arg, resource_variable_ops.BaseResourceVariable): + elif isinstance(arg, (resource_variable_ops.BaseResourceVariable, + resource_variable_ops.VariableSpec)): + if isinstance(arg, resource_variable_ops.VariableSpec): + name = arg.name or name + with func_graph.outer_graph.as_default(): + placeholder = graph_placeholder(dtypes.resource, arg.shape, + name=name) + + arg = resource_variable_ops.BaseResourceVariable( + name=name, + shape=arg.shape, + dtype=arg.dtype, + handle=placeholder, + handle_name=name) # Capture arg variables to create placeholders for them. These will be # removed as captures after the function is traced (since otherwise we'd # just add it back with a new placeholder when the variable was diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py index 264cc2052d1..3b12af46555 100644 --- a/tensorflow/python/framework/tensor_spec.py +++ b/tensorflow/python/framework/tensor_spec.py @@ -29,16 +29,13 @@ from tensorflow.python.framework import type_spec from tensorflow.python.util.tf_export import tf_export -@tf_export("TensorSpec") -class TensorSpec(type_spec.BatchableTypeSpec): - """Describes a tf.Tensor. - - Metadata for describing the `tf.Tensor` objects accepted or returned - by some TensorFlow APIs. - """ +class DenseSpec(type_spec.TypeSpec): + """Describes a dense object with shape, dtype, and name.""" __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"] + _component_specs = property(lambda self: self) + def __init__(self, shape, dtype=dtypes.float32, name=None): """Creates a TensorSpec. @@ -63,15 +60,6 @@ class TensorSpec(type_spec.BatchableTypeSpec): def from_spec(cls, spec, name=None): return cls(spec.shape, spec.dtype, name or spec.name) - @classmethod - def from_tensor(cls, tensor, name=None): - if isinstance(tensor, ops.EagerTensor): - return TensorSpec(tensor.shape, tensor.dtype, name) - elif isinstance(tensor, ops.Tensor): - return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) - else: - raise ValueError("`tensor` should be a tf.Tensor") - @property def shape(self): """Returns the `TensorShape` that represents the shape of the tensor.""" @@ -87,25 +75,14 @@ class TensorSpec(type_spec.BatchableTypeSpec): """Returns the (optionally provided) name of the described tensor.""" return self._name - def is_compatible_with(self, spec_or_tensor): - """Returns True if spec_or_tensor is compatible with this TensorSpec. - - Two tensors are considered compatible if they have the same dtype - and their shapes are compatible (see `tf.TensorShape.is_compatible_with`). - - Args: - spec_or_tensor: A tf.TensorSpec or a tf.Tensor - - Returns: - True if spec_or_tensor is compatible with self. - """ - return (isinstance(spec_or_tensor, (TensorSpec, ops.Tensor)) and - self._dtype.is_compatible_with(spec_or_tensor.dtype) and - self._shape.is_compatible_with(spec_or_tensor.shape)) + def is_compatible_with(self, spec_or_value): + return (isinstance(spec_or_value, (type(self), self.value_type)) and + self._dtype.is_compatible_with(spec_or_value.dtype) and + self._shape.is_compatible_with(spec_or_value.shape)) def __repr__(self): - return "TensorSpec(shape={}, dtype={}, name={})".format( - self.shape, repr(self.dtype), repr(self.name)) + return "{}(shape={}, dtype={}, name={})".format( + type(self).__name__, self.shape, repr(self.dtype), repr(self.name)) def __hash__(self): return hash((self._shape_tuple, self.dtype)) @@ -120,19 +97,60 @@ class TensorSpec(type_spec.BatchableTypeSpec): def __ne__(self, other): return not self == other - value_type = property(lambda self: ops.Tensor) - def most_specific_compatible_type(self, other): if (type(self) is not type(other)) or (self._dtype != other.dtype): raise ValueError("Types are not compatible: %r vs %r" % (self, other)) shape = self._shape.most_specific_compatible_shape(other.shape) name = self._name if self._name == other.name else None - return TensorSpec(shape, self._dtype, name) + return type(self)(shape, self._dtype, name) def _serialize(self): return (self._shape, self._dtype, self._name) - _component_specs = property(lambda self: self) + def _to_legacy_output_types(self): + return self._dtype + + def _to_legacy_output_shapes(self): + return self._shape + + def _to_legacy_output_classes(self): + return self.value_type + + +@tf_export("TensorSpec") +class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec): + """Describes a tf.Tensor. + + Metadata for describing the `tf.Tensor` objects accepted or returned + by some TensorFlow APIs. + """ + + __slots__ = [] + + def is_compatible_with(self, spec_or_tensor): # pylint:disable=useless-super-delegation + """Returns True if spec_or_tensor is compatible with this TensorSpec. + + Two tensors are considered compatible if they have the same dtype + and their shapes are compatible (see `tf.TensorShape.is_compatible_with`). + + Args: + spec_or_tensor: A tf.TensorSpec or a tf.Tensor + + Returns: + True if spec_or_tensor is compatible with self. + """ + return super(TensorSpec, self).is_compatible_with(spec_or_tensor) + + @classmethod + def from_tensor(cls, tensor, name=None): + if isinstance(tensor, ops.EagerTensor): + return TensorSpec(tensor.shape, tensor.dtype, name) + elif isinstance(tensor, ops.Tensor): + return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) + else: + raise ValueError("`tensor` should be a tf.Tensor") + + value_type = property(lambda self: ops.Tensor) def _to_components(self, value): try: @@ -174,15 +192,6 @@ class TensorSpec(type_spec.BatchableTypeSpec): raise ValueError("Unbatching a tensor is only supported for rank >= 1") return TensorSpec(self._shape[1:], self._dtype) - def _to_legacy_output_types(self): - return self._dtype - - def _to_legacy_output_shapes(self): - return self._shape - - def _to_legacy_output_classes(self): - return ops.Tensor - # TODO(b/133606651): Should is_compatible_with should check min/max bounds? class BoundedTensorSpec(TensorSpec): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index f7aaee331a4..b0ef3d5d341 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_logging_ops @@ -1964,3 +1965,24 @@ def copy_to_graph_uninitialized(var): ops.NotDifferentiable("Assert") ops.NotDifferentiable("VarIsInitializedOp") ops.NotDifferentiable("VariableShape") + + +class VariableSpec(tensor_spec.DenseSpec): + """Describes a tf.Variable.""" + + __slots__ = [] + + value_type = property(lambda self: BaseResourceVariable) + + def _to_components(self, value): + raise NotImplementedError + + def _from_components(self, components): + raise NotImplementedError + + def _from_compatible_tensor_list(self, tensor_list): + assert len(tensor_list) == 1 + return tensor_list[0] + + +_pywrap_utils.RegisterType("VariableSpec", VariableSpec) diff --git a/tensorflow/python/ops/variable_spec_test.py b/tensorflow/python/ops/variable_spec_test.py new file mode 100644 index 00000000000..7a79d59ca19 --- /dev/null +++ b/tensorflow/python/ops/variable_spec_test.py @@ -0,0 +1,66 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for VariableSpec.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + +VariableSpec = resource_variable_ops.VariableSpec + + +class VariableSpecTest(test.TestCase): + + def test_properties(self): + spec = VariableSpec(shape=(1, 2, 3), dtype=dtypes.float64, name='vs') + self.assertEqual('vs', spec.name) + self.assertEqual(tensor_shape.TensorShape((1, 2, 3)), spec.shape) + self.assertEqual(dtypes.float64, spec.dtype) + + def test_compatibility(self): + spec = VariableSpec(shape=None) + spec2 = VariableSpec(shape=[None, 15]) + spec3 = VariableSpec(shape=[None]) + + self.assertTrue(spec.is_compatible_with(spec2)) + self.assertFalse(spec2.is_compatible_with(spec3)) + + var = resource_variable_ops.UninitializedVariable( + shape=[3, 15], dtype=dtypes.float32) + var2 = resource_variable_ops.UninitializedVariable( + shape=[3], dtype=dtypes.int32) + + self.assertTrue(spec2.is_compatible_with(var)) + self.assertFalse(spec3.is_compatible_with(var2)) + + spec4 = VariableSpec(shape=None, dtype=dtypes.int32) + spec5 = VariableSpec(shape=[None], dtype=dtypes.int32) + + self.assertFalse(spec.is_compatible_with(spec4)) + self.assertTrue(spec4.is_compatible_with(spec5)) + self.assertTrue(spec4.is_compatible_with(var2)) + + tensor = constant_op.constant([1, 2, 3]) + self.assertFalse(spec4.is_compatible_with(tensor)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index c2606243475..5d049b5470e 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -54,9 +54,9 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2 -from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_factory_ops @@ -975,8 +975,6 @@ class LoadTest(test.TestCase, parameterized.TestCase): x=constant_op.constant(2.)).numpy()) def test_concrete_function_variable_argument(self, cycles): - # TODO(allenl): Fix variables in input signatures. - self.skipTest("Need to fix encoding of variables in inputs signatures") capture = variables.Variable(0) @def_function.function @@ -984,14 +982,29 @@ class LoadTest(test.TestCase, parameterized.TestCase): v.assign_add(1) capture.assign_sub(1) + @def_function.function(input_signature=[ + resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) + ]) + def func_with_input_signature(v): + v.assign_add(5) + capture.assign_sub(5) + return 1 + vsave = variables.Variable(1) root = tracking.AutoTrackable() root.f = func.get_concrete_function(vsave) + root.f_sig = func_with_input_signature.get_concrete_function() root.capture = capture + self.assertEqual(1, vsave.numpy()) root.f(vsave) self.assertEqual(2, vsave.numpy()) self.assertEqual(-1, capture.numpy()) + + root.f_sig(vsave) + self.assertEqual(7, vsave.numpy()) + self.assertEqual(-6, capture.numpy()) + imported = cycle(root, cycles) vload = variables.Variable(1) @@ -999,8 +1012,13 @@ class LoadTest(test.TestCase, parameterized.TestCase): self.assertEqual(2, vload.numpy()) imported.f(v=vload) self.assertEqual(3, vload.numpy()) - self.assertEqual(-3, imported.capture.numpy()) - self.assertEqual(-1, capture.numpy()) + self.assertEqual(-8, imported.capture.numpy()) + + imported.f_sig(v=vload) + self.assertEqual(8, vload.numpy()) + self.assertEqual(-13, imported.capture.numpy()) + + self.assertEqual(-6, capture.numpy()) def test_function_and_component(self, cycles): @@ -1644,7 +1662,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): def test_destroy_resource(self, cycles): def get_handle(): - return gen_resource_variable_ops.var_handle_op( + return resource_variable_ops.var_handle_op( shape=tensor_shape.as_shape([]), dtype=dtypes.float32, shared_name="my_var_name", @@ -1655,7 +1673,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): def destroy_resource(self): handle = get_handle() - gen_resource_variable_ops.destroy_resource_op( + resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=True) class MyResource(tracking.TrackableResource): @@ -1669,7 +1687,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): return get_handle() def _initialize(self): - gen_resource_variable_ops.assign_variable_op( + resource_variable_ops.assign_variable_op( self.resource_handle, 1.0, name="assign") class MyModel(tracking.AutoTrackable): @@ -1681,10 +1699,9 @@ class LoadTest(test.TestCase, parameterized.TestCase): @def_function.function(input_signature=[]) def increase(self): handle = self.resource.resource_handle - gen_resource_variable_ops.assign_add_variable_op( + resource_variable_ops.assign_add_variable_op( handle, 10.0, name="assign_add") - return gen_resource_variable_ops.read_variable_op( - handle, dtypes.float32) + return resource_variable_ops.read_variable_op(handle, dtypes.float32) root = MyModel() imported = cycle(root, cycles) @@ -1699,7 +1716,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): # Try to destroy the resource again, should fail. with self.assertRaisesRegexp(errors.NotFoundError, r"Resource .* does not exist."): - gen_resource_variable_ops.destroy_resource_op( + resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=False) def test_function_called_as_operation(self, cycles): diff --git a/tensorflow/python/saved_model/nested_structure_coder.py b/tensorflow/python/saved_model/nested_structure_coder.py index 3144bbdf942..570e913c5dd 100644 --- a/tensorflow/python/saved_model/nested_structure_coder.py +++ b/tensorflow/python/saved_model/nested_structure_coder.py @@ -44,6 +44,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import compat @@ -467,6 +468,8 @@ class _TypeSpecCodec(object): optional_ops.OptionalSpec, struct_pb2.TypeSpecProto.PER_REPLICA_SPEC: values.PerReplicaSpec, + struct_pb2.TypeSpecProto.VARIABLE_SPEC: + resource_variable_ops.VariableSpec, } # Mapping from type (TypeSpec subclass) to enum value. diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 4f3dd20ad43..e178c362d04 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -45,6 +45,7 @@ from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import loader_impl @@ -429,6 +430,18 @@ class SaveTest(test.TestCase): self.assertAllClose({"output_0": 3 * (1 + 4 + 9 + 16)}, _import_and_infer(save_dir, {"x": 3})) + def test_variable_args_cannot_be_used_as_signature(self): + @def_function.function(input_signature=[ + resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)]) + def f(unused_v): + return 1 + root = tracking.AutoTrackable() + root.f = f.get_concrete_function() + with self.assertRaisesRegexp(ValueError, + "tf.Variable inputs cannot be exported"): + save.save(root, os.path.join(self.get_temp_dir(), "saved_model"), + signatures=root.f) + class SavingOptionsTest(test.TestCase): diff --git a/tensorflow/python/saved_model/signature_serialization.py b/tensorflow/python/saved_model/signature_serialization.py index 3f3725f39c9..b31bbaa7fcf 100644 --- a/tensorflow/python/saved_model/signature_serialization.py +++ b/tensorflow/python/saved_model/signature_serialization.py @@ -22,6 +22,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import function as defun from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.saved_model import revived_types from tensorflow.python.saved_model import signature_constants from tensorflow.python.training.tracking import base @@ -51,12 +52,21 @@ def _valid_signature(concrete_function): # 1.x style. return False try: + _validate_inputs(concrete_function) _normalize_outputs(concrete_function.structured_outputs, "unused", "unused") except ValueError: return False return True +def _validate_inputs(concrete_function): + if any(isinstance(inp, resource_variable_ops.VariableSpec) + for inp in nest.flatten( + concrete_function.structured_input_signature)): + raise ValueError(("Functions that expect tf.Variable inputs cannot be " + "exported as signatures.")) + + def find_function_to_export(saveable_view): """Function to export, None if no suitable function was found.""" # If the user did not specify signatures, check the root object for a function @@ -98,6 +108,8 @@ def canonicalize_signatures(signatures): "got {}. Only `tf.functions` with an input signature or " "concrete functions can be used as a signature.").format(function)) + _validate_inputs(signature_function) + # Re-wrap the function so that it returns a dictionary of Tensors. This # matches the format of 1.x-style signatures. # pylint: disable=cell-var-from-loop diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 9ecd6152316..270a582783e 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -512,21 +512,23 @@ bool IsCompositeTensorHelper(PyObject* o) { return check_cache->CachedLookup(o); } -// Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec. +// Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or +// VariableSpec. // Returns 0 otherwise. // Returns -1 if an error occurred. bool IsTypeSpecHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec"); - int is_tensor_spec = IsInstanceOfRegisteredType(to_check, "TensorSpec"); - if ((is_type_spec == -1) || (is_tensor_spec == -1)) return -1; - return static_cast<int>(is_type_spec && !is_tensor_spec); + int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") || + IsInstanceOfRegisteredType(to_check, "VariableSpec")); + if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1; + return static_cast<int>(is_type_spec && !is_dense_spec); }); return check_cache->CachedLookup(o); } // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or -// (non-TensorSpec) TypeSpec. +// (non-TensorSpec and non-VariableSpec) TypeSpec. // Returns 0 otherwise. // Returns -1 if an error occurred. int IsSequenceOrCompositeHelper(PyObject* o) { diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-spec.pbtxt index 0594607e0aa..55ec596cc25 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-spec.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.TensorSpec" tf_class { is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>" + is_instance: "<class \'tensorflow.python.framework.tensor_spec.DenseSpec\'>" is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>" is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>" is_instance: "<type \'object\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-spec.pbtxt index 0594607e0aa..55ec596cc25 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-spec.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-spec.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.TensorSpec" tf_class { is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>" + is_instance: "<class \'tensorflow.python.framework.tensor_spec.DenseSpec\'>" is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>" is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>" is_instance: "<type \'object\'>"