Add Variable encoding so that functions with variable arguments can be saved to SavedModel.
PiperOrigin-RevId: 277826082 Change-Id: I38ab1cdf7990f449785271a0f37a10614efc7426
This commit is contained in:
		
							parent
							
								
									6b66d924e8
								
							
						
					
					
						commit
						e784a2202b
					
				| @ -1,10 +1,10 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
| 
 | 
 | ||||||
|  | package tensorflow; | ||||||
|  | 
 | ||||||
| import "tensorflow/core/framework/tensor_shape.proto"; | import "tensorflow/core/framework/tensor_shape.proto"; | ||||||
| import "tensorflow/core/framework/types.proto"; | import "tensorflow/core/framework/types.proto"; | ||||||
| 
 | 
 | ||||||
| package tensorflow; |  | ||||||
| 
 |  | ||||||
| // `StructuredValue` represents a dynamically typed value representing various | // `StructuredValue` represents a dynamically typed value representing various | ||||||
| // data structures that are inspired by Python data structures typically used in | // data structures that are inspired by Python data structures typically used in | ||||||
| // TensorFlow functions as inputs and outputs. | // TensorFlow functions as inputs and outputs. | ||||||
| @ -120,6 +120,7 @@ message TypeSpecProto { | |||||||
|     DATA_ITERATOR_SPEC = 6;   // IteratorSpec from data/ops/iterator_ops.py |     DATA_ITERATOR_SPEC = 6;   // IteratorSpec from data/ops/iterator_ops.py | ||||||
|     OPTIONAL_SPEC = 7;        // tf.OptionalSpec |     OPTIONAL_SPEC = 7;        // tf.OptionalSpec | ||||||
|     PER_REPLICA_SPEC = 8;     // PerReplicaSpec from distribute/values.py |     PER_REPLICA_SPEC = 8;     // PerReplicaSpec from distribute/values.py | ||||||
|  |     VARIABLE_SPEC = 9;        // tf.VariableSpec | ||||||
|   } |   } | ||||||
|   TypeSpecClass type_spec_class = 1; |   TypeSpecClass type_spec_class = 1; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -4672,6 +4672,18 @@ cuda_py_test( | |||||||
|     tags = ["no_windows_gpu"], |     tags = ["no_windows_gpu"], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | tf_py_test( | ||||||
|  |     name = "variable_spec_test", | ||||||
|  |     size = "small", | ||||||
|  |     srcs = ["ops/variable_spec_test.py"], | ||||||
|  |     additional_deps = [ | ||||||
|  |         ":framework_for_generated_wrappers", | ||||||
|  |         ":framework_test_lib", | ||||||
|  |         ":platform_test", | ||||||
|  |         "//third_party/py/numpy", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| py_library( | py_library( | ||||||
|     name = "training_lib", |     name = "training_lib", | ||||||
|     srcs = glob( |     srcs = glob( | ||||||
|  | |||||||
| @ -143,9 +143,11 @@ def _flat_shape_list(*params): | |||||||
|   Returns: |   Returns: | ||||||
|     A list of entries containing either `None` or `TensorShape`. |     A list of entries containing either `None` or `TensorShape`. | ||||||
|   """ |   """ | ||||||
|   return [tensor_shape.TensorShape(x.shape) |   return [ | ||||||
|           if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else None |       tensor_shape.TensorShape(x.shape) | ||||||
|           for x in nest.flatten(params, expand_composites=True)] |       if isinstance(x, (ops.Tensor, tensor_spec.DenseSpec)) else None | ||||||
|  |       for x in nest.flatten(params, expand_composites=True) | ||||||
|  |   ] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _shape_less_specific_than(relaxed, to_check): | def _shape_less_specific_than(relaxed, to_check): | ||||||
| @ -1651,7 +1653,7 @@ class ConcreteFunction(object): | |||||||
|                      self._func_graph.inputs[i].shape, |                      self._func_graph.inputs[i].shape, | ||||||
|                      arg.shape)) |                      arg.shape)) | ||||||
|       elif (self._signature is not None and |       elif (self._signature is not None and | ||||||
|             isinstance(self._signature[i], tensor_spec.TensorSpec)): |             isinstance(self._signature[i], tensor_spec.DenseSpec)): | ||||||
|         tensor_inputs.append( |         tensor_inputs.append( | ||||||
|             ops.convert_to_tensor(arg, self._signature[i].dtype)) |             ops.convert_to_tensor(arg, self._signature[i].dtype)) | ||||||
|       else: |       else: | ||||||
| @ -2208,7 +2210,8 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): | |||||||
|   need_packing = False |   need_packing = False | ||||||
|   for index, (value, spec) in enumerate(zip(flatten_inputs, |   for index, (value, spec) in enumerate(zip(flatten_inputs, | ||||||
|                                             flat_input_signature)): |                                             flat_input_signature)): | ||||||
|     if not _pywrap_utils.IsTensor(value): |     if (isinstance(spec, tensor_spec.TensorSpec) and | ||||||
|  |         not _pywrap_utils.IsTensor(value)): | ||||||
|       try: |       try: | ||||||
|         flatten_inputs[index] = ops.convert_to_tensor( |         flatten_inputs[index] = ops.convert_to_tensor( | ||||||
|             value, dtype_hint=spec.dtype) |             value, dtype_hint=spec.dtype) | ||||||
| @ -2392,11 +2395,12 @@ class Function(object): | |||||||
|           raise ValueError("Structure of Python function inputs does not match " |           raise ValueError("Structure of Python function inputs does not match " | ||||||
|                            "input_signature.") |                            "input_signature.") | ||||||
|         flat_inputs = nest.flatten(args, expand_composites=True) |         flat_inputs = nest.flatten(args, expand_composites=True) | ||||||
|         if any(not isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)) |         if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec, | ||||||
|  |                                     resource_variable_ops.BaseResourceVariable)) | ||||||
|                for arg in flat_inputs): |                for arg in flat_inputs): | ||||||
|           raise ValueError("When input_signature is provided, all inputs to " |           raise ValueError("When input_signature is provided, all inputs to " | ||||||
|                            "the Python function must be Tensors or " |                            "the Python function must be Tensors, Variables, " | ||||||
|                            "tf.TensorSpec objects.") |                            "tf.TensorSpec or tf.VariableSpec objects.") | ||||||
|         if any(not spec.is_compatible_with(other) |         if any(not spec.is_compatible_with(other) | ||||||
|                for spec, other in zip(self.flat_input_signature, flat_inputs)): |                for spec, other in zip(self.flat_input_signature, flat_inputs)): | ||||||
|           raise ValueError("Python inputs incompatible with input_signature: " |           raise ValueError("Python inputs incompatible with input_signature: " | ||||||
| @ -2701,7 +2705,7 @@ def register(func, *args, **kwargs): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def validate_signature(signature): | def validate_signature(signature): | ||||||
|   if any(not isinstance(arg, tensor_spec.TensorSpec) |   if any(not isinstance(arg, tensor_spec.DenseSpec) | ||||||
|          for arg in nest.flatten(signature, expand_composites=True)): |          for arg in nest.flatten(signature, expand_composites=True)): | ||||||
|     raise TypeError("Invalid input_signature {}; input_signature must be " |     raise TypeError("Invalid input_signature {}; input_signature must be " | ||||||
|                     "a possibly nested sequence of TensorSpec objects." |                     "a possibly nested sequence of TensorSpec objects." | ||||||
|  | |||||||
| @ -2055,6 +2055,27 @@ class FunctionTest(test.TestCase, parameterized.TestCase): | |||||||
|     with self.assertRaisesRegexp(ValueError, 'does not match'): |     with self.assertRaisesRegexp(ValueError, 'does not match'): | ||||||
|       defined(rt5) |       defined(rt5) | ||||||
| 
 | 
 | ||||||
|  |   def testInputSignatureWithVariableArgs(self): | ||||||
|  | 
 | ||||||
|  |     def f(v): | ||||||
|  |       v.assign_add(1) | ||||||
|  | 
 | ||||||
|  |     signature = [ | ||||||
|  |         resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) | ||||||
|  |     ] | ||||||
|  |     defined = function.defun(f, input_signature=signature) | ||||||
|  | 
 | ||||||
|  |     v1 = variables.Variable(0) | ||||||
|  |     v2 = variables.Variable(0) | ||||||
|  | 
 | ||||||
|  |     defined(v1) | ||||||
|  |     self.assertEqual(v1.numpy(), 1) | ||||||
|  |     self.assertEqual(v2.numpy(), 0) | ||||||
|  | 
 | ||||||
|  |     defined(v=v2) | ||||||
|  |     self.assertEqual(v1.numpy(), 1) | ||||||
|  |     self.assertEqual(v2.numpy(), 1) | ||||||
|  | 
 | ||||||
|   def testTensorKeywordArguments(self): |   def testTensorKeywordArguments(self): | ||||||
| 
 | 
 | ||||||
|     def foo(a, b): |     def foo(a, b): | ||||||
|  | |||||||
| @ -99,6 +99,9 @@ def convert_structure_to_signature(structure, arg_names=None): | |||||||
|     if isinstance(arg, composite_tensor.CompositeTensor): |     if isinstance(arg, composite_tensor.CompositeTensor): | ||||||
|       # TODO(b/133606651) Do we need to inject arg_name? |       # TODO(b/133606651) Do we need to inject arg_name? | ||||||
|       return arg._type_spec  # pylint: disable=protected-access |       return arg._type_spec  # pylint: disable=protected-access | ||||||
|  |     if isinstance(arg, resource_variable_ops.BaseResourceVariable): | ||||||
|  |       name = "/".join([str(p) for p in path]) | ||||||
|  |       return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name) | ||||||
|     if isinstance(arg, ( |     if isinstance(arg, ( | ||||||
|         int, |         int, | ||||||
|         float, |         float, | ||||||
| @ -292,7 +295,7 @@ class FuncGraph(ops.Graph): | |||||||
|     if key not in self._deferred_captures: |     if key not in self._deferred_captures: | ||||||
| 
 | 
 | ||||||
|       def convert_to_placeholder(s): |       def convert_to_placeholder(s): | ||||||
|         if not isinstance(s, tensor_spec.TensorSpec): |         if not isinstance(s, tensor_spec.DenseSpec): | ||||||
|           raise TypeError( |           raise TypeError( | ||||||
|               "Expected a nest of `TypeSpec` objects, found %s of type %s." % |               "Expected a nest of `TypeSpec` objects, found %s of type %s." % | ||||||
|               (s, type(s))) |               (s, type(s))) | ||||||
| @ -1177,7 +1180,7 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None): | |||||||
| 
 | 
 | ||||||
|     flattened = nest.flatten(arg_value, expand_composites=True) |     flattened = nest.flatten(arg_value, expand_composites=True) | ||||||
|     tensor_specs = [ |     tensor_specs = [ | ||||||
|         arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec) |         arg for arg in flattened if isinstance(arg, tensor_spec.DenseSpec) | ||||||
|     ] |     ] | ||||||
|     specified_names = [arg.name for arg in tensor_specs if arg.name] |     specified_names = [arg.name for arg in tensor_specs if arg.name] | ||||||
|     if specified_names and len(specified_names) < len(tensor_specs): |     if specified_names and len(specified_names) < len(tensor_specs): | ||||||
| @ -1209,7 +1212,20 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None): | |||||||
|               "_user_specified_name", |               "_user_specified_name", | ||||||
|               attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) |               attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) | ||||||
|         function_inputs.append(placeholder) |         function_inputs.append(placeholder) | ||||||
|       elif isinstance(arg, resource_variable_ops.BaseResourceVariable): |       elif isinstance(arg, (resource_variable_ops.BaseResourceVariable, | ||||||
|  |                             resource_variable_ops.VariableSpec)): | ||||||
|  |         if isinstance(arg, resource_variable_ops.VariableSpec): | ||||||
|  |           name = arg.name or name | ||||||
|  |           with func_graph.outer_graph.as_default(): | ||||||
|  |             placeholder = graph_placeholder(dtypes.resource, arg.shape, | ||||||
|  |                                             name=name) | ||||||
|  | 
 | ||||||
|  |             arg = resource_variable_ops.BaseResourceVariable( | ||||||
|  |                 name=name, | ||||||
|  |                 shape=arg.shape, | ||||||
|  |                 dtype=arg.dtype, | ||||||
|  |                 handle=placeholder, | ||||||
|  |                 handle_name=name) | ||||||
|         # Capture arg variables to create placeholders for them. These will be |         # Capture arg variables to create placeholders for them. These will be | ||||||
|         # removed as captures after the function is traced (since otherwise we'd |         # removed as captures after the function is traced (since otherwise we'd | ||||||
|         # just add it back with a new placeholder when the variable was |         # just add it back with a new placeholder when the variable was | ||||||
|  | |||||||
| @ -29,16 +29,13 @@ from tensorflow.python.framework import type_spec | |||||||
| from tensorflow.python.util.tf_export import tf_export | from tensorflow.python.util.tf_export import tf_export | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @tf_export("TensorSpec") | class DenseSpec(type_spec.TypeSpec): | ||||||
| class TensorSpec(type_spec.BatchableTypeSpec): |   """Describes a dense object with shape, dtype, and name.""" | ||||||
|   """Describes a tf.Tensor. |  | ||||||
| 
 |  | ||||||
|   Metadata for describing the `tf.Tensor` objects accepted or returned |  | ||||||
|   by some TensorFlow APIs. |  | ||||||
|   """ |  | ||||||
| 
 | 
 | ||||||
|   __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"] |   __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"] | ||||||
| 
 | 
 | ||||||
