Fix tf2xla error message in the ConcatOffset and MatrixDiagPart ops

PiperOrigin-RevId: 324748010
Change-Id: I22846e9fe5d30049f2272d7b8c9140c18ea00da8
This commit is contained in:
Sanjoy Das 2020-08-03 22:03:17 -07:00 committed by TensorFlower Gardener
parent 691f25bf80
commit 486296671c
4 changed files with 8 additions and 8 deletions

View File

@ -186,9 +186,11 @@ class ConcatOffsetOp : public XlaOpKernel {
const int32 inp0_element = inp0_dims[j]; const int32 inp0_element = inp0_dims[j];
const int32 inp_element = inp_dims[j]; const int32 inp_element = inp_dims[j];
OP_REQUIRES(ctx, inp0_element == inp_element, OP_REQUIRES(ctx, inp0_element == inp_element,
errors::InvalidArgument("input[", i, ",", j, errors::InvalidArgument(
"] mismatch: ", inp0_element, "All dimensions except ", axis, " must match. Input ",
" vs. ", inp_element)); i, " has shape [", absl::StrJoin(inp_dims, " "),
"] and doesn't match input 0 with shape [",
absl::StrJoin(inp0_dims, " "), "]."));
out_vec(j) = 0; out_vec(j) = 0;
} }
} }

View File

@ -243,7 +243,8 @@ class MatrixDiagOp : public XlaOpKernel {
errors::InvalidArgument("MatrixDiag op must have at least one input")); errors::InvalidArgument("MatrixDiag op must have at least one input"));
const TensorShape diag_shape = context->InputShape(0); const TensorShape diag_shape = context->InputShape(0);
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
errors::InvalidArgument("Expected >= 1 dims, got shape ", errors::InvalidArgument(
"diagonal must be at least 1-dim, received shape: ",
diag_shape.DebugString())); diag_shape.DebugString()));
const DataType dtype = context->expected_output_dtype(0); const DataType dtype = context->expected_output_dtype(0);

View File

@ -701,7 +701,6 @@ class ConcatOffsetTest(test.TestCase):
self.evaluate(off) self.evaluate(off)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testSizeMismatch(self): def testSizeMismatch(self):
cdim = constant_op.constant(1, dtypes.int32) cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32)

View File

@ -541,7 +541,6 @@ class MatrixDiagTest(test.TestCase):
array_ops.matrix_diag(0) array_ops.matrix_diag(0)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testInvalidShapeAtEval(self): def testInvalidShapeAtEval(self):
with self.session(use_gpu=True): with self.session(use_gpu=True):
v = array_ops.placeholder(dtype=dtypes_lib.float32) v = array_ops.placeholder(dtype=dtypes_lib.float32)
@ -891,7 +890,6 @@ class MatrixDiagPartTest(test.TestCase):
array_ops.matrix_diag_part(0) array_ops.matrix_diag_part(0)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testInvalidShapeAtEval(self): def testInvalidShapeAtEval(self):
with self.session(use_gpu=True): with self.session(use_gpu=True):
v = array_ops.placeholder(dtype=dtypes_lib.float32) v = array_ops.placeholder(dtype=dtypes_lib.float32)