When constructing a RaggedTensor with a numpy row-partitioning tensor, don't convert int32->int64.

PiperOrigin-RevId: 261408643
This commit is contained in:
Edward Loper 2019-08-02 15:58:32 -07:00 committed by TensorFlower Gardener
parent 39e7715eb0
commit 8628f75ee1
2 changed files with 24 additions and 3 deletions

View File

@ -789,9 +789,12 @@ class RaggedTensor(composite_tensor.CompositeTensor):
name=name)
else:
values = ops.convert_to_tensor(values, name="values")
partition = ops.convert_to_tensor(
partition, preferred_dtype=dtypes.int64,
name=name)
if isinstance(partition, np.ndarray) and partition.dtype == np.int32:
partition = ops.convert_to_tensor(partition, name=name)
else:
partition = ops.convert_to_tensor(
partition, preferred_dtype=dtypes.int64,
name=name)
if partition.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("%s must have dtype int32 or int64" % name)

View File

@ -376,6 +376,24 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowSplitsWithDifferentSplitTypes(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
splits1 = [0, 2, 2, 5, 6, 7]
splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
rt1 = RaggedTensor.from_row_splits(values, splits1)
rt2 = RaggedTensor.from_row_splits(values, splits2)
rt3 = RaggedTensor.from_row_splits(values, splits3)
rt4 = RaggedTensor.from_row_splits(values, splits4)
rt5 = RaggedTensor.from_row_splits(values, splits5)
self.assertEqual(rt1.row_splits.dtype, dtypes.int64)
self.assertEqual(rt2.row_splits.dtype, dtypes.int64)
self.assertEqual(rt3.row_splits.dtype, dtypes.int32)
self.assertEqual(rt4.row_splits.dtype, dtypes.int64)
self.assertEqual(rt5.row_splits.dtype, dtypes.int32)
def testFromRowSplitsWithEmptySplits(self):
err_msg = 'row_splits tensor may not be empty'
with self.assertRaisesRegexp(ValueError, err_msg):