[TF:XLA] Migrate unit tests to use the HLO verifier (only tests where the conversion is mostly automated).
PiperOrigin-RevId: 212303594
This commit is contained in:
parent
a8b2dd9f72
commit
96b77a647b
@ -87,6 +87,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
@ -123,6 +124,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
@ -352,6 +354,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
@ -402,6 +405,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -498,6 +502,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
@ -568,6 +573,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
@ -1131,6 +1137,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
@ -1709,6 +1716,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -2237,6 +2245,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
@ -2315,6 +2324,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -2428,6 +2438,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
@ -2888,6 +2899,7 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":hlo_tfgraph_builder",
|
":hlo_tfgraph_builder",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class BFloat16ConversionFoldingTest : public HloTestBase {
|
class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
|
BFloat16ConversionFoldingTest()
|
||||||
|
: HloVerifiedTestBase(/*layout_sensitive=*/false,
|
||||||
|
/*allow_mixed_precision=*/true) {}
|
||||||
|
|
||||||
bool FoldConversions(HloModule* module) {
|
bool FoldConversions(HloModule* module) {
|
||||||
TestBFloat16Support bfloat16_support_;
|
TestBFloat16Support bfloat16_support_;
|
||||||
BFloat16ConversionFolding fold(&bfloat16_support_);
|
BFloat16ConversionFolding fold(&bfloat16_support_);
|
||||||
@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_TRUE(FoldConversions(module.get()));
|
EXPECT_TRUE(FoldConversions(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction(), add1);
|
EXPECT_EQ(computation->root_instruction(), add1);
|
||||||
EXPECT_EQ(add0->shape().element_type(), BF16);
|
EXPECT_EQ(add0->shape().element_type(), BF16);
|
||||||
@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_FALSE(FoldConversions(module.get()));
|
EXPECT_FALSE(FoldConversions(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction(), convert2);
|
EXPECT_EQ(computation->root_instruction(), convert2);
|
||||||
EXPECT_EQ(mul0->shape().element_type(), F32);
|
EXPECT_EQ(mul0->shape().element_type(), F32);
|
||||||
@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_FALSE(FoldConversions(module.get()));
|
EXPECT_FALSE(FoldConversions(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction(), convert2);
|
EXPECT_EQ(computation->root_instruction(), convert2);
|
||||||
EXPECT_EQ(sub0->shape().element_type(), F32);
|
EXPECT_EQ(sub0->shape().element_type(), F32);
|
||||||
@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_FALSE(FoldConversions(module.get()));
|
EXPECT_FALSE(FoldConversions(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction(), convert1);
|
EXPECT_EQ(computation->root_instruction(), convert1);
|
||||||
EXPECT_EQ(gte->shape().element_type(), F32);
|
EXPECT_EQ(gte->shape().element_type(), F32);
|
||||||
@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
|
|||||||
|
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_TRUE(FoldConversions(module.get()));
|
EXPECT_TRUE(FoldConversions(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction(), tuple);
|
EXPECT_EQ(computation->root_instruction(), tuple);
|
||||||
EXPECT_EQ(tuple->operand(0), gte_a);
|
EXPECT_EQ(tuple->operand(0), gte_a);
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class BFloat16NormalizationTest : public HloTestBase {
|
class BFloat16NormalizationTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
|
BFloat16NormalizationTest()
|
||||||
|
: HloVerifiedTestBase(/*layout_sensitive=*/false,
|
||||||
|
/*allow_mixed_precision=*/true) {}
|
||||||
|
|
||||||
bool Normalize(HloModule* module) {
|
bool Normalize(HloModule* module) {
|
||||||
TestBFloat16Support bfloat16_support_;
|
TestBFloat16Support bfloat16_support_;
|
||||||
BFloat16Normalization normalization(&bfloat16_support_);
|
BFloat16Normalization normalization(&bfloat16_support_);
|
||||||
@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_FALSE(Normalize(module.get()));
|
EXPECT_FALSE(Normalize(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction(), add1);
|
EXPECT_EQ(computation->root_instruction(), add1);
|
||||||
EXPECT_EQ(add0->shape().element_type(), BF16);
|
EXPECT_EQ(add0->shape().element_type(), BF16);
|
||||||
@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_TRUE(Normalize(module.get()));
|
EXPECT_TRUE(Normalize(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
||||||
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
|
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
|
||||||
@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_TRUE(Normalize(module.get()));
|
EXPECT_TRUE(Normalize(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
||||||
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
|
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
|
||||||
@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
|
|||||||
|
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_TRUE(Normalize(module.get()));
|
EXPECT_TRUE(Normalize(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction(), reduce);
|
EXPECT_EQ(computation->root_instruction(), reduce);
|
||||||
EXPECT_EQ(reduce->called_computations().size(), 1);
|
EXPECT_EQ(reduce->called_computations().size(), 1);
|
||||||
@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
|
|||||||
|
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_TRUE(Normalize(module.get()));
|
EXPECT_TRUE(Normalize(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction(), gte);
|
EXPECT_EQ(computation->root_instruction(), gte);
|
||||||
EXPECT_EQ(gte->shape().element_type(), BF16);
|
EXPECT_EQ(gte->shape().element_type(), BF16);
|
||||||
@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
|
|||||||
|
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_TRUE(Normalize(module.get()));
|
EXPECT_TRUE(Normalize(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction(), gte);
|
EXPECT_EQ(computation->root_instruction(), gte);
|
||||||
EXPECT_EQ(gte->shape().element_type(), BF16);
|
EXPECT_EQ(gte->shape().element_type(), BF16);
|
||||||
@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_TRUE(Normalize(module.get()));
|
EXPECT_TRUE(Normalize(module));
|
||||||
|
|
||||||
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
||||||
EXPECT_EQ(dot->shape().element_type(), F32);
|
EXPECT_EQ(dot->shape().element_type(), F32);
|
||||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -31,7 +31,7 @@ namespace {
|
|||||||
|
|
||||||
using ::testing::UnorderedElementsAre;
|
using ::testing::UnorderedElementsAre;
|
||||||
|
|
||||||
class CallGraphTest : public HloTestBase {
|
class CallGraphTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
// Build and return a trivial computation taking and returning a scalar.
|
// Build and return a trivial computation taking and returning a scalar.
|
||||||
std::unique_ptr<HloComputation> MakeScalarComputation(
|
std::unique_ptr<HloComputation> MakeScalarComputation(
|
||||||
@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
HloComputation* computation =
|
HloComputation* computation =
|
||||||
module->AddEntryComputation(MakeScalarComputation());
|
module->AddEntryComputation(MakeScalarComputation());
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
EXPECT_EQ(1, call_graph->nodes().size());
|
EXPECT_EQ(1, call_graph->nodes().size());
|
||||||
EXPECT_TRUE(call_graph->IsFlattened());
|
EXPECT_TRUE(call_graph->IsFlattened());
|
||||||
|
|
||||||
@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) {
|
|||||||
HloComputation* unreachable_computation =
|
HloComputation* unreachable_computation =
|
||||||
module->AddEmbeddedComputation(MakeScalarComputation());
|
module->AddEmbeddedComputation(MakeScalarComputation());
|
||||||
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
EXPECT_EQ(2, call_graph->nodes().size());
|
EXPECT_EQ(2, call_graph->nodes().size());
|
||||||
|
|
||||||
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
|
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
|
||||||
@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) {
|
|||||||
HloComputation* entry_computation = module->AddEntryComputation(
|
HloComputation* entry_computation = module->AddEntryComputation(
|
||||||
MakeMappingComputation(map_computation, /*callsites=*/5));
|
MakeMappingComputation(map_computation, /*callsites=*/5));
|
||||||
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
EXPECT_EQ(2, call_graph->nodes().size());
|
EXPECT_EQ(2, call_graph->nodes().size());
|
||||||
|
|
||||||
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
|
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
|
||||||
@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) {
|
|||||||
HloComputation* entry_computation = module->AddEntryComputation(
|
HloComputation* entry_computation = module->AddEntryComputation(
|
||||||
MakeCallingComputation(called_computation, /*callsites=*/3));
|
MakeCallingComputation(called_computation, /*callsites=*/3));
|
||||||
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
EXPECT_EQ(2, call_graph->nodes().size());
|
EXPECT_EQ(2, call_graph->nodes().size());
|
||||||
|
|
||||||
// The called computation is only called from one other computation, but there
|
// The called computation is only called from one other computation, but there
|
||||||
@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) {
|
|||||||
HloComputation* entry_computation =
|
HloComputation* entry_computation =
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
EXPECT_EQ(2, call_graph->nodes().size());
|
EXPECT_EQ(2, call_graph->nodes().size());
|
||||||
|
|
||||||
EXPECT_FALSE(call_graph->IsFlattened());
|
EXPECT_FALSE(call_graph->IsFlattened());
|
||||||
@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
|
|||||||
HloComputation* entry_computation =
|
HloComputation* entry_computation =
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
|
|
||||||
EXPECT_EQ(3, call_graph->nodes().size());
|
EXPECT_EQ(3, call_graph->nodes().size());
|
||||||
|
|
||||||
@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
|
|||||||
entry_computation = module->AddEntryComputation(builder.Build());
|
entry_computation = module->AddEntryComputation(builder.Build());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
EXPECT_EQ(5, call_graph->nodes().size());
|
EXPECT_EQ(5, call_graph->nodes().size());
|
||||||
EXPECT_FALSE(call_graph->IsFlattened());
|
EXPECT_FALSE(call_graph->IsFlattened());
|
||||||
|
|
||||||
@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
|
|||||||
entry_computation = module->AddEntryComputation(builder.Build());
|
entry_computation = module->AddEntryComputation(builder.Build());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
EXPECT_EQ(5, call_graph->nodes().size());
|
EXPECT_EQ(5, call_graph->nodes().size());
|
||||||
|
|
||||||
// Verify NearestAncestorsInSameComputation for various instructions in the
|
// Verify NearestAncestorsInSameComputation for various instructions in the
|
||||||
@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
HloComputation* computation =
|
HloComputation* computation =
|
||||||
module->AddEntryComputation(MakeScalarComputation());
|
module->AddEntryComputation(MakeScalarComputation());
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
|
|
||||||
std::vector<HloComputation*> visited;
|
std::vector<HloComputation*> visited;
|
||||||
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
|
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
|
||||||
@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) {
|
|||||||
module->AddEntryComputation(MakeScalarComputation());
|
module->AddEntryComputation(MakeScalarComputation());
|
||||||
HloComputation* unreachable_computation =
|
HloComputation* unreachable_computation =
|
||||||
module->AddEmbeddedComputation(MakeScalarComputation());
|
module->AddEmbeddedComputation(MakeScalarComputation());
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
|
|
||||||
// Test visitation of only reachable nodes.
|
// Test visitation of only reachable nodes.
|
||||||
{
|
{
|
||||||
@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) {
|
|||||||
// Test that the call graph visitor properly propagates errors.
|
// Test that the call graph visitor properly propagates errors.
|
||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
module->AddEntryComputation(MakeScalarComputation());
|
module->AddEntryComputation(MakeScalarComputation());
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
|
|
||||||
Status status = call_graph->VisitNodes(
|
Status status = call_graph->VisitNodes(
|
||||||
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
|
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
|
||||||
|
@ -801,6 +801,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -822,6 +823,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -946,6 +948,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
@ -971,6 +974,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
@ -32,7 +32,7 @@ namespace cpu {
|
|||||||
|
|
||||||
using ::testing::ElementsAre;
|
using ::testing::ElementsAre;
|
||||||
|
|
||||||
class ConvCanonicalizationTest : public HloTestBase {
|
class ConvCanonicalizationTest : public HloVerifiedTestBase {
|
||||||
public:
|
public:
|
||||||
ConvCanonicalizationTest() {
|
ConvCanonicalizationTest() {
|
||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
|
|||||||
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
|
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
|
||||||
});
|
});
|
||||||
ConvCanonicalization conv_canonicalization(&target_machine_features);
|
ConvCanonicalization conv_canonicalization(&target_machine_features);
|
||||||
EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie());
|
||||||
|
|
||||||
const HloInstruction* output_reshape = entry_computation->root_instruction();
|
const HloInstruction* output_reshape = entry_computation->root_instruction();
|
||||||
EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
|
EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
|
||||||
@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
|
|||||||
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
|
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
|
||||||
});
|
});
|
||||||
ConvCanonicalization conv_canonicalization(&target_machine_features);
|
ConvCanonicalization conv_canonicalization(&target_machine_features);
|
||||||
EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
|
EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) {
|
|||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
class CpuCopyInsertionTest : public HloTestBase {
|
class CpuCopyInsertionTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
void InsertCopies(HloModule* module) {
|
void InsertCopies(HloModule* module) {
|
||||||
CpuCopyInsertion copy_insertion;
|
CpuCopyInsertion copy_insertion;
|
||||||
@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
|
|||||||
|
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
InsertCopies(module.get());
|
InsertCopies(module);
|
||||||
|
|
||||||
EXPECT_EQ(CountCopies(*module), 3);
|
EXPECT_EQ(CountCopies(*module), 3);
|
||||||
|
|
||||||
@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
|
|||||||
|
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
InsertCopies(module.get());
|
InsertCopies(module);
|
||||||
|
|
||||||
EXPECT_EQ(CountCopies(*subcomputation), 2);
|
EXPECT_EQ(CountCopies(*subcomputation), 2);
|
||||||
EXPECT_THAT(subcomputation->root_instruction(),
|
EXPECT_THAT(subcomputation->root_instruction(),
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ namespace {
|
|||||||
|
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
class CpuHloSupportCheckerTest : public HloTestBase {
|
class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
CpuHloSupportChecker& checker() { return checker_; }
|
CpuHloSupportChecker& checker() { return checker_; }
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
TF_ASSERT_OK(checker().Run(module.get()).status());
|
TF_ASSERT_OK(checker().Run(module).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
|
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||||
@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
Status status = checker().Run(module.get()).status();
|
Status status = checker().Run(module).status();
|
||||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||||
EXPECT_THAT(status.error_message(),
|
EXPECT_THAT(status.error_message(),
|
||||||
HasSubstr("CPU backend does not support"));
|
HasSubstr("CPU backend does not support"));
|
||||||
|
@ -19,14 +19,14 @@ limitations under the License.
|
|||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class ShapePartitionAssignerTest : public HloTestBase {
|
class ShapePartitionAssignerTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
typedef std::vector<int64> Vec;
|
typedef std::vector<int64> Vec;
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
|
|||||||
expected_partitions);
|
expected_partitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
class ShapePartitionIteratorTest : public HloTestBase {
|
class ShapePartitionIteratorTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
typedef std::vector<std::pair<int64, int64>> Partition;
|
typedef std::vector<std::pair<int64, int64>> Partition;
|
||||||
};
|
};
|
||||||
@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class RandomShapePartitionIteratorTest : public HloTestBase {
|
class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
typedef std::vector<std::pair<int64, int64>> Partition;
|
typedef std::vector<std::pair<int64, int64>> Partition;
|
||||||
RandomShapePartitionIteratorTest()
|
RandomShapePartitionIteratorTest()
|
||||||
|
@ -48,6 +48,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
|
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
@ -34,7 +34,7 @@ namespace xla {
|
|||||||
namespace cpu {
|
namespace cpu {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class CpuFusionTest : public HloTestBase {
|
class CpuFusionTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
CpuFusionTest() {}
|
CpuFusionTest() {}
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
|
|||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
CpuInstructionFusion fusion;
|
CpuInstructionFusion fusion;
|
||||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||||
|
|
||||||
// The computation root instruction was fused. Verify the fusion instruction
|
// The computation root instruction was fused. Verify the fusion instruction
|
||||||
// is now the root.
|
// is now the root.
|
||||||
@ -75,7 +75,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
|
|||||||
EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
|
EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
|
||||||
|
|
||||||
// Compile and execute the computation.
|
// Compile and execute the computation.
|
||||||
auto result = ExecuteAndTransfer(std::move(module), {});
|
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||||
|
|
||||||
// Check the output correctness.
|
// Check the output correctness.
|
||||||
LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
|
LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
|
||||||
@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
|
|||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
CpuInstructionFusion fusion;
|
CpuInstructionFusion fusion;
|
||||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||||
|
|
||||||
// The computation root instruction was fused. Verify the fusion instruction
|
// The computation root instruction was fused. Verify the fusion instruction
|
||||||
// is now the root.
|
// is now the root.
|
||||||
@ -122,7 +122,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
|
|||||||
EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
|
EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
|
||||||
|
|
||||||
// Compile and execute the computation.
|
// Compile and execute the computation.
|
||||||
auto result = ExecuteAndTransfer(std::move(module), {});
|
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||||
|
|
||||||
// Check the output correctness.
|
// Check the output correctness.
|
||||||
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
|
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
|
||||||
@ -184,7 +184,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
|
|||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
CpuInstructionFusion fusion;
|
CpuInstructionFusion fusion;
|
||||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||||
|
|
||||||
// The computation root instruction was fused. Verify the fusion instruction
|
// The computation root instruction was fused. Verify the fusion instruction
|
||||||
// is now the root.
|
// is now the root.
|
||||||
@ -209,7 +209,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
|
|||||||
<< fusion_instruction2->fused_instructions_computation()->ToString();
|
<< fusion_instruction2->fused_instructions_computation()->ToString();
|
||||||
|
|
||||||
// Compile and execute the computation.
|
// Compile and execute the computation.
|
||||||
auto result = ExecuteAndTransfer(std::move(module), {});
|
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||||
|
|
||||||
// Check the output correctness.
|
// Check the output correctness.
|
||||||
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
|
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
|
||||||
@ -256,7 +256,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
|
|||||||
|
|
||||||
// Run fusion.
|
// Run fusion.
|
||||||
CpuInstructionFusion fusion;
|
CpuInstructionFusion fusion;
|
||||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||||
|
|
||||||
auto fusion1 = result->operand(0);
|
auto fusion1 = result->operand(0);
|
||||||
auto fusion2 = result->operand(1);
|
auto fusion2 = result->operand(1);
|
||||||
@ -315,7 +315,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
|
|||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
CpuInstructionFusion fusion;
|
CpuInstructionFusion fusion;
|
||||||
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(fusion.Run(module).ValueOrDie());
|
||||||
|
|
||||||
// The only fusion instruction should be operand 0 of the tuple (formerly
|
// The only fusion instruction should be operand 0 of the tuple (formerly
|
||||||
// negate1).
|
// negate1).
|
||||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -30,7 +30,7 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class FlattenCallGraphTest : public HloTestBase {
|
class FlattenCallGraphTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
// Build and return a trivial computation taking and returning a scalar.
|
// Build and return a trivial computation taking and returning a scalar.
|
||||||
std::unique_ptr<HloComputation> MakeScalarComputation() {
|
std::unique_ptr<HloComputation> MakeScalarComputation() {
|
||||||
@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
|
||||||
EXPECT_TRUE(result);
|
EXPECT_TRUE(result);
|
||||||
std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module);
|
||||||
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
|
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
|
||||||
EXPECT_EQ(1, c_node.caller_callsites().size());
|
EXPECT_EQ(1, c_node.caller_callsites().size());
|
||||||
}
|
}
|
||||||
@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
|
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
|
||||||
EXPECT_EQ(2, cond_node.caller_callsites().size());
|
EXPECT_EQ(2, cond_node.caller_callsites().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
|
||||||
EXPECT_TRUE(result);
|
EXPECT_TRUE(result);
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
|
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
|
||||||
EXPECT_EQ(1, cond_node.caller_callsites().size());
|
EXPECT_EQ(1, cond_node.caller_callsites().size());
|
||||||
}
|
}
|
||||||
@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
|
|||||||
module->AddEntryComputation(
|
module->AddEntryComputation(
|
||||||
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
|
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
|
||||||
EXPECT_TRUE(result);
|
EXPECT_TRUE(result);
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
EXPECT_EQ(7, module->computation_count());
|
EXPECT_EQ(7, module->computation_count());
|
||||||
|
|
||||||
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
|
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
|
||||||
@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
|
|||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
EXPECT_EQ(2, module->computation_count());
|
EXPECT_EQ(2, module->computation_count());
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
|
||||||
EXPECT_TRUE(result);
|
EXPECT_TRUE(result);
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
// The true and false computations must now be different.
|
// The true and false computations must now be different.
|
||||||
EXPECT_EQ(3, module->computation_count());
|
EXPECT_EQ(3, module->computation_count());
|
||||||
|
|
||||||
|
@ -108,6 +108,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:test_utils",
|
"//tensorflow/compiler/xla/tests:test_utils",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -832,6 +833,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:test_utils",
|
"//tensorflow/compiler/xla/tests:test_utils",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
@ -901,6 +903,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
@ -24,14 +24,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
class GpuHloScheduleTest : public HloTestBase {
|
class GpuHloScheduleTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
using HloVec = std::vector<const HloInstruction*>;
|
using HloVec = std::vector<const HloInstruction*>;
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ namespace {
|
|||||||
|
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
class GpuHloSupportCheckerTest : public HloTestBase {
|
class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
GpuHloSupportChecker& checker() { return checker_; }
|
GpuHloSupportChecker& checker() { return checker_; }
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
TF_ASSERT_OK(checker().Run(module.get()).status());
|
TF_ASSERT_OK(checker().Run(module).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
|
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||||
@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
Status status = checker().Run(module.get()).status();
|
Status status = checker().Run(module).status();
|
||||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||||
EXPECT_THAT(status.error_message(),
|
EXPECT_THAT(status.error_message(),
|
||||||
HasSubstr("GPU backend does not support"));
|
HasSubstr("GPU backend does not support"));
|
||||||
|
@ -21,14 +21,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
class StreamAssignmentTest : public HloTestBase {
|
class StreamAssignmentTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
std::unique_ptr<HloModule> CreateNewModule() {
|
std::unique_ptr<HloModule> CreateNewModule() {
|
||||||
HloModuleConfig config;
|
HloModuleConfig config;
|
||||||
|
@ -29,14 +29,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
||||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class MinimumMemoryForSequenceTest : public HloTestBase {};
|
class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
|
||||||
|
|
||||||
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
|||||||
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
|
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
|
||||||
};
|
};
|
||||||
|
|
||||||
HloSchedule schedule(module.get());
|
HloSchedule schedule(module);
|
||||||
schedule.set_sequence(cond_computation,
|
schedule.set_sequence(cond_computation,
|
||||||
{cond_param, cond_iter, cond_data, cond_lt});
|
{cond_param, cond_iter, cond_data, cond_lt});
|
||||||
schedule.set_sequence(body_computation, {body_param});
|
schedule.set_sequence(body_computation, {body_param});
|
||||||
@ -233,7 +233,7 @@ class HeapSimulatorTracker {
|
|||||||
HeapSimulator::Result result_;
|
HeapSimulator::Result result_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class HeapSimulatorTest : public HloTestBase {
|
class HeapSimulatorTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
HeapSimulatorTest() {}
|
HeapSimulatorTest() {}
|
||||||
~HeapSimulatorTest() override {}
|
~HeapSimulatorTest() override {}
|
||||||
|
@ -20,13 +20,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class HloReachabilityTest : public HloTestBase {};
|
class HloReachabilityTest : public HloVerifiedTestBase {};
|
||||||
|
|
||||||
TEST_F(HloReachabilityTest, Reachability) {
|
TEST_F(HloReachabilityTest, Reachability) {
|
||||||
// Construct and test a reachability graph of the following form:
|
// Construct and test a reachability graph of the following form:
|
||||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
|
|||||||
|
|
||||||
using ::testing::_;
|
using ::testing::_;
|
||||||
|
|
||||||
class HloRematerializationTest : public HloTestBase {
|
class HloRematerializationTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
// Creates and returns a computation which can benefit from
|
// Creates and returns a computation which can benefit from
|
||||||
// rematerialization. The computation looks like:
|
// rematerialization. The computation looks like:
|
||||||
@ -177,7 +177,7 @@ TEST_F(HloRematerializationTest, SingleComputation) {
|
|||||||
// with rematerialization so pick a memory limit between these values (14KB).
|
// with rematerialization so pick a memory limit between these values (14KB).
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
/*memory_limit_bytes=*/14 * 1024, module.get()));
|
/*memory_limit_bytes=*/14 * 1024, module));
|
||||||
EXPECT_TRUE(changed);
|
EXPECT_TRUE(changed);
|
||||||
|
|
||||||
// Root should not have changed.
|
// Root should not have changed.
|
||||||
@ -211,7 +211,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
|
|||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
/*memory_limit_bytes=*/20 * 1024, module.get()));
|
/*memory_limit_bytes=*/20 * 1024, module));
|
||||||
|
|
||||||
// No instructions should have been materialized.
|
// No instructions should have been materialized.
|
||||||
EXPECT_FALSE(changed);
|
EXPECT_FALSE(changed);
|
||||||
@ -249,7 +249,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
|
|||||||
// bit lower (17KB) to force rematerialization of the entry computation.
|
// bit lower (17KB) to force rematerialization of the entry computation.
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
/*memory_limit_bytes=*/17 * 1024, module.get()));
|
/*memory_limit_bytes=*/17 * 1024, module));
|
||||||
EXPECT_TRUE(changed);
|
EXPECT_TRUE(changed);
|
||||||
|
|
||||||
// Only the entry computation should have a rematerialized instruction added.
|
// Only the entry computation should have a rematerialized instruction added.
|
||||||
@ -282,7 +282,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
|
|||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
/*memory_limit_bytes=*/15 * 1024, module.get()));
|
/*memory_limit_bytes=*/15 * 1024, module));
|
||||||
EXPECT_TRUE(changed);
|
EXPECT_TRUE(changed);
|
||||||
|
|
||||||
// Both computations should have rematerialized instructions added.
|
// Both computations should have rematerialized instructions added.
|
||||||
@ -321,7 +321,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
|
|||||||
// ~12K so pick something slightly larger.
|
// ~12K so pick something slightly larger.
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
/*memory_limit_bytes=*/13 * 1024, module.get()));
|
/*memory_limit_bytes=*/13 * 1024, module));
|
||||||
EXPECT_TRUE(changed);
|
EXPECT_TRUE(changed);
|
||||||
|
|
||||||
// All computations should have rematerialized instructions added.
|
// All computations should have rematerialized instructions added.
|
||||||
@ -390,7 +390,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
bool changed,
|
bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get()));
|
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module));
|
||||||
EXPECT_TRUE(changed);
|
EXPECT_TRUE(changed);
|
||||||
// The rng should not have been rematerialized.
|
// The rng should not have been rematerialized.
|
||||||
EXPECT_EQ(count_rngs(entry_computation), 1);
|
EXPECT_EQ(count_rngs(entry_computation), 1);
|
||||||
@ -482,7 +482,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
|
|||||||
// rematerialization).
|
// rematerialization).
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
/*memory_limit_bytes=*/22 * 1024, module.get()));
|
/*memory_limit_bytes=*/22 * 1024, module));
|
||||||
EXPECT_TRUE(changed);
|
EXPECT_TRUE(changed);
|
||||||
|
|
||||||
// The broadcast should have been rematerialized 3 times.
|
// The broadcast should have been rematerialized 3 times.
|
||||||
@ -576,7 +576,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
|
|||||||
// rematerialization).
|
// rematerialization).
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
/*memory_limit_bytes=*/22 * 1024, module.get()));
|
/*memory_limit_bytes=*/22 * 1024, module));
|
||||||
// Rematerialization should only occur if the rematerializable instruction has
|
// Rematerialization should only occur if the rematerializable instruction has
|
||||||
// no indirect uses.
|
// no indirect uses.
|
||||||
if (indirectly_used) {
|
if (indirectly_used) {
|
||||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
|
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ namespace {
|
|||||||
|
|
||||||
using ::tensorflow::GraphDef;
|
using ::tensorflow::GraphDef;
|
||||||
|
|
||||||
class HloTfGraphBuilderTest : public HloTestBase {
|
class HloTfGraphBuilderTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
HloTfGraphBuilderTest() {}
|
HloTfGraphBuilderTest() {}
|
||||||
HloTfGraphBuilder generator_;
|
HloTfGraphBuilder generator_;
|
||||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers;
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class TupleSimplifierTest : public HloTestBase {
|
class TupleSimplifierTest : public HloVerifiedTestBase {
|
||||||
protected:
|
protected:
|
||||||
void Run(HloModule* module, bool change_expected) {
|
void Run(HloModule* module, bool change_expected) {
|
||||||
TupleSimplifier simplifier;
|
TupleSimplifier simplifier;
|
||||||
@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
Run(module.get(), /*change_expected=*/false);
|
Run(module, /*change_expected=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
|
TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
|
||||||
@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
|
|||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
Run(module.get(), /*change_expected=*/false);
|
Run(module, /*change_expected=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TupleSimplifierTest, GteOfTuple) {
|
TEST_F(TupleSimplifierTest, GteOfTuple) {
|
||||||
@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) {
|
|||||||
|
|
||||||
EXPECT_THAT(computation->root_instruction(), gte);
|
EXPECT_THAT(computation->root_instruction(), gte);
|
||||||
|
|
||||||
Run(module.get(), /*change_expected=*/true);
|
Run(module, /*change_expected=*/true);
|
||||||
|
|
||||||
EXPECT_THAT(computation->root_instruction(), param1);
|
EXPECT_THAT(computation->root_instruction(), param1);
|
||||||
}
|
}
|
||||||
@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) {
|
|||||||
EXPECT_THAT(computation->root_instruction(),
|
EXPECT_THAT(computation->root_instruction(),
|
||||||
op::Negate(op::GetTupleElement(op::Tuple())));
|
op::Negate(op::GetTupleElement(op::Tuple())));
|
||||||
|
|
||||||
Run(module.get(), /*change_expected=*/true);
|
Run(module, /*change_expected=*/true);
|
||||||
|
|
||||||
EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
|
EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
|
||||||
}
|
}
|
||||||
@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
|
|||||||
|
|
||||||
EXPECT_THAT(computation->root_instruction(), element);
|
EXPECT_THAT(computation->root_instruction(), element);
|
||||||
|
|
||||||
Run(module.get(), /*change_expected=*/true);
|
Run(module, /*change_expected=*/true);
|
||||||
|
|
||||||
EXPECT_THAT(computation->root_instruction(), param);
|
EXPECT_THAT(computation->root_instruction(), param);
|
||||||
}
|
}
|
||||||
@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
|
|||||||
|
|
||||||
EXPECT_THAT(computation->root_instruction(), tuple);
|
EXPECT_THAT(computation->root_instruction(), tuple);
|
||||||
|
|
||||||
Run(module.get(), /*change_expected=*/true);
|
Run(module, /*change_expected=*/true);
|
||||||
|
|
||||||
EXPECT_THAT(computation->root_instruction(), tuple_param);
|
EXPECT_THAT(computation->root_instruction(), tuple_param);
|
||||||
}
|
}
|
||||||
@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
|
|||||||
|
|
||||||
EXPECT_THAT(computation->root_instruction(), tuple);
|
EXPECT_THAT(computation->root_instruction(), tuple);
|
||||||
|
|
||||||
Run(module.get(), /*change_expected=*/false);
|
Run(module, /*change_expected=*/false);
|
||||||
|
|
||||||
EXPECT_THAT(computation->root_instruction(), tuple);
|
EXPECT_THAT(computation->root_instruction(), tuple);
|
||||||
}
|
}
|
||||||
@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
|
|||||||
entry = module->AddEntryComputation(builder.Build());
|
entry = module->AddEntryComputation(builder.Build());
|
||||||
}
|
}
|
||||||
|
|
||||||
Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
|
Run(module, /*change_expected=*/true, /*exclude_entry=*/true);
|
||||||
|
|
||||||
EXPECT_THAT(c0->root_instruction(), p0);
|
EXPECT_THAT(c0->root_instruction(), p0);
|
||||||
EXPECT_THAT(c1->root_instruction(), p1);
|
EXPECT_THAT(c1->root_instruction(), p1);
|
||||||
|
Loading…
Reference in New Issue
Block a user