Unify TF and XLA error messages for the SplitV op
PiperOrigin-RevId: 324740588 Change-Id: I8bb30059cbf2c474087a040d786b632360f1ecb4
This commit is contained in:
parent
66dc5f61c9
commit
d46875f207
tensorflow
@ -105,6 +105,10 @@ class SplitVOp : public XlaOpKernel {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape index_shape = ctx->InputShape(2);
|
||||
|
||||
OP_REQUIRES(ctx, index_shape.num_elements() == 1,
|
||||
errors::InvalidArgument(
|
||||
"split_dim_tensor must have exactly one element."));
|
||||
|
||||
int64 split_dim_orig;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &split_dim_orig));
|
||||
int64 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims()
|
||||
|
@ -373,7 +373,6 @@ class SplitOpTest(test.TestCase):
|
||||
assert s1.shape.as_list() == [1]
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("b/123337890") # Error messages differ
|
||||
def testNonexistentDimTensor(self):
|
||||
x = array_ops.placeholder(dtypes.int32)
|
||||
values = np.zeros([5, 30])
|
||||
|
Loading…
Reference in New Issue
Block a user