|  |   _component_specs = property(lambda self: self) | ||||||
|  | 
 | ||||||
|   def __init__(self, shape, dtype=dtypes.float32, name=None): |   def __init__(self, shape, dtype=dtypes.float32, name=None): | ||||||
|     """Creates a TensorSpec. |     """Creates a TensorSpec. | ||||||
| 
 | 
 | ||||||
| @ -63,15 +60,6 @@ class TensorSpec(type_spec.BatchableTypeSpec): | |||||||
|   def from_spec(cls, spec, name=None): |   def from_spec(cls, spec, name=None): | ||||||
|     return cls(spec.shape, spec.dtype, name or spec.name) |     return cls(spec.shape, spec.dtype, name or spec.name) | ||||||
| 
 | 
 | ||||||
|   @classmethod |  | ||||||
|   def from_tensor(cls, tensor, name=None): |  | ||||||
|     if isinstance(tensor, ops.EagerTensor): |  | ||||||
|       return TensorSpec(tensor.shape, tensor.dtype, name) |  | ||||||
|     elif isinstance(tensor, ops.Tensor): |  | ||||||
|       return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) |  | ||||||
|     else: |  | ||||||
|       raise ValueError("`tensor` should be a tf.Tensor") |  | ||||||
| 
 |  | ||||||
