Add Variable encoding so that functions with variable arguments can be saved to SavedModel.
PiperOrigin-RevId: 277826082 Change-Id: I38ab1cdf7990f449785271a0f37a10614efc7426
This commit is contained in:
parent
6b66d924e8
commit
e784a2202b
tensorflow
@ -1,10 +1,10 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
// `StructuredValue` represents a dynamically typed value representing various
|
||||
// data structures that are inspired by Python data structures typically used in
|
||||
// TensorFlow functions as inputs and outputs.
|
||||
@ -120,6 +120,7 @@ message TypeSpecProto {
|
||||
DATA_ITERATOR_SPEC = 6; // IteratorSpec from data/ops/iterator_ops.py
|
||||
OPTIONAL_SPEC = 7; // tf.OptionalSpec
|
||||
PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py
|
||||
VARIABLE_SPEC = 9; // tf.VariableSpec
|
||||
}
|
||||
TypeSpecClass type_spec_class = 1;
|
||||
|
||||
|
@ -4672,6 +4672,18 @@ cuda_py_test(
|
||||
tags = ["no_windows_gpu"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "variable_spec_test",
|
||||
size = "small",
|
||||
srcs = ["ops/variable_spec_test.py"],
|
||||
additional_deps = [
|
||||
":framework_for_generated_wrappers",
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "training_lib",
|
||||
srcs = glob(
|
||||
|
@ -143,9 +143,11 @@ def _flat_shape_list(*params):
|
||||
Returns:
|
||||
A list of entries containing either `None` or `TensorShape`.
|
||||
"""
|
||||
return [tensor_shape.TensorShape(x.shape)
|
||||
if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else None
|
||||
for x in nest.flatten(params, expand_composites=True)]
|
||||
return [
|
||||
tensor_shape.TensorShape(x.shape)
|
||||
if isinstance(x, (ops.Tensor, tensor_spec.DenseSpec)) else None
|
||||
for x in nest.flatten(params, expand_composites=True)
|
||||
]
|
||||
|
||||
|
||||
def _shape_less_specific_than(relaxed, to_check):
|
||||
@ -1651,7 +1653,7 @@ class ConcreteFunction(object):
|
||||
self._func_graph.inputs[i].shape,
|
||||
arg.shape))
|
||||
elif (self._signature is not None and
|
||||
isinstance(self._signature[i], tensor_spec.TensorSpec)):
|
||||
isinstance(self._signature[i], tensor_spec.DenseSpec)):
|
||||
tensor_inputs.append(
|
||||
ops.convert_to_tensor(arg, self._signature[i].dtype))
|
||||
else:
|
||||
@ -2208,7 +2210,8 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
|
||||
need_packing = False
|
||||
for index, (value, spec) in enumerate(zip(flatten_inputs,
|
||||
flat_input_signature)):
|
||||
if not _pywrap_utils.IsTensor(value):
|
||||
if (isinstance(spec, tensor_spec.TensorSpec) and
|
||||
not _pywrap_utils.IsTensor(value)):
|
||||
try:
|
||||
flatten_inputs[index] = ops.convert_to_tensor(
|
||||
value, dtype_hint=spec.dtype)
|
||||
@ -2392,11 +2395,12 @@ class Function(object):
|
||||
raise ValueError("Structure of Python function inputs does not match "
|
||||
"input_signature.")
|
||||
flat_inputs = nest.flatten(args, expand_composites=True)
|
||||
if any(not isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec))
|
||||
if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec,
|
||||
resource_variable_ops.BaseResourceVariable))
|
||||
for arg in flat_inputs):
|
||||
raise ValueError("When input_signature is provided, all inputs to "
|
||||
"the Python function must be Tensors or "
|
||||
"tf.TensorSpec objects.")
|
||||
"the Python function must be Tensors, Variables, "
|
||||
"tf.TensorSpec or tf.VariableSpec objects.")
|
||||
if any(not spec.is_compatible_with(other)
|
||||
for spec, other in zip(self.flat_input_signature, flat_inputs)):
|
||||
raise ValueError("Python inputs incompatible with input_signature: "
|
||||
@ -2701,7 +2705,7 @@ def register(func, *args, **kwargs):
|
||||
|
||||
|
||||
def validate_signature(signature):
|
||||
if any(not isinstance(arg, tensor_spec.TensorSpec)
|
||||
if any(not isinstance(arg, tensor_spec.DenseSpec)
|
||||
for arg in nest.flatten(signature, expand_composites=True)):
|
||||
raise TypeError("Invalid input_signature {}; input_signature must be "
|
||||
"a possibly nested sequence of TensorSpec objects."
|
||||
|
@ -2055,6 +2055,27 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'does not match'):
|
||||
defined(rt5)
|
||||
|
||||
def testInputSignatureWithVariableArgs(self):
|
||||
|
||||
def f(v):
|
||||
v.assign_add(1)
|
||||
|
||||
signature = [
|
||||
resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
|
||||
]
|
||||
defined = function.defun(f, input_signature=signature)
|
||||
|
||||
v1 = variables.Variable(0)
|
||||
v2 = variables.Variable(0)
|
||||
|
||||
defined(v1)
|
||||
self.assertEqual(v1.numpy(), 1)
|
||||
self.assertEqual(v2.numpy(), 0)
|
||||
|
||||
defined(v=v2)
|
||||
self.assertEqual(v1.numpy(), 1)
|
||||
self.assertEqual(v2.numpy(), 1)
|
||||
|
||||
def testTensorKeywordArguments(self):
|
||||
|
||||
def foo(a, b):
|
||||
|
@ -99,6 +99,9 @@ def convert_structure_to_signature(structure, arg_names=None):
|
||||
if isinstance(arg, composite_tensor.CompositeTensor):
|
||||
# TODO(b/133606651) Do we need to inject arg_name?
|
||||
return arg._type_spec # pylint: disable=protected-access
|
||||
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
||||
name = "/".join([str(p) for p in path])
|
||||
return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name)
|
||||
if isinstance(arg, (
|
||||
int,
|
||||
float,
|
||||
@ -292,7 +295,7 @@ class FuncGraph(ops.Graph):
|
||||
if key not in self._deferred_captures:
|
||||
|
||||
def convert_to_placeholder(s):
|
||||
if not isinstance(s, tensor_spec.TensorSpec):
|
||||
if not isinstance(s, tensor_spec.DenseSpec):
|
||||
raise TypeError(
|
||||
"Expected a nest of `TypeSpec` objects, found %s of type %s." %
|
||||
(s, type(s)))
|
||||
@ -1177,7 +1180,7 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None):
|
||||
|
||||
flattened = nest.flatten(arg_value, expand_composites=True)
|
||||
tensor_specs = [
|
||||
arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec)
|
||||
arg for arg in flattened if isinstance(arg, tensor_spec.DenseSpec)
|
||||
]
|
||||
specified_names = [arg.name for arg in tensor_specs if arg.name]
|
||||
if specified_names and len(specified_names) < len(tensor_specs):
|
||||
@ -1209,7 +1212,20 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None):
|
||||
"_user_specified_name",
|
||||
attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
|
||||
function_inputs.append(placeholder)
|
||||
elif isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
||||
elif isinstance(arg, (resource_variable_ops.BaseResourceVariable,
|
||||
resource_variable_ops.VariableSpec)):
|
||||
if isinstance(arg, resource_variable_ops.VariableSpec):
|
||||
name = arg.name or name
|
||||
with func_graph.outer_graph.as_default():
|
||||
placeholder = graph_placeholder(dtypes.resource, arg.shape,
|
||||
name=name)
|
||||
|
||||
arg = resource_variable_ops.BaseResourceVariable(
|
||||
name=name,
|
||||
shape=arg.shape,
|
||||
dtype=arg.dtype,
|
||||
handle=placeholder,
|
||||
handle_name=name)
|
||||
# Capture arg variables to create placeholders for them. These will be
|
||||
# removed as captures after the function is traced (since otherwise we'd
|
||||
# just add it back with a new placeholder when the variable was
|
||||
|
@ -29,16 +29,13 @@ from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("TensorSpec")
|
||||
class TensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Describes a tf.Tensor.
|
||||
|
||||
Metadata for describing the `tf.Tensor` objects accepted or returned
|
||||
by some TensorFlow APIs.
|
||||
"""
|
||||
class DenseSpec(type_spec.TypeSpec):
|
||||
"""Describes a dense object with shape, dtype, and name."""
|
||||
|
||||
__slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"]
|
||||
|
||||
_component_specs = property(lambda self: self)
|
||||
|
||||
def __init__(self, shape, dtype=dtypes.float32, name=None):
|
||||
"""Creates a TensorSpec.
|
||||
|
||||
@ -63,15 +60,6 @@ class TensorSpec(type_spec.BatchableTypeSpec):
|
||||
def from_spec(cls, spec, name=None):
|
||||
return cls(spec.shape, spec.dtype, name or spec.name)
|
||||
|
||||
@classmethod
|
||||
def from_tensor(cls, tensor, name=None):
|
||||
if isinstance(tensor, ops.EagerTensor):
|
||||
return TensorSpec(tensor.shape, tensor.dtype, name)
|
||||
elif isinstance(tensor, ops.Tensor):
|
||||
return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
|
||||
else:
|
||||
raise ValueError("`tensor` should be a tf.Tensor")
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Returns the `TensorShape` that represents the shape of the tensor."""
|
||||
@ -87,25 +75,14 @@ class TensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Returns the (optionally provided) name of the described tensor."""
|
||||
return self._name
|
||||
|
||||
def is_compatible_with(self, spec_or_tensor):
|
||||
"""Returns True if spec_or_tensor is compatible with this TensorSpec.
|
||||
|
||||
Two tensors are considered compatible if they have the same dtype
|
||||
and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
|
||||
|
||||
Args:
|
||||
spec_or_tensor: A tf.TensorSpec or a tf.Tensor
|
||||
|
||||
Returns:
|
||||
True if spec_or_tensor is compatible with self.
|
||||
"""
|
||||
return (isinstance(spec_or_tensor, (TensorSpec, ops.Tensor)) and
|
||||
self._dtype.is_compatible_with(spec_or_tensor.dtype) and
|
||||
self._shape.is_compatible_with(spec_or_tensor.shape))
|
||||
def is_compatible_with(self, spec_or_value):
|
||||
return (isinstance(spec_or_value, (type(self), self.value_type)) and
|
||||
self._dtype.is_compatible_with(spec_or_value.dtype) and
|
||||
self._shape.is_compatible_with(spec_or_value.shape))
|
||||
|
||||
def __repr__(self):
|
||||
return "TensorSpec(shape={}, dtype={}, name={})".format(
|
||||
self.shape, repr(self.dtype), repr(self.name))
|
||||
return "{}(shape={}, dtype={}, name={})".format(
|
||||
type(self).__name__, self.shape, repr(self.dtype), repr(self.name))
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self._shape_tuple, self.dtype))
|
||||
@ -120,19 +97,60 @@ class TensorSpec(type_spec.BatchableTypeSpec):
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
value_type = property(lambda self: ops.Tensor)
|
||||
|
||||
def most_specific_compatible_type(self, other):
|
||||
if (type(self) is not type(other)) or (self._dtype != other.dtype):
|
||||
raise ValueError("Types are not compatible: %r vs %r" % (self, other))
|
||||
shape = self._shape.most_specific_compatible_shape(other.shape)
|
||||
name = self._name if self._name == other.name else None
|
||||
return TensorSpec(shape, self._dtype, name)
|
||||
return type(self)(shape, self._dtype, name)
|
||||
|
||||
def _serialize(self):
|
||||
return (self._shape, self._dtype, self._name)
|
||||
|
||||
_component_specs = property(lambda self: self)
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self._shape
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return self.value_type
|
||||
|
||||
|
||||
@tf_export("TensorSpec")
|
||||
class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec):
|
||||
"""Describes a tf.Tensor.
|
||||
|
||||
Metadata for describing the `tf.Tensor` objects accepted or returned
|
||||
by some TensorFlow APIs.
|
||||
"""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def is_compatible_with(self, spec_or_tensor): # pylint:disable=useless-super-delegation
|
||||
"""Returns True if spec_or_tensor is compatible with this TensorSpec.
|
||||
|
||||
Two tensors are considered compatible if they have the same dtype
|
||||
and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
|
||||
|
||||
Args:
|
||||
spec_or_tensor: A tf.TensorSpec or a tf.Tensor
|
||||
|
||||
Returns:
|
||||
True if spec_or_tensor is compatible with self.
|
||||
"""
|
||||
return super(TensorSpec, self).is_compatible_with(spec_or_tensor)
|
||||
|
||||
@classmethod
|
||||
def from_tensor(cls, tensor, name=None):
|
||||
if isinstance(tensor, ops.EagerTensor):
|
||||
return TensorSpec(tensor.shape, tensor.dtype, name)
|
||||
elif isinstance(tensor, ops.Tensor):
|
||||
return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
|
||||
else:
|
||||
raise ValueError("`tensor` should be a tf.Tensor")
|
||||
|
||||
value_type = property(lambda self: ops.Tensor)
|
||||
|
||||
def _to_components(self, value):
|
||||
try:
|
||||
@ -174,15 +192,6 @@ class TensorSpec(type_spec.BatchableTypeSpec):
|
||||
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
|
||||
return TensorSpec(self._shape[1:], self._dtype)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self._shape
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return ops.Tensor
|
||||
|
||||
|
||||
# TODO(b/133606651): Should is_compatible_with should check min/max bounds?
|
||||
class BoundedTensorSpec(TensorSpec):
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_logging_ops
|
||||
@ -1964,3 +1965,24 @@ def copy_to_graph_uninitialized(var):
|
||||
ops.NotDifferentiable("Assert")
|
||||
ops.NotDifferentiable("VarIsInitializedOp")
|
||||
ops.NotDifferentiable("VariableShape")
|
||||
|
||||
|
||||
class VariableSpec(tensor_spec.DenseSpec):
|
||||
"""Describes a tf.Variable."""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
value_type = property(lambda self: BaseResourceVariable)
|
||||
|
||||
def _to_components(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
def _from_components(self, components):
|
||||
raise NotImplementedError
|
||||
|
||||
def _from_compatible_tensor_list(self, tensor_list):
|
||||
assert len(tensor_list) == 1
|
||||
return tensor_list[0]
|
||||
|
||||
|
||||
_pywrap_utils.RegisterType("VariableSpec", VariableSpec)
|
||||
|
66
tensorflow/python/ops/variable_spec_test.py
Normal file
66
tensorflow/python/ops/variable_spec_test.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for VariableSpec."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
VariableSpec = resource_variable_ops.VariableSpec
|
||||
|
||||
|
||||
class VariableSpecTest(test.TestCase):
|
||||
|
||||
def test_properties(self):
|
||||
spec = VariableSpec(shape=(1, 2, 3), dtype=dtypes.float64, name='vs')
|
||||
self.assertEqual('vs', spec.name)
|
||||
self.assertEqual(tensor_shape.TensorShape((1, 2, 3)), spec.shape)
|
||||
self.assertEqual(dtypes.float64, spec.dtype)
|
||||
|
||||
def test_compatibility(self):
|
||||
spec = VariableSpec(shape=None)
|
||||
spec2 = VariableSpec(shape=[None, 15])
|
||||
spec3 = VariableSpec(shape=[None])
|
||||
|
||||
self.assertTrue(spec.is_compatible_with(spec2))
|
||||
self.assertFalse(spec2.is_compatible_with(spec3))
|
||||
|
||||
var = resource_variable_ops.UninitializedVariable(
|
||||
shape=[3, 15], dtype=dtypes.float32)
|
||||
var2 = resource_variable_ops.UninitializedVariable(
|
||||
shape=[3], dtype=dtypes.int32)
|
||||
|
||||
self.assertTrue(spec2.is_compatible_with(var))
|
||||
self.assertFalse(spec3.is_compatible_with(var2))
|
||||
|
||||
spec4 = VariableSpec(shape=None, dtype=dtypes.int32)
|
||||
spec5 = VariableSpec(shape=[None], dtype=dtypes.int32)
|
||||
|
||||
self.assertFalse(spec.is_compatible_with(spec4))
|
||||
self.assertTrue(spec4.is_compatible_with(spec5))
|
||||
self.assertTrue(spec4.is_compatible_with(var2))
|
||||
|
||||
tensor = constant_op.constant([1, 2, 3])
|
||||
self.assertFalse(spec4.is_compatible_with(tensor))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -54,9 +54,9 @@ from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import cond_v2
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
@ -975,8 +975,6 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
x=constant_op.constant(2.)).numpy())
|
||||
|
||||
def test_concrete_function_variable_argument(self, cycles):
|
||||
# TODO(allenl): Fix variables in input signatures.
|
||||
self.skipTest("Need to fix encoding of variables in inputs signatures")
|
||||
capture = variables.Variable(0)
|
||||
|
||||
@def_function.function
|
||||
@ -984,14 +982,29 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
v.assign_add(1)
|
||||
capture.assign_sub(1)
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
|
||||
])
|
||||
def func_with_input_signature(v):
|
||||
v.assign_add(5)
|
||||
capture.assign_sub(5)
|
||||
return 1
|
||||
|
||||
vsave = variables.Variable(1)
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func.get_concrete_function(vsave)
|
||||
root.f_sig = func_with_input_signature.get_concrete_function()
|
||||
root.capture = capture
|
||||
|
||||
self.assertEqual(1, vsave.numpy())
|
||||
root.f(vsave)
|
||||
self.assertEqual(2, vsave.numpy())
|
||||
self.assertEqual(-1, capture.numpy())
|
||||
|
||||
root.f_sig(vsave)
|
||||
self.assertEqual(7, vsave.numpy())
|
||||
self.assertEqual(-6, capture.numpy())
|
||||
|
||||
imported = cycle(root, cycles)
|
||||
|
||||
vload = variables.Variable(1)
|
||||
@ -999,8 +1012,13 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(2, vload.numpy())
|
||||
imported.f(v=vload)
|
||||
self.assertEqual(3, vload.numpy())
|
||||
self.assertEqual(-3, imported.capture.numpy())
|
||||
self.assertEqual(-1, capture.numpy())
|
||||
self.assertEqual(-8, imported.capture.numpy())
|
||||
|
||||
imported.f_sig(v=vload)
|
||||
self.assertEqual(8, vload.numpy())
|
||||
self.assertEqual(-13, imported.capture.numpy())
|
||||
|
||||
self.assertEqual(-6, capture.numpy())
|
||||
|
||||
def test_function_and_component(self, cycles):
|
||||
|
||||
@ -1644,7 +1662,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def test_destroy_resource(self, cycles):
|
||||
|
||||
def get_handle():
|
||||
return gen_resource_variable_ops.var_handle_op(
|
||||
return resource_variable_ops.var_handle_op(
|
||||
shape=tensor_shape.as_shape([]),
|
||||
dtype=dtypes.float32,
|
||||
shared_name="my_var_name",
|
||||
@ -1655,7 +1673,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def destroy_resource(self):
|
||||
handle = get_handle()
|
||||
gen_resource_variable_ops.destroy_resource_op(
|
||||
resource_variable_ops.destroy_resource_op(
|
||||
handle, ignore_lookup_error=True)
|
||||
|
||||
class MyResource(tracking.TrackableResource):
|
||||
@ -1669,7 +1687,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
return get_handle()
|
||||
|
||||
def _initialize(self):
|
||||
gen_resource_variable_ops.assign_variable_op(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
self.resource_handle, 1.0, name="assign")
|
||||
|
||||
class MyModel(tracking.AutoTrackable):
|
||||
@ -1681,10 +1699,9 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
@def_function.function(input_signature=[])
|
||||
def increase(self):
|
||||
handle = self.resource.resource_handle
|
||||
gen_resource_variable_ops.assign_add_variable_op(
|
||||
resource_variable_ops.assign_add_variable_op(
|
||||
handle, 10.0, name="assign_add")
|
||||
return gen_resource_variable_ops.read_variable_op(
|
||||
handle, dtypes.float32)
|
||||
return resource_variable_ops.read_variable_op(handle, dtypes.float32)
|
||||
|
||||
root = MyModel()
|
||||
imported = cycle(root, cycles)
|
||||
@ -1699,7 +1716,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
# Try to destroy the resource again, should fail.
|
||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||
r"Resource .* does not exist."):
|
||||
gen_resource_variable_ops.destroy_resource_op(
|
||||
resource_variable_ops.destroy_resource_op(
|
||||
handle, ignore_lookup_error=False)
|
||||
|
||||
def test_function_called_as_operation(self, cycles):
|
||||
|
@ -44,6 +44,7 @@ from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import compat
|
||||
@ -467,6 +468,8 @@ class _TypeSpecCodec(object):
|
||||
optional_ops.OptionalSpec,
|
||||
struct_pb2.TypeSpecProto.PER_REPLICA_SPEC:
|
||||
values.PerReplicaSpec,
|
||||
struct_pb2.TypeSpecProto.VARIABLE_SPEC:
|
||||
resource_variable_ops.VariableSpec,
|
||||
}
|
||||
|
||||
# Mapping from type (TypeSpec subclass) to enum value.
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
@ -429,6 +430,18 @@ class SaveTest(test.TestCase):
|
||||
self.assertAllClose({"output_0": 3 * (1 + 4 + 9 + 16)},
|
||||
_import_and_infer(save_dir, {"x": 3}))
|
||||
|
||||
def test_variable_args_cannot_be_used_as_signature(self):
|
||||
@def_function.function(input_signature=[
|
||||
resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)])
|
||||
def f(unused_v):
|
||||
return 1
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = f.get_concrete_function()
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"tf.Variable inputs cannot be exported"):
|
||||
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"),
|
||||
signatures=root.f)
|
||||
|
||||
|
||||
class SavingOptionsTest(test.TestCase):
|
||||
|
||||
|
@ -22,6 +22,7 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training.tracking import base
|
||||
@ -51,12 +52,21 @@ def _valid_signature(concrete_function):
|
||||
# 1.x style.
|
||||
return False
|
||||
try:
|
||||
_validate_inputs(concrete_function)
|
||||
_normalize_outputs(concrete_function.structured_outputs, "unused", "unused")
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _validate_inputs(concrete_function):
|
||||
if any(isinstance(inp, resource_variable_ops.VariableSpec)
|
||||
for inp in nest.flatten(
|
||||
concrete_function.structured_input_signature)):
|
||||
raise ValueError(("Functions that expect tf.Variable inputs cannot be "
|
||||
"exported as signatures."))
|
||||
|
||||
|
||||
def find_function_to_export(saveable_view):
|
||||
"""Function to export, None if no suitable function was found."""
|
||||
# If the user did not specify signatures, check the root object for a function
|
||||
@ -98,6 +108,8 @@ def canonicalize_signatures(signatures):
|
||||
"got {}. Only `tf.functions` with an input signature or "
|
||||
"concrete functions can be used as a signature.").format(function))
|
||||
|
||||
_validate_inputs(signature_function)
|
||||
|
||||
# Re-wrap the function so that it returns a dictionary of Tensors. This
|
||||
# matches the format of 1.x-style signatures.
|
||||
# pylint: disable=cell-var-from-loop
|
||||
|
@ -512,21 +512,23 @@ bool IsCompositeTensorHelper(PyObject* o) {
|
||||
return check_cache->CachedLookup(o);
|
||||
}
|
||||
|
||||
// Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec.
|
||||
// Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or
|
||||
// VariableSpec.
|
||||
// Returns 0 otherwise.
|
||||
// Returns -1 if an error occurred.
|
||||
bool IsTypeSpecHelper(PyObject* o) {
|
||||
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
|
||||
int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec");
|
||||
int is_tensor_spec = IsInstanceOfRegisteredType(to_check, "TensorSpec");
|
||||
if ((is_type_spec == -1) || (is_tensor_spec == -1)) return -1;
|
||||
return static_cast<int>(is_type_spec && !is_tensor_spec);
|
||||
int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") ||
|
||||
IsInstanceOfRegisteredType(to_check, "VariableSpec"));
|
||||
if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1;
|
||||
return static_cast<int>(is_type_spec && !is_dense_spec);
|
||||
});
|
||||
return check_cache->CachedLookup(o);
|
||||
}
|
||||
|
||||
// Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
|
||||
// (non-TensorSpec) TypeSpec.
|
||||
// (non-TensorSpec and non-VariableSpec) TypeSpec.
|
||||
// Returns 0 otherwise.
|
||||
// Returns -1 if an error occurred.
|
||||
int IsSequenceOrCompositeHelper(PyObject* o) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.TensorSpec"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_spec.DenseSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.TensorSpec"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_spec.DenseSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user