From b6d720fd9303d2841ddbb34536653eb1ec3b3dd9 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Fri, 22 Nov 2019 15:09:37 -0800 Subject: [PATCH] [XLA] Use alias analysis to find all aliased required assmts in mem space assmt. We were previously using dataflow analysis and that can be incorrect since it didn't know about aliased buffers. Due to that, we might end up assigning module inputs and outputs to the alternate memory space even though that is not allowed. PiperOrigin-RevId: 282050905 Change-Id: I8a257e15a00f2fb1a0155ac9eaa98a582a05cadc --- .../xla/service/memory_space_assignment.cc | 41 +++---- .../service/memory_space_assignment_test.cc | 105 ++++++++++++++++++ 2 files changed, 126 insertions(+), 20 deletions(-) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 08d932866c0..751d258142a 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -370,9 +370,7 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { // adding a required assignment. // TODO(berkin): If these values are already marked alternate memory, use // those instead. - const HloDataflowAnalysis& dataflow_analysis = - alias_analysis_.dataflow_analysis(); - const HloModule& module = dataflow_analysis.module(); + const HloModule& module = alias_analysis_.dataflow_analysis().module(); const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); HloComputation* entry_computation = module.entry_computation(); for (HloInstruction* parameter_instruction : @@ -382,15 +380,16 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { ShapeUtil::ForEachSubshape( parameter_instruction->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { - for (const HloValue* value : - dataflow_analysis.GetValueSet(parameter_instruction, index) - .values()) { - VLOG(3) << "Adding required assignment for parameter value = " - << value->ToShortString() - << " time = " << parameter_instruction_time; - required_assignments_[value].push_back( - {/*memory_space=*/MemorySpace::kDefault, - /*time=*/parameter_instruction_time}); + for (const HloBuffer* buffer : + alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) { + for (const HloValue* value : buffer->values()) { + VLOG(3) << "Adding required assignment for parameter value = " + << value->ToShortString() + << " time = " << parameter_instruction_time; + required_assignments_[value].push_back( + {/*memory_space=*/MemorySpace::kDefault, + /*time=*/parameter_instruction_time}); + } } }); } @@ -399,14 +398,16 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { ShapeUtil::ForEachSubshape( root_instruction->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { - for (const HloValue* value : - dataflow_analysis.GetValueSet(root_instruction, index).values()) { - VLOG(3) << "Adding required assignment for output value = " - << value->ToShortString() - << " time = " << root_instruction_time; - required_assignments_[value].push_back( - {/*memory_space=*/MemorySpace::kDefault, - /*time=*/root_instruction_time}); + for (const HloBuffer* buffer : + alias_analysis_.ComputeBuffersAt(root_instruction, index)) { + for (const HloValue* value : buffer->values()) { + VLOG(3) << "Adding required assignment for output value = " + << value->ToShortString() + << " time = " << root_instruction_time; + required_assignments_[value].push_back( + {/*memory_space=*/MemorySpace::kDefault, + /*time=*/root_instruction_time}); + } } }); } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 637259032da..6041b96636e 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -1927,6 +1927,111 @@ TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) { EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem)); } +TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) { + Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); + Shape f32v1 = ShapeUtil::MakeShape(F32, {1}); + Shape t_s32_f32v1 = ShapeUtil::MakeTupleShape({s32, f32v1}); + auto module = CreateNewVerifiedModule("SimpleWhile"); + HloSchedule schedule(module.get()); + + // A simple compare-to-limit (x < 4) computation for a While. + // + // condition: + // const4[s32] -----------------------------------\ + // \ + // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than + // + HloComputation* cond_computation; + { + auto builder = HloComputation::Builder("WhileCond"); + auto const4 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, t_s32_f32v1, "x")); + auto index = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); + auto compare = builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index, + const4, ComparisonDirection::kLt)); + cond_computation = module->AddEmbeddedComputation(builder.Build()); + schedule.set_sequence(cond_computation, {const4, param, index, compare}); + } + + // Builds a simple body computation for a While. + // + // body: + // constv[f32[1]] --------------------------------------\ + // \ + // /--- get-tuple-elementv[1] --- addv ---\ + // param[(s32,f32[1])] ---| tuple + // \--- get-tuple-elementc[0] --- addc ---/ + // / + // const1[s32] -----------------------------------------/ + // + HloComputation* body_computation; + { + auto builder = HloComputation::Builder("WhileBody"); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto constv = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.1f}))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, t_s32_f32v1, "x")); + auto indexc = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(const1->shape(), param, 0)); + auto addc = builder.AddInstruction(HloInstruction::CreateBinary( + indexc->shape(), HloOpcode::kAdd, indexc, const1)); + auto indexv = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(constv->shape(), param, 1)); + auto addv = builder.AddInstruction(HloInstruction::CreateBinary( + constv->shape(), HloOpcode::kAdd, indexv, constv)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({addc, addv})); + body_computation = module->AddEmbeddedComputation(builder.Build()); + schedule.set_sequence(body_computation, {const1, constv, param, indexc, + addc, indexv, addv, tuple}); + } + + // This tests a simple while loop where the parameters are aliased with the + // output buffers. + auto builder = HloComputation::Builder("SimpleWhile"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, t_s32_f32v1, "param")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(s32, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32v1, param, 1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + auto while0 = builder.AddInstruction(HloInstruction::CreateWhile( + t_s32_f32v1, cond_computation, body_computation, tuple)); + + HloComputation* computation = module->AddEntryComputation(builder.Build()); + schedule.set_sequence(computation, {param, gte0, gte1, tuple, while0}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*max_prefetch_interval=*/50); + + // Ensure all parameters and while are placed in default memory. + Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout( + F32, {4, 6}, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kDefaultMemorySpace); + Shape s32_in_default_mem = ShapeUtil::MakeShapeWithLayout( + xla::S32, {}, + /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kDefaultMemorySpace); + Shape f32v1_in_default_mem = ShapeUtil::MakeShapeWithLayout( + F32, {1}, + /*minor_to_major=*/{0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kDefaultMemorySpace); + Shape t_s32_f32v1_in_default_mem = + ShapeUtil::MakeTupleShape({s32_in_default_mem, f32v1_in_default_mem}); + EXPECT_THAT(param, op::ShapeWithLayout(t_s32_f32v1_in_default_mem)); + EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem)); +} + INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, MemorySpaceAssignmentTest, ::testing::Values(false, true));