[XLA] Fix CHECK-failure crash if a non-tuple was passed to GetTupleElement.
PiperOrigin-RevId: 168550703
This commit is contained in:
parent
010922ed91
commit
dc1eda8a6d
tensorflow/compiler/xla
@ -310,6 +310,11 @@ StatusOr<ComputationDataHandle> UserComputation::AddGetTupleElementInstruction(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
|
||||
LookUpRequest(get_tuple_element_request.operand()));
|
||||
if (!ShapeUtil::IsTuple(operand->output_shape())) {
|
||||
return InvalidArgument(
|
||||
"Operand to GetTupleElement() is not a tuple; got %s",
|
||||
ShapeUtil::HumanString(operand->output_shape()).c_str());
|
||||
}
|
||||
Shape element_shape = ShapeUtil::GetTupleElementShape(
|
||||
operand->output_shape(), get_tuple_element_request.index());
|
||||
|
||||
|
@ -123,6 +123,17 @@ XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
|
||||
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto value = builder.ConstantR1<float>({4.5f});
|
||||
builder.GetTupleElement(value, 1);
|
||||
auto result_status = builder.Build();
|
||||
EXPECT_FALSE(result_status.ok());
|
||||
EXPECT_THAT(
|
||||
result_status.status().error_message(),
|
||||
::testing::HasSubstr("Operand to GetTupleElement() is not a tuple"));
|
||||
}
|
||||
|
||||
// Extracts both elements from a tuple with GetTupleElement and then adds them
|
||||
// together.
|
||||
XLA_TEST_F(TupleTest, AddTupleElements) {
|
||||
|
Loading…
Reference in New Issue
Block a user