|   @property |   @property | ||||||
|   def shape(self): |   def shape(self): | ||||||
|     """Returns the `TensorShape` that represents the shape of the tensor.""" |     """Returns the `TensorShape` that represents the shape of the tensor.""" | ||||||
| @ -87,25 +75,14 @@ class TensorSpec(type_spec.BatchableTypeSpec): | |||||||
|     """Returns the (optionally provided) name of the described tensor.""" |     """Returns the (optionally provided) name of the described tensor.""" | ||||||
|     return self._name |     return self._name | ||||||
| 
 | 
 | ||||||
|   def is_compatible_with(self, spec_or_tensor): |   def is_compatible_with(self, spec_or_value): | ||||||
|     """Returns True if spec_or_tensor is compatible with this TensorSpec. |     return (isinstance(spec_or_value, (type(self), self.value_type)) and | ||||||
| 
 |             self._dtype.is_compatible_with(spec_or_value.dtype) and | ||||||
|     Two tensors are considered compatible if they have the same dtype |             self._shape.is_compatible_with(spec_or_value.shape)) | ||||||
|     and their shapes are compatible (see `tf.TensorShape.is_compatible_with`). |  | ||||||
| 
 |  | ||||||
|     Args: |  | ||||||
|       spec_or_tensor: A tf.TensorSpec or a tf.Tensor |  | ||||||
| 
 |  | ||||||
|     Returns: |  | ||||||
|       True if spec_or_tensor is compatible with self. |  | ||||||
|     """ |  | ||||||
|     return (isinstance(spec_or_tensor, (TensorSpec, ops.Tensor)) and |  | ||||||
|             self._dtype.is_compatible_with(spec_or_tensor.dtype) and |  | ||||||
|             self._shape.is_compatible_with(spec_or_tensor.shape)) |  | ||||||
| 
 | 
 | ||||||
