[XLA] Merge HloTestBase and HloVerifiedTestBase.
No need to have this distinction any longer; you can simply call CreateNewUnverifiedModule or CreateNewVerifiedModule as you please. PiperOrigin-RevId: 221018719
This commit is contained in:
parent
a87acdb86d
commit
388eb3753b
@ -87,7 +87,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -124,7 +123,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -164,7 +162,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
],
|
||||
@ -282,7 +279,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:hlo_element_type_converter",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:lib",
|
||||
@ -365,7 +362,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -421,7 +417,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
@ -467,7 +462,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
@ -520,7 +514,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -569,7 +562,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -592,7 +584,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -1089,7 +1080,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -1172,7 +1162,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -1433,7 +1422,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -1509,7 +1497,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -1561,7 +1548,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -1598,7 +1584,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -1648,7 +1633,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -1707,7 +1692,7 @@ tf_cc_test(
|
||||
":while_loop_analysis",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -1740,7 +1725,7 @@ tf_cc_test(
|
||||
":hlo_matchers",
|
||||
":while_loop_simplifier",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -1771,7 +1756,7 @@ tf_cc_test(
|
||||
":hlo_matchers",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1799,7 +1784,7 @@ tf_cc_test(
|
||||
":implicit_broadcast_remover",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1844,7 +1829,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
@ -1878,7 +1862,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -2284,7 +2268,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -2347,7 +2330,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -2444,7 +2426,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
@ -2562,7 +2543,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -2629,7 +2609,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -2691,7 +2670,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
@ -2732,7 +2711,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:lib",
|
||||
@ -2771,7 +2749,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
@ -2846,7 +2823,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -3034,7 +3010,6 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo_tfgraph_builder",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
@ -3361,7 +3336,7 @@ tf_cc_test(
|
||||
":while_loop_invariant_code_motion",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
@ -3391,7 +3366,7 @@ tf_cc_test(
|
||||
":while_loop_constant_sinking",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
@ -3452,7 +3427,7 @@ tf_cc_test(
|
||||
":indexed_array_analysis",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -3549,7 +3524,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -33,7 +33,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -54,7 +53,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
|
||||
return [](const Shape&, const Shape&) { return false; };
|
||||
}
|
||||
|
||||
class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
|
||||
class AlgebraicSimplifierTest : public HloTestBase {};
|
||||
|
||||
// Test that A + 0 is simplified to A
|
||||
TEST_F(AlgebraicSimplifierTest, AddZero) {
|
||||
@ -2906,7 +2905,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
|
||||
/*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
|
||||
|
||||
// TODO(b/80488902): verify this module.
|
||||
auto module = HloTestBase::CreateNewUnverifiedModule();
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto* computation = module->AddEntryComputation(b.Build());
|
||||
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
|
||||
@ -3084,7 +3083,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
|
||||
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
|
||||
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
// TODO(b/80488902): verify this module.
|
||||
auto module = HloTestBase::CreateNewUnverifiedModule();
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
// Create operand to the pad.
|
||||
@ -3166,7 +3165,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
// ReduceWindow(Convert(op), x).
|
||||
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
|
||||
// TODO(b/80488902): verify this module.
|
||||
auto module = HloTestBase::CreateNewUnverifiedModule();
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
// Create operand to the pad.
|
||||
@ -3846,7 +3845,7 @@ struct DotOfConcatTestSpec {
|
||||
};
|
||||
|
||||
class DotOfConcatSimplificationTest
|
||||
: public HloVerifiedTestBase,
|
||||
: public HloTestBase,
|
||||
public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
|
||||
|
||||
// Test that we transform
|
||||
@ -4022,7 +4021,7 @@ struct DotOfGatherTestSpec {
|
||||
};
|
||||
|
||||
class DotOfGatherSimplificationTest
|
||||
: public HloVerifiedTestBase,
|
||||
: public HloTestBase,
|
||||
public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
|
||||
|
||||
// input: dot(DS(ctA), ctB))
|
||||
|
@ -17,14 +17,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class BatchDotSimplificationTest : public HloVerifiedTestBase {};
|
||||
class BatchDotSimplificationTest : public HloTestBase {};
|
||||
|
||||
TEST_F(BatchDotSimplificationTest,
|
||||
ElideSingleDegenerateBatchDotDim_VectorVector) {
|
||||
|
@ -29,14 +29,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using BatchNormExpanderTest = HloVerifiedTestBase;
|
||||
using BatchNormExpanderTest = HloTestBase;
|
||||
|
||||
// Test that we expand BatchNormTraining.
|
||||
TEST_F(BatchNormExpanderTest, BatchNormTraining) {
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
@ -65,11 +65,11 @@ class TestBFloat16Support : public BFloat16Support {
|
||||
}
|
||||
};
|
||||
|
||||
class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
|
||||
class BFloat16ConversionFoldingTest : public HloTestBase {
|
||||
protected:
|
||||
BFloat16ConversionFoldingTest()
|
||||
: HloVerifiedTestBase(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/true) {}
|
||||
: HloTestBase(/*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
|
||||
|
||||
bool FoldConversions(HloModule* module) {
|
||||
TestBFloat16Support bfloat16_support_;
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
@ -68,11 +68,11 @@ class TestBFloat16Support : public BFloat16Support {
|
||||
}
|
||||
};
|
||||
|
||||
class BFloat16NormalizationTest : public HloVerifiedTestBase {
|
||||
class BFloat16NormalizationTest : public HloTestBase {
|
||||
protected:
|
||||
BFloat16NormalizationTest()
|
||||
: HloVerifiedTestBase(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/true) {}
|
||||
: HloTestBase(/*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
|
||||
|
||||
bool Normalize(HloModule* module) {
|
||||
TestBFloat16Support bfloat16_support_;
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
@ -55,11 +55,11 @@ class TestBFloat16Support : public BFloat16Support {
|
||||
}
|
||||
};
|
||||
|
||||
class BFloat16PropagationTest : public HloVerifiedTestBase {
|
||||
class BFloat16PropagationTest : public HloTestBase {
|
||||
protected:
|
||||
BFloat16PropagationTest()
|
||||
: HloVerifiedTestBase(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/true) {}
|
||||
: HloTestBase(/*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
|
||||
|
||||
// Runs the propagation pass on the given module, and returns whether the
|
||||
// module is changed after this pass.
|
||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -81,7 +81,7 @@ const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
|
||||
return main_list.GetInstructions();
|
||||
}
|
||||
|
||||
class BufferAssignmentTest : public HloVerifiedTestBase {
|
||||
class BufferAssignmentTest : public HloTestBase {
|
||||
protected:
|
||||
~BufferAssignmentTest() override {}
|
||||
|
||||
@ -1818,7 +1818,7 @@ ENTRY main {
|
||||
}
|
||||
}
|
||||
|
||||
class WhileBufferAssignmentTest : public HloVerifiedTestBase {
|
||||
class WhileBufferAssignmentTest : public HloTestBase {
|
||||
protected:
|
||||
std::unique_ptr<HloComputation> BuildWhileConditionComputation(
|
||||
const string& name) {
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -31,7 +31,7 @@ namespace {
|
||||
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
class CallGraphTest : public HloVerifiedTestBase {
|
||||
class CallGraphTest : public HloTestBase {
|
||||
protected:
|
||||
// Build and return a trivial computation taking and returning a scalar.
|
||||
std::unique_ptr<HloComputation> MakeScalarComputation(
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -40,7 +40,7 @@ namespace {
|
||||
|
||||
// Tests for call inlining that are most tractable at the HLO level (vs
|
||||
// ComputationBuilder API in call_test.cc).
|
||||
using CallInlinerTest = HloVerifiedTestBase;
|
||||
using CallInlinerTest = HloTestBase;
|
||||
|
||||
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
||||
// "inner" computation just has a control dependency from the "zero" value to
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -37,7 +37,7 @@ namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class ConditionalSimplifierTest : public HloVerifiedTestBase {
|
||||
class ConditionalSimplifierTest : public HloTestBase {
|
||||
public:
|
||||
// Makes a computation that contains a conditional with constant predicate.
|
||||
HloComputation* MakeConditional(HloModule* module);
|
||||
|
@ -824,7 +824,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
@ -846,7 +845,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
@ -887,7 +885,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -971,7 +968,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -997,7 +993,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
@ -32,7 +32,7 @@ namespace cpu {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
class ConvCanonicalizationTest : public HloVerifiedTestBase {
|
||||
class ConvCanonicalizationTest : public HloTestBase {
|
||||
public:
|
||||
ConvCanonicalizationTest() {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) {
|
||||
return count;
|
||||
}
|
||||
|
||||
class CpuCopyInsertionTest : public HloVerifiedTestBase {
|
||||
class CpuCopyInsertionTest : public HloTestBase {
|
||||
protected:
|
||||
void InsertCopies(HloModule* module) {
|
||||
CpuCopyInsertion copy_insertion;
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
@ -25,7 +25,7 @@ namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
|
||||
class CpuHloSupportCheckerTest : public HloTestBase {
|
||||
protected:
|
||||
CpuHloSupportChecker& checker() { return checker_; }
|
||||
|
||||
@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
// Since verifier is reporting sparse layouts as errors, we should
|
||||
// use a regular HloModule instead of VerifiedHloModule to avoid
|
||||
// verifier errors being triggered in the destructor.
|
||||
auto module = HloTestBase::CreateNewUnverifiedModule();
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module.get()).status();
|
||||
|
@ -17,13 +17,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
|
||||
class ParallelTaskAssignmentTest : public HloTestBase {
|
||||
protected:
|
||||
const HloCostAnalysis::ShapeSizeFunction shape_size_func_ =
|
||||
cpu::CpuExecutable::ShapeSizeBytes;
|
||||
@ -35,7 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
|
||||
cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
|
||||
|
||||
ParallelTaskAssignmentTest()
|
||||
: HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) {
|
||||
: HloTestBase(), target_machine_features_([](int64 shape_size) {
|
||||
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
|
||||
}) {}
|
||||
|
||||
|
@ -19,14 +19,14 @@ limitations under the License.
|
||||
#include <random>
|
||||
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace {
|
||||
|
||||
class ShapePartitionAssignerTest : public HloVerifiedTestBase {
|
||||
class ShapePartitionAssignerTest : public HloTestBase {
|
||||
protected:
|
||||
typedef std::vector<int64> Vec;
|
||||
|
||||
@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
|
||||
expected_partitions);
|
||||
}
|
||||
|
||||
class ShapePartitionIteratorTest : public HloVerifiedTestBase {
|
||||
class ShapePartitionIteratorTest : public HloTestBase {
|
||||
protected:
|
||||
typedef std::vector<std::pair<int64, int64>> Partition;
|
||||
};
|
||||
@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
|
||||
}
|
||||
}
|
||||
|
||||
class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
|
||||
class RandomShapePartitionIteratorTest : public HloTestBase {
|
||||
protected:
|
||||
typedef std::vector<std::pair<int64, int64>> Partition;
|
||||
RandomShapePartitionIteratorTest()
|
||||
|
@ -48,7 +48,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -34,7 +34,7 @@ namespace xla {
|
||||
namespace cpu {
|
||||
namespace {
|
||||
|
||||
class CpuFusionTest : public HloVerifiedTestBase {
|
||||
class CpuFusionTest : public HloTestBase {
|
||||
protected:
|
||||
CpuFusionTest() {}
|
||||
|
||||
|
@ -18,14 +18,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class DefuserTest : public HloVerifiedTestBase {
|
||||
class DefuserTest : public HloTestBase {
|
||||
protected:
|
||||
// Returns the number of fusion instructions in the module.
|
||||
int FusionCount(const HloModule* m) {
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -30,7 +30,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class FlattenCallGraphTest : public HloVerifiedTestBase {
|
||||
class FlattenCallGraphTest : public HloTestBase {
|
||||
protected:
|
||||
// Build and return a trivial computation taking and returning a scalar.
|
||||
std::unique_ptr<HloComputation> MakeScalarComputation() {
|
||||
|
@ -111,7 +111,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
@ -463,7 +462,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:shape_inference",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -627,7 +626,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
|
||||
],
|
||||
)
|
||||
@ -849,7 +848,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -909,7 +907,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
@ -1036,6 +1033,6 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
],
|
||||
)
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -29,7 +29,7 @@ namespace {
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
using ::testing::_;
|
||||
|
||||
class CudnnConvPadForTensorCoresTest : public HloVerifiedTestBase {};
|
||||
class CudnnConvPadForTensorCoresTest : public HloTestBase {};
|
||||
|
||||
TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
@ -34,11 +34,11 @@ namespace {
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
using ::testing::_;
|
||||
|
||||
class CudnnConvRewriterTest : public HloVerifiedTestBase {
|
||||
class CudnnConvRewriterTest : public HloTestBase {
|
||||
public:
|
||||
CudnnConvRewriterTest()
|
||||
: HloVerifiedTestBase(/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false) {
|
||||
: HloTestBase(/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
WindowDimension* window_dim = default_conv_window_.add_dimensions();
|
||||
window_dim->set_size(1);
|
||||
|
@ -24,14 +24,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class GpuHloScheduleTest : public HloVerifiedTestBase {
|
||||
class GpuHloScheduleTest : public HloTestBase {
|
||||
protected:
|
||||
using HloVec = std::vector<const HloInstruction*>;
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
@ -25,7 +25,7 @@ namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
|
||||
class GpuHloSupportCheckerTest : public HloTestBase {
|
||||
protected:
|
||||
GpuHloSupportChecker& checker() { return checker_; }
|
||||
|
||||
@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
// Since verifier is reporting sparse layouts as errors, we should
|
||||
// use a regular HloModule instead of VerifiedHloModule to avoid
|
||||
// verifier errors being triggered in the destructor.
|
||||
auto module = HloTestBase::CreateNewUnverifiedModule();
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module.get()).status();
|
||||
|
@ -21,14 +21,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class StreamAssignmentTest : public HloVerifiedTestBase {
|
||||
class StreamAssignmentTest : public HloTestBase {
|
||||
protected:
|
||||
std::unique_ptr<HloModule> CreateNewUnverifiedModule() {
|
||||
HloModuleConfig config;
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
@ -32,7 +32,7 @@ namespace gpu {
|
||||
namespace {
|
||||
using match::Concatenate;
|
||||
|
||||
class VariadicOpSplitterTest : public HloVerifiedTestBase {};
|
||||
class VariadicOpSplitterTest : public HloTestBase {};
|
||||
|
||||
TEST_F(VariadicOpSplitterTest, DontSplit) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
|
@ -30,13 +30,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
|
||||
class MinimumMemoryForSequenceTest : public HloTestBase {};
|
||||
|
||||
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
@ -351,7 +351,7 @@ class HeapSimulatorTracker {
|
||||
HeapSimulator::Result result_;
|
||||
};
|
||||
|
||||
class HeapSimulatorTest : public HloVerifiedTestBase {
|
||||
class HeapSimulatorTest : public HloTestBase {
|
||||
protected:
|
||||
HeapSimulatorTest() {}
|
||||
~HeapSimulatorTest() override {}
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -39,9 +39,9 @@ namespace {
|
||||
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
class HloAliasAnalysisTest : public HloVerifiedTestBase {
|
||||
class HloAliasAnalysisTest : public HloTestBase {
|
||||
protected:
|
||||
HloAliasAnalysisTest() : HloVerifiedTestBase() {
|
||||
HloAliasAnalysisTest() : HloTestBase() {
|
||||
module_ = CreateNewVerifiedModule();
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using HloConstantFoldingTest = HloVerifiedTestBase;
|
||||
using HloConstantFoldingTest = HloTestBase;
|
||||
|
||||
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
@ -19,13 +19,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class HloCreationUtilsTest : public HloVerifiedTestBase {
|
||||
class HloCreationUtilsTest : public HloTestBase {
|
||||
protected:
|
||||
std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
|
||||
PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class HloCseTest : public HloVerifiedTestBase {
|
||||
class HloCseTest : public HloTestBase {
|
||||
protected:
|
||||
HloCseTest() {}
|
||||
};
|
||||
|
@ -22,13 +22,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class HloDomainTest : public HloVerifiedTestBase {
|
||||
class HloDomainTest : public HloTestBase {
|
||||
protected:
|
||||
bool FindUserViaDomainPath(HloInstruction* instruction,
|
||||
HloInstruction* operand) const {
|
||||
|
@ -33,7 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -50,9 +50,9 @@ namespace {
|
||||
static std::array<bool, 2> use_bf16_params{true, false};
|
||||
|
||||
class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
|
||||
public HloVerifiedTestBase {
|
||||
public HloTestBase {
|
||||
protected:
|
||||
HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) {
|
||||
HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) {
|
||||
evaluator_ = absl::make_unique<HloEvaluator>();
|
||||
}
|
||||
|
||||
@ -67,7 +67,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
|
||||
}
|
||||
|
||||
// Evaluate function that takes in a local module instead of using m_
|
||||
// that is in HloVerifiedTestBase. Once m_ in HloVerifiedTestBase is
|
||||
// that is in HloTestBase. Once m_ in HloTestBase is
|
||||
// removed, this should be the default Evaluate function.
|
||||
Literal EvaluateWithModule(
|
||||
HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
|
||||
@ -1298,7 +1298,7 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
|
||||
class HloEvaluatorPreciseReduceTest : public HloTestBase {};
|
||||
|
||||
// Tests that Reduce doesn't lose precision when adding many numbers (because
|
||||
// it accumulates its result in a double).
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
|
||||
@ -39,7 +39,7 @@ namespace {
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
class HloInstructionTest : public HloVerifiedTestBase {
|
||||
class HloInstructionTest : public HloTestBase {
|
||||
protected:
|
||||
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
|
||||
};
|
||||
|
@ -19,14 +19,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class HloPassPipelineTest : public HloVerifiedTestBase {
|
||||
class HloPassPipelineTest : public HloTestBase {
|
||||
protected:
|
||||
StatusOr<HloModuleGroup> ParseModuleGroup(
|
||||
absl::Span<const string> hlo_strings) {
|
||||
|
@ -20,13 +20,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
class HloReachabilityTest : public HloVerifiedTestBase {};
|
||||
class HloReachabilityTest : public HloTestBase {};
|
||||
|
||||
TEST_F(HloReachabilityTest, Reachability) {
|
||||
// Construct and test a reachability graph of the following form:
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
using ::testing::_;
|
||||
|
||||
class HloRematerializationTest : public HloVerifiedTestBase {
|
||||
class HloRematerializationTest : public HloTestBase {
|
||||
protected:
|
||||
// Creates and returns a computation which can benefit from
|
||||
// rematerialization. The computation looks like:
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
|
||||
@ -24,7 +24,7 @@ namespace {
|
||||
|
||||
using ::tensorflow::GraphDef;
|
||||
|
||||
class HloTfGraphBuilderTest : public HloVerifiedTestBase {
|
||||
class HloTfGraphBuilderTest : public HloTestBase {
|
||||
protected:
|
||||
HloTfGraphBuilderTest() {}
|
||||
HloTfGraphBuilder generator_;
|
||||
|
@ -35,7 +35,7 @@ namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
// This class cannot be converted to use HloVerifiedTestBase. It explicitly
|
||||
// This class cannot be converted to use HloTestBase. It explicitly
|
||||
// uses HloTestBase to create and test malformed HLOs.
|
||||
class HloVerifierTest : public HloTestBase {
|
||||
public:
|
||||
|
@ -18,14 +18,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase {
|
||||
class ImplicitBroadcastRemoverTest : public HloTestBase {
|
||||
protected:
|
||||
ImplicitBroadcastRemover remover_;
|
||||
};
|
||||
|
@ -16,12 +16,12 @@ limitations under the License.
|
||||
#include <ctype.h>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
|
||||
class IndexedArrayAnalysisTest : public HloTestBase {
|
||||
protected:
|
||||
void AssertArrayForRootExpressionIs(const string& hlo_text,
|
||||
const string& root_expression) {
|
||||
|
@ -35,7 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -49,7 +49,7 @@ namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
class LayoutAssignmentTest : public HloVerifiedTestBase {
|
||||
class LayoutAssignmentTest : public HloTestBase {
|
||||
protected:
|
||||
void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
|
||||
ChannelLayoutConstraints* channel_constraints = nullptr) {
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using MapInlinerTest = HloVerifiedTestBase;
|
||||
using MapInlinerTest = HloTestBase;
|
||||
|
||||
// Test that `map` with `max` is transformed to `max`
|
||||
TEST_F(MapInlinerTest, MapMax) {
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
@ -34,7 +34,7 @@ namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class ReshapeMoverTest : public HloVerifiedTestBase {};
|
||||
class ReshapeMoverTest : public HloTestBase {};
|
||||
|
||||
TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class TupleSimplifierTest : public HloVerifiedTestBase {
|
||||
class TupleSimplifierTest : public HloTestBase {
|
||||
protected:
|
||||
void Run(HloModule* module, bool change_expected) {
|
||||
TupleSimplifier simplifier;
|
||||
|
@ -17,13 +17,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class WhileLoopAnalysisTest : public HloVerifiedTestBase {};
|
||||
class WhileLoopAnalysisTest : public HloTestBase {};
|
||||
|
||||
TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) {
|
||||
const char* const kHloModule = R"(
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -26,7 +26,7 @@ namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
|
||||
class WhileLoopInvariantCodeMotionTest : public HloTestBase {
|
||||
public:
|
||||
// Makes a computation which has one parameter, of the given shape, and always
|
||||
// returns PRED[]{true}. This is useful as a dummy loop condition.
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -29,7 +29,7 @@ namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class WhileLoopSimplifierTest : public HloVerifiedTestBase {
|
||||
class WhileLoopSimplifierTest : public HloTestBase {
|
||||
protected:
|
||||
// Makes an HloModule that contains a loop with `num_iters` iteration.
|
||||
TF_MUST_USE_RESULT std::unique_ptr<VerifiedHloModule>
|
||||
|
@ -141,44 +141,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_verified_test_base",
|
||||
testonly = True,
|
||||
srcs = ["hlo_verified_test_base.cc"],
|
||||
hdrs = ["hlo_verified_test_base.h"],
|
||||
deps = [
|
||||
":hlo_test_base",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_verified_test_base_test",
|
||||
srcs = ["hlo_verified_test_base_test.cc"],
|
||||
deps = [
|
||||
":hlo_test_base",
|
||||
":hlo_verified_test_base",
|
||||
":test_macros_cpu",
|
||||
":test_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "local_client_aot_test_helper",
|
||||
srcs = ["local_client_aot_test_helper.cc"],
|
||||
|
@ -85,6 +85,23 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
|
||||
|
||||
} // namespace
|
||||
|
||||
Status VerifiedHloModule::Verify() {
|
||||
if (computation_count() == 0) {
|
||||
// The computation was never built. Nothing to verify.
|
||||
return Status::OK();
|
||||
}
|
||||
return verifier_.Run(this).status();
|
||||
}
|
||||
|
||||
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
|
||||
Status status = Verify();
|
||||
if (!status.ok()) {
|
||||
ADD_FAILURE() << "HloVerifier failed on module " << name()
|
||||
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
|
||||
<< ": " << status;
|
||||
}
|
||||
}
|
||||
|
||||
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
|
||||
bool allow_mixed_precision_in_hlo_verifier,
|
||||
std::function<bool(const HloInstruction*)>
|
||||
@ -100,7 +117,11 @@ HloTestBase::HloTestBase(se::Platform* test_platform,
|
||||
bool allow_mixed_precision_in_hlo_verifier,
|
||||
std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func)
|
||||
: test_runner_(test_platform), reference_runner_(reference_platform) {
|
||||
: test_runner_(test_platform),
|
||||
reference_runner_(reference_platform),
|
||||
verifier_layout_sensitive_(verifier_layout_sensitive),
|
||||
allow_mixed_precision_in_hlo_verifier_(
|
||||
allow_mixed_precision_in_hlo_verifier) {
|
||||
hlo_verifier_ = absl::make_unique<HloVerifier>(
|
||||
/*layout_sensitive=*/verifier_layout_sensitive,
|
||||
/*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
|
||||
@ -112,6 +133,32 @@ std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
|
||||
return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
|
||||
}
|
||||
|
||||
std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
|
||||
const string& name) {
|
||||
return absl::make_unique<VerifiedHloModule>(
|
||||
name, GetModuleConfigForTest(), verifier_layout_sensitive_,
|
||||
allow_mixed_precision_in_hlo_verifier_);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>>
|
||||
HloTestBase::ParseAndReturnUnverifiedModule(absl::string_view hlo_text,
|
||||
const HloModuleConfig& config) {
|
||||
auto module = absl::make_unique<HloModule>(TestName(), config);
|
||||
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
|
||||
const HloModuleConfig& config) {
|
||||
auto module = absl::make_unique<VerifiedHloModule>(
|
||||
TestName(), config, verifier_layout_sensitive_,
|
||||
allow_mixed_precision_in_hlo_verifier_);
|
||||
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
|
||||
TF_RETURN_IF_ERROR(module->Verify());
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
|
||||
HloModule* module) {
|
||||
|
@ -38,6 +38,31 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// An HLO module derived class which verifies itself on destruction. This class
|
||||
// is intended to be used in unit tests. Any verification errors are raised via
|
||||
// ADD_FAILURE.
|
||||
class VerifiedHloModule : public HloModule {
|
||||
public:
|
||||
VerifiedHloModule(const string& name, const HloModuleConfig& config,
|
||||
bool verifier_layout_sensitive,
|
||||
bool allow_mixed_precision_in_hlo_verifier)
|
||||
: HloModule(name, config),
|
||||
verifier_(verifier_layout_sensitive,
|
||||
allow_mixed_precision_in_hlo_verifier) {}
|
||||
|
||||
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
|
||||
|
||||
// Verifies the module using HloVerifier and returns the status.
|
||||
Status Verify();
|
||||
|
||||
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
|
||||
// included in the failure message.
|
||||
void VerifyOrAddFailure(const string& message);
|
||||
|
||||
private:
|
||||
HloVerifier verifier_;
|
||||
};
|
||||
|
||||
// A base class for tests which build and/or run HLO code. The class includes
|
||||
// support for running an HLO module on two platforms and compare the results.
|
||||
// This is a lower level of abstraction than using the client interface and
|
||||
@ -74,11 +99,26 @@ class HloTestBase : public ::testing::Test {
|
||||
// tests.
|
||||
//
|
||||
// This returns a vanilla HloModule that doesn't run the HLO verifier on
|
||||
// destruction. If you want to run the verifier, you want
|
||||
// HloVerifiedTestBase::CreateNewVerifiedModule.
|
||||
// destruction.
|
||||
std::unique_ptr<HloModule> CreateNewUnverifiedModule(
|
||||
const string& name = TestName());
|
||||
|
||||
// Like CreateNewUnverifiedModule, except the HloModule returned here runs the
|
||||
// HLO verifier on destruction.
|
||||
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
|
||||
const string& name = TestName());
|
||||
|
||||
// Parses the given string and returns module as a vanilla, unverified
|
||||
// HloModule.
|
||||
StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
|
||||
absl::string_view hlo_text,
|
||||
const HloModuleConfig& config = HloModuleConfig());
|
||||
|
||||
// Parses the given string and returns module as a VerifiedHloModule.
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||
absl::string_view hlo_text,
|
||||
const HloModuleConfig& config = HloModuleConfig());
|
||||
|
||||
// Runs the hlo_pass with the provided module and returns the result. This
|
||||
// function also verifies that the module remains unchanged when hlo_pass
|
||||
// returns false as the StatusOr value.
|
||||
@ -252,6 +292,8 @@ class HloTestBase : public ::testing::Test {
|
||||
HloRunner test_runner_;
|
||||
HloRunner reference_runner_;
|
||||
|
||||
bool verifier_layout_sensitive_;
|
||||
bool allow_mixed_precision_in_hlo_verifier_;
|
||||
std::unique_ptr<HloVerifier> hlo_verifier_;
|
||||
|
||||
ErrorSpec error_spec_{0.0001};
|
||||
|
@ -1,68 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status VerifiedHloModule::Verify() {
|
||||
if (computation_count() == 0) {
|
||||
// The computation was never built. Nothing to verify.
|
||||
return Status::OK();
|
||||
}
|
||||
return verifier_.Run(this).status();
|
||||
}
|
||||
|
||||
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
|
||||
Status status = Verify();
|
||||
if (!status.ok()) {
|
||||
ADD_FAILURE() << "HloVerifier failed on module " << name()
|
||||
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
|
||||
<< ": " << status;
|
||||
}
|
||||
}
|
||||
|
||||
HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
|
||||
bool allow_mixed_precision)
|
||||
: HloTestBase(
|
||||
/*verifier_layout_sensitive=*/layout_sensitive,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision),
|
||||
verifier_layout_sensitive_(layout_sensitive),
|
||||
allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {}
|
||||
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||
HloVerifiedTestBase::ParseAndReturnVerifiedModule(
|
||||
absl::string_view hlo_text, const HloModuleConfig& config) {
|
||||
auto module = CreateNewVerifiedModule(TestName());
|
||||
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
|
||||
TF_RETURN_IF_ERROR(module->Verify());
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule(
|
||||
const string& name) {
|
||||
return absl::make_unique<VerifiedHloModule>(
|
||||
name, GetModuleConfigForTest(), verifier_layout_sensitive_,
|
||||
allow_mixed_precision_in_hlo_verifier_);
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -1,86 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/base/macros.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// An HLO module derived class which verifies itself on destruction. This class
|
||||
// is intended to be used in unit tests. Any verification errors are raised via
|
||||
// ADD_FAILURE.
|
||||
class VerifiedHloModule : public HloModule {
|
||||
public:
|
||||
VerifiedHloModule(const string& name, const HloModuleConfig& config,
|
||||
bool verifier_layout_sensitive,
|
||||
bool allow_mixed_precision_in_hlo_verifier)
|
||||
: HloModule(name, config),
|
||||
verifier_(verifier_layout_sensitive,
|
||||
allow_mixed_precision_in_hlo_verifier) {}
|
||||
|
||||
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
|
||||
|
||||
// Verifies the module using HloVerifier and returns the status.
|
||||
Status Verify();
|
||||
|
||||
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
|
||||
// included in the failure message.
|
||||
void VerifyOrAddFailure(const string& message);
|
||||
|
||||
private:
|
||||
HloVerifier verifier_;
|
||||
};
|
||||
|
||||
// A base class for HLO tests that stores a default VerifiedHloModule.
|
||||
class HloVerifiedTestBase : public HloTestBase {
|
||||
protected:
|
||||
HloVerifiedTestBase(bool layout_sensitive = false,
|
||||
bool allow_mixed_precision = false);
|
||||
|
||||
// Constructs a default shape verifier.
|
||||
std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
|
||||
|
||||
// Parses the given string and returns module as a VerifiedHloModule.
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||
absl::string_view hlo_text,
|
||||
const HloModuleConfig& config = HloModuleConfig());
|
||||
|
||||
// Creates and returns a verified HLO module with the given name.
|
||||
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
|
||||
const string& name = TestName());
|
||||
|
||||
// CreateNewUnverifiedModule creates an *unverified* module, which presumably
|
||||
// isn't what you want if you're using HloVerifiedTestBase, so we delete this
|
||||
// function to keep you from accidentally calling it. If you really want it,
|
||||
// you can get it by calling HloTestBase::CreateNewUnverifiedModule().
|
||||
std::unique_ptr<HloModule> CreateNewUnverifiedModule(
|
||||
const string& name = TestName()) = delete;
|
||||
|
||||
private:
|
||||
bool verifier_layout_sensitive_;
|
||||
bool allow_mixed_precision_in_hlo_verifier_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_
|
@ -1,115 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// This class includes unit tests which are expected to fail because invalid HLO
|
||||
// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to
|
||||
// include the necessary gunit parts to test this test machinery (needs the
|
||||
// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the
|
||||
// disabled tests enabled and failures can be manually compared against
|
||||
// expectations.
|
||||
class HloVerifiedTestBaseTest : public HloVerifiedTestBase {};
|
||||
|
||||
XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) {
|
||||
// Test shouldn't fail if no module is created at all.
|
||||
}
|
||||
|
||||
XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewUnverifiedModule) {
|
||||
// Call CreateNewUnverifiedModule and build up a valid module.
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
}
|
||||
|
||||
// This test is expected to fail. See test class comment.
|
||||
XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewUnverifiedModule) {
|
||||
// Call CreateNewUnverifiedModule and build up a invalid module.
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
*module->entry_computation()->root_instruction()->mutable_shape() =
|
||||
ShapeUtil::MakeShape(PRED, {1, 2, 3});
|
||||
}
|
||||
|
||||
XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule ParseAndReturnVerifiedModuleGood
|
||||
|
||||
ENTRY entry {
|
||||
x = f32[] parameter(0)
|
||||
y = f32[] parameter(1)
|
||||
ROOT add = f32[] add(x,y)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
|
||||
}
|
||||
|
||||
XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule ParseAndReturnVerifiedModuleGood
|
||||
|
||||
ENTRY entry {
|
||||
x = f32[] parameter(0)
|
||||
y = f32[] parameter(1)
|
||||
ROOT add = f32[] add(x,y)
|
||||
}
|
||||
|
||||
RANDOM GARBAGE
|
||||
)";
|
||||
|
||||
ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
|
||||
}
|
||||
|
||||
// This test is expected to fail. See test class comment.
|
||||
XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule ParseAndReturnVerifiedModuleBad
|
||||
|
||||
ENTRY entry {
|
||||
x = f32[] parameter(0)
|
||||
y = f32[] parameter(1)
|
||||
ROOT add = f32[1234] add(x,y)
|
||||
}
|
||||
)";
|
||||
|
||||
ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user