Allow NdarraySpec to be written in saved model.

PiperOrigin-RevId: 326121293
Change-Id: I7a4351a9ab3e0381ff5616f67d0e61880f3bb649
This commit is contained in:
Akshay Modi 2020-08-11 16:05:01 -07:00 committed by TensorFlower Gardener
parent 6974852f96
commit b297140e1a
5 changed files with 44 additions and 0 deletions

View File

@ -136,6 +136,7 @@ message TypeSpecProto {
PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py
VARIABLE_SPEC = 9; // tf.VariableSpec
ROW_PARTITION_SPEC = 10; // RowPartitionSpec from ragged/row_partition.py
NDARRAY_SPEC = 11; // TF Numpy NDarray spec
}
TypeSpecClass type_spec_class = 1;

View File

@ -587,6 +587,7 @@ py_strict_library(
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:optional_ops",
"//tensorflow/python/distribute:values",
"//tensorflow/python/ops/numpy_ops:numpy",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:row_partition",
"@six_archive//:six",

View File

@ -50,9 +50,11 @@ from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numpy_ops as tnp
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.saved_model import load
@ -1810,6 +1812,34 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
def test_saving_ndarray_specs(self, cycles):
class NdarrayModule(module.Module):
@def_function.function
def plain(self, x):
return tnp.add(x, 1)
@def_function.function(input_signature=[
np_arrays.NdarraySpec(tensor_spec.TensorSpec([], dtypes.float32))])
def with_signature(self, x):
return tnp.add(x, 1)
m = NdarrayModule()
c = tnp.asarray(3.0, tnp.float32)
output_plain, output_with_signature = m.plain(c), m.with_signature(c)
loaded_m = cycle(m, cycles)
load_output_plain, load_output_with_signature = (
loaded_m.plain(c), loaded_m.with_signature(c))
self.assertIsInstance(output_plain, tnp.ndarray)
self.assertIsInstance(load_output_plain, tnp.ndarray)
self.assertIsInstance(output_with_signature, tnp.ndarray)
self.assertIsInstance(load_output_with_signature, tnp.ndarray)
self.assertAllClose(output_plain, load_output_plain)
self.assertAllClose(output_with_signature, load_output_with_signature)
class SingleCycleTests(test.TestCase, parameterized.TestCase):

View File

@ -48,6 +48,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import row_partition
from tensorflow.python.util import compat
@ -516,6 +517,8 @@ class _TypeSpecCodec(object):
resource_variable_ops.VariableSpec,
struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC:
row_partition.RowPartitionSpec,
struct_pb2.TypeSpecProto.NDARRAY_SPEC:
np_arrays.NdarraySpec,
}
# Mapping from type (TypeSpec subclass) to enum value.

View File

@ -28,6 +28,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.saved_model import nested_structure_coder
@ -331,6 +332,14 @@ class NestedStructureTest(test.TestCase):
decoded = self._coder.decode_proto(encoded)
self.assertEqual(structure, decoded)
def testEncodeDecodeNdarraySpec(self):
structure = [np_arrays.NdarraySpec(
tensor_spec.TensorSpec([4, 2], dtypes.float32))]
self.assertTrue(self._coder.can_encode(structure))
encoded = self._coder.encode_structure(structure)
decoded = self._coder.decode_proto(encoded)
self.assertEqual(structure, decoded)
def testNotEncodable(self):
class NotEncodable(object):