From e4e9409b3de9a8d12a56fc0e2fa7270bffd0d41a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Dec 2018 18:56:53 -0800 Subject: [PATCH] Fix internal type mismatch in ragged.map_fn PiperOrigin-RevId: 225110815 --- tensorflow/python/ops/ragged/BUILD | 4 ++-- .../python/ops/ragged/ragged_map_fn_op_test.py | 13 +++++++++++++ tensorflow/python/ops/ragged/ragged_map_ops.py | 4 +++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index c0db8bfbb5c..440d9db8246 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -263,17 +263,17 @@ py_library( srcs = ["ragged_map_ops.py"], srcs_version = "PY2AND3", deps = [ - ":ragged_array_ops", - ":ragged_factory_ops", ":ragged_tensor", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", diff --git a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py index 49c0996b24f..171cb347de0 100644 --- a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.keras import backend from tensorflow.python.ops import array_ops @@ -270,6 +271,18 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase, elems, dtype=ragged.RaggedTensorType(dtype=dtypes.int64, ragged_rank=10)) + def testMapOnSparseTensor(self): + s = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + values=[0, 5, 0, 4], + dense_shape=[2, 2], + ) + t2 = ragged.RaggedTensor.from_sparse(s) + id_t2 = ragged.map_fn( + lambda x: x, t2, + ) + self.assertRaggedEqual(id_t2, [[0, 5], [0, 4]]) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/ops/ragged/ragged_map_ops.py b/tensorflow/python/ops/ragged/ragged_map_ops.py index af40352b1d0..fbe188bd1a3 100644 --- a/tensorflow/python/ops/ragged/ragged_map_ops.py +++ b/tensorflow/python/ops/ragged/ragged_map_ops.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.ragged import ragged_tensor @@ -238,6 +239,7 @@ def map_fn(fn, n = (tensor_shape.dimension_value(static_shape[0]) or array_ops.shape(elems_flat[0])[0]) + n = math_ops.cast(n, dtype=dtypes.int32) # Create a flat list of TAs. # Flatten the dtype structure to a list. @@ -254,7 +256,7 @@ def map_fn(fn, for t in dtype_components_flat ] - i = constant_op.constant(0) + i = constant_op.constant(0, dtype=dtypes.int32) def compute(i, tas): """The loop body of map_fn.