Unify TF and XLA error messages for the SplitV op

PiperOrigin-RevId: 324740588
Change-Id: I8bb30059cbf2c474087a040d786b632360f1ecb4
This commit is contained in:
Sanjoy Das 2020-08-03 20:46:54 -07:00 committed by TensorFlower Gardener
parent 66dc5f61c9
commit d46875f207
2 changed files with 4 additions and 1 deletions
tensorflow
compiler/tf2xla/kernels
python/kernel_tests

View File

@ -105,6 +105,10 @@ class SplitVOp : public XlaOpKernel {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape index_shape = ctx->InputShape(2);
OP_REQUIRES(ctx, index_shape.num_elements() == 1,
errors::InvalidArgument(
"split_dim_tensor must have exactly one element."));
int64 split_dim_orig;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &split_dim_orig));
int64 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims()

View File

@ -373,7 +373,6 @@ class SplitOpTest(test.TestCase):
assert s1.shape.as_list() == [1]
@test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testNonexistentDimTensor(self):
x = array_ops.placeholder(dtypes.int32)
values = np.zeros([5, 30])