|   def __repr__(self): |   def __repr__(self): | ||||||
|     return "TensorSpec(shape={}, dtype={}, name={})".format( |     return "{}(shape={}, dtype={}, name={})".format( | ||||||
|         self.shape, repr(self.dtype), repr(self.name)) |         type(self).__name__, self.shape, repr(self.dtype), repr(self.name)) | ||||||
| 
 | 
 | ||||||
|   def __hash__(self): |   def __hash__(self): | ||||||
|     return hash((self._shape_tuple, self.dtype)) |     return hash((self._shape_tuple, self.dtype)) | ||||||
| @ -120,19 +97,60 @@ class TensorSpec(type_spec.BatchableTypeSpec): | |||||||
|   def __ne__(self, other): |   def __ne__(self, other): | ||||||
|     return not self == other |     return not self == other | ||||||
| 
 | 
 | ||||||
|   value_type = property(lambda self: ops.Tensor) |  | ||||||
| 
 |  | ||||||
|   def most_specific_compatible_type(self, other): |   def most_specific_compatible_type(self, other): | ||||||
|     if (type(self) is not type(other)) or (self._dtype != other.dtype): |     if (type(self) is not type(other)) or (self._dtype != other.dtype): | ||||||
|       raise ValueError("Types are not compatible: %r vs %r" % (self, other)) |       raise ValueError("Types are not compatible: %r vs %r" % (self, other)) | ||||||
|     shape = self._shape.most_specific_compatible_shape(other.shape) |     shape = self._shape.most_specific_compatible_shape(other.shape) | ||||||
|     name = self._name if self._name == other.name else None |     name = self._name if self._name == other.name else None | ||||||
|     return TensorSpec(shape, self._dtype, name) |     return type(self)(shape, self._dtype, name) | ||||||
| 
 | 
 | ||||||
|   def _serialize(self): |   def _serialize(self): | ||||||
|     return (self._shape, self._dtype, self._name) |     return (self._shape, self._dtype, self._name) | ||||||
| 
 | 
 | ||||||
|   _component_specs = property(lambda self: self) |   def _to_legacy_output_types(self): | ||||||
|  |     return self._dtype | ||||||
|  | 
 | ||||||
|  |   def _to_legacy_output_shapes(self): | ||||||
|  |     return self._shape | ||||||
|  | 
 | ||||||
|  |   def _to_legacy_output_classes(self): | ||||||
|  |     return self.value_type | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @tf_export("TensorSpec") | ||||||
|  | class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec): | ||||||
|  |   """Describes a tf.Tensor. | ||||||
|  | 
 | ||||||
|  |   Metadata for describing the `tf.Tensor` objects accepted or returned | ||||||
|  |   by some TensorFlow APIs. | ||||||
|  |   """ | ||||||
|  | 
 | ||||||
|  |   __slots__ = [] | ||||||
|  | 
 | ||||||
|  |   def is_compatible_with(self, spec_or_tensor):  # pylint:disable=useless-super-delegation | ||||||
|  |     """Returns True if spec_or_tensor is compatible with this TensorSpec. | ||||||
|  | 
 | ||||||
|  |     Two tensors are considered compatible if they have the same dtype | ||||||
|  |     and their shapes are compatible (see `tf.TensorShape.is_compatible_with`). | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |       spec_or_tensor: A tf.TensorSpec or a tf.Tensor | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |       True if spec_or_tensor is compatible with self. | ||||||
|  |     """ | ||||||
|  |     return super(TensorSpec, self).is_compatible_with(spec_or_tensor) | ||||||
|  | 
 | ||||||
|  |   @classmethod | ||||||
|  |   def from_tensor(cls, tensor, name=None): | ||||||
|  |     if isinstance(tensor, ops.EagerTensor): | ||||||
|  |       return TensorSpec(tensor.shape, tensor.dtype, name) | ||||||
|  |     elif isinstance(tensor, ops.Tensor): | ||||||
|  |       return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) | ||||||
|  |     else: | ||||||
|  |       raise ValueError("`tensor` should be a tf.Tensor") | ||||||
|  | 
 | ||||||
|  |   value_type = property(lambda self: ops.Tensor) | ||||||
| 
 | 
 | ||||||
|   def _to_components(self, value): |   def _to_components(self, value): | ||||||
|     try: |     try: | ||||||
| @ -174,15 +192,6 @@ class TensorSpec(type_spec.BatchableTypeSpec): | |||||||
|       raise ValueError("Unbatching a tensor is only supported for rank >= 1") |       raise ValueError("Unbatching a tensor is only supported for rank >= 1") | ||||||
|     return TensorSpec(self._shape[1:], self._dtype) |     return TensorSpec(self._shape[1:], self._dtype) | ||||||
| 
 | 
 | ||||||
|   def _to_legacy_output_types(self): |  | ||||||
|     return self._dtype |  | ||||||
| 
 |  | ||||||
|   def _to_legacy_output_shapes(self): |  | ||||||
|     return self._shape |  | ||||||
| 
 |  | ||||||
