[TF:XLA] Migrate unit tests to use the HLO verifier (only tests where the conversion is mostly automated).
PiperOrigin-RevId: 212303594
This commit is contained in:
parent
a8b2dd9f72
commit
96b77a647b
@ -87,6 +87,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -123,6 +124,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -352,6 +354,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -402,6 +405,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
@ -498,6 +502,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -568,6 +573,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -1131,6 +1137,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -1709,6 +1716,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
@ -2237,6 +2245,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -2315,6 +2324,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
@ -2428,6 +2438,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -2888,6 +2899,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo_tfgraph_builder",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support {
|
||||
}
|
||||
};
|
||||
|
||||
class BFloat16ConversionFoldingTest : public HloTestBase {
|
||||
class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
BFloat16ConversionFoldingTest()
|
||||
: HloVerifiedTestBase(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/true) {}
|
||||
|
||||
bool FoldConversions(HloModule* module) {
|
||||
TestBFloat16Support bfloat16_support_;
|
||||
BFloat16ConversionFolding fold(&bfloat16_support_);
|
||||
@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(FoldConversions(module.get()));
|
||||
EXPECT_TRUE(FoldConversions(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), add1);
|
||||
EXPECT_EQ(add0->shape().element_type(), BF16);
|
||||
@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_FALSE(FoldConversions(module.get()));
|
||||
EXPECT_FALSE(FoldConversions(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), convert2);
|
||||
EXPECT_EQ(mul0->shape().element_type(), F32);
|
||||
@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_FALSE(FoldConversions(module.get()));
|
||||
EXPECT_FALSE(FoldConversions(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), convert2);
|
||||
EXPECT_EQ(sub0->shape().element_type(), F32);
|
||||
@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_FALSE(FoldConversions(module.get()));
|
||||
EXPECT_FALSE(FoldConversions(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), convert1);
|
||||
EXPECT_EQ(gte->shape().element_type(), F32);
|
||||
@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
|
||||
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(FoldConversions(module.get()));
|
||||
EXPECT_TRUE(FoldConversions(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), tuple);
|
||||
EXPECT_EQ(tuple->operand(0), gte_a);
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support {
|
||||
}
|
||||
};
|
||||
|
||||
class BFloat16NormalizationTest : public HloTestBase {
|
||||
class BFloat16NormalizationTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
BFloat16NormalizationTest()
|
||||
: HloVerifiedTestBase(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/true) {}
|
||||
|
||||
bool Normalize(HloModule* module) {
|
||||
TestBFloat16Support bfloat16_support_;
|
||||
BFloat16Normalization normalization(&bfloat16_support_);
|
||||
@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_FALSE(Normalize(module.get()));
|
||||
EXPECT_FALSE(Normalize(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), add1);
|
||||
EXPECT_EQ(add0->shape().element_type(), BF16);
|
||||
@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
EXPECT_TRUE(Normalize(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
||||
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
|
||||
@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
EXPECT_TRUE(Normalize(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
||||
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
|
||||
@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
|
||||
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
EXPECT_TRUE(Normalize(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), reduce);
|
||||
EXPECT_EQ(reduce->called_computations().size(), 1);
|
||||
@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
|
||||
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
EXPECT_TRUE(Normalize(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), gte);
|
||||
EXPECT_EQ(gte->shape().element_type(), BF16);
|
||||
@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
|
||||
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
EXPECT_TRUE(Normalize(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), gte);
|
||||
EXPECT_EQ(gte->shape().element_type(), BF16);
|
||||
@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
EXPECT_TRUE(Normalize(module));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
||||
EXPECT_EQ(dot->shape().element_type(), F32);
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -31,7 +31,7 @@ namespace {
|
||||
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
class CallGraphTest : public HloTestBase {
|
||||
class CallGraphTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
// Build and return a trivial computation taking and returning a scalar.
|
||||
std::unique_ptr<HloComputation> MakeScalarComputation(
|
||||
@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) {
|
||||
auto module = CreateNewModule();
|
||||
HloComputation* computation =
|
||||
module->AddEntryComputation(MakeScalarComputation());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
EXPECT_EQ(1, call_graph->nodes().size());
|
||||
EXPECT_TRUE(call_graph->IsFlattened());
|
||||
|
||||
@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) {
|
||||
HloComputation* unreachable_computation =
|
||||
module->AddEmbeddedComputation(MakeScalarComputation());
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
EXPECT_EQ(2, call_graph->nodes().size());
|
||||
|
||||
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
|
||||
@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) {
|
||||
HloComputation* entry_computation = module->AddEntryComputation(
|
||||
MakeMappingComputation(map_computation, /*callsites=*/5));
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
EXPECT_EQ(2, call_graph->nodes().size());
|
||||
|
||||
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
|
||||
@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) {
|
||||
HloComputation* entry_computation = module->AddEntryComputation(
|
||||
MakeCallingComputation(called_computation, /*callsites=*/3));
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
EXPECT_EQ(2, call_graph->nodes().size());
|
||||
|
||||
// The called computation is only called from one other computation, but there
|
||||
@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) {
|
||||
HloComputation* entry_computation =
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
EXPECT_EQ(2, call_graph->nodes().size());
|
||||
|
||||
EXPECT_FALSE(call_graph->IsFlattened());
|
||||
@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
|
||||
HloComputation* entry_computation =
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
|
||||
EXPECT_EQ(3, call_graph->nodes().size());
|
||||
|
||||
@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
|
||||
entry_computation = module->AddEntryComputation(builder.Build());
|
||||
}
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
EXPECT_EQ(5, call_graph->nodes().size());
|
||||
EXPECT_FALSE(call_graph->IsFlattened());
|
||||
|
||||
@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
|
||||
entry_computation = module->AddEntryComputation(builder.Build());
|
||||
}
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
EXPECT_EQ(5, call_graph->nodes().size());
|
||||
|
||||
// Verify NearestAncestorsInSameComputation for various instructions in the
|
||||
@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) {
|
||||
auto module = CreateNewModule();
|
||||
HloComputation* computation =
|
||||
module->AddEntryComputation(MakeScalarComputation());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
|
||||
std::vector<HloComputation*> visited;
|
||||
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
|
||||
@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) {
|
||||
module->AddEntryComputation(MakeScalarComputation());
|
||||
HloComputation* unreachable_computation =
|
||||
module->AddEmbeddedComputation(MakeScalarComputation());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
|
||||
// Test visitation of only reachable nodes.
|
||||
{
|
||||
@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) {
|
||||
// Test that the call graph visitor properly propagates errors.
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(MakeScalarComputation());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
|
||||
Status status = call_graph->VisitNodes(
|
||||
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
|
||||
|
@ -801,6 +801,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
@ -822,6 +823,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
@ -946,6 +948,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -971,6 +974,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
@ -32,7 +32,7 @@ namespace cpu {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
class ConvCanonicalizationTest : public HloTestBase {
|
||||
class ConvCanonicalizationTest : public HloVerifiedTestBase {
|
||||
public:
|
||||
ConvCanonicalizationTest() {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
|
||||
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
|
||||
});
|
||||
ConvCanonicalization conv_canonicalization(&target_machine_features);
|
||||
EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie());
|
||||
|
||||
const HloInstruction* output_reshape = entry_computation->root_instruction();
|
||||
EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
|
||||
@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
|
||||
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
|
||||
});
|
||||
ConvCanonicalization conv_canonicalization(&target_machine_features);
|
||||
EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
|
||||
EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) {
|
||||
return count;
|
||||
}
|
||||
|
||||
class CpuCopyInsertionTest : public HloTestBase {
|
||||
class CpuCopyInsertionTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
void InsertCopies(HloModule* module) {
|
||||
CpuCopyInsertion copy_insertion;
|
||||
@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
InsertCopies(module.get());
|
||||
InsertCopies(module);
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 3);
|
||||
|
||||
@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
InsertCopies(module.get());
|
||||
InsertCopies(module);
|
||||
|
||||
EXPECT_EQ(CountCopies(*subcomputation), 2);
|
||||
EXPECT_THAT(subcomputation->root_instruction(),
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
@ -25,7 +25,7 @@ namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class CpuHloSupportCheckerTest : public HloTestBase {
|
||||
class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
CpuHloSupportChecker& checker() { return checker_; }
|
||||
|
||||
@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) {
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK(checker().Run(module.get()).status());
|
||||
TF_ASSERT_OK(checker().Run(module).status());
|
||||
}
|
||||
|
||||
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module.get()).status();
|
||||
Status status = checker().Run(module).status();
|
||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("CPU backend does not support"));
|
||||
|
@ -19,14 +19,14 @@ limitations under the License.
|
||||
#include <random>
|
||||
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace {
|
||||
|
||||
class ShapePartitionAssignerTest : public HloTestBase {
|
||||
class ShapePartitionAssignerTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
typedef std::vector<int64> Vec;
|
||||
|
||||
@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
|
||||
expected_partitions);
|
||||
}
|
||||
|
||||
class ShapePartitionIteratorTest : public HloTestBase {
|
||||
class ShapePartitionIteratorTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
typedef std::vector<std::pair<int64, int64>> Partition;
|
||||
};
|
||||
@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
|
||||
}
|
||||
}
|
||||
|
||||
class RandomShapePartitionIteratorTest : public HloTestBase {
|
||||
class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
typedef std::vector<std::pair<int64, int64>> Partition;
|
||||
RandomShapePartitionIteratorTest()
|
||||
|
@ -48,6 +48,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -34,7 +34,7 @@ namespace xla {
|
||||
namespace cpu {
|
||||
namespace {
|
||||
|
||||
class CpuFusionTest : public HloTestBase {
|
||||
class CpuFusionTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
CpuFusionTest() {}
|
||||
|
||||
@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
CpuInstructionFusion fusion;
|
||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||
|
||||
// The computation root instruction was fused. Verify the fusion instruction
|
||||
// is now the root.
|
||||
@ -75,7 +75,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
|
||||
EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
|
||||
|
||||
// Compile and execute the computation.
|
||||
auto result = ExecuteAndTransfer(std::move(module), {});
|
||||
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||
|
||||
// Check the output correctness.
|
||||
LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
|
||||
@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
CpuInstructionFusion fusion;
|
||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||
|
||||
// The computation root instruction was fused. Verify the fusion instruction
|
||||
// is now the root.
|
||||
@ -122,7 +122,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
|
||||
EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
|
||||
|
||||
// Compile and execute the computation.
|
||||
auto result = ExecuteAndTransfer(std::move(module), {});
|
||||
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||
|
||||
// Check the output correctness.
|
||||
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
|
||||
@ -184,7 +184,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
CpuInstructionFusion fusion;
|
||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||
|
||||
// The computation root instruction was fused. Verify the fusion instruction
|
||||
// is now the root.
|
||||
@ -209,7 +209,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
|
||||
<< fusion_instruction2->fused_instructions_computation()->ToString();
|
||||
|
||||
// Compile and execute the computation.
|
||||
auto result = ExecuteAndTransfer(std::move(module), {});
|
||||
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||
|
||||
// Check the output correctness.
|
||||
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
|
||||
@ -256,7 +256,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
|
||||
|
||||
// Run fusion.
|
||||
CpuInstructionFusion fusion;
|
||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||
|
||||
auto fusion1 = result->operand(0);
|
||||
auto fusion2 = result->operand(1);
|
||||
@ -315,7 +315,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
CpuInstructionFusion fusion;
|
||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||
|
||||
// The only fusion instruction should be operand 0 of the tuple (formerly
|
||||
// negate1).
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -30,7 +30,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class FlattenCallGraphTest : public HloTestBase {
|
||||
class FlattenCallGraphTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
// Build and return a trivial computation taking and returning a scalar.
|
||||
std::unique_ptr<HloComputation> MakeScalarComputation() {
|
||||
@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
|
||||
}
|
||||
|
||||
{
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
|
||||
EXPECT_TRUE(result);
|
||||
std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module);
|
||||
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
|
||||
EXPECT_EQ(1, c_node.caller_callsites().size());
|
||||
}
|
||||
@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
|
||||
EXPECT_EQ(2, cond_node.caller_callsites().size());
|
||||
}
|
||||
|
||||
{
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
|
||||
EXPECT_TRUE(result);
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
|
||||
EXPECT_EQ(1, cond_node.caller_callsites().size());
|
||||
}
|
||||
@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
|
||||
module->AddEntryComputation(
|
||||
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
|
||||
EXPECT_TRUE(result);
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
EXPECT_EQ(7, module->computation_count());
|
||||
|
||||
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
|
||||
@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
|
||||
module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(2, module->computation_count());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
|
||||
EXPECT_TRUE(result);
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
// The true and false computations must now be different.
|
||||
EXPECT_EQ(3, module->computation_count());
|
||||
|
||||
|
@ -108,6 +108,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
@ -832,6 +833,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -901,6 +903,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -24,14 +24,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class GpuHloScheduleTest : public HloTestBase {
|
||||
class GpuHloScheduleTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
using HloVec = std::vector<const HloInstruction*>;
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
@ -25,7 +25,7 @@ namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class GpuHloSupportCheckerTest : public HloTestBase {
|
||||
class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
GpuHloSupportChecker& checker() { return checker_; }
|
||||
|
||||
@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) {
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK(checker().Run(module.get()).status());
|
||||
TF_ASSERT_OK(checker().Run(module).status());
|
||||
}
|
||||
|
||||
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module.get()).status();
|
||||
Status status = checker().Run(module).status();
|
||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("GPU backend does not support"));
|
||||
|
@ -21,14 +21,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class StreamAssignmentTest : public HloTestBase {
|
||||
class StreamAssignmentTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
std::unique_ptr<HloModule> CreateNewModule() {
|
||||
HloModuleConfig config;
|
||||
|
@ -29,14 +29,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class MinimumMemoryForSequenceTest : public HloTestBase {};
|
||||
class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
|
||||
|
||||
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
||||
auto module = CreateNewModule();
|
||||
@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
|
||||
};
|
||||
|
||||
HloSchedule schedule(module.get());
|
||||
HloSchedule schedule(module);
|
||||
schedule.set_sequence(cond_computation,
|
||||
{cond_param, cond_iter, cond_data, cond_lt});
|
||||
schedule.set_sequence(body_computation, {body_param});
|
||||
@ -233,7 +233,7 @@ class HeapSimulatorTracker {
|
||||
HeapSimulator::Result result_;
|
||||
};
|
||||
|
||||
class HeapSimulatorTest : public HloTestBase {
|
||||
class HeapSimulatorTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
HeapSimulatorTest() {}
|
||||
~HeapSimulatorTest() override {}
|
||||
|
@ -20,13 +20,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
class HloReachabilityTest : public HloTestBase {};
|
||||
class HloReachabilityTest : public HloVerifiedTestBase {};
|
||||
|
||||
TEST_F(HloReachabilityTest, Reachability) {
|
||||
// Construct and test a reachability graph of the following form:
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
using ::testing::_;
|
||||
|
||||
class HloRematerializationTest : public HloTestBase {
|
||||
class HloRematerializationTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
// Creates and returns a computation which can benefit from
|
||||
// rematerialization. The computation looks like:
|
||||
@ -177,7 +177,7 @@ TEST_F(HloRematerializationTest, SingleComputation) {
|
||||
// with rematerialization so pick a memory limit between these values (14KB).
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/14 * 1024, module.get()));
|
||||
/*memory_limit_bytes=*/14 * 1024, module));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
// Root should not have changed.
|
||||
@ -211,7 +211,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/20 * 1024, module.get()));
|
||||
/*memory_limit_bytes=*/20 * 1024, module));
|
||||
|
||||
// No instructions should have been materialized.
|
||||
EXPECT_FALSE(changed);
|
||||
@ -249,7 +249,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
|
||||
// bit lower (17KB) to force rematerialization of the entry computation.
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/17 * 1024, module.get()));
|
||||
/*memory_limit_bytes=*/17 * 1024, module));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
// Only the entry computation should have a rematerialized instruction added.
|
||||
@ -282,7 +282,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/15 * 1024, module.get()));
|
||||
/*memory_limit_bytes=*/15 * 1024, module));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
// Both computations should have rematerialized instructions added.
|
||||
@ -321,7 +321,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
|
||||
// ~12K so pick something slightly larger.
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/13 * 1024, module.get()));
|
||||
/*memory_limit_bytes=*/13 * 1024, module));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
// All computations should have rematerialized instructions added.
|
||||
@ -390,7 +390,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get()));
|
||||
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module));
|
||||
EXPECT_TRUE(changed);
|
||||
// The rng should not have been rematerialized.
|
||||
EXPECT_EQ(count_rngs(entry_computation), 1);
|
||||
@ -482,7 +482,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
|
||||
// rematerialization).
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/22 * 1024, module.get()));
|
||||
/*memory_limit_bytes=*/22 * 1024, module));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
// The broadcast should have been rematerialized 3 times.
|
||||
@ -576,7 +576,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
|
||||
// rematerialization).
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/22 * 1024, module.get()));
|
||||
/*memory_limit_bytes=*/22 * 1024, module));
|
||||
// Rematerialization should only occur if the rematerializable instruction has
|
||||
// no indirect uses.
|
||||
if (indirectly_used) {
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
|
||||
@ -24,7 +24,7 @@ namespace {
|
||||
|
||||
using ::tensorflow::GraphDef;
|
||||
|
||||
class HloTfGraphBuilderTest : public HloTestBase {
|
||||
class HloTfGraphBuilderTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
HloTfGraphBuilderTest() {}
|
||||
HloTfGraphBuilder generator_;
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class TupleSimplifierTest : public HloTestBase {
|
||||
class TupleSimplifierTest : public HloVerifiedTestBase {
|
||||
protected:
|
||||
void Run(HloModule* module, bool change_expected) {
|
||||
TupleSimplifier simplifier;
|
||||
@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) {
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Run(module.get(), /*change_expected=*/false);
|
||||
Run(module, /*change_expected=*/false);
|
||||
}
|
||||
|
||||
TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
|
||||
@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Run(module.get(), /*change_expected=*/false);
|
||||
Run(module, /*change_expected=*/false);
|
||||
}
|
||||
|
||||
TEST_F(TupleSimplifierTest, GteOfTuple) {
|
||||
@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) {
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), gte);
|
||||
|
||||
Run(module.get(), /*change_expected=*/true);
|
||||
Run(module, /*change_expected=*/true);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), param1);
|
||||
}
|
||||
@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) {
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Negate(op::GetTupleElement(op::Tuple())));
|
||||
|
||||
Run(module.get(), /*change_expected=*/true);
|
||||
Run(module, /*change_expected=*/true);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
|
||||
}
|
||||
@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), element);
|
||||
|
||||
Run(module.get(), /*change_expected=*/true);
|
||||
Run(module, /*change_expected=*/true);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), param);
|
||||
}
|
||||
@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), tuple);
|
||||
|
||||
Run(module.get(), /*change_expected=*/true);
|
||||
Run(module, /*change_expected=*/true);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), tuple_param);
|
||||
}
|
||||
@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), tuple);
|
||||
|
||||
Run(module.get(), /*change_expected=*/false);
|
||||
Run(module, /*change_expected=*/false);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), tuple);
|
||||
}
|
||||
@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
|
||||
entry = module->AddEntryComputation(builder.Build());
|
||||
}
|
||||
|
||||
Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
|
||||
Run(module, /*change_expected=*/true, /*exclude_entry=*/true);
|
||||
|
||||
EXPECT_THAT(c0->root_instruction(), p0);
|
||||
EXPECT_THAT(c1->root_instruction(), p1);
|
||||
|
Loading…
Reference in New Issue
Block a user