[XLA] Merge HloTestBase and HloVerifiedTestBase.

No need to have this distinction any longer; you can simply call
CreateNewUnverifiedModule or CreateNewVerifiedModule as you please.

PiperOrigin-RevId: 221018719
This commit is contained in:
Justin Lebar 2018-11-11 16:18:06 -08:00 committed by TensorFlower Gardener
parent a87acdb86d
commit 388eb3753b
56 changed files with 219 additions and 474 deletions

View File

@ -87,7 +87,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@ -124,7 +123,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@ -164,7 +162,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
],
@ -282,7 +279,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_element_type_converter",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
@ -365,7 +362,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@ -421,7 +417,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@ -467,7 +462,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@ -520,7 +514,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@ -569,7 +562,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -592,7 +584,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@ -1089,7 +1080,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -1172,7 +1162,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -1433,7 +1422,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
@ -1509,7 +1497,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@ -1561,7 +1548,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -1598,7 +1584,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -1648,7 +1633,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@ -1707,7 +1692,7 @@ tf_cc_test(
":while_loop_analysis",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@ -1740,7 +1725,7 @@ tf_cc_test(
":hlo_matchers",
":while_loop_simplifier",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
@ -1771,7 +1756,7 @@ tf_cc_test(
":hlo_matchers",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
],
)
@ -1799,7 +1784,7 @@ tf_cc_test(
":implicit_broadcast_remover",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
],
)
@ -1844,7 +1829,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@ -1878,7 +1862,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@ -2284,7 +2268,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -2347,7 +2330,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -2444,7 +2426,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@ -2562,7 +2543,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@ -2629,7 +2609,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -2691,7 +2670,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@ -2732,7 +2711,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
@ -2771,7 +2749,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@ -2846,7 +2823,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
@ -3034,7 +3010,6 @@ tf_cc_test(
deps = [
":hlo_tfgraph_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
],
@ -3361,7 +3336,7 @@ tf_cc_test(
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
],
)
@ -3391,7 +3366,7 @@ tf_cc_test(
":while_loop_constant_sinking",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
],
)
@ -3452,7 +3427,7 @@ tf_cc_test(
":indexed_array_analysis",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:test",
],
@ -3549,7 +3524,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"@com_google_absl//absl/memory",

View File

