Fix tf2xla error message in the ConcatOffset and MatrixDiagPart ops
PiperOrigin-RevId: 324748010 Change-Id: I22846e9fe5d30049f2272d7b8c9140c18ea00da8
This commit is contained in:
parent
691f25bf80
commit
486296671c
tensorflow
compiler/tf2xla/kernels
python/kernel_tests
@ -186,9 +186,11 @@ class ConcatOffsetOp : public XlaOpKernel {
|
||||
const int32 inp0_element = inp0_dims[j];
|
||||
const int32 inp_element = inp_dims[j];
|
||||
OP_REQUIRES(ctx, inp0_element == inp_element,
|
||||
errors::InvalidArgument("input[", i, ",", j,
|
||||
"] mismatch: ", inp0_element,
|
||||
" vs. ", inp_element));
|
||||
errors::InvalidArgument(
|
||||
"All dimensions except ", axis, " must match. Input ",
|
||||
i, " has shape [", absl::StrJoin(inp_dims, " "),
|
||||
"] and doesn't match input 0 with shape [",
|
||||
absl::StrJoin(inp0_dims, " "), "]."));
|
||||
out_vec(j) = 0;
|
||||
}
|
||||
}
|
||||
|
@ -243,8 +243,9 @@ class MatrixDiagOp : public XlaOpKernel {
|
||||
errors::InvalidArgument("MatrixDiag op must have at least one input"));
|
||||
const TensorShape diag_shape = context->InputShape(0);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
|
||||
errors::InvalidArgument("Expected >= 1 dims, got shape ",
|
||||
diag_shape.DebugString()));
|
||||
errors::InvalidArgument(
|
||||
"diagonal must be at least 1-dim, received shape: ",
|
||||
diag_shape.DebugString()));
|
||||
|
||||
const DataType dtype = context->expected_output_dtype(0);
|
||||
const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
|
||||
|
@ -701,7 +701,6 @@ class ConcatOffsetTest(test.TestCase):
|
||||
self.evaluate(off)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("b/123337890") # Error messages differ
|
||||
def testSizeMismatch(self):
|
||||
cdim = constant_op.constant(1, dtypes.int32)
|
||||
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
||||
|
@ -541,7 +541,6 @@ class MatrixDiagTest(test.TestCase):
|
||||
array_ops.matrix_diag(0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("b/123337890") # Error messages differ
|
||||
def testInvalidShapeAtEval(self):
|
||||
with self.session(use_gpu=True):
|
||||
v = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||
@ -891,7 +890,6 @@ class MatrixDiagPartTest(test.TestCase):
|
||||
array_ops.matrix_diag_part(0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("b/123337890") # Error messages differ
|
||||
def testInvalidShapeAtEval(self):
|
||||
with self.session(use_gpu=True):
|
||||
v = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user