When constructing a RaggedTensor with a numpy row-partitioning tensor, don't convert int32->int64.
PiperOrigin-RevId: 261408643
This commit is contained in:
parent
39e7715eb0
commit
8628f75ee1
@ -789,9 +789,12 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
name=name)
|
||||
else:
|
||||
values = ops.convert_to_tensor(values, name="values")
|
||||
partition = ops.convert_to_tensor(
|
||||
partition, preferred_dtype=dtypes.int64,
|
||||
name=name)
|
||||
if isinstance(partition, np.ndarray) and partition.dtype == np.int32:
|
||||
partition = ops.convert_to_tensor(partition, name=name)
|
||||
else:
|
||||
partition = ops.convert_to_tensor(
|
||||
partition, preferred_dtype=dtypes.int64,
|
||||
name=name)
|
||||
if partition.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("%s must have dtype int32 or int64" % name)
|
||||
|
||||
|
@ -376,6 +376,24 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromRowSplitsWithDifferentSplitTypes(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
splits1 = [0, 2, 2, 5, 6, 7]
|
||||
splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
|
||||
splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
|
||||
splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
|
||||
rt1 = RaggedTensor.from_row_splits(values, splits1)
|
||||
rt2 = RaggedTensor.from_row_splits(values, splits2)
|
||||
rt3 = RaggedTensor.from_row_splits(values, splits3)
|
||||
rt4 = RaggedTensor.from_row_splits(values, splits4)
|
||||
rt5 = RaggedTensor.from_row_splits(values, splits5)
|
||||
self.assertEqual(rt1.row_splits.dtype, dtypes.int64)
|
||||
self.assertEqual(rt2.row_splits.dtype, dtypes.int64)
|
||||
self.assertEqual(rt3.row_splits.dtype, dtypes.int32)
|
||||
self.assertEqual(rt4.row_splits.dtype, dtypes.int64)
|
||||
self.assertEqual(rt5.row_splits.dtype, dtypes.int32)
|
||||
|
||||
def testFromRowSplitsWithEmptySplits(self):
|
||||
err_msg = 'row_splits tensor may not be empty'
|
||||
with self.assertRaisesRegexp(ValueError, err_msg):
|
||||
|
Loading…
x
Reference in New Issue
Block a user