|   def _to_legacy_output_classes(self): |  | ||||||
|     return ops.Tensor |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| # TODO(b/133606651): Should is_compatible_with should check min/max bounds? | # TODO(b/133606651): Should is_compatible_with should check min/max bounds? | ||||||
| class BoundedTensorSpec(TensorSpec): | class BoundedTensorSpec(TensorSpec): | ||||||
|  | |||||||
| @ -33,6 +33,7 @@ from tensorflow.python.framework import cpp_shape_inference_pb2 | |||||||
| from tensorflow.python.framework import dtypes | from tensorflow.python.framework import dtypes | ||||||
| from tensorflow.python.framework import ops | from tensorflow.python.framework import ops | ||||||
| from tensorflow.python.framework import tensor_shape | from tensorflow.python.framework import tensor_shape | ||||||
|  | from tensorflow.python.framework import tensor_spec | ||||||
| from tensorflow.python.ops import array_ops | from tensorflow.python.ops import array_ops | ||||||
| from tensorflow.python.ops import gen_array_ops | from tensorflow.python.ops import gen_array_ops | ||||||
| from tensorflow.python.ops import gen_logging_ops | from tensorflow.python.ops import gen_logging_ops | ||||||
| @ -1964,3 +1965,24 @@ def copy_to_graph_uninitialized(var): | |||||||
| ops.NotDifferentiable("Assert") | ops.NotDifferentiable("Assert") | ||||||
| ops.NotDifferentiable("VarIsInitializedOp") | ops.NotDifferentiable("VarIsInitializedOp") | ||||||
| ops.NotDifferentiable("VariableShape") | ops.NotDifferentiable("VariableShape") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class VariableSpec(tensor_spec.DenseSpec): | ||||||
|  |   """Describes a tf.Variable.""" | ||||||
|  | 
 | ||||||
|  |   __slots__ = [] | ||||||
|  | 
 | ||||||
|  |   value_type = property(lambda self: BaseResourceVariable) | ||||||
|  | 
 | ||||||
|  |   def _to_components(self, value): | ||||||
|  |     raise NotImplementedError | ||||||
|  | 
 | ||||||
|  |   def _from_components(self, components): | ||||||
|  |     raise NotImplementedError | ||||||
|  | 
 | ||||||
|  |   def _from_compatible_tensor_list(self, tensor_list): | ||||||
|  |     assert len(tensor_list) == 1 | ||||||
|  |     return tensor_list[0] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _pywrap_utils.RegisterType("VariableSpec", VariableSpec) | ||||||
|  | |||||||
							
								
								
									
										66
									
								
								tensorflow/python/ops/variable_spec_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								tensorflow/python/ops/variable_spec_test.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,66 @@ | |||||||
|  | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================== | ||||||
|  | """Tests for VariableSpec.""" | ||||||
|  | 
 | ||||||
|  | from __future__ import absolute_import | ||||||
|  | from __future__ import division | ||||||
|  | from __future__ import print_function | ||||||
|  | 
 | ||||||
|  | from tensorflow.python.framework import constant_op | ||||||
|  | from tensorflow.python.framework import dtypes | ||||||
|  | from tensorflow.python.framework import tensor_shape | ||||||
|  | from tensorflow.python.ops import resource_variable_ops | ||||||
|  | from tensorflow.python.platform import test | ||||||
|  | 
 | ||||||
|  | VariableSpec = resource_variable_ops.VariableSpec | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class VariableSpecTest(test.TestCase): | ||||||
|  | 
 | ||||||
|  |   def test_properties(self): | ||||||
|  |     spec = VariableSpec(shape=(1, 2, 3), dtype=dtypes.float64, name='vs') | ||||||
|  |     self.assertEqual('vs', spec.name) | ||||||
|  |     self.assertEqual(tensor_shape.TensorShape((1, 2, 3)), spec.shape) | ||||||
|  |     self.assertEqual(dtypes.float64, spec.dtype) | ||||||
|  | 
 | ||||||
|  |   def test_compatibility(self): | ||||||
|  |     spec = VariableSpec(shape=None) | ||||||
|  |     spec2 = VariableSpec(shape=[None, 15]) | ||||||
|  |     spec3 = VariableSpec(shape=[None]) | ||||||
|  | 
 | ||||||
|  |     self.assertTrue(spec.is_compatible_with(spec2)) | ||||||
|  |     self.assertFalse(spec2.is_compatible_with(spec3)) | ||||||
|  | 
 | ||||||
|  |     var = resource_variable_ops.UninitializedVariable( | ||||||
|  |         shape=[3, 15], dtype=dtypes.float32) | ||||||
|  |     var2 = resource_variable_ops.UninitializedVariable( | ||||||
|  |         shape=[3], dtype=dtypes.int32) | ||||||
|  | 
 | ||||||
|  |     self.assertTrue(spec2.is_compatible_with(var)) | ||||||
|  |     self.assertFalse(spec3.is_compatible_with(var2)) | ||||||
|  | 
 | ||||||
|  |     spec4 = VariableSpec(shape=None, dtype=dtypes.int32) | ||||||
|  |     spec5 = VariableSpec(shape=[None], dtype=dtypes.int32) | ||||||
|  | 
 | ||||||
|  |     self.assertFalse(spec.is_compatible_with(spec4)) | ||||||
|  |     self.assertTrue(spec4.is_compatible_with(spec5)) | ||||||
|  |     self.assertTrue(spec4.is_compatible_with(var2)) | ||||||
|  | 
 | ||||||
|  |     tensor = constant_op.constant([1, 2, 3]) | ||||||
|  |     self.assertFalse(spec4.is_compatible_with(tensor)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |   test.main() | ||||||
| @ -54,9 +54,9 @@ from tensorflow.python.lib.io import file_io | |||||||
| from tensorflow.python.module import module | from tensorflow.python.module import module | ||||||
| from tensorflow.python.ops import array_ops | from tensorflow.python.ops import array_ops | ||||||
| from tensorflow.python.ops import cond_v2 | from tensorflow.python.ops import cond_v2 | ||||||
| from tensorflow.python.ops import gen_resource_variable_ops |  | ||||||
| from tensorflow.python.ops import lookup_ops | from tensorflow.python.ops import lookup_ops | ||||||
| from tensorflow.python.ops import math_ops | from tensorflow.python.ops import math_ops | ||||||
|  | from tensorflow.python.ops import resource_variable_ops | ||||||
| from tensorflow.python.ops import variable_scope | from tensorflow.python.ops import variable_scope | ||||||
| from tensorflow.python.ops import variables | from tensorflow.python.ops import variables | ||||||
| from tensorflow.python.ops.ragged import ragged_factory_ops | from tensorflow.python.ops.ragged import ragged_factory_ops | ||||||
| @ -975,8 +975,6 @@ class LoadTest(test.TestCase, parameterized.TestCase): | |||||||
|                                     x=constant_op.constant(2.)).numpy()) |                                     x=constant_op.constant(2.)).numpy()) | ||||||
| 
 | 
 | ||||||
