From 8628f75ee10efb4047611b0b7c0c47ac53dc39d3 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Fri, 2 Aug 2019 15:58:32 -0700 Subject: [PATCH] When constructing a RaggedTensor with a numpy row-partitioning tensor, don't convert int32->int64. PiperOrigin-RevId: 261408643 --- tensorflow/python/ops/ragged/ragged_tensor.py | 9 ++++++--- .../python/ops/ragged/ragged_tensor_test.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index b9c3193c286..3556707f139 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -789,9 +789,12 @@ class RaggedTensor(composite_tensor.CompositeTensor): name=name) else: values = ops.convert_to_tensor(values, name="values") - partition = ops.convert_to_tensor( - partition, preferred_dtype=dtypes.int64, - name=name) + if isinstance(partition, np.ndarray) and partition.dtype == np.int32: + partition = ops.convert_to_tensor(partition, name=name) + else: + partition = ops.convert_to_tensor( + partition, preferred_dtype=dtypes.int64, + name=name) if partition.dtype not in (dtypes.int32, dtypes.int64): raise ValueError("%s must have dtype int32 or int64" % name) diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py index edbd84414da..06338725b26 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py @@ -376,6 +376,24 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + def testFromRowSplitsWithDifferentSplitTypes(self): + values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + splits1 = [0, 2, 2, 5, 6, 7] + splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64) + splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32) + splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) + splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32) + rt1 = RaggedTensor.from_row_splits(values, splits1) + rt2 = RaggedTensor.from_row_splits(values, splits2) + rt3 = RaggedTensor.from_row_splits(values, splits3) + rt4 = RaggedTensor.from_row_splits(values, splits4) + rt5 = RaggedTensor.from_row_splits(values, splits5) + self.assertEqual(rt1.row_splits.dtype, dtypes.int64) + self.assertEqual(rt2.row_splits.dtype, dtypes.int64) + self.assertEqual(rt3.row_splits.dtype, dtypes.int32) + self.assertEqual(rt4.row_splits.dtype, dtypes.int64) + self.assertEqual(rt5.row_splits.dtype, dtypes.int32) + def testFromRowSplitsWithEmptySplits(self): err_msg = 'row_splits tensor may not be empty' with self.assertRaisesRegexp(ValueError, err_msg):