[tf.data] Add tests for MapDefun with Variant inputs.

PiperOrigin-RevId: 234394341
This commit is contained in:
Rachel Lim 2019-02-17 14:30:45 -08:00 committed by TensorFlower Gardener
parent f164de2f3f
commit 439f3eb035
2 changed files with 68 additions and 0 deletions

View File

@ -322,6 +322,8 @@ py_test(
"//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/experimental/ops:map_defun",
"//tensorflow/python/data/kernel_tests:test_base",
],

View File

@ -27,12 +27,14 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
@ -254,6 +256,70 @@ class MapDefunTest(test_base.DatasetTestBase):
expected = x + c
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
def testMapDefunWithVariantTensor(self):
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.variant)])
def fn(x):
return x
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant)
serialized = array_ops.stack([serialized, serialized])
map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.variant],
[None])[0]
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertSparseValuesEqual(expected, actual)
def testMapDefunWithVariantTensorAsCaptured(self):
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant)
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
del x
return serialized
x = constant_op.constant([0, 0])
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0]
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertSparseValuesEqual(expected, actual)
def testMapDefunWithStrTensor(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
def fn(x):
return x
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.string)
serialized = array_ops.stack([serialized, serialized])
map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.string],
[None])[0]
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertSparseValuesEqual(expected, actual)
if __name__ == "__main__":
test.main()