|   def test_concrete_function_variable_argument(self, cycles): |   def test_concrete_function_variable_argument(self, cycles): | ||||||
|     # TODO(allenl): Fix variables in input signatures. |  | ||||||
|     self.skipTest("Need to fix encoding of variables in inputs signatures") |  | ||||||
|     capture = variables.Variable(0) |     capture = variables.Variable(0) | ||||||
| 
 | 
 | ||||||
|     @def_function.function |     @def_function.function | ||||||
| @ -984,14 +982,29 @@ class LoadTest(test.TestCase, parameterized.TestCase): | |||||||
|       v.assign_add(1) |       v.assign_add(1) | ||||||
|       capture.assign_sub(1) |       capture.assign_sub(1) | ||||||
| 
 | 
 | ||||||
|  |     @def_function.function(input_signature=[ | ||||||
|  |         resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) | ||||||
|  |     ]) | ||||||
|  |     def func_with_input_signature(v): | ||||||
|  |       v.assign_add(5) | ||||||
|  |       capture.assign_sub(5) | ||||||
|  |       return 1 | ||||||
|  | 
 | ||||||
|     vsave = variables.Variable(1) |     vsave = variables.Variable(1) | ||||||
|     root = tracking.AutoTrackable() |     root = tracking.AutoTrackable() | ||||||
|     root.f = func.get_concrete_function(vsave) |     root.f = func.get_concrete_function(vsave) | ||||||
|  |     root.f_sig = func_with_input_signature.get_concrete_function() | ||||||
|     root.capture = capture |     root.capture = capture | ||||||
|  | 
 | ||||||
|     self.assertEqual(1, vsave.numpy()) |     self.assertEqual(1, vsave.numpy()) | ||||||
|     root.f(vsave) |     root.f(vsave) | ||||||
|     self.assertEqual(2, vsave.numpy()) |     self.assertEqual(2, vsave.numpy()) | ||||||
|     self.assertEqual(-1, capture.numpy()) |     self.assertEqual(-1, capture.numpy()) | ||||||
|  | 
 | ||||||
|  |     root.f_sig(vsave) | ||||||
|  |     self.assertEqual(7, vsave.numpy()) | ||||||
|  |     self.assertEqual(-6, capture.numpy()) | ||||||
|  | 
 | ||||||
|     imported = cycle(root, cycles) |     imported = cycle(root, cycles) | ||||||
| 
 | 
 | ||||||
|     vload = variables.Variable(1) |     vload = variables.Variable(1) | ||||||
| @ -999,8 +1012,13 @@ class LoadTest(test.TestCase, parameterized.TestCase): | |||||||
|     self.assertEqual(2, vload.numpy()) |     self.assertEqual(2, vload.numpy()) | ||||||
|     imported.f(v=vload) |     imported.f(v=vload) | ||||||
|     self.assertEqual(3, vload.numpy()) |     self.assertEqual(3, vload.numpy()) | ||||||
|     self.assertEqual(-3, imported.capture.numpy()) |     self.assertEqual(-8, imported.capture.numpy()) | ||||||
|     self.assertEqual(-1, capture.numpy()) | 
 | ||||||
|  |     imported.f_sig(v=vload) | ||||||
|  |     self.assertEqual(8, vload.numpy()) | ||||||
|  |     self.assertEqual(-13, imported.capture.numpy()) | ||||||
|  | 
 | ||||||
|  |     self.assertEqual(-6, capture.numpy()) | ||||||
| 
 | 
 | ||||||
|   def test_function_and_component(self, cycles): |   def test_function_and_component(self, cycles): | ||||||
| 
 | 
 | ||||||
| @ -1644,7 +1662,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): | |||||||
|   def test_destroy_resource(self, cycles): |   def test_destroy_resource(self, cycles): | ||||||
| 
 | 
 | ||||||
|     def get_handle(): |     def get_handle(): | ||||||
|       return gen_resource_variable_ops.var_handle_op( |       return resource_variable_ops.var_handle_op( | ||||||
|           shape=tensor_shape.as_shape([]), |           shape=tensor_shape.as_shape([]), | ||||||
|           dtype=dtypes.float32, |           dtype=dtypes.float32, | ||||||
|           shared_name="my_var_name", |           shared_name="my_var_name", | ||||||
| @ -1655,7 +1673,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): | |||||||
| 
 | 
 | ||||||
|       def destroy_resource(self): |       def destroy_resource(self): | ||||||
|         handle = get_handle() |         handle = get_handle() | ||||||
|         gen_resource_variable_ops.destroy_resource_op( |         resource_variable_ops.destroy_resource_op( | ||||||
|             handle, ignore_lookup_error=True) |             handle, ignore_lookup_error=True) | ||||||
| 
 | 
 | ||||||
|     class MyResource(tracking.TrackableResource): |     class MyResource(tracking.TrackableResource): | ||||||
| @ -1669,7 +1687,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): | |||||||
|         return get_handle() |         return get_handle() | ||||||
| 
 | 
 | ||||||
