Fix tf2xla error message in the ConcatOffset and MatrixDiagPart ops

PiperOrigin-RevId: 324748010
Change-Id: I22846e9fe5d30049f2272d7b8c9140c18ea00da8
This commit is contained in:
Sanjoy Das 2020-08-03 22:03:17 -07:00 committed by TensorFlower Gardener
parent 691f25bf80
commit 486296671c
4 changed files with 8 additions and 8 deletions
tensorflow
compiler/tf2xla/kernels
python/kernel_tests

View File

@ -186,9 +186,11 @@ class ConcatOffsetOp : public XlaOpKernel {
const int32 inp0_element = inp0_dims[j];
const int32 inp_element = inp_dims[j];
OP_REQUIRES(ctx, inp0_element == inp_element,
errors::InvalidArgument("input[", i, ",", j,
"] mismatch: ", inp0_element,
" vs. ", inp_element));
errors::InvalidArgument(
"All dimensions except ", axis, " must match. Input ",
i, " has shape [", absl::StrJoin(inp_dims, " "),
"] and doesn't match input 0 with shape [",
absl::StrJoin(inp0_dims, " "), "]."));
out_vec(j) = 0;
}
}

View File

@ -243,8 +243,9 @@ class MatrixDiagOp : public XlaOpKernel {
errors::InvalidArgument("MatrixDiag op must have at least one input"));
const TensorShape diag_shape = context->InputShape(0);
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
errors::InvalidArgument("Expected >= 1 dims, got shape ",
diag_shape.DebugString()));
errors::InvalidArgument(
"diagonal must be at least 1-dim, received shape: ",
diag_shape.DebugString()));
const DataType dtype = context->expected_output_dtype(0);
const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);

View File

@ -701,7 +701,6 @@ class ConcatOffsetTest(test.TestCase):
self.evaluate(off)
@test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testSizeMismatch(self):
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)

View File

@ -541,7 +541,6 @@ class MatrixDiagTest(test.TestCase):
array_ops.matrix_diag(0)
@test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testInvalidShapeAtEval(self):
with self.session(use_gpu=True):
v = array_ops.placeholder(dtype=dtypes_lib.float32)
@ -891,7 +890,6 @@ class MatrixDiagPartTest(test.TestCase):
array_ops.matrix_diag_part(0)
@test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testInvalidShapeAtEval(self):
with self.session(use_gpu=True):
v = array_ops.placeholder(dtype=dtypes_lib.float32)