[XLA] Fix CHECK-failure crash if a non-tuple was passed to GetTupleElement.

PiperOrigin-RevId: 168550703
This commit is contained in:
Peter Hawkins 2017-09-13 09:35:09 -07:00 committed by TensorFlower Gardener
parent 010922ed91
commit dc1eda8a6d
2 changed files with 16 additions and 0 deletions
tensorflow/compiler/xla

View File

@ -310,6 +310,11 @@ StatusOr<ComputationDataHandle> UserComputation::AddGetTupleElementInstruction(
TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
LookUpRequest(get_tuple_element_request.operand()));
if (!ShapeUtil::IsTuple(operand->output_shape())) {
return InvalidArgument(
"Operand to GetTupleElement() is not a tuple; got %s",
ShapeUtil::HumanString(operand->output_shape()).c_str());
}
Shape element_shape = ShapeUtil::GetTupleElementShape(
operand->output_shape(), get_tuple_element_request.index());

View File

@ -123,6 +123,17 @@ XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
ComputationBuilder builder(client_, TestName());
auto value = builder.ConstantR1<float>({4.5f});
builder.GetTupleElement(value, 1);
auto result_status = builder.Build();
EXPECT_FALSE(result_status.ok());
EXPECT_THAT(
result_status.status().error_message(),
::testing::HasSubstr("Operand to GetTupleElement() is not a tuple"));
}
// Extracts both elements from a tuple with GetTupleElement and then adds them
// together.
XLA_TEST_F(TupleTest, AddTupleElements) {