[XLA:CPU] Remove code from parallel CPU backend outlining that was causing unnecessary copies to be inserted, and which is no longer necessary since we added co-located buffer support for kCall.
*) All bitcast copy is no longer necessary as CopyInsertion will insert copies at the root of the computation for a parameter which is live-out. *) Copy if root does not define buffer no longer necessary because colocated assignment looks at points-to set of root instruction. PiperOrigin-RevId: 168412076
This commit is contained in:
parent
5da4df92c2
commit
7e023d865d
@ -519,8 +519,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||||
"//tensorflow/compiler/xla/service:logical_buffer",
|
|
||||||
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -21,8 +21,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
@ -64,13 +62,11 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) {
|
|||||||
instructions_to_outline.clear();
|
instructions_to_outline.clear();
|
||||||
HloInstruction* outline_candidate = instruction;
|
HloInstruction* outline_candidate = instruction;
|
||||||
instructions_to_outline.push_back(outline_candidate);
|
instructions_to_outline.push_back(outline_candidate);
|
||||||
bool all_bitcasts = outline_candidate->opcode() == HloOpcode::kBitcast;
|
|
||||||
|
|
||||||
// Outline sole users with the current instruction.
|
// Outline sole users with the current instruction.
|
||||||
while (CanOutlineWithUser(outline_candidate)) {
|
while (CanOutlineWithUser(outline_candidate)) {
|
||||||
HloInstruction* prior_candidate = outline_candidate;
|
HloInstruction* prior_candidate = outline_candidate;
|
||||||
outline_candidate = *outline_candidate->users().begin();
|
outline_candidate = *outline_candidate->users().begin();
|
||||||
all_bitcasts &= outline_candidate->opcode() == HloOpcode::kBitcast;
|
|
||||||
if (std::any_of(outline_candidate->operands().begin(),
|
if (std::any_of(outline_candidate->operands().begin(),
|
||||||
outline_candidate->operands().end(),
|
outline_candidate->operands().end(),
|
||||||
[&](const HloInstruction* operand) {
|
[&](const HloInstruction* operand) {
|
||||||
@ -86,24 +82,6 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) {
|
|||||||
}
|
}
|
||||||
instructions_to_outline.push_back(outline_candidate);
|
instructions_to_outline.push_back(outline_candidate);
|
||||||
}
|
}
|
||||||
// If all instructions in the outline candidates are a bitcast, then create
|
|
||||||
// a copy at the head of the bitcasts and include it in the outlined
|
|
||||||
// instructions. The underlying problem is that a computation which forwards
|
|
||||||
// a parameter buffer to the output is not properly handled by the backends
|
|
||||||
// or analysis.
|
|
||||||
//
|
|
||||||
// This would be better handled by being smarter about choosing outline
|
|
||||||
// candidates in the first place.
|
|
||||||
if (all_bitcasts) {
|
|
||||||
// 'head' is the first instruction in the chain of bitcasts.
|
|
||||||
HloInstruction* head = instructions_to_outline[0];
|
|
||||||
HloInstruction* head_operand = head->mutable_operand(0);
|
|
||||||
HloInstruction* copy =
|
|
||||||
entry_computation->AddInstruction(HloInstruction::CreateUnary(
|
|
||||||
head_operand->shape(), HloOpcode::kCopy, head_operand));
|
|
||||||
TF_RETURN_IF_ERROR(head->ReplaceOperandWith(0, copy));
|
|
||||||
instructions_to_outline.insert(instructions_to_outline.begin(), copy);
|
|
||||||
}
|
|
||||||
|
|
||||||
outlined.insert(instructions_to_outline.begin(),
|
outlined.insert(instructions_to_outline.begin(),
|
||||||
instructions_to_outline.end());
|
instructions_to_outline.end());
|
||||||
@ -122,25 +100,6 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) {
|
|||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
|
|
||||||
TuplePointsToAnalysis::Run(module));
|
|
||||||
for (auto& computation : module->computations()) {
|
|
||||||
if (computation->IsFusionComputation()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
HloInstruction* root = computation->root_instruction();
|
|
||||||
// Copy root instruction if it does not define its own top-level buffer.
|
|
||||||
// TODO(b/32885001) Remove these copies (at least for the unambiguous case).
|
|
||||||
// TODO(b/32885001) Perform shallow copy if root value is a tuple.
|
|
||||||
if (!points_to_analysis->InstructionDefinesBufferAtIndex(root,
|
|
||||||
/*index=*/{})) {
|
|
||||||
HloInstruction* copy = computation->AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(root->shape(), HloOpcode::kCopy, root));
|
|
||||||
computation->set_root_instruction(copy);
|
|
||||||
changed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_VLOG_LINES(2, "ParallelizationPreparation EXIT");
|
XLA_VLOG_LINES(2, "ParallelizationPreparation EXIT");
|
||||||
XLA_VLOG_LINES(2, module->ToString());
|
XLA_VLOG_LINES(2, module->ToString());
|
||||||
return changed;
|
return changed;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user