|       def _initialize(self): |       def _initialize(self): | ||||||
|         gen_resource_variable_ops.assign_variable_op( |         resource_variable_ops.assign_variable_op( | ||||||
|             self.resource_handle, 1.0, name="assign") |             self.resource_handle, 1.0, name="assign") | ||||||
| 
 | 
 | ||||||
|     class MyModel(tracking.AutoTrackable): |     class MyModel(tracking.AutoTrackable): | ||||||
| @ -1681,10 +1699,9 @@ class LoadTest(test.TestCase, parameterized.TestCase): | |||||||
|       @def_function.function(input_signature=[]) |       @def_function.function(input_signature=[]) | ||||||
|       def increase(self): |       def increase(self): | ||||||
|         handle = self.resource.resource_handle |         handle = self.resource.resource_handle | ||||||
|         gen_resource_variable_ops.assign_add_variable_op( |         resource_variable_ops.assign_add_variable_op( | ||||||
|             handle, 10.0, name="assign_add") |             handle, 10.0, name="assign_add") | ||||||
|         return gen_resource_variable_ops.read_variable_op( |         return resource_variable_ops.read_variable_op(handle, dtypes.float32) | ||||||
|             handle, dtypes.float32) |  | ||||||
| 
 | 
 | ||||||
|     root = MyModel() |     root = MyModel() | ||||||
|     imported = cycle(root, cycles) |     imported = cycle(root, cycles) | ||||||
| @ -1699,7 +1716,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): | |||||||
|     # Try to destroy the resource again, should fail. |     # Try to destroy the resource again, should fail. | ||||||
|     with self.assertRaisesRegexp(errors.NotFoundError, |     with self.assertRaisesRegexp(errors.NotFoundError, | ||||||
|                                  r"Resource .* does not exist."): |                                  r"Resource .* does not exist."): | ||||||
|       gen_resource_variable_ops.destroy_resource_op( |       resource_variable_ops.destroy_resource_op( | ||||||
|           handle, ignore_lookup_error=False) |           handle, ignore_lookup_error=False) | ||||||
| 
 | 
 | ||||||
|   def test_function_called_as_operation(self, cycles): |   def test_function_called_as_operation(self, cycles): | ||||||
|  | |||||||
| @ -44,6 +44,7 @@ from tensorflow.python.framework import indexed_slices | |||||||
| from tensorflow.python.framework import sparse_tensor | from tensorflow.python.framework import sparse_tensor | ||||||
| from tensorflow.python.framework import tensor_shape | from tensorflow.python.framework import tensor_shape | ||||||
| from tensorflow.python.framework import tensor_spec | from tensorflow.python.framework import tensor_spec | ||||||
|  | from tensorflow.python.ops import resource_variable_ops | ||||||
| from tensorflow.python.ops import tensor_array_ops | from tensorflow.python.ops import tensor_array_ops | ||||||
| from tensorflow.python.ops.ragged import ragged_tensor | from tensorflow.python.ops.ragged import ragged_tensor | ||||||
| from tensorflow.python.util import compat | from tensorflow.python.util import compat | ||||||
| @ -467,6 +468,8 @@ class _TypeSpecCodec(object): | |||||||
|           optional_ops.OptionalSpec, |           optional_ops.OptionalSpec, | ||||||
|       struct_pb2.TypeSpecProto.PER_REPLICA_SPEC: |       struct_pb2.TypeSpecProto.PER_REPLICA_SPEC: | ||||||
|           values.PerReplicaSpec, |           values.PerReplicaSpec, | ||||||
|  |       struct_pb2.TypeSpecProto.VARIABLE_SPEC: | ||||||
|  |           resource_variable_ops.VariableSpec, | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   # Mapping from type (TypeSpec subclass) to enum value. |   # Mapping from type (TypeSpec subclass) to enum value. | ||||||
|  | |||||||
| @ -45,6 +45,7 @@ from tensorflow.python.module import module | |||||||
| from tensorflow.python.ops import array_ops | from tensorflow.python.ops import array_ops | ||||||
| from tensorflow.python.ops import lookup_ops | from tensorflow.python.ops import lookup_ops | ||||||
| from tensorflow.python.ops import math_ops | from tensorflow.python.ops import math_ops | ||||||
|  | from tensorflow.python.ops import resource_variable_ops | ||||||
| from tensorflow.python.ops import variables | from tensorflow.python.ops import variables | ||||||
| from tensorflow.python.saved_model import loader | from tensorflow.python.saved_model import loader | ||||||
| from tensorflow.python.saved_model import loader_impl | from tensorflow.python.saved_model import loader_impl | ||||||
| @ -429,6 +430,18 @@ class SaveTest(test.TestCase): | |||||||
|     self.assertAllClose({"output_0": 3 * (1 + 4 + 9 + 16)}, |     self.assertAllClose({"output_0": 3 * (1 + 4 + 9 + 16)}, | ||||||
|                         _import_and_infer(save_dir, {"x": 3})) |                         _import_and_infer(save_dir, {"x": 3})) | ||||||
| 
 | 
 | ||||||
