diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index dfdcc54411c..1733b9817b3 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -322,6 +322,8 @@ py_test( "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", "//tensorflow/python:session", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/experimental/ops:map_defun", "//tensorflow/python/data/kernel_tests:test_base", ], diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py index a48f0808a6a..4e99189279c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py @@ -27,12 +27,14 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test @@ -254,6 +256,70 @@ class MapDefunTest(test_base.DatasetTestBase): expected = x + c self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op)) + def testMapDefunWithVariantTensor(self): + + @function.defun( + input_signature=[tensor_spec.TensorSpec([], dtypes.variant)]) + def fn(x): + return x + + st = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + + serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant) + serialized = array_ops.stack([serialized, serialized]) + map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.variant], + [None])[0] + deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], + values=[1, 2, 1, 2], + dense_shape=[2, 3, 4]) + actual = self.evaluate(deserialized) + self.assertSparseValuesEqual(expected, actual) + + def testMapDefunWithVariantTensorAsCaptured(self): + + st = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant) + + @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) + def fn(x): + del x + return serialized + + x = constant_op.constant([0, 0]) + map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0] + deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], + values=[1, 2, 1, 2], + dense_shape=[2, 3, 4]) + actual = self.evaluate(deserialized) + self.assertSparseValuesEqual(expected, actual) + + def testMapDefunWithStrTensor(self): + + @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) + def fn(x): + return x + + st = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + + serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.string) + serialized = array_ops.stack([serialized, serialized]) + map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.string], + [None])[0] + deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], + values=[1, 2, 1, 2], + dense_shape=[2, 3, 4]) + actual = self.evaluate(deserialized) + self.assertSparseValuesEqual(expected, actual) + if __name__ == "__main__": test.main()