Fix internal type mismatch in ragged.map_fn
PiperOrigin-RevId: 225110815
This commit is contained in:
parent
7b9865971e
commit
e4e9409b3d
@ -263,17 +263,17 @@ py_library(
|
||||
srcs = ["ragged_map_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
||||
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -270,6 +271,18 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
elems,
|
||||
dtype=ragged.RaggedTensorType(dtype=dtypes.int64, ragged_rank=10))
|
||||
|
||||
def testMapOnSparseTensor(self):
|
||||
s = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
|
||||
values=[0, 5, 0, 4],
|
||||
dense_shape=[2, 2],
|
||||
)
|
||||
t2 = ragged.RaggedTensor.from_sparse(s)
|
||||
id_t2 = ragged.map_fn(
|
||||
lambda x: x, t2,
|
||||
)
|
||||
self.assertRaggedEqual(id_t2, [[0, 5], [0, 4]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
||||
@ -30,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -238,6 +239,7 @@ def map_fn(fn,
|
||||
n = (tensor_shape.dimension_value(static_shape[0]) or
|
||||
array_ops.shape(elems_flat[0])[0])
|
||||
|
||||
n = math_ops.cast(n, dtype=dtypes.int32)
|
||||
# Create a flat list of TAs.
|
||||
|
||||
# Flatten the dtype structure to a list.
|
||||
@ -254,7 +256,7 @@ def map_fn(fn,
|
||||
for t in dtype_components_flat
|
||||
]
|
||||
|
||||
i = constant_op.constant(0)
|
||||
i = constant_op.constant(0, dtype=dtypes.int32)
|
||||
|
||||
def compute(i, tas):
|
||||
"""The loop body of map_fn.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user