@ -33,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -54,7 +53,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
return [](const Shape&, const Shape&) { return false; };
}
class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
class AlgebraicSimplifierTest : public HloTestBase {};
// Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, AddZero) {
@ -2906,7 +2905,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
/*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewUnverifiedModule();
auto module = CreateNewUnverifiedModule();
auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
@ -3084,7 +3083,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewUnverifiedModule();
auto module = CreateNewUnverifiedModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@ -3166,7 +3165,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewUnverifiedModule();
auto module = CreateNewUnverifiedModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@ -3846,7 +3845,7 @@ struct DotOfConcatTestSpec {
};
class DotOfConcatSimplificationTest
: public HloVerifiedTestBase,
: public HloTestBase,
public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
// Test that we transform
@ -4022,7 +4021,7 @@ struct DotOfGatherTestSpec {
};
class DotOfGatherSimplificationTest
: public HloVerifiedTestBase,
: public HloTestBase,
public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
// input: dot(DS(ctA), ctB))

View File

@ -17,14 +17,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
namespace xla {
namespace {
namespace op = xla::testing::opcode_matchers;
class BatchDotSimplificationTest : public HloVerifiedTestBase {};
class BatchDotSimplificationTest : public HloTestBase {};
TEST_F(BatchDotSimplificationTest,
ElideSingleDegenerateBatchDotDim_VectorVector) {

View File

@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
using BatchNormExpanderTest = HloVerifiedTestBase;
using BatchNormExpanderTest = HloTestBase;
// Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest, BatchNormTraining) {

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -65,11 +65,11 @@ class TestBFloat16Support : public BFloat16Support {
}
};
class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
class BFloat16ConversionFoldingTest : public HloTestBase {
protected:
BFloat16ConversionFoldingTest()
: HloVerifiedTestBase(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/true) {}
: HloTestBase(/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
bool FoldConversions(HloModule* module) {
TestBFloat16Support bfloat16_support_;

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -68,11 +68,11 @@ class TestBFloat16Support : public BFloat16Support {
}
};
class BFloat16NormalizationTest : public HloVerifiedTestBase {
class BFloat16NormalizationTest : public HloTestBase {
protected:
BFloat16NormalizationTest()
: HloVerifiedTestBase(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/true) {}
: HloTestBase(/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
bool Normalize(HloModule* module) {
TestBFloat16Support bfloat16_support_;

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -55,11 +55,11 @@ class TestBFloat16Support : public BFloat16Support {
}
};
class BFloat16PropagationTest : public HloVerifiedTestBase {
class BFloat16PropagationTest : public HloTestBase {
protected:
BFloat16PropagationTest()
: HloVerifiedTestBase(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/true) {}
: HloTestBase(/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
// Runs the propagation pass on the given module, and returns whether the
// module is changed after this pass.

View File

@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -81,7 +81,7 @@ const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
return main_list.GetInstructions();
}
class BufferAssignmentTest : public HloVerifiedTestBase {
class BufferAssignmentTest : public HloTestBase {
protected:
~BufferAssignmentTest() override {}
@ -1818,7 +1818,7 @@ ENTRY main {
}
}
class WhileBufferAssignmentTest : public HloVerifiedTestBase {
class WhileBufferAssignmentTest : public HloTestBase {
protected:
std::unique_ptr<HloComputation> BuildWhileConditionComputation(
const string& name) {

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -31,7 +31,7 @@ namespace {
using ::testing::UnorderedElementsAre;
class CallGraphTest : public HloVerifiedTestBase {
class CallGraphTest : public HloTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation(

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -40,7 +40,7 @@ namespace {
// Tests for call inlining that are most tractable at the HLO level (vs
// ComputationBuilder API in call_test.cc).
using CallInlinerTest = HloVerifiedTestBase;
using CallInlinerTest = HloTestBase;
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// "inner" computation just has a control dependency from the "zero" value to

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
@ -37,7 +37,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class ConditionalSimplifierTest : public HloVerifiedTestBase {
class ConditionalSimplifierTest : public HloTestBase {
public:
// Makes a computation that contains a conditional with constant predicate.
HloComputation* MakeConditional(HloModule* module);

View File

@ -824,7 +824,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@ -846,7 +845,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@ -887,7 +885,6 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -971,7 +968,6 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@ -997,7 +993,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@ -32,7 +32,7 @@ namespace cpu {
using ::testing::ElementsAre;
class ConvCanonicalizationTest : public HloVerifiedTestBase {
class ConvCanonicalizationTest : public HloTestBase {
public:
ConvCanonicalizationTest() {
for (int i = 0; i < 2; ++i) {

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) {
return count;
}
class CpuCopyInsertionTest : public HloVerifiedTestBase {
class CpuCopyInsertionTest : public HloTestBase {
protected:
void InsertCopies(HloModule* module) {
CpuCopyInsertion copy_insertion;

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
class CpuHloSupportCheckerTest : public HloTestBase {
protected:
CpuHloSupportChecker& checker() { return checker_; }
@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
// Since verifier is reporting sparse layouts as errors, we should
// use a regular HloModule instead of VerifiedHloModule to avoid
// verifier errors being triggered in the destructor.
auto module = HloTestBase::CreateNewUnverifiedModule();
auto module = CreateNewUnverifiedModule();
module->AddEntryComputation(builder.Build());
Status status = checker().Run(module.get()).status();

View File

@ -17,13 +17,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
class ParallelTaskAssignmentTest : public HloTestBase {
protected:
const HloCostAnalysis::ShapeSizeFunction shape_size_func_ =
cpu::CpuExecutable::ShapeSizeBytes;
@ -35,7 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
ParallelTaskAssignmentTest()
: HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) {
: HloTestBase(), target_machine_features_([](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
}) {}

View File

@ -19,14 +19,14 @@ limitations under the License.
#include <random>
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace cpu {
namespace {
class ShapePartitionAssignerTest : public HloVerifiedTestBase {
class ShapePartitionAssignerTest : public HloTestBase {
protected:
typedef std::vector<int64> Vec;
@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
expected_partitions);
}
class ShapePartitionIteratorTest : public HloVerifiedTestBase {
class ShapePartitionIteratorTest : public HloTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
};
@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
}
}
class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
class RandomShapePartitionIteratorTest : public HloTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
RandomShapePartitionIteratorTest()

View File

@ -48,7 +48,6 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
@ -34,7 +34,7 @@ namespace xla {
namespace cpu {
namespace {
class CpuFusionTest : public HloVerifiedTestBase {
class CpuFusionTest : public HloTestBase {
protected:
CpuFusionTest() {}

View File

@ -18,14 +18,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
class DefuserTest : public HloVerifiedTestBase {
class DefuserTest : public HloTestBase {
protected:
// Returns the number of fusion instructions in the module.
int FusionCount(const HloModule* m) {

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -30,7 +30,7 @@ limitations under the License.
namespace xla {
namespace {
class FlattenCallGraphTest : public HloVerifiedTestBase {
class FlattenCallGraphTest : public HloTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation() {

View File

@ -111,7 +111,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@ -463,7 +462,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:test",
],
@ -627,7 +626,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
],
)
@ -849,7 +848,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
@ -909,7 +907,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
@ -1036,6 +1033,6 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
],
)

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@ -29,7 +29,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
class CudnnConvPadForTensorCoresTest : public HloVerifiedTestBase {};
class CudnnConvPadForTensorCoresTest : public HloTestBase {};
TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) {
auto module = ParseAndReturnVerifiedModule(R"(

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@ -34,11 +34,11 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
class CudnnConvRewriterTest : public HloVerifiedTestBase {
class CudnnConvRewriterTest : public HloTestBase {
public:
CudnnConvRewriterTest()
: HloVerifiedTestBase(/*layout_sensitive=*/true,
/*allow_mixed_precision=*/false) {
: HloTestBase(/*layout_sensitive=*/true,
/*allow_mixed_precision=*/false) {
for (int i = 0; i < 2; ++i) {
WindowDimension* window_dim = default_conv_window_.add_dimensions();
window_dim->set_size(1);

View File

@ -24,14 +24,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
class GpuHloScheduleTest : public HloVerifiedTestBase {
class GpuHloScheduleTest : public HloTestBase {
protected:
using HloVec = std::vector<const HloInstruction*>;

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
class GpuHloSupportCheckerTest : public HloTestBase {
protected:
GpuHloSupportChecker& checker() { return checker_; }
@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
// Since verifier is reporting sparse layouts as errors, we should
// use a regular HloModule instead of VerifiedHloModule to avoid
// verifier errors being triggered in the destructor.
auto module = HloTestBase::CreateNewUnverifiedModule();
auto module = CreateNewUnverifiedModule();
module->AddEntryComputation(builder.Build());
Status status = checker().Run(module.get()).status();

View File

@ -21,14 +21,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
class StreamAssignmentTest : public HloVerifiedTestBase {
class StreamAssignmentTest : public HloTestBase {
protected:
std::unique_ptr<HloModule> CreateNewUnverifiedModule() {
HloModuleConfig config;

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -32,7 +32,7 @@ namespace gpu {
namespace {
using match::Concatenate;
class VariadicOpSplitterTest : public HloVerifiedTestBase {};
class VariadicOpSplitterTest : public HloTestBase {};
TEST_F(VariadicOpSplitterTest, DontSplit) {
auto module = ParseAndReturnVerifiedModule(R"(

View File

@ -30,13 +30,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
class MinimumMemoryForSequenceTest : public HloTestBase {};
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
auto module = CreateNewVerifiedModule();
@ -351,7 +351,7 @@ class HeapSimulatorTracker {
HeapSimulator::Result result_;
};
class HeapSimulatorTest : public HloVerifiedTestBase {
class HeapSimulatorTest : public HloTestBase {
protected:
HeapSimulatorTest() {}
~HeapSimulatorTest() override {}

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
@ -39,9 +39,9 @@ namespace {
using ::testing::UnorderedElementsAre;
class HloAliasAnalysisTest : public HloVerifiedTestBase {
class HloAliasAnalysisTest : public HloTestBase {
protected:
HloAliasAnalysisTest() : HloVerifiedTestBase() {
HloAliasAnalysisTest() : HloTestBase() {
module_ = CreateNewVerifiedModule();
}

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
using HloConstantFoldingTest = HloVerifiedTestBase;
using HloConstantFoldingTest = HloTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());

View File

@ -19,13 +19,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
class HloCreationUtilsTest : public HloVerifiedTestBase {
class HloCreationUtilsTest : public HloTestBase {
protected:
std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
class HloCseTest : public HloVerifiedTestBase {
class HloCseTest : public HloTestBase {
protected:
HloCseTest() {}
};

View File

@ -22,13 +22,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
class HloDomainTest : public HloVerifiedTestBase {
class HloDomainTest : public HloTestBase {
protected:
bool FindUserViaDomainPath(HloInstruction* instruction,
HloInstruction* operand) const {

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@ -50,9 +50,9 @@ namespace {
static std::array<bool, 2> use_bf16_params{true, false};
class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
public HloVerifiedTestBase {
public HloTestBase {
protected:
HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) {
HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) {
evaluator_ = absl::make_unique<HloEvaluator>();
}
@ -67,7 +67,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
}
// Evaluate function that takes in a local module instead of using m_
// that is in HloVerifiedTestBase. Once m_ in HloVerifiedTestBase is
// that is in HloTestBase. Once m_ in HloTestBase is
// removed, this should be the default Evaluate function.
Literal EvaluateWithModule(
HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
@ -1298,7 +1298,7 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
class HloEvaluatorPreciseReduceTest : public HloTestBase {};
// Tests that Reduce doesn't lose precision when adding many numbers (because
// it accumulates its result in a double).

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@ -39,7 +39,7 @@ namespace {
using ::testing::ElementsAre;
using ::testing::UnorderedElementsAre;
class HloInstructionTest : public HloVerifiedTestBase {
class HloInstructionTest : public HloTestBase {
protected:
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
};

View File

@ -19,14 +19,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
class HloPassPipelineTest : public HloVerifiedTestBase {
class HloPassPipelineTest : public HloTestBase {
protected:
StatusOr<HloModuleGroup> ParseModuleGroup(
absl::Span<const string> hlo_strings) {

View File

@ -20,13 +20,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace {
class HloReachabilityTest : public HloVerifiedTestBase {};
class HloReachabilityTest : public HloTestBase {};
TEST_F(HloReachabilityTest, Reachability) {
// Construct and test a reachability graph of the following form:

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
using ::testing::_;
class HloRematerializationTest : public HloVerifiedTestBase {
class HloRematerializationTest : public HloTestBase {
protected:
// Creates and returns a computation which can benefit from
// rematerialization. The computation looks like:

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@ -24,7 +24,7 @@ namespace {
using ::tensorflow::GraphDef;
class HloTfGraphBuilderTest : public HloVerifiedTestBase {
class HloTfGraphBuilderTest : public HloTestBase {
protected:
HloTfGraphBuilderTest() {}
HloTfGraphBuilder generator_;

View File

@ -35,7 +35,7 @@ namespace {
using ::testing::HasSubstr;
// This class cannot be converted to use HloVerifiedTestBase. It explicitly
// This class cannot be converted to use HloTestBase. It explicitly
// uses HloTestBase to create and test malformed HLOs.
class HloVerifierTest : public HloTestBase {
public:

View File

@ -18,14 +18,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase {
class ImplicitBroadcastRemoverTest : public HloTestBase {
protected:
ImplicitBroadcastRemover remover_;
};

View File

@ -16,12 +16,12 @@ limitations under the License.
#include <ctype.h>
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
namespace xla {
namespace {
class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
class IndexedArrayAnalysisTest : public HloTestBase {
protected:
void AssertArrayForRootExpressionIs(const string& hlo_text,
const string& root_expression) {

View File

@ -35,7 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -49,7 +49,7 @@ namespace {
using ::testing::ElementsAre;
class LayoutAssignmentTest : public HloVerifiedTestBase {
class LayoutAssignmentTest : public HloTestBase {
protected:
void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr) {

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
using MapInlinerTest = HloVerifiedTestBase;
using MapInlinerTest = HloTestBase;
// Test that `map` with `max` is transformed to `max`
TEST_F(MapInlinerTest, MapMax) {

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -34,7 +34,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class ReshapeMoverTest : public HloVerifiedTestBase {};
class ReshapeMoverTest : public HloTestBase {};
TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
auto m = CreateNewVerifiedModule();

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
class TupleSimplifierTest : public HloVerifiedTestBase {
class TupleSimplifierTest : public HloTestBase {
protected:
void Run(HloModule* module, bool change_expected) {
TupleSimplifier simplifier;

View File

@ -17,13 +17,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
class WhileLoopAnalysisTest : public HloVerifiedTestBase {};
class WhileLoopAnalysisTest : public HloTestBase {};
TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) {
const char* const kHloModule = R"(

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@ -26,7 +26,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
class WhileLoopInvariantCodeMotionTest : public HloTestBase {
public:
// Makes a computation which has one parameter, of the given shape, and always
// returns PRED[]{true}. This is useful as a dummy loop condition.

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@ -29,7 +29,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class WhileLoopSimplifierTest : public HloVerifiedTestBase {
class WhileLoopSimplifierTest : public HloTestBase {
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
TF_MUST_USE_RESULT std::unique_ptr<VerifiedHloModule>

View File

@ -141,44 +141,6 @@ cc_library(
],
)
cc_library(
name = "hlo_verified_test_base",
testonly = True,
srcs = ["hlo_verified_test_base.cc"],
hdrs = ["hlo_verified_test_base.h"],
deps = [
":hlo_test_base",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
],
)
tf_cc_test(
name = "hlo_verified_test_base_test",
srcs = ["hlo_verified_test_base_test.cc"],
deps = [
":hlo_test_base",
":hlo_verified_test_base",
":test_macros_cpu",
":test_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
tf_cc_binary(
name = "local_client_aot_test_helper",
srcs = ["local_client_aot_test_helper.cc"],

View File

@ -85,6 +85,23 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace
Status VerifiedHloModule::Verify() {
if (computation_count() == 0) {
// The computation was never built. Nothing to verify.
return Status::OK();
}
return verifier_.Run(this).status();
}
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
Status status = Verify();
if (!status.ok()) {
ADD_FAILURE() << "HloVerifier failed on module " << name()
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
<< ": " << status;
}
}
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier,
std::function<bool(const HloInstruction*)>
@ -100,7 +117,11 @@ HloTestBase::HloTestBase(se::Platform* test_platform,
bool allow_mixed_precision_in_hlo_verifier,
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func)
: test_runner_(test_platform), reference_runner_(reference_platform) {
: test_runner_(test_platform),
reference_runner_(reference_platform),
verifier_layout_sensitive_(verifier_layout_sensitive),
allow_mixed_precision_in_hlo_verifier_(
allow_mixed_precision_in_hlo_verifier) {
hlo_verifier_ = absl::make_unique<HloVerifier>(
/*layout_sensitive=*/verifier_layout_sensitive,
/*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
@ -112,6 +133,32 @@ std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
}
std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
const string& name) {
return absl::make_unique<VerifiedHloModule>(
name, GetModuleConfigForTest(), verifier_layout_sensitive_,
allow_mixed_precision_in_hlo_verifier_);
}
StatusOr<std::unique_ptr<HloModule>>
HloTestBase::ParseAndReturnUnverifiedModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
auto module = absl::make_unique<HloModule>(TestName(), config);
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
return std::move(module);
}
StatusOr<std::unique_ptr<VerifiedHloModule>>
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
auto module = absl::make_unique<VerifiedHloModule>(
TestName(), config, verifier_layout_sensitive_,
allow_mixed_precision_in_hlo_verifier_);
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
return std::move(module);
}
/* static */
StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
HloModule* module) {

View File

@ -38,6 +38,31 @@ limitations under the License.
namespace xla {
// An HLO module derived class which verifies itself on destruction. This class
// is intended to be used in unit tests. Any verification errors are raised via
// ADD_FAILURE.
class VerifiedHloModule : public HloModule {
public:
VerifiedHloModule(const string& name, const HloModuleConfig& config,
bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier)
: HloModule(name, config),
verifier_(verifier_layout_sensitive,
allow_mixed_precision_in_hlo_verifier) {}
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
// Verifies the module using HloVerifier and returns the status.
Status Verify();
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
// included in the failure message.
void VerifyOrAddFailure(const string& message);
private:
HloVerifier verifier_;
};
// A base class for tests which build and/or run HLO code. The class includes
// support for running an HLO module on two platforms and compare the results.
// This is a lower level of abstraction than using the client interface and
@ -74,11 +99,26 @@ class HloTestBase : public ::testing::Test {
// tests.
//
// This returns a vanilla HloModule that doesn't run the HLO verifier on
// destruction. If you want to run the verifier, you want
// HloVerifiedTestBase::CreateNewVerifiedModule.
// destruction.
std::unique_ptr<HloModule> CreateNewUnverifiedModule(
const string& name = TestName());
// Like CreateNewUnverifiedModule, except the HloModule returned here runs the
// HLO verifier on destruction.
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
const string& name = TestName());
// Parses the given string and returns module as a vanilla, unverified
// HloModule.
StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
// Parses the given string and returns module as a VerifiedHloModule.
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
// Runs the hlo_pass with the provided module and returns the result. This
// function also verifies that the module remains unchanged when hlo_pass
// returns false as the StatusOr value.
@ -252,6 +292,8 @@ class HloTestBase : public ::testing::Test {
HloRunner test_runner_;
HloRunner reference_runner_;
bool verifier_layout_sensitive_;
bool allow_mixed_precision_in_hlo_verifier_;
std::unique_ptr<HloVerifier> hlo_verifier_;
ErrorSpec error_spec_{0.0001};

View File

@ -1,68 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
Status VerifiedHloModule::Verify() {
if (computation_count() == 0) {
// The computation was never built. Nothing to verify.
return Status::OK();
}
return verifier_.Run(this).status();
}
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
Status status = Verify();
if (!status.ok()) {
ADD_FAILURE() << "HloVerifier failed on module " << name()
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
<< ": " << status;
}
}
HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
bool allow_mixed_precision)
: HloTestBase(
/*verifier_layout_sensitive=*/layout_sensitive,
/*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision),
verifier_layout_sensitive_(layout_sensitive),
allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {}
StatusOr<std::unique_ptr<VerifiedHloModule>>
HloVerifiedTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config) {
auto module = CreateNewVerifiedModule(TestName());
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
return std::move(module);
}
std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule(
const string& name) {
return absl::make_unique<VerifiedHloModule>(
name, GetModuleConfigForTest(), verifier_layout_sensitive_,
allow_mixed_precision_in_hlo_verifier_);
}
} // namespace xla

View File

@ -1,86 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_
#include <functional>
#include <memory>
#include <utility>
#include "absl/base/macros.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
// An HLO module derived class which verifies itself on destruction. This class
// is intended to be used in unit tests. Any verification errors are raised via
// ADD_FAILURE.
class VerifiedHloModule : public HloModule {
public:
VerifiedHloModule(const string& name, const HloModuleConfig& config,
bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier)
: HloModule(name, config),
verifier_(verifier_layout_sensitive,
allow_mixed_precision_in_hlo_verifier) {}
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
// Verifies the module using HloVerifier and returns the status.
Status Verify();
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
// included in the failure message.
void VerifyOrAddFailure(const string& message);
private:
HloVerifier verifier_;
};
// A base class for HLO tests that stores a default VerifiedHloModule.
class HloVerifiedTestBase : public HloTestBase {
protected:
HloVerifiedTestBase(bool layout_sensitive = false,
bool allow_mixed_precision = false);
// Constructs a default shape verifier.
std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
// Parses the given string and returns module as a VerifiedHloModule.
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
// Creates and returns a verified HLO module with the given name.
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
const string& name = TestName());
// CreateNewUnverifiedModule creates an *unverified* module, which presumably
// isn't what you want if you're using HloVerifiedTestBase, so we delete this
// function to keep you from accidentally calling it. If you really want it,
// you can get it by calling HloTestBase::CreateNewUnverifiedModule().
std::unique_ptr<HloModule> CreateNewUnverifiedModule(
const string& name = TestName()) = delete;
private:
bool verifier_layout_sensitive_;
bool allow_mixed_precision_in_hlo_verifier_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_

View File

@ -1,115 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
// This class includes unit tests which are expected to fail because invalid HLO
// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to
// include the necessary gunit parts to test this test machinery (needs the
// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the
// disabled tests enabled and failures can be manually compared against
// expectations.
class HloVerifiedTestBaseTest : public HloVerifiedTestBase {};
XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) {
// Test shouldn't fail if no module is created at all.
}
XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewUnverifiedModule) {
// Call CreateNewUnverifiedModule and build up a valid module.
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
builder.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
module->AddEntryComputation(builder.Build());
}
// This test is expected to fail. See test class comment.
XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewUnverifiedModule) {
// Call CreateNewUnverifiedModule and build up a invalid module.
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
builder.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
module->AddEntryComputation(builder.Build());
*module->entry_computation()->root_instruction()->mutable_shape() =
ShapeUtil::MakeShape(PRED, {1, 2, 3});
}
XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) {
const char* const hlo_string = R"(
HloModule ParseAndReturnVerifiedModuleGood
ENTRY entry {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x,y)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
}
XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) {
const char* const hlo_string = R"(
HloModule ParseAndReturnVerifiedModuleGood
ENTRY entry {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x,y)
}
RANDOM GARBAGE
)";
ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
}
// This test is expected to fail. See test class comment.
XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) {
const char* const hlo_string = R"(
HloModule ParseAndReturnVerifiedModuleBad
ENTRY entry {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[1234] add(x,y)
}
)";
ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
}
} // namespace
} // namespace xla