Fix internal type mismatch in ragged.map_fn

PiperOrigin-RevId: 225110815
This commit is contained in:
A. Unique TensorFlower 2018-12-11 18:56:53 -08:00 committed by TensorFlower Gardener
parent 7b9865971e
commit e4e9409b3d
3 changed files with 18 additions and 3 deletions

View File

@ -263,17 +263,17 @@ py_library(
srcs = ["ragged_map_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_factory_ops",
":ragged_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",

View File

@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend
from tensorflow.python.ops import array_ops
@ -270,6 +271,18 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
elems,
dtype=ragged.RaggedTensorType(dtype=dtypes.int64, ragged_rank=10))
def testMapOnSparseTensor(self):
s = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
values=[0, 5, 0, 4],
dense_shape=[2, 2],
)
t2 = ragged.RaggedTensor.from_sparse(s)
id_t2 = ragged.map_fn(
lambda x: x, t2,
)
self.assertRaggedEqual(id_t2, [[0, 5], [0, 4]])
if __name__ == '__main__':
googletest.main()

View File

@ -30,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.ragged import ragged_tensor
@ -238,6 +239,7 @@ def map_fn(fn,
n = (tensor_shape.dimension_value(static_shape[0]) or
array_ops.shape(elems_flat[0])[0])
n = math_ops.cast(n, dtype=dtypes.int32)
# Create a flat list of TAs.
# Flatten the dtype structure to a list.
@ -254,7 +256,7 @@ def map_fn(fn,
for t in dtype_components_flat
]
i = constant_op.constant(0)
i = constant_op.constant(0, dtype=dtypes.int32)
def compute(i, tas):
"""The loop body of map_fn.