[XLA] Insert control edges from write to read instructions for same buffers inside loops
Previously, if there were two non-fused instructions inside the loop, call them A and B, and A was reading and B was writing into the same buffer B, there was a necessity for copying B, as the order of (A, B) was not fixed. With this patch we make a best-effort approach to order reads before writes (this is not always possible, e.g. for a loop where every iteration swaps too argument). This drastically reduce the number of copies required in many loop , which in turn greatly improves the performance of many loops on GPU (as each copy is a separate kernel launch, taking at least ~3us of overhead). PiperOrigin-RevId: 339152422 Change-Id: Iea5b849e11fc43da2f20e6b063039ecc784363a1
This commit is contained in:
parent
7bd42cf6ba
commit
8f73770a19
@ -546,8 +546,8 @@ cc_library(
|
||||
hdrs = ["resource_operation_safety_analysis.h"],
|
||||
deps = [
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
@ -718,7 +718,6 @@ cc_library(
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:scope_internal",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
@ -731,6 +730,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -758,9 +758,9 @@ cc_library(
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -994,13 +994,13 @@ cc_library(
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
|
@ -34,7 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
@ -42,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
@ -24,11 +24,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
@ -27,11 +27,11 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
@ -30,12 +30,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
|
@ -3460,6 +3460,33 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "loop_schedule_linearizer",
|
||||
srcs = ["loop_schedule_linearizer.cc"],
|
||||
hdrs = ["loop_schedule_linearizer.h"],
|
||||
deps = [
|
||||
":dump",
|
||||
":hlo",
|
||||
":hlo_alias_analysis",
|
||||
":hlo_dce",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_ordering",
|
||||
":hlo_pass",
|
||||
":logical_buffer",
|
||||
":tuple_simplifier",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service/graphcycles",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
@ -3488,6 +3515,28 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "loop_schedule_linearizer_test",
|
||||
srcs = ["loop_schedule_linearizer_test.cc"],
|
||||
deps = [
|
||||
":copy_insertion",
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_matchers",
|
||||
":hlo_runner",
|
||||
":loop_schedule_linearizer",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "memory_space_assignment_utils",
|
||||
srcs = ["memory_space_assignment_utils.cc"],
|
||||
|
@ -1239,6 +1239,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:logistic_expander",
|
||||
"//tensorflow/compiler/xla/service:loop_schedule_linearizer",
|
||||
"//tensorflow/compiler/xla/service:qr_expander",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
|
||||
|
@ -92,6 +92,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/logistic_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
|
||||
#include "tensorflow/compiler/xla/service/qr_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
|
||||
@ -362,6 +363,7 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
|
||||
if (hlo_module->config().alias_passthrough_params()) {
|
||||
pipeline.AddPass<AliasPassthroughParams>();
|
||||
}
|
||||
pipeline.AddPass<LoopScheduleLinearizer>(GetCanShareBuffer());
|
||||
pipeline.AddPass<GpuCopyInsertion>(GetCanShareBuffer());
|
||||
pipeline.AddPass<GpuSanitizeConstantNames>();
|
||||
return pipeline.Run(hlo_module).status();
|
||||
|
@ -4,6 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
"//tensorflow/compiler/xla:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
// (2) When a new edge (x->y) is inserted, do nothing if rank[x] < rank[y].
|
||||
// (3) Otherwise: adjust ranks in the neighborhood of x and y.
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/ordered_set.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_GRAPHCYCLES_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_GRAPHCYCLES_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -149,4 +149,4 @@ class GraphCycles {
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_GRAPHCYCLES_H_
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
// A test for the GraphCycles interface.
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
|
||||
#include <optional>
|
||||
#include <random>
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_ORDERED_SET_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_ORDERED_SET_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -82,4 +82,4 @@ class OrderedSet {
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_ORDERED_SET_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/ordered_set.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
166
tensorflow/compiler/xla/service/loop_schedule_linearizer.cc
Normal file
166
tensorflow/compiler/xla/service/loop_schedule_linearizer.cc
Normal file
@ -0,0 +1,166 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Calculate ordering for HLO, for fast online checking of whether adding
|
||||
// additional dependencies would create cycles.
|
||||
struct ComputationInstructionOrdering {
|
||||
explicit ComputationInstructionOrdering(const HloComputation& computation) {
|
||||
for (const HloInstruction* instr : computation.instructions()) {
|
||||
for (const HloInstruction* control_pred : instr->control_predecessors()) {
|
||||
CHECK(this->InsertEdge(*control_pred, *instr))
|
||||
<< "Graph already contained a cycle";
|
||||
}
|
||||
|
||||
for (int op_id = 0; op_id < instr->operand_count(); op_id++) {
|
||||
const HloInstruction* op = instr->operand(op_id);
|
||||
CHECK(this->InsertEdge(*op, *instr))
|
||||
<< "Graph already contained a cycle";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int32 NodeIdForInstruction(const HloInstruction& instr) {
|
||||
int32 instruction_id = instr.unique_id();
|
||||
auto it = node_id_to_graph_id.find(instruction_id);
|
||||
|
||||
if (it != node_id_to_graph_id.end()) {
|
||||
return it->second;
|
||||
}
|
||||
int32 node_id = graph_cycles.NewNode();
|
||||
node_id_to_graph_id[instruction_id] = node_id;
|
||||
return node_id;
|
||||
}
|
||||
|
||||
// Returns `false` if adding an edge would have introduced a cycle. Does not
|
||||
// add an edge in that case. Returns `true` otherwise.
|
||||
bool InsertEdge(const HloInstruction& source, const HloInstruction& dest) {
|
||||
int32 source_id = NodeIdForInstruction(source);
|
||||
int32 dest_id = NodeIdForInstruction(dest);
|
||||
return graph_cycles.InsertEdge(source_id, dest_id);
|
||||
}
|
||||
|
||||
absl::flat_hash_map<int32, int32> node_id_to_graph_id;
|
||||
|
||||
tensorflow::GraphCycles graph_cycles;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static StatusOr<bool> AddControlEdgesForLoopWrites(
|
||||
HloInstruction* xla_while, HloAliasAnalysis& alias_analysis) {
|
||||
HloDataflowAnalysis& dataflow = alias_analysis.dataflow_analysis();
|
||||
HloComputation* body = xla_while->while_body();
|
||||
HloInstruction* root = body->root_instruction();
|
||||
HloInstruction* input = body->parameter_instruction(0);
|
||||
|
||||
bool changed = false;
|
||||
|
||||
// Compute dependency ordering ourselves. The reason we don't reuse other
|
||||
// computations is because it is hard to extract the underlying graph from
|
||||
// those abstractions.
|
||||
ComputationInstructionOrdering ordering(*body);
|
||||
ShapeTree<bool> indices_to_copy(xla_while->shape());
|
||||
|
||||
for (auto& p : indices_to_copy) {
|
||||
const ShapeIndex& index = p.first;
|
||||
|
||||
if (index.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (dataflow.GetValueSet(root, index).values().size() > 1 ||
|
||||
dataflow.GetValueSet(input, index).values().size() > 1) {
|
||||
VLOG(2) << "Index " << index.ToString() << " is associated with multiple "
|
||||
<< "values, not attempting to introduce stricter dependencies";
|
||||
} else {
|
||||
HloValue& value_at_root = dataflow.GetUniqueValueAt(root, index);
|
||||
HloValue& value_at_input = dataflow.GetUniqueValueAt(input, index);
|
||||
|
||||
if (value_at_root.shape().IsTuple()) {
|
||||
// TODO(cheshire): For simplicity we currently do not handle nested
|
||||
// tuples, as we haven't seen them in the examples we care about.
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(cheshire): This is too conservative and does not take aliasing
|
||||
// into account.
|
||||
HloInstruction* write = value_at_root.defining_instruction();
|
||||
|
||||
for (const HloUse& use : value_at_input.uses()) {
|
||||
HloInstruction* read = use.instruction;
|
||||
|
||||
if (read != write &&
|
||||
value_at_root != value_at_input
|
||||
|
||||
// TODO(cheshire): Parents sometimes differ in case of e.g. nested
|
||||
// loops, where the value is read/written into in the inner loop.
|
||||
// For now we skip this case for simplicity (as the inner loop
|
||||
// performance is more important in any case)
|
||||
&& read->parent() == write->parent()) {
|
||||
VLOG(2) << "Inside " << body->name() << ", index "
|
||||
<< index.ToString();
|
||||
if (!ordering.InsertEdge(*read, *write)) {
|
||||
VLOG(2) << "Not adding a control dependency from "
|
||||
<< read->ToShortString() << " to " << write->ToShortString()
|
||||
<< " as it would introduce a cycle";
|
||||
continue;
|
||||
}
|
||||
|
||||
changed |= absl::c_linear_search(read->control_successors(), write);
|
||||
|
||||
// Unless we want a copy, read should happen before write.
|
||||
TF_RETURN_IF_ERROR(read->AddControlDependencyTo(write));
|
||||
VLOG(2) << "Adding dependency: " << read->ToShortString()
|
||||
<< " before " << write->ToShortString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> LoopScheduleLinearizer::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
for (HloInstruction* instruction :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() == HloOpcode::kWhile) {
|
||||
StatusOr<bool> updated_loop =
|
||||
AddControlEdgesForLoopWrites(instruction, *alias_analysis);
|
||||
TF_RETURN_IF_ERROR(updated_loop.status());
|
||||
changed |= *updated_loop;
|
||||
}
|
||||
}
|
||||
}
|
||||
DumpHloModuleDuringPassIfEnabled(
|
||||
name(), "after inserting control edges inside while loop bodies",
|
||||
*module);
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // end namespace xla
|
53
tensorflow/compiler/xla/service/loop_schedule_linearizer.h
Normal file
53
tensorflow/compiler/xla/service/loop_schedule_linearizer.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOOP_SCHEDULE_LINEARIZER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_LOOP_SCHEDULE_LINEARIZER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Adds control dependency edges from instructions which "write" values inside
|
||||
// the loop, to instructions which "read" those same values, in order to avoid
|
||||
// extraneous copies. This is not always possible with our buffer layout
|
||||
// constraints (that is, assuming that every element of the tuple the while loop
|
||||
// operates upon gets the same buffer) as it may create cycles (an easiest
|
||||
// example of a dependency cycle is a loop doing `(a, b) = (b, a)`). Thus we
|
||||
// take a best-effort approach instead: add dependency edges only if we can show
|
||||
// they don't create a cycle.
|
||||
class LoopScheduleLinearizer : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "loop-schedule-linearizer"; }
|
||||
|
||||
explicit LoopScheduleLinearizer(
|
||||
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr)
|
||||
: can_share_buffer_(can_share_buffer) {}
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
// Backend specific function that decides whether an instruction can share
|
||||
// buffer with its operand.
|
||||
HloDataflowAnalysis::CanShareBuffer can_share_buffer_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOOP_SCHEDULE_LINEARIZER_H_
|
128
tensorflow/compiler/xla/service/loop_schedule_linearizer_test.cc
Normal file
128
tensorflow/compiler/xla/service/loop_schedule_linearizer_test.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
int64 CountCopies(const HloComputation& computation) {
|
||||
int64 count = 0;
|
||||
for (const auto& instruction : computation.instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
int64 CountCopies(const HloModule& module) {
|
||||
int64 count = 0;
|
||||
for (const auto& computation : module.computations()) {
|
||||
count += CountCopies(*computation);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
int64 CountControlEdges(const HloComputation& computation) {
|
||||
int64 count = 0;
|
||||
for (const auto& instruction : computation.instructions()) {
|
||||
count += instruction->control_successors().size();
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
int64 CountControlEdges(const HloModule& module) {
|
||||
int64 count = 0;
|
||||
for (const auto& computation : module.computations()) {
|
||||
count += CountControlEdges(*computation);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
class LoopScheduleLinearizerTest : public HloTestBase {
|
||||
protected:
|
||||
void InsertCopies(HloModule* module) {
|
||||
LoopScheduleLinearizer loop_schedule_linearizer;
|
||||
ASSERT_IS_OK(loop_schedule_linearizer.Run(module).status());
|
||||
|
||||
CopyInsertion copy_insertion;
|
||||
ASSERT_IS_OK(copy_insertion.Run(module).status());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(LoopScheduleLinearizerTest, NoExtraCopiesRequired) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
while_body {
|
||||
input = (s32[], s32[]) parameter(0)
|
||||
counter = s32[] get-tuple-element(input), index=0
|
||||
buffer = s32[] get-tuple-element(input), index=1
|
||||
|
||||
one = s32[] constant(1)
|
||||
|
||||
updated_counter = s32[] add(counter, one)
|
||||
|
||||
updated_buffer = s32[] add(buffer, counter)
|
||||
ROOT out = (s32[], s32[]) tuple(updated_counter, updated_buffer)
|
||||
}
|
||||
|
||||
while_cond {
|
||||
input = (s32[], s32[]) parameter(0)
|
||||
counter = s32[] get-tuple-element(input), index=0
|
||||
bound = s32[] constant(100)
|
||||
ROOT cmp = pred[] compare(counter, bound), direction=LT
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
zero = s32[] constant(0)
|
||||
buffer = s32[] parameter(0)
|
||||
while_input = (s32[], s32[]) tuple(zero, buffer)
|
||||
ROOT out = (s32[], s32[]) while(while_input), condition=while_cond, body=while_body
|
||||
}
|
||||
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(
|
||||
*module->entry_computation()->root_instruction()->while_body()),
|
||||
0);
|
||||
EXPECT_EQ(CountControlEdges(
|
||||
*module->entry_computation()->root_instruction()->while_body()),
|
||||
1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user