STT-tensorflow/tensorflow/compiler/xla/service/call_graph_test.cc
A. Unique TensorFlower 57467ada28 [XLA] Replace individual comparison HLO ops with a single compare op.
PiperOrigin-RevId: 237318727
2019-03-07 14:05:30 -08:00

570 lines
24 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
using ::testing::UnorderedElementsAre;
class CallGraphTest : public HloTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation(
HloOpcode opcode = HloOpcode::kNegate) {
HloComputation::Builder builder(TestName() + ".ScalarComputation");
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
builder.AddInstruction(
HloInstruction::CreateUnary(kScalarShape, opcode, param0));
return builder.Build();
}
// Build and return a computation which takes a scalar and maps (kMap) the
// given computation to the value 'callsites' number of times.
std::unique_ptr<HloComputation> MakeMappingComputation(
HloComputation* map_computation, int64 callsites) {
HloComputation::Builder builder(TestName() + ".MappingComputation");
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* last_value = param0;
for (int64 i = 0; i < callsites; ++i) {
last_value = builder.AddInstruction(HloInstruction::CreateMap(
kScalarShape, {last_value}, map_computation));
}
return builder.Build();
}
// Build and return a computation which takes a scalar and calls (kCall) the
// given computation with value 'callsites' number of times.
std::unique_ptr<HloComputation> MakeCallingComputation(
HloComputation* callee_computation, int64 callsites,
const string& suffix = ".CallingComputation") {
HloComputation::Builder builder(TestName() + suffix);
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* last_value = param0;
for (int64 i = 0; i < callsites; ++i) {
last_value = builder.AddInstruction(HloInstruction::CreateCall(
kScalarShape, {last_value}, callee_computation));
}
return builder.Build();
}
// Build and return a computation which takes a scalar and returns a PRED
// value.
std::unique_ptr<HloComputation> MakeConditionComputation() {
HloComputation::Builder builder(TestName() + ".ConditionComputation");
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
zero, ComparisonDirection::kGt));
return builder.Build();
}
const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
};
TEST_F(CallGraphTest, SingletonComputation) {
// Test the call graph of a module with a single computation.
auto module = CreateNewVerifiedModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(1, call_graph->nodes().size());
EXPECT_TRUE(call_graph->IsFlattened());
const CallGraphNode& node = call_graph->GetNode(computation);
EXPECT_EQ(computation, node.computation());
EXPECT_EQ(node.depth(), 0);
EXPECT_TRUE(node.callsites().empty());
EXPECT_TRUE(node.callees().empty());
EXPECT_TRUE(node.caller_callsites().empty());
EXPECT_TRUE(node.callers().empty());
EXPECT_EQ(CallContext::kSequential, node.context());
}
TEST_F(CallGraphTest, UnreachableComputation) {
// Test the call graph of a module with an entry computation and an
// unreachable computation.
auto module = CreateNewVerifiedModule();
HloComputation* entry_computation =
module->AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
EXPECT_EQ(entry_node.depth(), 0);
EXPECT_EQ(entry_computation, entry_node.computation());
EXPECT_EQ(CallContext::kSequential, entry_node.context());
const CallGraphNode& unreachable_node =
call_graph->GetNode(unreachable_computation);
EXPECT_EQ(unreachable_node.depth(), 0);
EXPECT_EQ(unreachable_computation, unreachable_node.computation());
EXPECT_EQ(CallContext::kSequential, unreachable_node.context());
}
TEST_F(CallGraphTest, ParallelComputation) {
// Test a call graph of a module with an entry computation which calls another
// computation in a parallel context via kMap.
auto module = CreateNewVerifiedModule();
HloComputation* map_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
HloComputation* entry_computation = module->AddEntryComputation(
MakeMappingComputation(map_computation, /*callsites=*/5));
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
EXPECT_EQ(entry_computation, entry_node.computation());
EXPECT_EQ(entry_node.depth(), 0);
EXPECT_EQ(CallContext::kSequential, entry_node.context());
EXPECT_EQ(5, entry_node.callsites().size());
EXPECT_EQ(1, entry_node.callees().size());
EXPECT_TRUE(entry_node.caller_callsites().empty());
EXPECT_TRUE(entry_node.callers().empty());
const CallGraphNode& map_node = call_graph->GetNode(map_computation);
EXPECT_EQ(map_computation, map_node.computation());
EXPECT_EQ(map_node.depth(), 1);
EXPECT_EQ(CallContext::kParallel, map_node.context());
EXPECT_TRUE(map_node.callsites().empty());
EXPECT_TRUE(map_node.callees().empty());
EXPECT_EQ(5, map_node.caller_callsites().size());
EXPECT_EQ(1, map_node.callers().size());
}
TEST_F(CallGraphTest, SequentialComputations) {
// Test a call graph of a module with an entry computation which calls another
// computation in a sequential context via kCall.
auto module = CreateNewVerifiedModule();
HloComputation* called_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
HloComputation* entry_computation = module->AddEntryComputation(
MakeCallingComputation(called_computation, /*callsites=*/3));
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(2, call_graph->nodes().size());
// The called computation is only called from one other computation, but there
// are multiple callsites.
EXPECT_FALSE(call_graph->IsFlattened());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
EXPECT_EQ(entry_computation, entry_node.computation());
EXPECT_EQ(CallContext::kSequential, entry_node.context());
EXPECT_EQ(3, entry_node.callsites().size());
EXPECT_EQ(1, entry_node.callees().size());
EXPECT_TRUE(entry_node.caller_callsites().empty());
EXPECT_TRUE(entry_node.callers().empty());
const CallGraphNode& called_node = call_graph->GetNode(called_computation);
EXPECT_EQ(called_computation, called_node.computation());
EXPECT_EQ(CallContext::kSequential, called_node.context());
EXPECT_TRUE(called_node.callsites().empty());
EXPECT_TRUE(called_node.callees().empty());
EXPECT_EQ(3, called_node.caller_callsites().size());
EXPECT_EQ(1, called_node.callers().size());
}
TEST_F(CallGraphTest, ContextBothComputations) {
// Test a call graph of a module with an entry computation which calls another
// computation in both a parallel and sequential context.
auto module = CreateNewVerifiedModule();
HloComputation* subcomputation =
module->AddEmbeddedComputation(MakeScalarComputation());
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* call = builder.AddInstruction(
HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation));
HloInstruction* map = builder.AddInstruction(
HloInstruction::CreateMap(kScalarShape, {call}, subcomputation));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(2, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
EXPECT_EQ(entry_computation, entry_node.computation());
EXPECT_EQ(2, entry_node.callsites().size());
const CallSite& call_callsite = entry_node.callsites()[0];
EXPECT_EQ(call, call_callsite.instruction());
EXPECT_THAT(call_callsite.called_computations(),
UnorderedElementsAre(subcomputation));
EXPECT_EQ(CallContext::kSequential, call_callsite.context());
EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite);
const CallSite& map_callsite = entry_node.callsites()[1];
EXPECT_EQ(map, map_callsite.instruction());
EXPECT_THAT(map_callsite.called_computations(),
UnorderedElementsAre(subcomputation));
EXPECT_EQ(CallContext::kParallel, map_callsite.context());
EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite);
const CallGraphNode& sub_node = call_graph->GetNode(subcomputation);
EXPECT_EQ(sub_node.depth(), 1);
EXPECT_EQ(CallContext::kBoth, sub_node.context());
}
TEST_F(CallGraphTest, ComputationWithConditional) {
// Test a call graph of a module with a conditional.
auto module = CreateNewVerifiedModule();
HloComputation* true_computation =
module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil));
HloComputation* false_computation =
module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor));
HloComputation::Builder builder(TestName());
HloInstruction* pred = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloInstruction* const1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
HloInstruction* const2 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.6f)));
HloInstruction* conditional =
builder.AddInstruction(HloInstruction::CreateConditional(
kScalarShape, pred, const1, true_computation, const2,
false_computation));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(3, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
EXPECT_EQ(entry_node.depth(), 0);
EXPECT_EQ(entry_computation, entry_node.computation());
EXPECT_EQ(1, entry_node.callsites().size());
const CallSite& conditional_callsite = entry_node.callsites()[0];
EXPECT_EQ(conditional, conditional_callsite.instruction());
EXPECT_THAT(conditional_callsite.called_computations(),
UnorderedElementsAre(true_computation, false_computation));
EXPECT_EQ(CallContext::kSequential, conditional_callsite.context());
EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite);
const CallGraphNode& true_node = call_graph->GetNode(true_computation);
EXPECT_EQ(true_node.depth(), 1);
EXPECT_TRUE(true_node.callees().empty());
EXPECT_EQ(1, true_node.callers().size());
EXPECT_EQ(entry_computation, true_node.callers()[0]);
const CallGraphNode& false_node = call_graph->GetNode(false_computation);
EXPECT_EQ(false_node.depth(), 1);
EXPECT_TRUE(false_node.callees().empty());
EXPECT_EQ(1, false_node.callers().size());
EXPECT_EQ(entry_computation, false_node.callers()[0]);
}
TEST_F(CallGraphTest, ComplexGraph) {
// Test a call graph of a module with several computation called in various
// contexts. The call graph looks like:
//
// entry
// / |
// a |
// / | \ |
// b | cond
// \ |
// c
//
// Calls are made via kCall, kWhile, and kMap instructions.
auto module = CreateNewVerifiedModule();
HloComputation* cond_computation =
module->AddEmbeddedComputation(MakeConditionComputation());
HloComputation* c_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
HloComputation* b_computation = module->AddEmbeddedComputation(
MakeMappingComputation(c_computation, /*callsites=*/1));
HloComputation* a_computation;
{
HloComputation::Builder builder(TestName() + ".a");
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* call = builder.AddInstruction(
HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
builder.AddInstruction(HloInstruction::CreateWhile(
kScalarShape, cond_computation, b_computation, call));
a_computation = module->AddEmbeddedComputation(builder.Build());
}
HloComputation* entry_computation;
{
HloComputation::Builder builder(TestName() + ".entry");
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
builder.AddInstruction(HloInstruction::CreateWhile(
kScalarShape, cond_computation, a_computation, param0));
entry_computation = module->AddEntryComputation(builder.Build());
}
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(5, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
const CallGraphNode& a_node = call_graph->GetNode(a_computation);
const CallGraphNode& b_node = call_graph->GetNode(b_computation);
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
// Verify depths.
EXPECT_EQ(entry_node.depth(), 0);
EXPECT_EQ(a_node.depth(), 1);
EXPECT_EQ(b_node.depth(), 2);
EXPECT_EQ(c_node.depth(), 3);
EXPECT_EQ(cond_node.depth(), 2);
// Entry computation has one while instruction calling two computations
// (cond_computation and a_computation).
ASSERT_EQ(1, entry_node.callsites().size());
const std::vector<HloComputation*>& called_computations =
entry_node.callsites()[0].called_computations();
EXPECT_THAT(called_computations,
UnorderedElementsAre(cond_computation, a_computation));
EXPECT_EQ(CallContext::kSequential, entry_node.context());
EXPECT_TRUE(c_node.callsites().empty());
EXPECT_THAT(c_node.callers(),
UnorderedElementsAre(a_computation, b_computation));
EXPECT_EQ(CallContext::kBoth, c_node.context());
// Visit the graph and verify nodes were visited in callee-before-caller
// order.
std::vector<const HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
visited.push_back(node.computation());
return Status::OK();
}));
EXPECT_EQ(visited.size(), 5);
// All values in visited should be unique.
EXPECT_EQ(
std::unordered_set<const HloComputation*>(visited.begin(), visited.end())
.size(),
5);
// Verify visitation order of some computations in the graph.
auto index_of = [&visited](const HloComputation* comp) {
auto it = absl::c_find(visited, comp);
EXPECT_NE(it, visited.end());
return std::distance(visited.begin(), it);
};
EXPECT_EQ(4, index_of(entry_computation));
EXPECT_LT(index_of(cond_computation), index_of(a_computation));
EXPECT_LT(index_of(c_computation), index_of(b_computation));
EXPECT_LT(index_of(b_computation), index_of(a_computation));
// Verify dominance relations between computation in the graph.
// Entry dominates everybody, and is dominated by no one except itself.
EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation));
EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation));
EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation));
EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation));
EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation));
EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation));
EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation));
EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation));
EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation));
// 'a' only dominates 'b' and 'c'.
EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation));
EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation));
EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation));
EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation));
EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation));
EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation));
EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation));
EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation));
EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation));
EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation));
EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation));
EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation));
EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation));
}
TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
// Test NearestAncestorsInSameComputation on a call graph of a module with
// several computation called in various contexts. The call graph looks like:
//
// entry
// / |
// a |
// / | \ |
// b | cond
// \ |
// c
//
// Calls are made via kCall, kWhile, and kMap instructions.
auto module = CreateNewVerifiedModule();
HloComputation* cond_computation =
module->AddEmbeddedComputation(MakeConditionComputation());
HloComputation* c_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
HloComputation* b_computation = module->AddEmbeddedComputation(
MakeMappingComputation(c_computation, /*callsites=*/1));
HloInstruction* b_map = b_computation->root_instruction();
HloComputation* a_computation;
HloInstruction* a_call;
HloInstruction* a_while;
{
HloComputation::Builder builder(TestName() + ".a");
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
a_call = builder.AddInstruction(
HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
a_while = builder.AddInstruction(HloInstruction::CreateWhile(
kScalarShape, cond_computation, b_computation, a_call));
a_computation = module->AddEmbeddedComputation(builder.Build());
}
HloComputation* entry_computation;
HloInstruction* entry_while;
{
HloComputation::Builder builder(TestName() + ".entry");
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
entry_while = builder.AddInstruction(HloInstruction::CreateWhile(
kScalarShape, cond_computation, a_computation, param0));
entry_computation = module->AddEntryComputation(builder.Build());
}
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
EXPECT_EQ(5, call_graph->nodes().size());
// Verify NearestAncestorsInSameComputation for various instructions in the
// module.
EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_call, a_call),
std::make_pair(a_call, a_call));
// c_computation is called from more than one site, so
// NearestAncestorsInSameComputation bails and returns nullptrs.
std::pair<HloInstruction*, HloInstruction*> null_pair = {nullptr, nullptr};
EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(
b_map, c_computation->root_instruction()),
null_pair);
EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, entry_while),
std::make_pair(entry_while, entry_while));
EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, a_call),
std::make_pair(a_while, a_call));
EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, a_call),
std::make_pair(a_while, a_call));
EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, b_map),
std::make_pair(a_while, a_while));
}
TEST_F(CallGraphTest, VisitSingletonComputation) {
// Test the call graph visitor with a call graph with a single node.
auto module = CreateNewVerifiedModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
std::vector<HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
visited.push_back(node.computation());
return Status::OK();
}));
EXPECT_THAT(visited, UnorderedElementsAre(computation));
}
TEST_F(CallGraphTest, VisitUnreachableComputation) {
// Test the call graph visitor with a call graph with an unreachable node.
auto module = CreateNewVerifiedModule();
HloComputation* entry_computation =
module->AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
// Test visitation of only reachable nodes.
{
std::vector<const HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes(
[&visited](const CallGraphNode& node) {
visited.push_back(node.computation());
return Status::OK();
},
/*visit_unreachable_nodes=*/false));
EXPECT_EQ(visited.size(), 1);
EXPECT_EQ(visited[0], entry_computation);
}
// Test visitation of all nodes (reachable and unreachable).
{
std::vector<HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes(
[&visited](const CallGraphNode& node) {
visited.push_back(node.computation());
return Status::OK();
},
/*visit_unreachable_nodes=*/true));
EXPECT_EQ(visited.size(), 2);
EXPECT_THAT(visited, UnorderedElementsAre(entry_computation,
unreachable_computation));
}
}
TEST_F(CallGraphTest, VisitWithError) {
// Test that the call graph visitor properly propagates errors.
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(MakeScalarComputation());
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
Status status = call_graph->VisitNodes(
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
ASSERT_FALSE(status.ok());
ASSERT_EQ(status.code(), tensorflow::error::INTERNAL);
ASSERT_THAT(status.error_message(),
::testing::HasSubstr("Visitation failed"));
}
} // namespace
} // namespace xla