|  |   def test_variable_args_cannot_be_used_as_signature(self): | ||||||
|  |     @def_function.function(input_signature=[ | ||||||
|  |         resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)]) | ||||||
|  |     def f(unused_v): | ||||||
|  |       return 1 | ||||||
|  |     root = tracking.AutoTrackable() | ||||||
|  |     root.f = f.get_concrete_function() | ||||||
|  |     with self.assertRaisesRegexp(ValueError, | ||||||
|  |                                  "tf.Variable inputs cannot be exported"): | ||||||
|  |       save.save(root, os.path.join(self.get_temp_dir(), "saved_model"), | ||||||
|  |                 signatures=root.f) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class SavingOptionsTest(test.TestCase): | class SavingOptionsTest(test.TestCase): | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -22,6 +22,7 @@ from tensorflow.python.eager import def_function | |||||||
| from tensorflow.python.eager import function as defun | from tensorflow.python.eager import function as defun | ||||||
| from tensorflow.python.framework import ops | from tensorflow.python.framework import ops | ||||||
| from tensorflow.python.framework import tensor_spec | from tensorflow.python.framework import tensor_spec | ||||||
|  | from tensorflow.python.ops import resource_variable_ops | ||||||
| from tensorflow.python.saved_model import revived_types | from tensorflow.python.saved_model import revived_types | ||||||
| from tensorflow.python.saved_model import signature_constants | from tensorflow.python.saved_model import signature_constants | ||||||
| from tensorflow.python.training.tracking import base | from tensorflow.python.training.tracking import base | ||||||
| @ -51,12 +52,21 @@ def _valid_signature(concrete_function): | |||||||
|     # 1.x style. |     # 1.x style. | ||||||
|     return False |     return False | ||||||
|   try: |   try: | ||||||
|  |     _validate_inputs(concrete_function) | ||||||
|     _normalize_outputs(concrete_function.structured_outputs, "unused", "unused") |     _normalize_outputs(concrete_function.structured_outputs, "unused", "unused") | ||||||
|   except ValueError: |   except ValueError: | ||||||
|     return False |     return False | ||||||
|   return True |   return True | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def _validate_inputs(concrete_function): | ||||||
|  |   if any(isinstance(inp, resource_variable_ops.VariableSpec) | ||||||
|  |          for inp in nest.flatten( | ||||||
|  |              concrete_function.structured_input_signature)): | ||||||
|  |     raise ValueError(("Functions that expect tf.Variable inputs cannot be " | ||||||
|  |                       "exported as signatures.")) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def find_function_to_export(saveable_view): | def find_function_to_export(saveable_view): | ||||||
|   """Function to export, None if no suitable function was found.""" |   """Function to export, None if no suitable function was found.""" | ||||||
|   # If the user did not specify signatures, check the root object for a function |   # If the user did not specify signatures, check the root object for a function | ||||||
| @ -98,6 +108,8 @@ def canonicalize_signatures(signatures): | |||||||
|            "got {}. Only `tf.functions` with an input signature or " |            "got {}. Only `tf.functions` with an input signature or " | ||||||
|            "concrete functions can be used as a signature.").format(function)) |            "concrete functions can be used as a signature.").format(function)) | ||||||
| 
 | 
 | ||||||
|  |     _validate_inputs(signature_function) | ||||||
|  | 
 | ||||||
|     # Re-wrap the function so that it returns a dictionary of Tensors. This |     # Re-wrap the function so that it returns a dictionary of Tensors. This | ||||||
|     # matches the format of 1.x-style signatures. |     # matches the format of 1.x-style signatures. | ||||||
|     # pylint: disable=cell-var-from-loop |     # pylint: disable=cell-var-from-loop | ||||||
|  | |||||||
| @ -512,21 +512,23 @@ bool IsCompositeTensorHelper(PyObject* o) { | |||||||
|   return check_cache->CachedLookup(o); |   return check_cache->CachedLookup(o); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec.
 | // Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or
 | ||||||
|  | // VariableSpec.
 | ||||||
| // Returns 0 otherwise.
 | // Returns 0 otherwise.
 | ||||||
| // Returns -1 if an error occurred.
 | // Returns -1 if an error occurred.
 | ||||||
| bool IsTypeSpecHelper(PyObject* o) { | bool IsTypeSpecHelper(PyObject* o) { | ||||||
|   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { | ||||||
|     int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec"); |     int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec"); | ||||||
|     int is_tensor_spec = IsInstanceOfRegisteredType(to_check, "TensorSpec"); |     int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") || | ||||||
|     if ((is_type_spec == -1) || (is_tensor_spec == -1)) return -1; |                          IsInstanceOfRegisteredType(to_check, "VariableSpec")); | ||||||
|     return static_cast<int>(is_type_spec && !is_tensor_spec); |     if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1; | ||||||
|  |     return static_cast<int>(is_type_spec && !is_dense_spec); | ||||||
|   }); |   }); | ||||||
|   return check_cache->CachedLookup(o); |   return check_cache->CachedLookup(o); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
 | // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
 | ||||||
| // (non-TensorSpec) TypeSpec.
 | // (non-TensorSpec and non-VariableSpec) TypeSpec.
 | ||||||
| // Returns 0 otherwise.
 | // Returns 0 otherwise.
 | ||||||
| // Returns -1 if an error occurred.
 | // Returns -1 if an error occurred.
 | ||||||
| int IsSequenceOrCompositeHelper(PyObject* o) { | int IsSequenceOrCompositeHelper(PyObject* o) { | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| path: "tensorflow.TensorSpec" | path: "tensorflow.TensorSpec" | ||||||
| tf_class { | tf_class { | ||||||
|   is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>" |   is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>" | ||||||
|  |   is_instance: "<class \'tensorflow.python.framework.tensor_spec.DenseSpec\'>" | ||||||
|   is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>" |   is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>" | ||||||
|   is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>" |   is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>" | ||||||
|   is_instance: "<type \'object\'>" |   is_instance: "<type \'object\'>" | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| path: "tensorflow.TensorSpec" | path: "tensorflow.TensorSpec" | ||||||
| tf_class { | tf_class { | ||||||
|   is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>" |   is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>" | ||||||
|  |   is_instance: "<class \'tensorflow.python.framework.tensor_spec.DenseSpec\'>" | ||||||
|   is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>" |   is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>" | ||||||
|   is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>" |   is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>" | ||||||
|   is_instance: "<type \'object\'>" |   is_instance: "<type \'object\'>" | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user