[tf.data] Add tests for MapDefun with Variant inputs.
PiperOrigin-RevId: 234394341
This commit is contained in:
parent
f164de2f3f
commit
439f3eb035
@ -322,6 +322,8 @@ py_test(
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/experimental/ops:map_defun",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
],
|
||||
|
@ -27,12 +27,14 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -254,6 +256,70 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
expected = x + c
|
||||
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
|
||||
|
||||
def testMapDefunWithVariantTensor(self):
|
||||
|
||||
@function.defun(
|
||||
input_signature=[tensor_spec.TensorSpec([], dtypes.variant)])
|
||||
def fn(x):
|
||||
return x
|
||||
|
||||
st = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
|
||||
|
||||
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant)
|
||||
serialized = array_ops.stack([serialized, serialized])
|
||||
map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.variant],
|
||||
[None])[0]
|
||||
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
|
||||
values=[1, 2, 1, 2],
|
||||
dense_shape=[2, 3, 4])
|
||||
actual = self.evaluate(deserialized)
|
||||
self.assertSparseValuesEqual(expected, actual)
|
||||
|
||||
def testMapDefunWithVariantTensorAsCaptured(self):
|
||||
|
||||
st = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
|
||||
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant)
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
|
||||
def fn(x):
|
||||
del x
|
||||
return serialized
|
||||
|
||||
x = constant_op.constant([0, 0])
|
||||
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0]
|
||||
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
|
||||
values=[1, 2, 1, 2],
|
||||
dense_shape=[2, 3, 4])
|
||||
actual = self.evaluate(deserialized)
|
||||
self.assertSparseValuesEqual(expected, actual)
|
||||
|
||||
def testMapDefunWithStrTensor(self):
|
||||
|
||||
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
|
||||
def fn(x):
|
||||
return x
|
||||
|
||||
st = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
|
||||
|
||||
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.string)
|
||||
serialized = array_ops.stack([serialized, serialized])
|
||||
map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.string],
|
||||
[None])[0]
|
||||
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
|
||||
values=[1, 2, 1, 2],
|
||||
dense_shape=[2, 3, 4])
|
||||
actual = self.evaluate(deserialized)
|
||||
self.assertSparseValuesEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user