Use VerifiedHloModule in more tests.
Fix tests with HLO bugs. PiperOrigin-RevId: 275481999 Change-Id: I803e5f455de4fe92369601d22a26fd657f524331
This commit is contained in:
parent
609c4408de
commit
8d22a4426e
@ -466,7 +466,6 @@ tf_cc_test(
|
||||
":hlo_dce",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_ordering",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -618,9 +617,9 @@ tf_cc_test(
|
||||
srcs = ["hlo_matchers_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
@ -631,7 +630,6 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1287,7 +1285,6 @@ tf_cc_test(
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_ordering",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -1364,7 +1361,6 @@ tf_cc_test(
|
||||
":hlo_matchers",
|
||||
":hlo_module_group",
|
||||
":hlo_module_group_metadata",
|
||||
":hlo_parser",
|
||||
":hlo_proto",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -1447,7 +1443,6 @@ tf_cc_test(
|
||||
":hlo_dce",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_ordering",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -1848,6 +1843,7 @@ tf_cc_test(
|
||||
":hlo_pass",
|
||||
":pattern_matcher",
|
||||
":pattern_matcher_gmock",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -1941,8 +1937,8 @@ tf_cc_test(
|
||||
srcs = ["gather_expander_test.cc"],
|
||||
deps = [
|
||||
":gather_expander",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_macros_header",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
],
|
||||
@ -1976,7 +1972,6 @@ tf_cc_test(
|
||||
":conditional_simplifier",
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -2017,7 +2012,6 @@ tf_cc_test(
|
||||
":convolution_group_converter",
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
@ -2308,7 +2302,6 @@ xla_test(
|
||||
":hlo_get_dimension_size_rewriter",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":hlo_runner",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -2491,7 +2484,6 @@ tf_cc_test(
|
||||
":cpu_plugin",
|
||||
":hlo_cost_analysis",
|
||||
":hlo_execution_profile",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
@ -2505,7 +2497,6 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":pattern_matcher",
|
||||
":pattern_matcher_gmock",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -2526,7 +2517,6 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -2720,7 +2710,6 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_liveness_analysis",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -3176,7 +3165,6 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_module_dce",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
||||
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
@ -3641,9 +3642,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
|
||||
}
|
||||
}
|
||||
|
||||
auto out_dims = in_dims;
|
||||
out_dims[in_channel_idx] = options.f_output_channels;
|
||||
|
||||
auto make_shape = [](absl::Span<const int64> dims,
|
||||
bool minor_to_major_layout) {
|
||||
if (minor_to_major_layout) {
|
||||
@ -3654,20 +3652,26 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
|
||||
};
|
||||
auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
|
||||
auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
|
||||
auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout);
|
||||
|
||||
HloInstruction* input =
|
||||
b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
|
||||
HloInstruction* filter =
|
||||
b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
|
||||
Shape out_shape = ShapeInference::InferConvolveShape(
|
||||
in_shape, f_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums)
|
||||
.ValueOrDie();
|
||||
if (options.output_minor_to_major_layout) {
|
||||
out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(),
|
||||
{0, 1, 2, 3});
|
||||
}
|
||||
|
||||
b.AddInstruction(HloInstruction::CreateConvolve(
|
||||
out_shape, input, filter,
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
|
||||
DefaultPrecisionConfig(2)));
|
||||
|
||||
// TODO(b/80488902): verify this module.
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto* computation = module->AddEntryComputation(b.Build());
|
||||
|
||||
AlgebraicSimplifierOptions simplifier_options;
|
||||
@ -3841,8 +3845,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
|
||||
|
||||
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
|
||||
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
// TODO(b/80488902): verify this module.
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
// Create operand to the pad.
|
||||
@ -3883,9 +3886,10 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
dim->set_padding_high(100);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_base_dilation(1);
|
||||
dim->set_stride(1);
|
||||
}
|
||||
const Shape reduce_window_shape =
|
||||
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
|
||||
ShapeUtil::MakeShape(F32, {111, 113, 113, 116});
|
||||
HloInstruction* reduce_init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
|
||||
HloInstruction* reduce_window =
|
||||
@ -3923,8 +3927,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
|
||||
// ReduceWindow(Convert(op), x).
|
||||
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
|
||||
// TODO(b/80488902): verify this module.
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
// Create operand to the pad.
|
||||
@ -3969,9 +3972,10 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
|
||||
dim->set_padding_high(100);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_base_dilation(1);
|
||||
dim->set_stride(1);
|
||||
}
|
||||
const Shape reduce_window_shape =
|
||||
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
|
||||
ShapeUtil::MakeShape(F32, {111, 113, 113, 116});
|
||||
HloInstruction* reduce_init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
|
||||
HloInstruction* reduce_window =
|
||||
|
||||
@ -2605,8 +2605,7 @@ ENTRY entry_computation {
|
||||
}
|
||||
|
||||
)";
|
||||
auto module_or_status =
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
|
||||
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
|
||||
auto module = module_or_status.ConsumeValueOrDie();
|
||||
|
||||
RunCopyInsertion(module.get());
|
||||
|
||||
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
@ -202,7 +201,7 @@ ENTRY main {
|
||||
ROOT result = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true
|
||||
}
|
||||
)";
|
||||
auto status = ParseAndReturnUnverifiedModule(hlo_string);
|
||||
auto status = ParseAndReturnVerifiedModule(hlo_string);
|
||||
TF_ASSERT_OK(status.status());
|
||||
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
|
||||
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
|
||||
@ -268,7 +267,7 @@ TEST_F(ConditionalSimplifierTest,
|
||||
param.1), true_computation=computation.1,
|
||||
false_computation=computation.2
|
||||
})";
|
||||
auto status = ParseAndReturnUnverifiedModule(hlo_string);
|
||||
auto status = ParseAndReturnVerifiedModule(hlo_string);
|
||||
TF_ASSERT_OK(status.status());
|
||||
std::unique_ptr<HloModule> module = status.ConsumeValueOrDie();
|
||||
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
|
||||
@ -315,7 +314,7 @@ ENTRY main {
|
||||
ROOT result = () tuple()
|
||||
}
|
||||
)";
|
||||
auto status = ParseAndReturnUnverifiedModule(hlo_string);
|
||||
auto status = ParseAndReturnVerifiedModule(hlo_string);
|
||||
TF_ASSERT_OK(status.status());
|
||||
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
|
||||
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
|
||||
@ -359,7 +358,7 @@ ENTRY main {
|
||||
ROOT result = (f32[10,10]{1,0}) tuple(get-first-index)
|
||||
}
|
||||
)";
|
||||
auto status = ParseAndReturnUnverifiedModule(hlo_string);
|
||||
auto status = ParseAndReturnVerifiedModule(hlo_string);
|
||||
TF_ASSERT_OK(status.status());
|
||||
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
|
||||
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
|
||||
@ -405,7 +404,7 @@ ENTRY main {
|
||||
ROOT result = (f32[10,10]{1,0}) tuple(get-second-index)
|
||||
}
|
||||
)";
|
||||
auto status = ParseAndReturnUnverifiedModule(hlo_string);
|
||||
auto status = ParseAndReturnVerifiedModule(hlo_string);
|
||||
TF_ASSERT_OK(status.status());
|
||||
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
|
||||
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
|
||||
@ -471,7 +470,7 @@ ENTRY main {
|
||||
ROOT add = f32[] add(gte.0, gte.1)
|
||||
}
|
||||
)";
|
||||
auto status = ParseAndReturnUnverifiedModule(hlo_string);
|
||||
auto status = ParseAndReturnVerifiedModule(hlo_string);
|
||||
TF_ASSERT_OK(status.status());
|
||||
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
|
||||
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
|
||||
|
||||
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -44,7 +43,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2
|
||||
ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,2]{2,0,1} %copy, f32[1,1,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
auto computation = module->entry_computation();
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
@ -76,7 +75,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2
|
||||
ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,4]{2,0,1} %copy, f32[1,2,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
auto computation = module->entry_computation();
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
@ -106,7 +105,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16
|
||||
ROOT %convolution = f32[3,3,512,1]{3,2,1,0} convolution(f32[16,19,19,512]{3,2,1,0} %input, f32[16,19,19,512]{3,2,1,0} %filter), window={size=19x19 pad=1_1x1_1}, dim_labels=f01b_i01o->01fb, batch_group_count=512
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
auto computation = module->entry_computation();
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
|
||||
@ -404,7 +404,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
|
||||
|
||||
class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
protected:
|
||||
WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {}
|
||||
WhileCopyInsertionTest() : module_(CreateNewVerifiedModule()) {}
|
||||
|
||||
// Builds a While condition computation which reads the induction variable
|
||||
// from the tuple parameter, and returns a predicate indicating whether this
|
||||
@ -451,8 +451,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
auto data = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
||||
// Use 'induction_variable' in computation with no path to output tuple.
|
||||
Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
auto convert = builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
|
||||
auto update = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
|
||||
HloInstruction::CreateBroadcast(data_shape_, convert, {}));
|
||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape_, HloOpcode::kAdd, data, update));
|
||||
// Create output Tuple.
|
||||
@ -521,8 +524,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
||||
|
||||
// Use 'induction_variable' in computation with no path to output tuple.
|
||||
Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
auto convert = builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
|
||||
auto update = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
|
||||
HloInstruction::CreateBroadcast(data_shape_, convert, {}));
|
||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape_, HloOpcode::kAdd, data, update));
|
||||
// Create output Tuple.
|
||||
@ -685,11 +691,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto v1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {}));
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto v2 = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
|
||||
HloInstruction::CreateBroadcast(data_shape_, zero, {}));
|
||||
|
||||
auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2}));
|
||||
auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
|
||||
@ -709,7 +715,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto one_vec = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {}));
|
||||
auto data_init =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
|
||||
|
||||
@ -722,7 +728,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto data_init = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
|
||||
HloInstruction::CreateBroadcast(data_shape_, one, {}));
|
||||
auto one_vec = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
|
||||
{1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
||||
@ -840,10 +846,10 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
|
||||
ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
|
||||
ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
|
||||
|
||||
EXPECT_THAT(
|
||||
while_hlo->while_body()->root_instruction(),
|
||||
op::Tuple(op::Add(op::Copy(), op::Constant()),
|
||||
op::Add(op::GetTupleElement(), op::Broadcast(op::Copy()))));
|
||||
EXPECT_THAT(while_hlo->while_body()->root_instruction(),
|
||||
op::Tuple(op::Add(op::Copy(), op::Constant()),
|
||||
op::Add(op::GetTupleElement(),
|
||||
op::Broadcast(op::Convert(op::Copy())))));
|
||||
|
||||
// Both init indices need copies as they are constants.
|
||||
EXPECT_THAT(while_hlo->operand(0),
|
||||
@ -953,12 +959,17 @@ TEST_F(WhileCopyInsertionTest,
|
||||
auto data_param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, data_shape_, "data"));
|
||||
// Add dummy ops to ensure loop_init elements aren't entry parameters.
|
||||
auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
iter_param->shape(), HloOpcode::kExp, iter_param));
|
||||
Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
auto convert = builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(f32_scalar_shape, iter_param));
|
||||
auto iter_value = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(convert->shape(), HloOpcode::kExp, convert));
|
||||
auto convert2 = builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(induction_variable_shape_, iter_value));
|
||||
auto data_value = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
data_param->shape(), HloOpcode::kExp, data_param));
|
||||
auto loop_init = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({iter_value, data_value}));
|
||||
HloInstruction::CreateTuple({convert2, data_value}));
|
||||
|
||||
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
loop_state_shape_, condition1, body1, loop_init));
|
||||
@ -983,9 +994,9 @@ TEST_F(WhileCopyInsertionTest,
|
||||
EXPECT_EQ(CountCopies(*entry), 2);
|
||||
|
||||
EXPECT_THAT(while_hlo1->operand(0),
|
||||
op::Tuple(op::Exp(), op::Copy(op::Exp())));
|
||||
op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
|
||||
EXPECT_THAT(while_hlo2->operand(0),
|
||||
op::Tuple(op::Exp(), op::Copy(op::Exp())));
|
||||
op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
|
||||
}
|
||||
|
||||
// Tests while body computation with nested tuple elements:
|
||||
|
||||
@ -565,7 +565,7 @@ ENTRY main {
|
||||
operand = s32[20,10]{1,0} parameter(0)
|
||||
indices = s32[32,20] parameter(1)
|
||||
dynamic_size = s32[] parameter(2)
|
||||
ROOT gather = f32[32,10,10]{2,1,0} gather(%operand, %indices),
|
||||
ROOT gather = s32[32,20,10]{2,1,0} gather(%operand, %indices),
|
||||
offset_dims={2},
|
||||
collapsed_slice_dims={0},
|
||||
start_index_map={0},
|
||||
@ -574,7 +574,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||
DynamicParameterBinding::DynamicDimension{1, {}, 0}));
|
||||
|
||||
@ -45,8 +45,8 @@ TEST_F(DynamicIndexSplitterTest, DynamicSlice) {
|
||||
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
|
||||
config.set_debug_options(debug_options);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module, ParseAndReturnUnverifiedModule(kDynamicSlice, config));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kDynamicSlice, config));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
DynamicIndexSplitter().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
@ -84,7 +84,7 @@ TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) {
|
||||
config.set_debug_options(debug_options);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module, ParseAndReturnUnverifiedModule(kDynamicUpdateSlice, config));
|
||||
auto module, ParseAndReturnVerifiedModule(kDynamicUpdateSlice, config));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
DynamicIndexSplitter().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
@ -122,8 +122,8 @@ TEST_F(DynamicIndexSplitterTest, AlreadyScalar) {
|
||||
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
|
||||
config.set_debug_options(debug_options);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module, ParseAndReturnUnverifiedModule(kDynamicSlice, config));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kDynamicSlice, config));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
DynamicIndexSplitter().Run(module.get()));
|
||||
EXPECT_FALSE(changed);
|
||||
|
||||
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
@ -220,10 +219,8 @@ TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) {
|
||||
class ExecutionTest : public HloTestBase {
|
||||
protected:
|
||||
std::unique_ptr<HloModule> GetHloModule(const string& hlo_text) {
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(GetDebugOptionsForTest());
|
||||
std::unique_ptr<HloModule> module =
|
||||
ParseAndReturnUnverifiedModule(hlo_text, config).ValueOrDie();
|
||||
ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
|
||||
return module;
|
||||
}
|
||||
Literal PadAndExecute(std::unique_ptr<HloModule> module,
|
||||
|
||||
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -56,7 +55,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
DynamicParameterBinding binding;
|
||||
|
||||
@ -94,7 +93,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
DynamicParameterBinding binding;
|
||||
|
||||
@ -133,7 +132,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
DynamicParameterBinding binding;
|
||||
|
||||
|
||||
@ -14,13 +14,17 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gather_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
|
||||
|
||||
using GatherExpanderTest = HloTestBase;
|
||||
|
||||
TEST_F(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowGatherMultipleBatchDims
|
||||
|
||||
@ -36,7 +40,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_text));
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
|
||||
Status status = GatherExpander{}.Run(module.get()).status();
|
||||
EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
@ -47,7 +51,7 @@ ENTRY main {
|
||||
"indices are not supported."));
|
||||
}
|
||||
|
||||
TEST(GatherExpanderTest, AvoidDegenerateDims) {
|
||||
TEST_F(GatherExpanderTest, AvoidDegenerateDims) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowGatherV2
|
||||
|
||||
@ -63,7 +67,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_text));
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
|
||||
ASSERT_TRUE(changed);
|
||||
|
||||
@ -105,7 +109,7 @@ ENTRY main {
|
||||
ShapeUtil::GetTupleElementShape(while_shape, 3)));
|
||||
}
|
||||
|
||||
TEST(GatherExpanderTest, CheckOpMetadata) {
|
||||
TEST_F(GatherExpanderTest, CheckOpMetadata) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowGatherV2
|
||||
|
||||
@ -121,7 +125,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_text));
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
OpMetadata metadata;
|
||||
metadata.set_op_name("Gather");
|
||||
module->entry_computation()->root_instruction()->set_metadata(metadata);
|
||||
|
||||
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -158,7 +157,7 @@ TEST_F(HloComputationTest, PostOrderTrace) {
|
||||
builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1));
|
||||
auto negate2 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
// Trace instructions should be at the end of the sort.
|
||||
EXPECT_THAT(computation->MakeInstructionPostOrder(),
|
||||
@ -697,7 +696,7 @@ ENTRY entry {
|
||||
ROOT t = (f32[128], f32[128]) tuple(add, crs1)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(),
|
||||
ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(),
|
||||
op::Add(), op::Tuple()));
|
||||
|
||||
@ -1639,7 +1639,7 @@ ENTRY root {
|
||||
p1 = s32[1000] copy(param)
|
||||
ROOT t = (s32[1000], s32[1000]) tuple(p0, p1)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
auto entry = module_->entry_computation();
|
||||
entry->GetInstructionWithName("t");
|
||||
auto& dataflow_analysis = RunAnalysis(GetParam());
|
||||
@ -1990,7 +1990,7 @@ ENTRY %AddDependency (p: f32[3]) -> f32[3] {
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest()));
|
||||
ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
|
||||
HloDataflowAnalysis::Run(*module));
|
||||
@ -2010,7 +2010,7 @@ INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation,
|
||||
class HloDataflowAnalysisTestBase : public HloTestBase {
|
||||
protected:
|
||||
void BuildModule(std::unique_ptr<HloComputation> computation) {
|
||||
module_ = CreateNewUnverifiedModule();
|
||||
module_ = CreateNewVerifiedModule();
|
||||
computation_ = module_->AddEntryComputation(std::move(computation));
|
||||
}
|
||||
|
||||
@ -2228,7 +2228,7 @@ TEST_F(CanShareOperandBufferWithUserTest,
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto operand = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape, one, {1}));
|
||||
HloInstruction::CreateBroadcast(data_shape, one, {}));
|
||||
|
||||
auto neg = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
|
||||
@ -2256,7 +2256,7 @@ TEST_F(CanShareOperandBufferWithUserTest,
|
||||
auto zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(0)));
|
||||
auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
slice_shape, param, {zero, zero}, {1, 2, 2}));
|
||||
slice_shape, param, {zero, zero}, {1, 2}));
|
||||
|
||||
auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, param, ds, {zero, zero}));
|
||||
@ -2448,33 +2448,29 @@ TEST_F(CanShareOperandBufferWithUserTest,
|
||||
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape update_shape = ShapeUtil::MakeShape(F32, {4});
|
||||
Shape starts_shape = ShapeUtil::MakeShape(S32, {1});
|
||||
Shape data_shape = ShapeUtil::MakeShape(F32, {1, 8});
|
||||
Shape update_shape = ShapeUtil::MakeShape(F32, {1, 4});
|
||||
Shape starts_shape = ShapeUtil::MakeShape(S32, {2});
|
||||
auto data = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||
auto update = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, update_shape, "update"));
|
||||
auto start0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, starts_shape, "start0"));
|
||||
auto start1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(3, starts_shape, "start1"));
|
||||
auto start = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, starts_shape, "start"));
|
||||
|
||||
auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, data, update, {start0, start1}));
|
||||
data_shape, data, update, {start}));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
// The DynamicUpdateSlice instruction can share with the data operand, but not
|
||||
// with update or starts.
|
||||
// with update or start.
|
||||
EXPECT_TRUE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {}));
|
||||
EXPECT_FALSE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {}));
|
||||
EXPECT_FALSE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(start0, {}, dus, {}));
|
||||
EXPECT_FALSE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(start1, {}, dus, {}));
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(start, {}, dus, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
|
||||
@ -2498,7 +2494,7 @@ TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
|
||||
index_vector_dim=1
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
computation_ = module_->entry_computation();
|
||||
RunAnalysis();
|
||||
|
||||
@ -2526,7 +2522,7 @@ TEST_F(CanShareOperandBufferWithUserTest, TriangularSolveCanShare) {
|
||||
transpose_a=NO_TRANSPOSE
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
computation_ = module_->entry_computation();
|
||||
RunAnalysis();
|
||||
|
||||
@ -2611,7 +2607,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto add_operand = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape, one, {1}));
|
||||
HloInstruction::CreateBroadcast(data_shape, one, {}));
|
||||
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape, HloOpcode::kAdd, dot, add_operand));
|
||||
@ -2633,7 +2629,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto operand = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape, one, {1}));
|
||||
HloInstruction::CreateBroadcast(data_shape, one, {}));
|
||||
|
||||
auto reverse = builder.AddInstruction(
|
||||
HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
|
||||
@ -2661,7 +2657,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto operand = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape, one, {1}));
|
||||
HloInstruction::CreateBroadcast(data_shape, one, {}));
|
||||
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape, HloOpcode::kMultiply, operand, operand));
|
||||
auto two = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
@ -2683,14 +2679,30 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
||||
module_ = CreateNewVerifiedModule();
|
||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape pred_scalar_shape = ShapeUtil::MakeShape(PRED, {});
|
||||
|
||||
auto make_cond = [&data_shape]() {
|
||||
auto b = HloComputation::Builder(TestName() + ".And");
|
||||
auto p0 = b.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, pred_scalar_shape, "p0"));
|
||||
auto p1 = b.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, pred_scalar_shape, "p1"));
|
||||
b.AddInstruction(
|
||||
HloInstruction::CreateBinary(pred_scalar_shape, HloOpcode::kAnd, p0, p1));
|
||||
auto and_computation = module_->AddEmbeddedComputation(b.Build());
|
||||
|
||||
auto make_cond = [&data_shape, &and_computation]() {
|
||||
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
||||
auto data = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
|
||||
auto compare = builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {8}), data, data, ComparisonDirection::kEq));
|
||||
auto true_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateReduce(ShapeUtil::MakeShape(PRED, {}), compare,
|
||||
true_value, {0}, and_computation));
|
||||
return builder.Build();
|
||||
};
|
||||
|
||||
@ -2703,7 +2715,6 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
||||
return builder.Build();
|
||||
};
|
||||
|
||||
module_ = CreateNewUnverifiedModule();
|
||||
HloComputation* cond_computation =
|
||||
module_->AddEmbeddedComputation(make_cond());
|
||||
HloComputation* body_computation =
|
||||
@ -2734,11 +2745,11 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
|
||||
auto one = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto ones = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(shape, one, {1}));
|
||||
HloInstruction::CreateBroadcast(shape, one, {}));
|
||||
auto add = sub_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
|
||||
|
||||
module_ = CreateNewUnverifiedModule();
|
||||
module_ = CreateNewVerifiedModule();
|
||||
auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
|
||||
sub_computation->CreateFusionInstruction({add, ones},
|
||||
HloInstruction::FusionKind::kLoop);
|
||||
|
||||
@ -76,19 +76,20 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) {
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
||||
auto token = builder.AddInstruction(HloInstruction::CreateToken());
|
||||
builder.AddInstruction(
|
||||
auto send = builder.AddInstruction(
|
||||
HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
|
||||
builder.AddInstruction(HloInstruction::CreateSendDone(send));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({}));
|
||||
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_EQ(4, computation->instruction_count());
|
||||
EXPECT_EQ(5, computation->instruction_count());
|
||||
|
||||
HloDCE dce;
|
||||
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
|
||||
|
||||
EXPECT_EQ(4, computation->instruction_count());
|
||||
EXPECT_EQ(5, computation->instruction_count());
|
||||
}
|
||||
|
||||
TEST_F(HloDceTest, CustomCallInstructionsWithSideEffect) {
|
||||
@ -250,7 +251,7 @@ TEST_F(HloDceTest, DeadInstructionWithCalledComputation) {
|
||||
// Tests that a while instruction with an infeed (effectul instruction) in its
|
||||
// body is not removed, even its user count is 0.
|
||||
TEST_F(HloDceTest, CalledComputationWithSideEffect) {
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto module = CreateNewVerifiedModule();
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
// Condition computation of a while instruction.
|
||||
@ -274,8 +275,10 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
|
||||
auto token = body_builder.AddInstruction(HloInstruction::CreateToken());
|
||||
auto infeed = body_builder.AddInstruction(
|
||||
HloInstruction::CreateInfeed(shape, token, ""));
|
||||
body_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed));
|
||||
auto infeed_data = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, infeed, 0));
|
||||
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kAdd, param, infeed_data));
|
||||
}
|
||||
auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
@ -306,7 +309,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
|
||||
|
||||
// Tests that a nested call instruction with a side effect is not removed.
|
||||
TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto module = CreateNewVerifiedModule();
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
// Nested called computation with a side effect.
|
||||
@ -328,8 +331,8 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
|
||||
{
|
||||
auto param = callee_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param"));
|
||||
callee_builder.AddInstruction(
|
||||
HloInstruction::CreateCall(shape, {param}, nested_called_computation));
|
||||
callee_builder.AddInstruction(HloInstruction::CreateCall(
|
||||
ShapeUtil::MakeTokenShape(), {param}, nested_called_computation));
|
||||
}
|
||||
auto called_computation =
|
||||
module->AddEmbeddedComputation(callee_builder.Build());
|
||||
@ -338,22 +341,20 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param"));
|
||||
auto live_call = builder.AddInstruction(
|
||||
HloInstruction::CreateCall(shape, {param}, called_computation));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
|
||||
auto live_call = builder.AddInstruction(HloInstruction::CreateCall(
|
||||
ShapeUtil::MakeTokenShape(), {param}, called_computation));
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_EQ(3, computation->instruction_count());
|
||||
EXPECT_EQ(2, param->user_count());
|
||||
EXPECT_EQ(2, computation->instruction_count());
|
||||
EXPECT_EQ(1, param->user_count());
|
||||
EXPECT_EQ(0, live_call->user_count());
|
||||
EXPECT_TRUE(HasInstruction(*computation, live_call));
|
||||
|
||||
HloDCE dce;
|
||||
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
|
||||
|
||||
EXPECT_EQ(3, computation->instruction_count());
|
||||
EXPECT_EQ(2, param->user_count());
|
||||
EXPECT_EQ(2, computation->instruction_count());
|
||||
EXPECT_EQ(1, param->user_count());
|
||||
EXPECT_EQ(0, live_call->user_count());
|
||||
EXPECT_TRUE(HasInstruction(*computation, live_call));
|
||||
}
|
||||
@ -400,7 +401,7 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) {
|
||||
}
|
||||
|
||||
TEST_F(HloDceTest, KeepUsedSubcomputation) {
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
HloComputation::Builder subcomp_builder("reduction_subcomp");
|
||||
@ -418,7 +419,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) {
|
||||
|
||||
// Create a dead reduce instruction.
|
||||
builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
ShapeUtil::MakeShape(F32, {1}),
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
|
||||
builder.AddInstruction(
|
||||
@ -428,7 +429,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) {
|
||||
// Add another instruction as the root of the computation that also uses
|
||||
// reduce_subcomp.
|
||||
builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
ShapeUtil::MakeShape(F32, {1}),
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")),
|
||||
builder.AddInstruction(
|
||||
|
||||
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
@ -29,7 +28,7 @@ using ::testing::ContainsRegex;
|
||||
class HloExecutionProfileTest : public HloTestBase {};
|
||||
|
||||
TEST_F(HloExecutionProfileTest, Basic) {
|
||||
auto hlo_module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto hlo_module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule test_module
|
||||
ENTRY entry_computation {
|
||||
lhs = f32[30,30]{1,0} parameter(0)
|
||||
|
||||
@ -40,7 +40,7 @@ class HloGetDimensionSizeRewriterTest : public HloTestBase {
|
||||
};
|
||||
|
||||
TEST_F(HloGetDimensionSizeRewriterTest, Ok) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = s32[3,4] parameter(0)
|
||||
|
||||
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -80,7 +79,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
HloInputOutputAliasConfig config(
|
||||
module->entry_computation()->root_instruction()->shape());
|
||||
@ -112,7 +111,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
HloInputOutputAliasConfig config(
|
||||
module->entry_computation()->root_instruction()->shape());
|
||||
@ -151,7 +150,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
HloInputOutputAliasConfig config(
|
||||
module->entry_computation()->root_instruction()->shape());
|
||||
@ -182,7 +181,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
HloInputOutputAliasConfig config(
|
||||
module->entry_computation()->root_instruction()->shape());
|
||||
@ -208,7 +207,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
HloInputOutputAliasConfig config(
|
||||
module->entry_computation()->root_instruction()->shape());
|
||||
|
||||
@ -2895,7 +2895,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
|
||||
// These opcodes are not handled here.
|
||||
case HloOpcode::kTrace:
|
||||
break;
|
||||
return Status::OK();
|
||||
}
|
||||
return InternalError(
|
||||
"Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
|
||||
|
||||
@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
@ -951,7 +950,7 @@ ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
auto* root = module->entry_computation()->root_instruction();
|
||||
auto* t1 = root->operand(0);
|
||||
@ -1187,11 +1186,12 @@ TEST_F(HloInstructionTest, FuseInstructionKeepsInstruction) {
|
||||
p2 = f32[32,32]{1,0} parameter(0)
|
||||
p3 = f32[32,32]{1,0} parameter(1)
|
||||
c1 = f32[] constant(1)
|
||||
broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
|
||||
mul = f32[32,32]{1,0} multiply(p2, p3)
|
||||
ROOT add = f32[32,32]{1,0} fusion(mul, c1), kind=kLoop, calls=fused_add
|
||||
ROOT add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kHloString));
|
||||
ParseAndReturnVerifiedModule(kHloString));
|
||||
HloInstruction* fused_add = module->entry_computation()->root_instruction();
|
||||
HloInstruction* mul = fused_add->mutable_operand(0);
|
||||
EXPECT_EQ(1, mul->user_count());
|
||||
@ -1215,11 +1215,12 @@ TEST_F(HloInstructionTest, FuseInstructionIntoMultiOutputKeepsInstruction) {
|
||||
p3 = f32[32,32]{1,0} parameter(1)
|
||||
c1 = f32[] constant(1)
|
||||
mul = f32[32,32]{1,0} multiply(p2, p3)
|
||||
add = f32[32,32]{1,0} fusion(mul, c1), kind=kLoop, calls=fused_add
|
||||
broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
|
||||
add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
|
||||
ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(mul, add)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kHloString));
|
||||
ParseAndReturnVerifiedModule(kHloString));
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
HloInstruction* mul = root->mutable_operand(0);
|
||||
HloInstruction* fused_add = root->mutable_operand(1);
|
||||
@ -1740,7 +1741,7 @@ ENTRY entry (param: s32[]) -> s32[] {
|
||||
// Check that deep clones really deep clones every instruction and
|
||||
// computations, without leaving dangling pointers to the old module.
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
std::unique_ptr<HloModule> clone = module->Clone();
|
||||
for (HloComputation* computation : clone->computations()) {
|
||||
EXPECT_EQ(computation->parent(), clone.get());
|
||||
@ -1860,7 +1861,7 @@ TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
|
||||
dim_labels=b0f_0io->b0f, operand_precision={high,default}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kHloString));
|
||||
ParseAndReturnVerifiedModule(kHloString));
|
||||
auto* conv = module->entry_computation()->root_instruction();
|
||||
|
||||
auto clone = conv->Clone();
|
||||
@ -1873,10 +1874,10 @@ TEST_F(HloInstructionTest, PreserveOuterDimensionPartitionsOnClone) {
|
||||
constexpr char kHloString[] = R"(
|
||||
HloModule test_module
|
||||
ENTRY test {
|
||||
ROOT iota = f32[100] iota(), iota_dimension=1, outer_dimension_partitions={0, 50}
|
||||
ROOT iota = f32[100] iota(), iota_dimension=0, outer_dimension_partitions={0, 50}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kHloString));
|
||||
ParseAndReturnVerifiedModule(kHloString));
|
||||
auto* iota = module->entry_computation()->root_instruction();
|
||||
|
||||
auto clone = iota->Clone();
|
||||
|
||||
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
@ -59,7 +58,7 @@ class HloLivenessAnalysisTest : public HloTestBase {
|
||||
|
||||
// Test that add instruction at entry root is live at all output shape indices.
|
||||
TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleModule
|
||||
ENTRY SimpleComputation {
|
||||
constant.1 = s32[] constant(0)
|
||||
@ -75,7 +74,7 @@ TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
|
||||
|
||||
// Test that a dead add instruction is marked as dead by analysis.
|
||||
TEST_F(HloLivenessAnalysisTest, DeadAdd) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleModule
|
||||
ENTRY SimpleComputation {
|
||||
constant.1 = s32[] constant(0)
|
||||
@ -94,7 +93,7 @@ TEST_F(HloLivenessAnalysisTest, DeadAdd) {
|
||||
// Test that all output shape indices of entry root tuple (and defining
|
||||
// instruction in its output) are marked live.
|
||||
TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleModule
|
||||
ENTRY SimpleComputation {
|
||||
constant.1 = s32[] constant(0)
|
||||
@ -113,7 +112,7 @@ TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
|
||||
// Tests that all outputs of nested tuple and entry root (and defining
|
||||
// instruction values appearing in its output) are marked live.
|
||||
TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleModule
|
||||
ENTRY SimpleComputation {
|
||||
constant.1 = s32[] constant(1)
|
||||
@ -140,7 +139,7 @@ TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
|
||||
// Tests that GTE at entry root of Tuple instruction only propgates liveness
|
||||
// to the live elements in tuple.
|
||||
TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleModule
|
||||
ENTRY SimpleComputation {
|
||||
constant.1 = s32[] constant(0)
|
||||
@ -162,7 +161,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
|
||||
// Tests that GTE at entry root of nested Tuple instruction only propgates
|
||||
// liveness to the live elements in tuple.
|
||||
TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleModule
|
||||
ENTRY SimpleComputation {
|
||||
constant.1 = s32[] constant(0)
|
||||
@ -199,7 +198,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
|
||||
// Tests that GTE of GTE (at entry root) of nested Tuple instruction only
|
||||
// propgates liveness to the live elements in tuple.
|
||||
TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleModule
|
||||
ENTRY SimpleComputation {
|
||||
constant.1 = s32[] constant(0)
|
||||
@ -240,7 +239,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
|
||||
|
||||
// Test that live/dead while tuple elements are marked live/dead correctly.
|
||||
TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleLoop
|
||||
SimpleLoop.body {
|
||||
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
|
||||
@ -291,8 +290,13 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
|
||||
// Tests that a tuple element live in while.cond computation, propagates
|
||||
// liveness to while.body.root/while.result/while.operand (where it is unused).
|
||||
TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleLoop
|
||||
add_S32 {
|
||||
lhs = s32[] parameter(0)
|
||||
rhs = s32[] parameter(1)
|
||||
ROOT add = s32[] add(lhs, rhs)
|
||||
}
|
||||
SimpleLoop.body {
|
||||
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
|
||||
@ -305,8 +309,10 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
|
||||
SimpleLoop.condition {
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1
|
||||
add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4)
|
||||
get-tuple-element.4 = s32[3]{0} get-tuple-element(loop_var.2), index=1
|
||||
zero = s32[] constant(0)
|
||||
reduce = s32[] reduce(get-tuple-element.4, zero), dimensions={0}, to_apply=add_S32
|
||||
add.1 = s32[] add(get-tuple-element.3, reduce)
|
||||
constant.2 = s32[] constant(5)
|
||||
ROOT less-than = pred[] compare(add.1, constant.2), direction=LT
|
||||
}
|
||||
@ -338,14 +344,14 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
|
||||
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {}));
|
||||
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0}));
|
||||
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1}));
|
||||
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {}));
|
||||
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {}));
|
||||
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {}));
|
||||
}
|
||||
|
||||
// Tests that a use of while.result{0} propagates liveness to
|
||||
// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}.
|
||||
TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleLoop
|
||||
SimpleLoop.body {
|
||||
loop_var.1 = (s32[], s32[], s32[]) parameter(0)
|
||||
@ -399,7 +405,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
|
||||
}
|
||||
|
||||
TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule OutfeedLoop
|
||||
WhileBody {
|
||||
body_param = (s32[]) parameter(0)
|
||||
@ -432,7 +438,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
|
||||
}
|
||||
|
||||
TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule OutfeedLoop
|
||||
InnerWhileBody {
|
||||
body_param = (s32[]) parameter(0)
|
||||
|
||||
@ -14,9 +14,10 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
using ::testing::_;
|
||||
@ -25,6 +26,8 @@ using ::testing::Eq;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using HloMatchersTest = HloTestBase;
|
||||
|
||||
string DescribeHloMatcher(const ::testing::Matcher<const HloInstruction*>& m) {
|
||||
std::stringstream ss;
|
||||
m.DescribeTo(&ss);
|
||||
@ -39,7 +42,7 @@ string Explain(const T& t, const M& m) {
|
||||
return listener.str();
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, Test) {
|
||||
TEST_F(HloMatchersTest, Test) {
|
||||
auto shape = ShapeUtil::MakeShape(F32, {1});
|
||||
auto param = HloInstruction::CreateParameter(0, shape, "param");
|
||||
auto mul = HloInstruction::CreateBinary(shape, HloOpcode::kMultiply,
|
||||
@ -85,7 +88,7 @@ TEST(HloMatchersTest, Test) {
|
||||
"add, (%param = f32[1]{0} parameter(0))"));
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, CustomCallMatcher) {
|
||||
TEST_F(HloMatchersTest, CustomCallMatcher) {
|
||||
auto c1 =
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3}));
|
||||
auto c2 =
|
||||
@ -116,7 +119,7 @@ TEST(HloMatchersTest, CustomCallMatcher) {
|
||||
R"(custom-call with call target that is equal to "foo_target")");
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, ShapeMatcher) {
|
||||
TEST_F(HloMatchersTest, ShapeMatcher) {
|
||||
auto p0 = HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}), "param");
|
||||
|
||||
@ -154,7 +157,7 @@ TEST(HloMatchersTest, ShapeMatcher) {
|
||||
"(expected: f32[7,5]{1,0})");
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, ShardingMatcher) {
|
||||
TEST_F(HloMatchersTest, ShardingMatcher) {
|
||||
auto p0 = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {5}),
|
||||
"param.0");
|
||||
p0->clear_sharding();
|
||||
@ -196,7 +199,7 @@ TEST(HloMatchersTest, ShardingMatcher) {
|
||||
"has incorrect sharding (expected: {maximal device=0})");
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, DotMatcher) {
|
||||
TEST_F(HloMatchersTest, DotMatcher) {
|
||||
string hlo_string = R"(
|
||||
HloModule DotOperationFusion_TransposeFusion
|
||||
|
||||
@ -208,7 +211,7 @@ ENTRY DotOperationFusion_TransposeFusion {
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
|
||||
EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1),
|
||||
@ -232,7 +235,7 @@ ENTRY DotOperationFusion_TransposeFusion {
|
||||
"rhs_contracting_dimensions (got {0} want {1})");
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, ComparisonMatcher) {
|
||||
TEST_F(HloMatchersTest, ComparisonMatcher) {
|
||||
auto shape = ShapeUtil::MakeShape(F32, {1});
|
||||
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
|
||||
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
|
||||
@ -264,7 +267,7 @@ TEST(HloMatchersTest, ComparisonMatcher) {
|
||||
"has wrong comparison direction (got EQ, want NE)"));
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, AsyncCopyMatcher) {
|
||||
TEST_F(HloMatchersTest, AsyncCopyMatcher) {
|
||||
Shape shape_memspace1 = ShapeUtil::MakeShapeWithLayout(
|
||||
F32, {16}, /*minor_to_major=*/{0}, /*tiles=*/{},
|
||||
/*element_size_in_bits=*/0, /*memory_space=*/1);
|
||||
|
||||
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -137,7 +136,7 @@ ENTRY root {
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
auto size_fn = [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
|
||||
@ -190,7 +189,7 @@ ENTRY entry {
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
|
||||
auto size_fn = [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
|
||||
@ -334,7 +333,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
EXPECT_FALSE(module->has_schedule());
|
||||
TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status());
|
||||
ASSERT_TRUE(module->has_schedule());
|
||||
|
||||
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
@ -71,7 +70,7 @@ class HloModuleDceTest : public HloTestBase {
|
||||
|
||||
// Tests that a while with all outputs live is unmodified.
|
||||
TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleLoop
|
||||
SimpleLoop.body {
|
||||
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
|
||||
@ -108,7 +107,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
|
||||
// Tests a while loop with one unused output (which is used in the while loop
|
||||
// body by an instruction with side-effects: rng) is unmodified.
|
||||
TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleLoop
|
||||
SimpleLoop.body {
|
||||
loop_var.1 = (s32[], f32[]) parameter(0)
|
||||
@ -118,7 +117,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
|
||||
get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1
|
||||
constant.2 = f32[] constant(1.0)
|
||||
rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform
|
||||
add.1 = s32[] add(get-tuple-element.2, constant.2)
|
||||
add.1 = f32[] add(get-tuple-element.2, constant.2)
|
||||
ROOT tuple = (s32[], f32[]) tuple(add, add.1)
|
||||
}
|
||||
SimpleLoop.condition {
|
||||
@ -148,7 +147,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
|
||||
// Tests that a while loop with one dead tuple element at {1} has its while
|
||||
// loop body modified to make that tuple element pass-through the while body.
|
||||
TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleLoop
|
||||
SimpleLoop.body {
|
||||
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
|
||||
@ -191,7 +190,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
|
||||
// dead in while.body{1} and at while.result{1}) propgates liveness of this
|
||||
// tuple element to while.body{1} and at while.result{1}.
|
||||
TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleLoop
|
||||
SimpleLoop.body {
|
||||
loop_var.1 = (s32[], s32[]) parameter(0)
|
||||
@ -233,7 +232,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
|
||||
// Tests that HloModuleDCE can remove a dead tuple element at index {1} between
|
||||
// two dependent while loops.
|
||||
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleLoop
|
||||
SimpleLoop.body0 {
|
||||
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
|
||||
@ -301,7 +300,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
|
||||
// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
|
||||
// while.2{1}, between two dependent while loops.
|
||||
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule SimpleLoop
|
||||
SimpleLoop.body0 {
|
||||
loop_var.1 = (s32[3]{0}, s32[]) parameter(0)
|
||||
@ -367,7 +366,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
|
||||
|
||||
// Tests that a while whose body has outfeed operations is not DCE-ed.
|
||||
TEST_F(HloModuleDceTest, WhileWithOutfeed) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule OutfeedLoop
|
||||
WhileBody {
|
||||
body_param = (s32[]) parameter(0)
|
||||
@ -404,7 +403,7 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) {
|
||||
// variable changes are not elided within the loop body, if the condition
|
||||
// computation uses them.
|
||||
TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule InfiniteLoop
|
||||
WhileBody {
|
||||
body_param = (s32[], s32[]) parameter(0)
|
||||
|
||||
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -45,7 +44,7 @@ ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(text));
|
||||
ParseAndReturnVerifiedModule(text));
|
||||
HloModuleGroup group(std::move(module));
|
||||
|
||||
EXPECT_EQ(group.modules().size(), 1);
|
||||
@ -84,9 +83,9 @@ ENTRY %entry (a: f32[]) -> f32[] {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
|
||||
ParseAndReturnUnverifiedModule(text_0));
|
||||
ParseAndReturnVerifiedModule(text_0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
|
||||
ParseAndReturnUnverifiedModule(text_1));
|
||||
ParseAndReturnVerifiedModule(text_1));
|
||||
std::vector<std::unique_ptr<HloModule>> modules;
|
||||
modules.push_back(std::move(module_0));
|
||||
modules.push_back(std::move(module_1));
|
||||
@ -123,9 +122,9 @@ ENTRY %entry (a: f32[]) -> f32[] {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
|
||||
ParseAndReturnUnverifiedModule(text_0));
|
||||
ParseAndReturnVerifiedModule(text_0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
|
||||
ParseAndReturnUnverifiedModule(text_1));
|
||||
ParseAndReturnVerifiedModule(text_1));
|
||||
HloModuleGroup group(TestName());
|
||||
group.push_back(std::move(module_0));
|
||||
group.push_back(std::move(module_1));
|
||||
@ -179,7 +178,7 @@ ENTRY entry {
|
||||
const int64 send_channel = i;
|
||||
const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(absl::StrFormat(
|
||||
ParseAndReturnVerifiedModule(absl::StrFormat(
|
||||
text, i, send_channel, send_channel,
|
||||
recv_channel, recv_channel)));
|
||||
group.push_back(std::move(module));
|
||||
|
||||
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
@ -215,7 +214,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(text));
|
||||
ParseAndReturnVerifiedModule(text));
|
||||
ASSERT_FALSE(module->has_schedule());
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module_copy,
|
||||
@ -237,7 +236,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(text));
|
||||
ParseAndReturnVerifiedModule(text));
|
||||
ASSERT_TRUE(module->has_schedule());
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module_copy,
|
||||
@ -274,7 +273,7 @@ ENTRY ReduceR3ToR2.v3 {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(text));
|
||||
ParseAndReturnVerifiedModule(text));
|
||||
|
||||
// Perform various transformations on the graph:
|
||||
//
|
||||
|
||||
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
@ -338,7 +337,7 @@ ENTRY while.v11 {
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
DependencyHloOrdering ordering(module.get());
|
||||
ordering.ToString(); // Shouldn't crash.
|
||||
}
|
||||
@ -375,7 +374,7 @@ ENTRY root {
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
|
||||
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
|
||||
DependencyHloOrdering ordering(module.get());
|
||||
@ -507,23 +506,22 @@ TEST_F(HloOrderingTest, InterferenceWithOuterRoot) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule InterferenceWithOuterRoot, is_scheduled=true
|
||||
|
||||
Emmbedded (embedded_param: f32[42]) -> f32[42] {
|
||||
embedded_param = f32[42]{0} parameter(0)
|
||||
multiply = f32[42]{0} multiply(embedded_param, embedded_param)
|
||||
ROOT log = f32[42]{0} log(multiply)
|
||||
Emmbedded (embedded_param: f32[4096,4096]) -> f32[4096,4096] {
|
||||
embedded_param = f32[4096,4096]{1,0} parameter(0)
|
||||
multiply = f32[4096,4096]{1,0} multiply(embedded_param, embedded_param)
|
||||
ROOT log = f32[4096,4096]{1,0} log(multiply)
|
||||
}
|
||||
|
||||
ENTRY InterferenceWithOuterRoot {
|
||||
param = f32[4096,4096]{1,0} parameter(0)
|
||||
ROOT add = f32[4096,4096]{1,0} add(param, param)
|
||||
call = f32[42]{0} call(param), to_apply=Emmbedded
|
||||
call = f32[4096,4096]{1,0} call(param), to_apply=Emmbedded
|
||||
}
|
||||
|
||||
)";
|
||||
HloModuleConfig hlo_config;
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string, hlo_config));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string, hlo_config));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
|
||||
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
|
||||
DependencyHloOrdering ordering(module.get());
|
||||
|
||||
@ -602,9 +602,8 @@ ENTRY %entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
@ -643,9 +642,8 @@ ENTRY %entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user