Use VerifiedHloModule in more tests.

Fix tests with HLO bugs.

PiperOrigin-RevId: 275481999
Change-Id: I803e5f455de4fe92369601d22a26fd657f524331
This commit is contained in:
Adrian Kuegel 2019-10-18 08:56:58 -07:00 committed by TensorFlower Gardener
parent 609c4408de
commit 8d22a4426e
27 changed files with 228 additions and 217 deletions

View File

@ -466,7 +466,6 @@ tf_cc_test(
":hlo_dce",
":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@ -618,9 +617,9 @@ tf_cc_test(
srcs = ["hlo_matchers_test.cc"],
deps = [
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@ -631,7 +630,6 @@ tf_cc_test(
deps = [
":hlo",
":hlo_casting_utils",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
@ -1287,7 +1285,6 @@ tf_cc_test(
":hlo_dataflow_analysis",
":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@ -1364,7 +1361,6 @@ tf_cc_test(
":hlo_matchers",
":hlo_module_group",
":hlo_module_group_metadata",
":hlo_parser",
":hlo_proto",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
@ -1447,7 +1443,6 @@ tf_cc_test(
":hlo_dce",
":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@ -1848,6 +1843,7 @@ tf_cc_test(
":hlo_pass",
":pattern_matcher",
":pattern_matcher_gmock",
":shape_inference",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@ -1941,8 +1937,8 @@ tf_cc_test(
srcs = ["gather_expander_test.cc"],
deps = [
":gather_expander",
":hlo_parser",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_macros_header",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
],
@ -1976,7 +1972,6 @@ tf_cc_test(
":conditional_simplifier",
":hlo",
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
@ -2017,7 +2012,6 @@ tf_cc_test(
":convolution_group_converter",
":hlo",
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@ -2308,7 +2302,6 @@ xla_test(
":hlo_get_dimension_size_rewriter",
":hlo_matchers",
":hlo_parser",
":hlo_runner",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -2491,7 +2484,6 @@ tf_cc_test(
":cpu_plugin",
":hlo_cost_analysis",
":hlo_execution_profile",
":hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@ -2505,7 +2497,6 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
":hlo_parser",
":pattern_matcher",
":pattern_matcher_gmock",
"//tensorflow/compiler/xla:literal",
@ -2526,7 +2517,6 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":hlo_memory_scheduler",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@ -2720,7 +2710,6 @@ tf_cc_test(
deps = [
":hlo",
":hlo_liveness_analysis",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -3176,7 +3165,6 @@ tf_cc_test(
deps = [
":hlo",
":hlo_module_dce",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@ -3641,9 +3642,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
}
}
auto out_dims = in_dims;
out_dims[in_channel_idx] = options.f_output_channels;
auto make_shape = [](absl::Span<const int64> dims,
bool minor_to_major_layout) {
if (minor_to_major_layout) {
@ -3654,20 +3652,26 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
};
auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout);
HloInstruction* input =
b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
HloInstruction* filter =
b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
Shape out_shape = ShapeInference::InferConvolveShape(
in_shape, f_shape, /*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums)
.ValueOrDie();
if (options.output_minor_to_major_layout) {
out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(),
{0, 1, 2, 3});
}
b.AddInstruction(HloInstruction::CreateConvolve(
out_shape, input, filter,
/*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
DefaultPrecisionConfig(2)));
// TODO(b/80488902): verify this module.
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifierOptions simplifier_options;
@ -3841,8 +3845,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// TODO(b/80488902): verify this module.
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@ -3883,9 +3886,10 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
dim->set_padding_high(100);
dim->set_window_dilation(1);
dim->set_base_dilation(1);
dim->set_stride(1);
}
const Shape reduce_window_shape =
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
ShapeUtil::MakeShape(F32, {111, 113, 113, 116});
HloInstruction* reduce_init_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* reduce_window =
@ -3923,8 +3927,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
// TODO(b/80488902): verify this module.
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@ -3969,9 +3972,10 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
dim->set_padding_high(100);
dim->set_window_dilation(1);
dim->set_base_dilation(1);
dim->set_stride(1);
}
const Shape reduce_window_shape =
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
ShapeUtil::MakeShape(F32, {111, 113, 113, 116});
HloInstruction* reduce_init_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* reduce_window =

View File

@ -2605,8 +2605,7 @@ ENTRY entry_computation {
}
)";
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
auto module = module_or_status.ConsumeValueOrDie();
RunCopyInsertion(module.get());

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@ -202,7 +201,7 @@ ENTRY main {
ROOT result = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true
}
)";
auto status = ParseAndReturnUnverifiedModule(hlo_string);
auto status = ParseAndReturnVerifiedModule(hlo_string);
TF_ASSERT_OK(status.status());
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
@ -268,7 +267,7 @@ TEST_F(ConditionalSimplifierTest,
param.1), true_computation=computation.1,
false_computation=computation.2
})";
auto status = ParseAndReturnUnverifiedModule(hlo_string);
auto status = ParseAndReturnVerifiedModule(hlo_string);
TF_ASSERT_OK(status.status());
std::unique_ptr<HloModule> module = status.ConsumeValueOrDie();
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
@ -315,7 +314,7 @@ ENTRY main {
ROOT result = () tuple()
}
)";
auto status = ParseAndReturnUnverifiedModule(hlo_string);
auto status = ParseAndReturnVerifiedModule(hlo_string);
TF_ASSERT_OK(status.status());
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
@ -359,7 +358,7 @@ ENTRY main {
ROOT result = (f32[10,10]{1,0}) tuple(get-first-index)
}
)";
auto status = ParseAndReturnUnverifiedModule(hlo_string);
auto status = ParseAndReturnVerifiedModule(hlo_string);
TF_ASSERT_OK(status.status());
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
@ -405,7 +404,7 @@ ENTRY main {
ROOT result = (f32[10,10]{1,0}) tuple(get-second-index)
}
)";
auto status = ParseAndReturnUnverifiedModule(hlo_string);
auto status = ParseAndReturnVerifiedModule(hlo_string);
TF_ASSERT_OK(status.status());
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
@ -471,7 +470,7 @@ ENTRY main {
ROOT add = f32[] add(gte.0, gte.1)
}
)";
auto status = ParseAndReturnUnverifiedModule(hlo_string);
auto status = ParseAndReturnVerifiedModule(hlo_string);
TF_ASSERT_OK(status.status());
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
@ -44,7 +43,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2
ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,2]{2,0,1} %copy, f32[1,1,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
auto computation = module->entry_computation();
HloInstruction* root = computation->root_instruction();
@ -76,7 +75,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2
ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,4]{2,0,1} %copy, f32[1,2,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
auto computation = module->entry_computation();
HloInstruction* root = computation->root_instruction();
@ -106,7 +105,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16
ROOT %convolution = f32[3,3,512,1]{3,2,1,0} convolution(f32[16,19,19,512]{3,2,1,0} %input, f32[16,19,19,512]{3,2,1,0} %filter), window={size=19x19 pad=1_1x1_1}, dim_labels=f01b_i01o->01fb, batch_group_count=512
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
auto computation = module->entry_computation();
HloInstruction* root = computation->root_instruction();

View File

@ -404,7 +404,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
class WhileCopyInsertionTest : public CopyInsertionTest {
protected:
WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {}
WhileCopyInsertionTest() : module_(CreateNewVerifiedModule()) {}
// Builds a While condition computation which reads the induction variable
// from the tuple parameter, and returns a predicate indicating whether this
@ -451,8 +451,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
// Use 'induction_variable' in computation with no path to output tuple.
Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
auto convert = builder.AddInstruction(
HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
auto update = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
HloInstruction::CreateBroadcast(data_shape_, convert, {}));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data, update));
// Create output Tuple.
@ -521,8 +524,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
// Use 'induction_variable' in computation with no path to output tuple.
Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
auto convert = builder.AddInstruction(
HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
auto update = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
HloInstruction::CreateBroadcast(data_shape_, convert, {}));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data, update));
// Create output Tuple.
@ -685,11 +691,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto v1 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
HloInstruction::CreateBroadcast(data_shape_, one, {}));
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto v2 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2}));
auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
@ -709,7 +715,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto one_vec = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
HloInstruction::CreateBroadcast(data_shape_, one, {}));
auto data_init =
builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
@ -722,7 +728,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto data_init = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
HloInstruction::CreateBroadcast(data_shape_, one, {}));
auto one_vec = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
{1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
@ -840,10 +846,10 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
EXPECT_THAT(
while_hlo->while_body()->root_instruction(),
op::Tuple(op::Add(op::Copy(), op::Constant()),
op::Add(op::GetTupleElement(), op::Broadcast(op::Copy()))));
EXPECT_THAT(while_hlo->while_body()->root_instruction(),
op::Tuple(op::Add(op::Copy(), op::Constant()),
op::Add(op::GetTupleElement(),
op::Broadcast(op::Convert(op::Copy())))));
// Both init indices need copies as they are constants.
EXPECT_THAT(while_hlo->operand(0),
@ -953,12 +959,17 @@ TEST_F(WhileCopyInsertionTest,
auto data_param = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "data"));
// Add dummy ops to ensure loop_init elements aren't entry parameters.
auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary(
iter_param->shape(), HloOpcode::kExp, iter_param));
Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
auto convert = builder.AddInstruction(
HloInstruction::CreateConvert(f32_scalar_shape, iter_param));
auto iter_value = builder.AddInstruction(
HloInstruction::CreateUnary(convert->shape(), HloOpcode::kExp, convert));
auto convert2 = builder.AddInstruction(
HloInstruction::CreateConvert(induction_variable_shape_, iter_value));
auto data_value = builder.AddInstruction(HloInstruction::CreateUnary(
data_param->shape(), HloOpcode::kExp, data_param));
auto loop_init = builder.AddInstruction(
HloInstruction::CreateTuple({iter_value, data_value}));
HloInstruction::CreateTuple({convert2, data_value}));
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape_, condition1, body1, loop_init));
@ -983,9 +994,9 @@ TEST_F(WhileCopyInsertionTest,
EXPECT_EQ(CountCopies(*entry), 2);
EXPECT_THAT(while_hlo1->operand(0),
op::Tuple(op::Exp(), op::Copy(op::Exp())));
op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
EXPECT_THAT(while_hlo2->operand(0),
op::Tuple(op::Exp(), op::Copy(op::Exp())));
op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
}
// Tests while body computation with nested tuple elements:

View File

@ -565,7 +565,7 @@ ENTRY main {
operand = s32[20,10]{1,0} parameter(0)
indices = s32[32,20] parameter(1)
dynamic_size = s32[] parameter(2)
ROOT gather = f32[32,10,10]{2,1,0} gather(%operand, %indices),
ROOT gather = s32[32,20,10]{2,1,0} gather(%operand, %indices),
offset_dims={2},
collapsed_slice_dims={0},
start_index_map={0},
@ -574,7 +574,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{2, {}},
DynamicParameterBinding::DynamicDimension{1, {}, 0}));

View File

@ -45,8 +45,8 @@ TEST_F(DynamicIndexSplitterTest, DynamicSlice) {
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnUnverifiedModule(kDynamicSlice, config));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kDynamicSlice, config));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
DynamicIndexSplitter().Run(module.get()));
EXPECT_TRUE(changed);
@ -84,7 +84,7 @@ TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) {
config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnUnverifiedModule(kDynamicUpdateSlice, config));
auto module, ParseAndReturnVerifiedModule(kDynamicUpdateSlice, config));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
DynamicIndexSplitter().Run(module.get()));
EXPECT_TRUE(changed);
@ -122,8 +122,8 @@ TEST_F(DynamicIndexSplitterTest, AlreadyScalar) {
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnUnverifiedModule(kDynamicSlice, config));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kDynamicSlice, config));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
DynamicIndexSplitter().Run(module.get()));
EXPECT_FALSE(changed);

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
@ -220,10 +219,8 @@ TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) {
class ExecutionTest : public HloTestBase {
protected:
std::unique_ptr<HloModule> GetHloModule(const string& hlo_text) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
std::unique_ptr<HloModule> module =
ParseAndReturnUnverifiedModule(hlo_text, config).ValueOrDie();
ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
return module;
}
Literal PadAndExecute(std::unique_ptr<HloModule> module,

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
@ -56,7 +55,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
DynamicParameterBinding binding;
@ -94,7 +93,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
DynamicParameterBinding binding;
@ -133,7 +132,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
DynamicParameterBinding binding;

View File

@ -14,13 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
using GatherExpanderTest = HloTestBase;
TEST_F(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
const string hlo_text = R"(
HloModule TensorFlowGatherMultipleBatchDims
@ -36,7 +40,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_text));
ParseAndReturnVerifiedModule(hlo_text));
Status status = GatherExpander{}.Run(module.get()).status();
EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
@ -47,7 +51,7 @@ ENTRY main {
"indices are not supported."));
}
TEST(GatherExpanderTest, AvoidDegenerateDims) {
TEST_F(GatherExpanderTest, AvoidDegenerateDims) {
const string hlo_text = R"(
HloModule TensorFlowGatherV2
@ -63,7 +67,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_text));
ParseAndReturnVerifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
ASSERT_TRUE(changed);
@ -105,7 +109,7 @@ ENTRY main {
ShapeUtil::GetTupleElementShape(while_shape, 3)));
}
TEST(GatherExpanderTest, CheckOpMetadata) {
TEST_F(GatherExpanderTest, CheckOpMetadata) {
const string hlo_text = R"(
HloModule TensorFlowGatherV2
@ -121,7 +125,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_text));
ParseAndReturnVerifiedModule(hlo_text));
OpMetadata metadata;
metadata.set_op_name("Gather");
module->entry_computation()->root_instruction()->set_metadata(metadata);

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -158,7 +157,7 @@ TEST_F(HloComputationTest, PostOrderTrace) {
builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1));
auto negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
// Trace instructions should be at the end of the sort.
EXPECT_THAT(computation->MakeInstructionPostOrder(),
@ -697,7 +696,7 @@ ENTRY entry {
ROOT t = (f32[128], f32[128]) tuple(add, crs1)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(),
ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(),
op::Add(), op::Tuple()));

View File

@ -1639,7 +1639,7 @@ ENTRY root {
p1 = s32[1000] copy(param)
ROOT t = (s32[1000], s32[1000]) tuple(p0, p1)
})";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
auto entry = module_->entry_computation();
entry->GetInstructionWithName("t");
auto& dataflow_analysis = RunAnalysis(GetParam());
@ -1990,7 +1990,7 @@ ENTRY %AddDependency (p: f32[3]) -> f32[3] {
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest()));
ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
HloDataflowAnalysis::Run(*module));
@ -2010,7 +2010,7 @@ INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation,
class HloDataflowAnalysisTestBase : public HloTestBase {
protected:
void BuildModule(std::unique_ptr<HloComputation> computation) {
module_ = CreateNewUnverifiedModule();
module_ = CreateNewVerifiedModule();
computation_ = module_->AddEntryComputation(std::move(computation));
}
@ -2228,7 +2228,7 @@ TEST_F(CanShareOperandBufferWithUserTest,
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
HloInstruction::CreateBroadcast(data_shape, one, {}));
auto neg = builder.AddInstruction(
HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
@ -2256,7 +2256,7 @@ TEST_F(CanShareOperandBufferWithUserTest,
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(0)));
auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
slice_shape, param, {zero, zero}, {1, 2, 2}));
slice_shape, param, {zero, zero}, {1, 2}));
auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, param, ds, {zero, zero}));
@ -2448,33 +2448,29 @@ TEST_F(CanShareOperandBufferWithUserTest,
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
auto builder = HloComputation::Builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
Shape update_shape = ShapeUtil::MakeShape(F32, {4});
Shape starts_shape = ShapeUtil::MakeShape(S32, {1});
Shape data_shape = ShapeUtil::MakeShape(F32, {1, 8});
Shape update_shape = ShapeUtil::MakeShape(F32, {1, 4});
Shape starts_shape = ShapeUtil::MakeShape(S32, {2});
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
auto update = builder.AddInstruction(
HloInstruction::CreateParameter(1, update_shape, "update"));
auto start0 = builder.AddInstruction(
HloInstruction::CreateParameter(2, starts_shape, "start0"));
auto start1 = builder.AddInstruction(
HloInstruction::CreateParameter(3, starts_shape, "start1"));
auto start = builder.AddInstruction(
HloInstruction::CreateParameter(2, starts_shape, "start"));
auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, data, update, {start0, start1}));
data_shape, data, update, {start}));
BuildModuleAndRunAnalysis(builder.Build());
// The DynamicUpdateSlice instruction can share with the data operand, but not
// with update or starts.
// with update or start.
EXPECT_TRUE(
dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {}));
EXPECT_FALSE(
dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {}));
EXPECT_FALSE(
dataflow_analysis_->CanShareOperandBufferWithUser(start0, {}, dus, {}));
EXPECT_FALSE(
dataflow_analysis_->CanShareOperandBufferWithUser(start1, {}, dus, {}));
dataflow_analysis_->CanShareOperandBufferWithUser(start, {}, dus, {}));
}
TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
@ -2498,7 +2494,7 @@ TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
index_vector_dim=1
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
computation_ = module_->entry_computation();
RunAnalysis();
@ -2526,7 +2522,7 @@ TEST_F(CanShareOperandBufferWithUserTest, TriangularSolveCanShare) {
transpose_a=NO_TRANSPOSE
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
computation_ = module_->entry_computation();
RunAnalysis();
@ -2611,7 +2607,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
HloInstruction::CreateBroadcast(data_shape, one, {}));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kAdd, dot, add_operand));
@ -2633,7 +2629,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
HloInstruction::CreateBroadcast(data_shape, one, {}));
auto reverse = builder.AddInstruction(
HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
@ -2661,7 +2657,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
HloInstruction::CreateBroadcast(data_shape, one, {}));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kMultiply, operand, operand));
auto two = builder.AddInstruction(HloInstruction::CreateConstant(
@ -2683,14 +2679,30 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
}
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
module_ = CreateNewVerifiedModule();
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
Shape pred_scalar_shape = ShapeUtil::MakeShape(PRED, {});
auto make_cond = [&data_shape]() {
auto b = HloComputation::Builder(TestName() + ".And");
auto p0 = b.AddInstruction(
HloInstruction::CreateParameter(0, pred_scalar_shape, "p0"));
auto p1 = b.AddInstruction(
HloInstruction::CreateParameter(1, pred_scalar_shape, "p1"));
b.AddInstruction(
HloInstruction::CreateBinary(pred_scalar_shape, HloOpcode::kAnd, p0, p1));
auto and_computation = module_->AddEmbeddedComputation(b.Build());
auto make_cond = [&data_shape, &and_computation]() {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
builder.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
auto compare = builder.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {8}), data, data, ComparisonDirection::kEq));
auto true_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
builder.AddInstruction(
HloInstruction::CreateReduce(ShapeUtil::MakeShape(PRED, {}), compare,
true_value, {0}, and_computation));
return builder.Build();
};
@ -2703,7 +2715,6 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
return builder.Build();
};
module_ = CreateNewUnverifiedModule();
HloComputation* cond_computation =
module_->AddEmbeddedComputation(make_cond());
HloComputation* body_computation =
@ -2734,11 +2745,11 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
auto one = sub_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto ones = sub_builder.AddInstruction(
HloInstruction::CreateBroadcast(shape, one, {1}));
HloInstruction::CreateBroadcast(shape, one, {}));
auto add = sub_builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
module_ = CreateNewUnverifiedModule();
module_ = CreateNewVerifiedModule();
auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
sub_computation->CreateFusionInstruction({add, ones},
HloInstruction::FusionKind::kLoop);

View File

@ -76,19 +76,20 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) {
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto token = builder.AddInstruction(HloInstruction::CreateToken());
builder.AddInstruction(
auto send = builder.AddInstruction(
HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
builder.AddInstruction(HloInstruction::CreateSendDone(send));
builder.AddInstruction(HloInstruction::CreateTuple({}));
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_EQ(5, computation->instruction_count());
HloDCE dce;
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_EQ(5, computation->instruction_count());
}
TEST_F(HloDceTest, CustomCallInstructionsWithSideEffect) {
@ -250,7 +251,7 @@ TEST_F(HloDceTest, DeadInstructionWithCalledComputation) {
// Tests that a while instruction with an infeed (effectul instruction) in its
// body is not removed, even its user count is 0.
TEST_F(HloDceTest, CalledComputationWithSideEffect) {
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(F32, {});
// Condition computation of a while instruction.
@ -274,8 +275,10 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
auto token = body_builder.AddInstruction(HloInstruction::CreateToken());
auto infeed = body_builder.AddInstruction(
HloInstruction::CreateInfeed(shape, token, ""));
body_builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed));
auto infeed_data = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, infeed, 0));
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, param, infeed_data));
}
auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
@ -306,7 +309,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
// Tests that a nested call instruction with a side effect is not removed.
TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(F32, {});
// Nested called computation with a side effect.
@ -328,8 +331,8 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
{
auto param = callee_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
callee_builder.AddInstruction(
HloInstruction::CreateCall(shape, {param}, nested_called_computation));
callee_builder.AddInstruction(HloInstruction::CreateCall(
ShapeUtil::MakeTokenShape(), {param}, nested_called_computation));
}
auto called_computation =
module->AddEmbeddedComputation(callee_builder.Build());
@ -338,22 +341,20 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
auto live_call = builder.AddInstruction(
HloInstruction::CreateCall(shape, {param}, called_computation));
builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
auto live_call = builder.AddInstruction(HloInstruction::CreateCall(
ShapeUtil::MakeTokenShape(), {param}, called_computation));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, computation->instruction_count());
EXPECT_EQ(2, param->user_count());
EXPECT_EQ(2, computation->instruction_count());
EXPECT_EQ(1, param->user_count());
EXPECT_EQ(0, live_call->user_count());
EXPECT_TRUE(HasInstruction(*computation, live_call));
HloDCE dce;
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
EXPECT_EQ(2, param->user_count());
EXPECT_EQ(2, computation->instruction_count());
EXPECT_EQ(1, param->user_count());
EXPECT_EQ(0, live_call->user_count());
EXPECT_TRUE(HasInstruction(*computation, live_call));
}
@ -400,7 +401,7 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) {
}
TEST_F(HloDceTest, KeepUsedSubcomputation) {
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName());
HloComputation::Builder subcomp_builder("reduction_subcomp");
@ -418,7 +419,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) {
// Create a dead reduce instruction.
builder.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(F32, {1}),
ShapeUtil::MakeShape(F32, {}),
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
builder.AddInstruction(
@ -428,7 +429,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) {
// Add another instruction as the root of the computation that also uses
// reduce_subcomp.
builder.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(F32, {1}),
ShapeUtil::MakeShape(F32, {}),
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")),
builder.AddInstruction(

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
@ -29,7 +28,7 @@ using ::testing::ContainsRegex;
class HloExecutionProfileTest : public HloTestBase {};
TEST_F(HloExecutionProfileTest, Basic) {
auto hlo_module = ParseAndReturnUnverifiedModule(R"(
auto hlo_module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY entry_computation {
lhs = f32[30,30]{1,0} parameter(0)

View File

@ -40,7 +40,7 @@ class HloGetDimensionSizeRewriterTest : public HloTestBase {
};
TEST_F(HloGetDimensionSizeRewriterTest, Ok) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule _
ENTRY gds {
p = s32[3,4] parameter(0)

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
@ -80,7 +79,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
HloInputOutputAliasConfig config(
module->entry_computation()->root_instruction()->shape());
@ -112,7 +111,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
HloInputOutputAliasConfig config(
module->entry_computation()->root_instruction()->shape());
@ -151,7 +150,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
HloInputOutputAliasConfig config(
module->entry_computation()->root_instruction()->shape());
@ -182,7 +181,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
HloInputOutputAliasConfig config(
module->entry_computation()->root_instruction()->shape());
@ -208,7 +207,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
HloInputOutputAliasConfig config(
module->entry_computation()->root_instruction()->shape());

View File

@ -2895,7 +2895,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
// These opcodes are not handled here.
case HloOpcode::kTrace:
break;
return Status::OK();
}
return InternalError(
"Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@ -951,7 +950,7 @@ ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
auto* root = module->entry_computation()->root_instruction();
auto* t1 = root->operand(0);
@ -1187,11 +1186,12 @@ TEST_F(HloInstructionTest, FuseInstructionKeepsInstruction) {
p2 = f32[32,32]{1,0} parameter(0)
p3 = f32[32,32]{1,0} parameter(1)
c1 = f32[] constant(1)
broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
mul = f32[32,32]{1,0} multiply(p2, p3)
ROOT add = f32[32,32]{1,0} fusion(mul, c1), kind=kLoop, calls=fused_add
ROOT add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kHloString));
ParseAndReturnVerifiedModule(kHloString));
HloInstruction* fused_add = module->entry_computation()->root_instruction();
HloInstruction* mul = fused_add->mutable_operand(0);
EXPECT_EQ(1, mul->user_count());
@ -1215,11 +1215,12 @@ TEST_F(HloInstructionTest, FuseInstructionIntoMultiOutputKeepsInstruction) {
p3 = f32[32,32]{1,0} parameter(1)
c1 = f32[] constant(1)
mul = f32[32,32]{1,0} multiply(p2, p3)
add = f32[32,32]{1,0} fusion(mul, c1), kind=kLoop, calls=fused_add
broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(mul, add)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kHloString));
ParseAndReturnVerifiedModule(kHloString));
HloInstruction* root = module->entry_computation()->root_instruction();
HloInstruction* mul = root->mutable_operand(0);
HloInstruction* fused_add = root->mutable_operand(1);
@ -1740,7 +1741,7 @@ ENTRY entry (param: s32[]) -> s32[] {
// Check that deep clones really deep clones every instruction and
// computations, without leaving dangling pointers to the old module.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
std::unique_ptr<HloModule> clone = module->Clone();
for (HloComputation* computation : clone->computations()) {
EXPECT_EQ(computation->parent(), clone.get());
@ -1860,7 +1861,7 @@ TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
dim_labels=b0f_0io->b0f, operand_precision={high,default}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kHloString));
ParseAndReturnVerifiedModule(kHloString));
auto* conv = module->entry_computation()->root_instruction();
auto clone = conv->Clone();
@ -1873,10 +1874,10 @@ TEST_F(HloInstructionTest, PreserveOuterDimensionPartitionsOnClone) {
constexpr char kHloString[] = R"(
HloModule test_module
ENTRY test {
ROOT iota = f32[100] iota(), iota_dimension=1, outer_dimension_partitions={0, 50}
ROOT iota = f32[100] iota(), iota_dimension=0, outer_dimension_partitions={0, 50}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kHloString));
ParseAndReturnVerifiedModule(kHloString));
auto* iota = module->entry_computation()->root_instruction();
auto clone = iota->Clone();

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
@ -59,7 +58,7 @@ class HloLivenessAnalysisTest : public HloTestBase {
// Test that add instruction at entry root is live at all output shape indices.
TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@ -75,7 +74,7 @@ TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
// Test that a dead add instruction is marked as dead by analysis.
TEST_F(HloLivenessAnalysisTest, DeadAdd) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@ -94,7 +93,7 @@ TEST_F(HloLivenessAnalysisTest, DeadAdd) {
// Test that all output shape indices of entry root tuple (and defining
// instruction in its output) are marked live.
TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@ -113,7 +112,7 @@ TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
// Tests that all outputs of nested tuple and entry root (and defining
// instruction values appearing in its output) are marked live.
TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(1)
@ -140,7 +139,7 @@ TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
// Tests that GTE at entry root of Tuple instruction only propgates liveness
// to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@ -162,7 +161,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
// Tests that GTE at entry root of nested Tuple instruction only propgates
// liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@ -199,7 +198,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
// Tests that GTE of GTE (at entry root) of nested Tuple instruction only
// propgates liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@ -240,7 +239,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
// Test that live/dead while tuple elements are marked live/dead correctly.
TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@ -291,8 +290,13 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
// Tests that a tuple element live in while.cond computation, propagates
// liveness to while.body.root/while.result/while.operand (where it is unused).
TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleLoop
add_S32 {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(lhs, rhs)
}
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
@ -305,8 +309,10 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
SimpleLoop.condition {
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1
add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4)
get-tuple-element.4 = s32[3]{0} get-tuple-element(loop_var.2), index=1
zero = s32[] constant(0)
reduce = s32[] reduce(get-tuple-element.4, zero), dimensions={0}, to_apply=add_S32
add.1 = s32[] add(get-tuple-element.3, reduce)
constant.2 = s32[] constant(5)
ROOT less-than = pred[] compare(add.1, constant.2), direction=LT
}
@ -338,14 +344,14 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {}));
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0}));
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1}));
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {}));
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {}));
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {}));
}
// Tests that a use of while.result{0} propagates liveness to
// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}.
TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[], s32[]) parameter(0)
@ -399,7 +405,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
}
TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule OutfeedLoop
WhileBody {
body_param = (s32[]) parameter(0)
@ -432,7 +438,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
}
TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule OutfeedLoop
InnerWhileBody {
body_param = (s32[]) parameter(0)

View File

@ -14,9 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
@ -25,6 +26,8 @@ using ::testing::Eq;
namespace xla {
namespace {
using HloMatchersTest = HloTestBase;
string DescribeHloMatcher(const ::testing::Matcher<const HloInstruction*>& m) {
std::stringstream ss;
m.DescribeTo(&ss);
@ -39,7 +42,7 @@ string Explain(const T& t, const M& m) {
return listener.str();
}
TEST(HloMatchersTest, Test) {
TEST_F(HloMatchersTest, Test) {
auto shape = ShapeUtil::MakeShape(F32, {1});
auto param = HloInstruction::CreateParameter(0, shape, "param");
auto mul = HloInstruction::CreateBinary(shape, HloOpcode::kMultiply,
@ -85,7 +88,7 @@ TEST(HloMatchersTest, Test) {
"add, (%param = f32[1]{0} parameter(0))"));
}
TEST(HloMatchersTest, CustomCallMatcher) {
TEST_F(HloMatchersTest, CustomCallMatcher) {
auto c1 =
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3}));
auto c2 =
@ -116,7 +119,7 @@ TEST(HloMatchersTest, CustomCallMatcher) {
R"(custom-call with call target that is equal to "foo_target")");
}
TEST(HloMatchersTest, ShapeMatcher) {
TEST_F(HloMatchersTest, ShapeMatcher) {
auto p0 = HloInstruction::CreateParameter(
0, ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}), "param");
@ -154,7 +157,7 @@ TEST(HloMatchersTest, ShapeMatcher) {
"(expected: f32[7,5]{1,0})");
}
TEST(HloMatchersTest, ShardingMatcher) {
TEST_F(HloMatchersTest, ShardingMatcher) {
auto p0 = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {5}),
"param.0");
p0->clear_sharding();
@ -196,7 +199,7 @@ TEST(HloMatchersTest, ShardingMatcher) {
"has incorrect sharding (expected: {maximal device=0})");
}
TEST(HloMatchersTest, DotMatcher) {
TEST_F(HloMatchersTest, DotMatcher) {
string hlo_string = R"(
HloModule DotOperationFusion_TransposeFusion
@ -208,7 +211,7 @@ ENTRY DotOperationFusion_TransposeFusion {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1),
@ -232,7 +235,7 @@ ENTRY DotOperationFusion_TransposeFusion {
"rhs_contracting_dimensions (got {0} want {1})");
}
TEST(HloMatchersTest, ComparisonMatcher) {
TEST_F(HloMatchersTest, ComparisonMatcher) {
auto shape = ShapeUtil::MakeShape(F32, {1});
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
@ -264,7 +267,7 @@ TEST(HloMatchersTest, ComparisonMatcher) {
"has wrong comparison direction (got EQ, want NE)"));
}
TEST(HloMatchersTest, AsyncCopyMatcher) {
TEST_F(HloMatchersTest, AsyncCopyMatcher) {
Shape shape_memspace1 = ShapeUtil::MakeShapeWithLayout(
F32, {16}, /*minor_to_major=*/{0}, /*tiles=*/{},
/*element_size_in_bits=*/0, /*memory_space=*/1);

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
@ -137,7 +136,7 @@ ENTRY root {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
auto size_fn = [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
@ -190,7 +189,7 @@ ENTRY entry {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
auto size_fn = [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
@ -334,7 +333,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
EXPECT_FALSE(module->has_schedule());
TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status());
ASSERT_TRUE(module->has_schedule());

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@ -71,7 +70,7 @@ class HloModuleDceTest : public HloTestBase {
// Tests that a while with all outputs live is unmodified.
TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@ -108,7 +107,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
// Tests a while loop with one unused output (which is used in the while loop
// body by an instruction with side-effects: rng) is unmodified.
TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], f32[]) parameter(0)
@ -118,7 +117,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1
constant.2 = f32[] constant(1.0)
rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform
add.1 = s32[] add(get-tuple-element.2, constant.2)
add.1 = f32[] add(get-tuple-element.2, constant.2)
ROOT tuple = (s32[], f32[]) tuple(add, add.1)
}
SimpleLoop.condition {
@ -148,7 +147,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
// Tests that a while loop with one dead tuple element at {1} has its while
// loop body modified to make that tuple element pass-through the while body.
TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@ -191,7 +190,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
// dead in while.body{1} and at while.result{1}) propgates liveness of this
// tuple element to while.body{1} and at while.result{1}.
TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[]) parameter(0)
@ -233,7 +232,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
// Tests that HloModuleDCE can remove a dead tuple element at index {1} between
// two dependent while loops.
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleLoop
SimpleLoop.body0 {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@ -301,7 +300,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
// while.2{1}, between two dependent while loops.
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule SimpleLoop
SimpleLoop.body0 {
loop_var.1 = (s32[3]{0}, s32[]) parameter(0)
@ -367,7 +366,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
// Tests that a while whose body has outfeed operations is not DCE-ed.
TEST_F(HloModuleDceTest, WhileWithOutfeed) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule OutfeedLoop
WhileBody {
body_param = (s32[]) parameter(0)
@ -404,7 +403,7 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) {
// variable changes are not elided within the loop body, if the condition
// computation uses them.
TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule InfiniteLoop
WhileBody {
body_param = (s32[], s32[]) parameter(0)

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -45,7 +44,7 @@ ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(text));
ParseAndReturnVerifiedModule(text));
HloModuleGroup group(std::move(module));
EXPECT_EQ(group.modules().size(), 1);
@ -84,9 +83,9 @@ ENTRY %entry (a: f32[]) -> f32[] {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
ParseAndReturnUnverifiedModule(text_0));
ParseAndReturnVerifiedModule(text_0));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
ParseAndReturnUnverifiedModule(text_1));
ParseAndReturnVerifiedModule(text_1));
std::vector<std::unique_ptr<HloModule>> modules;
modules.push_back(std::move(module_0));
modules.push_back(std::move(module_1));
@ -123,9 +122,9 @@ ENTRY %entry (a: f32[]) -> f32[] {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
ParseAndReturnUnverifiedModule(text_0));
ParseAndReturnVerifiedModule(text_0));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
ParseAndReturnUnverifiedModule(text_1));
ParseAndReturnVerifiedModule(text_1));
HloModuleGroup group(TestName());
group.push_back(std::move(module_0));
group.push_back(std::move(module_1));
@ -179,7 +178,7 @@ ENTRY entry {
const int64 send_channel = i;
const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(absl::StrFormat(
ParseAndReturnVerifiedModule(absl::StrFormat(
text, i, send_channel, send_channel,
recv_channel, recv_channel)));
group.push_back(std::move(module));

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@ -215,7 +214,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(text));
ParseAndReturnVerifiedModule(text));
ASSERT_FALSE(module->has_schedule());
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module_copy,
@ -237,7 +236,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(text));
ParseAndReturnVerifiedModule(text));
ASSERT_TRUE(module->has_schedule());
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module_copy,
@ -274,7 +273,7 @@ ENTRY ReduceR3ToR2.v3 {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(text));
ParseAndReturnVerifiedModule(text));
// Perform various transformations on the graph:
//

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@ -338,7 +337,7 @@ ENTRY while.v11 {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
DependencyHloOrdering ordering(module.get());
ordering.ToString(); // Shouldn't crash.
}
@ -375,7 +374,7 @@ ENTRY root {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
DependencyHloOrdering ordering(module.get());
@ -507,23 +506,22 @@ TEST_F(HloOrderingTest, InterferenceWithOuterRoot) {
absl::string_view hlo_string = R"(
HloModule InterferenceWithOuterRoot, is_scheduled=true
Emmbedded (embedded_param: f32[42]) -> f32[42] {
embedded_param = f32[42]{0} parameter(0)
multiply = f32[42]{0} multiply(embedded_param, embedded_param)
ROOT log = f32[42]{0} log(multiply)
Emmbedded (embedded_param: f32[4096,4096]) -> f32[4096,4096] {
embedded_param = f32[4096,4096]{1,0} parameter(0)
multiply = f32[4096,4096]{1,0} multiply(embedded_param, embedded_param)
ROOT log = f32[4096,4096]{1,0} log(multiply)
}
ENTRY InterferenceWithOuterRoot {
param = f32[4096,4096]{1,0} parameter(0)
ROOT add = f32[4096,4096]{1,0} add(param, param)
call = f32[42]{0} call(param), to_apply=Emmbedded
call = f32[4096,4096]{1,0} call(param), to_apply=Emmbedded
}
)";
HloModuleConfig hlo_config;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string, hlo_config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, hlo_config));
TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
DependencyHloOrdering ordering(module.get());

View File

@ -602,9 +602,8 @@ ENTRY %entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
@ -643,9 +642,8 @@ ENTRY %entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(