Allow NdarraySpec to be written in saved model.
PiperOrigin-RevId: 326121293 Change-Id: I7a4351a9ab3e0381ff5616f67d0e61880f3bb649
This commit is contained in:
parent
6974852f96
commit
b297140e1a
tensorflow
core/protobuf
python/saved_model
@ -136,6 +136,7 @@ message TypeSpecProto {
|
||||
PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py
|
||||
VARIABLE_SPEC = 9; // tf.VariableSpec
|
||||
ROW_PARTITION_SPEC = 10; // RowPartitionSpec from ragged/row_partition.py
|
||||
NDARRAY_SPEC = 11; // TF Numpy NDarray spec
|
||||
}
|
||||
TypeSpecClass type_spec_class = 1;
|
||||
|
||||
|
@ -587,6 +587,7 @@ py_strict_library(
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:optional_ops",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/ops/numpy_ops:numpy",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:row_partition",
|
||||
"@six_archive//:six",
|
||||
|
@ -50,9 +50,11 @@ from tensorflow.python.ops import cond_v2
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import numpy_ops as tnp
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.saved_model import load
|
||||
@ -1810,6 +1812,34 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
|
||||
self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
|
||||
|
||||
def test_saving_ndarray_specs(self, cycles):
|
||||
class NdarrayModule(module.Module):
|
||||
|
||||
@def_function.function
|
||||
def plain(self, x):
|
||||
return tnp.add(x, 1)
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
np_arrays.NdarraySpec(tensor_spec.TensorSpec([], dtypes.float32))])
|
||||
def with_signature(self, x):
|
||||
return tnp.add(x, 1)
|
||||
|
||||
m = NdarrayModule()
|
||||
c = tnp.asarray(3.0, tnp.float32)
|
||||
output_plain, output_with_signature = m.plain(c), m.with_signature(c)
|
||||
|
||||
loaded_m = cycle(m, cycles)
|
||||
|
||||
load_output_plain, load_output_with_signature = (
|
||||
loaded_m.plain(c), loaded_m.with_signature(c))
|
||||
|
||||
self.assertIsInstance(output_plain, tnp.ndarray)
|
||||
self.assertIsInstance(load_output_plain, tnp.ndarray)
|
||||
self.assertIsInstance(output_with_signature, tnp.ndarray)
|
||||
self.assertIsInstance(load_output_with_signature, tnp.ndarray)
|
||||
self.assertAllClose(output_plain, load_output_plain)
|
||||
self.assertAllClose(output_with_signature, load_output_with_signature)
|
||||
|
||||
|
||||
class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -48,6 +48,7 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import row_partition
|
||||
from tensorflow.python.util import compat
|
||||
@ -516,6 +517,8 @@ class _TypeSpecCodec(object):
|
||||
resource_variable_ops.VariableSpec,
|
||||
struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC:
|
||||
row_partition.RowPartitionSpec,
|
||||
struct_pb2.TypeSpecProto.NDARRAY_SPEC:
|
||||
np_arrays.NdarraySpec,
|
||||
}
|
||||
|
||||
# Mapping from type (TypeSpec subclass) to enum value.
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import nested_structure_coder
|
||||
@ -331,6 +332,14 @@ class NestedStructureTest(test.TestCase):
|
||||
decoded = self._coder.decode_proto(encoded)
|
||||
self.assertEqual(structure, decoded)
|
||||
|
||||
def testEncodeDecodeNdarraySpec(self):
|
||||
structure = [np_arrays.NdarraySpec(
|
||||
tensor_spec.TensorSpec([4, 2], dtypes.float32))]
|
||||
self.assertTrue(self._coder.can_encode(structure))
|
||||
encoded = self._coder.encode_structure(structure)
|
||||
decoded = self._coder.decode_proto(encoded)
|
||||
self.assertEqual(structure, decoded)
|
||||
|
||||
def testNotEncodable(self):
|
||||
|
||||
class NotEncodable(object):
|
||||
|
Loading…
Reference in New Issue
Block a user