[XLA] Insert control edges from write to read instructions for same buffers inside loops

Previously, if there were two non-fused instructions inside the loop, call them A and B,
and A was reading and B was writing into the same buffer B, there was a necessity for
copying B, as the order of (A, B) was not fixed.

With this patch we make a best-effort approach to order reads before writes (this is not
always possible, e.g. for a loop where every iteration swaps too argument).

This drastically reduce the number of copies required in many loop , which in
turn greatly improves the performance of many loops on GPU (as each copy is a
separate kernel launch, taking at least ~3us of overhead).

PiperOrigin-RevId: 339152422
Change-Id: Iea5b849e11fc43da2f20e6b063039ecc784363a1
This commit is contained in:
George Karpenkov 2020-10-26 17:23:56 -07:00 committed by TensorFlower Gardener
parent 7bd42cf6ba
commit 8f73770a19
19 changed files with 420 additions and 20 deletions

View File

@ -546,8 +546,8 @@ cc_library(
hdrs = ["resource_operation_safety_analysis.h"],
deps = [
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
@ -718,7 +718,6 @@ cc_library(
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:side_effect_util",
@ -731,6 +730,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:union_find",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@ -758,9 +758,9 @@ cc_library(
deps = [
":flags",
":xla_activity_proto_cc",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -994,13 +994,13 @@ cc_library(
":xla_activity_listener",
":xla_activity_proto_cc",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:union_find",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
@ -42,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h"

View File

@ -24,11 +24,11 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h"

View File

@ -27,11 +27,11 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"

View File

@ -30,12 +30,12 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h"

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/algorithm.h"

View File

@ -3460,6 +3460,33 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "loop_schedule_linearizer",
srcs = ["loop_schedule_linearizer.cc"],
hdrs = ["loop_schedule_linearizer.h"],
deps = [
":dump",
":hlo",
":hlo_alias_analysis",
":hlo_dce",
":hlo_graph_dumper",
":hlo_ordering",
":hlo_pass",
":logical_buffer",
":tuple_simplifier",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service/graphcycles",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@ -3488,6 +3515,28 @@ tf_cc_test(
],
)
tf_cc_test(
name = "loop_schedule_linearizer_test",
srcs = ["loop_schedule_linearizer_test.cc"],
deps = [
":copy_insertion",
":hlo",
":hlo_graph_dumper",
":hlo_matchers",
":hlo_runner",
":loop_schedule_linearizer",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "memory_space_assignment_utils",
srcs = ["memory_space_assignment_utils.cc"],

View File

@ -1239,6 +1239,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:logistic_expander",
"//tensorflow/compiler/xla/service:loop_schedule_linearizer",
"//tensorflow/compiler/xla/service:qr_expander",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",

View File

@ -92,6 +92,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/logistic_expander.h"
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
#include "tensorflow/compiler/xla/service/qr_expander.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
@ -362,6 +363,7 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
if (hlo_module->config().alias_passthrough_params()) {
pipeline.AddPass<AliasPassthroughParams>();
}
pipeline.AddPass<LoopScheduleLinearizer>(GetCanShareBuffer());
pipeline.AddPass<GpuCopyInsertion>(GetCanShareBuffer());
pipeline.AddPass<GpuSanitizeConstantNames>();
return pipeline.Run(hlo_module).status();

View File

@ -4,6 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:internal",
"//tensorflow/compiler/xla:internal",
],
licenses = ["notice"], # Apache 2.0
)

View File

@ -29,7 +29,7 @@ limitations under the License.
// (2) When a new edge (x->y) is inserted, do nothing if rank[x] < rank[y].
// (3) Otherwise: adjust ranks in the neighborhood of x and y.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include <algorithm>
#include <unordered_set>
@ -38,7 +38,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
#include "tensorflow/compiler/xla/service/graphcycles/ordered_set.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_GRAPHCYCLES_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_GRAPHCYCLES_H_
#include <vector>
@ -149,4 +149,4 @@ class GraphCycles {
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_GRAPHCYCLES_H_

View File

@ -15,7 +15,7 @@ limitations under the License.
// A test for the GraphCycles interface.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
#include <optional>
#include <random>

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_ORDERED_SET_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_ORDERED_SET_H_
#include <vector>
@ -82,4 +82,4 @@ class OrderedSet {
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_ORDERED_SET_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
#include "tensorflow/compiler/xla/service/graphcycles/ordered_set.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"

View File

@ -0,0 +1,166 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
namespace xla {
namespace {
// Calculate ordering for HLO, for fast online checking of whether adding
// additional dependencies would create cycles.
struct ComputationInstructionOrdering {
explicit ComputationInstructionOrdering(const HloComputation& computation) {
for (const HloInstruction* instr : computation.instructions()) {
for (const HloInstruction* control_pred : instr->control_predecessors()) {
CHECK(this->InsertEdge(*control_pred, *instr))
<< "Graph already contained a cycle";
}
for (int op_id = 0; op_id < instr->operand_count(); op_id++) {
const HloInstruction* op = instr->operand(op_id);
CHECK(this->InsertEdge(*op, *instr))
<< "Graph already contained a cycle";
}
}
}
int32 NodeIdForInstruction(const HloInstruction& instr) {
int32 instruction_id = instr.unique_id();
auto it = node_id_to_graph_id.find(instruction_id);
if (it != node_id_to_graph_id.end()) {
return it->second;
}
int32 node_id = graph_cycles.NewNode();
node_id_to_graph_id[instruction_id] = node_id;
return node_id;
}
// Returns `false` if adding an edge would have introduced a cycle. Does not
// add an edge in that case. Returns `true` otherwise.
bool InsertEdge(const HloInstruction& source, const HloInstruction& dest) {
int32 source_id = NodeIdForInstruction(source);
int32 dest_id = NodeIdForInstruction(dest);
return graph_cycles.InsertEdge(source_id, dest_id);
}
absl::flat_hash_map<int32, int32> node_id_to_graph_id;
tensorflow::GraphCycles graph_cycles;
};
} // namespace
static StatusOr<bool> AddControlEdgesForLoopWrites(
HloInstruction* xla_while, HloAliasAnalysis& alias_analysis) {
HloDataflowAnalysis& dataflow = alias_analysis.dataflow_analysis();
HloComputation* body = xla_while->while_body();
HloInstruction* root = body->root_instruction();
HloInstruction* input = body->parameter_instruction(0);
bool changed = false;
// Compute dependency ordering ourselves. The reason we don't reuse other
// computations is because it is hard to extract the underlying graph from
// those abstractions.
ComputationInstructionOrdering ordering(*body);
ShapeTree<bool> indices_to_copy(xla_while->shape());
for (auto& p : indices_to_copy) {
const ShapeIndex& index = p.first;
if (index.empty()) {
continue;
}
if (dataflow.GetValueSet(root, index).values().size() > 1 ||
dataflow.GetValueSet(input, index).values().size() > 1) {
VLOG(2) << "Index " << index.ToString() << " is associated with multiple "
<< "values, not attempting to introduce stricter dependencies";
} else {
HloValue& value_at_root = dataflow.GetUniqueValueAt(root, index);
HloValue& value_at_input = dataflow.GetUniqueValueAt(input, index);
if (value_at_root.shape().IsTuple()) {
// TODO(cheshire): For simplicity we currently do not handle nested
// tuples, as we haven't seen them in the examples we care about.
continue;
}
// TODO(cheshire): This is too conservative and does not take aliasing
// into account.
HloInstruction* write = value_at_root.defining_instruction();
for (const HloUse& use : value_at_input.uses()) {
HloInstruction* read = use.instruction;
if (read != write &&
value_at_root != value_at_input
// TODO(cheshire): Parents sometimes differ in case of e.g. nested
// loops, where the value is read/written into in the inner loop.
// For now we skip this case for simplicity (as the inner loop
// performance is more important in any case)
&& read->parent() == write->parent()) {
VLOG(2) << "Inside " << body->name() << ", index "
<< index.ToString();
if (!ordering.InsertEdge(*read, *write)) {
VLOG(2) << "Not adding a control dependency from "
<< read->ToShortString() << " to " << write->ToShortString()
<< " as it would introduce a cycle";
continue;
}
changed |= absl::c_linear_search(read->control_successors(), write);
// Unless we want a copy, read should happen before write.
TF_RETURN_IF_ERROR(read->AddControlDependencyTo(write));
VLOG(2) << "Adding dependency: " << read->ToShortString()
<< " before " << write->ToShortString();
}
}
}
}
return changed;
}
StatusOr<bool> LoopScheduleLinearizer::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, can_share_buffer_));
bool changed = false;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kWhile) {
StatusOr<bool> updated_loop =
AddControlEdgesForLoopWrites(instruction, *alias_analysis);
TF_RETURN_IF_ERROR(updated_loop.status());
changed |= *updated_loop;
}
}
}
DumpHloModuleDuringPassIfEnabled(
name(), "after inserting control edges inside while loop bodies",
*module);
return changed;
}
} // end namespace xla

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOOP_SCHEDULE_LINEARIZER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LOOP_SCHEDULE_LINEARIZER_H_
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// Adds control dependency edges from instructions which "write" values inside
// the loop, to instructions which "read" those same values, in order to avoid
// extraneous copies. This is not always possible with our buffer layout
// constraints (that is, assuming that every element of the tuple the while loop
// operates upon gets the same buffer) as it may create cycles (an easiest
// example of a dependency cycle is a loop doing `(a, b) = (b, a)`). Thus we
// take a best-effort approach instead: add dependency edges only if we can show
// they don't create a cycle.
class LoopScheduleLinearizer : public HloModulePass {
public:
absl::string_view name() const override { return "loop-schedule-linearizer"; }
explicit LoopScheduleLinearizer(
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr)
: can_share_buffer_(can_share_buffer) {}
StatusOr<bool> Run(HloModule* module) override;
private:
// Backend specific function that decides whether an instruction can share
// buffer with its operand.
HloDataflowAnalysis::CanShareBuffer can_share_buffer_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOOP_SCHEDULE_LINEARIZER_H_

View File

@ -0,0 +1,128 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
#include <set>
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace xla {
namespace {
int64 CountCopies(const HloComputation& computation) {
int64 count = 0;
for (const auto& instruction : computation.instructions()) {
if (instruction->opcode() == HloOpcode::kCopy) {
count++;
}
}
return count;
}
int64 CountCopies(const HloModule& module) {
int64 count = 0;
for (const auto& computation : module.computations()) {
count += CountCopies(*computation);
}
return count;
}
int64 CountControlEdges(const HloComputation& computation) {
int64 count = 0;
for (const auto& instruction : computation.instructions()) {
count += instruction->control_successors().size();
}
return count;
}
int64 CountControlEdges(const HloModule& module) {
int64 count = 0;
for (const auto& computation : module.computations()) {
count += CountControlEdges(*computation);
}
return count;
}
class LoopScheduleLinearizerTest : public HloTestBase {
protected:
void InsertCopies(HloModule* module) {
LoopScheduleLinearizer loop_schedule_linearizer;
ASSERT_IS_OK(loop_schedule_linearizer.Run(module).status());
CopyInsertion copy_insertion;
ASSERT_IS_OK(copy_insertion.Run(module).status());
}
};
TEST_F(LoopScheduleLinearizerTest, NoExtraCopiesRequired) {
absl::string_view hlo_string = R"(
HloModule module
while_body {
input = (s32[], s32[]) parameter(0)
counter = s32[] get-tuple-element(input), index=0
buffer = s32[] get-tuple-element(input), index=1
one = s32[] constant(1)
updated_counter = s32[] add(counter, one)
updated_buffer = s32[] add(buffer, counter)
ROOT out = (s32[], s32[]) tuple(updated_counter, updated_buffer)
}
while_cond {
input = (s32[], s32[]) parameter(0)
counter = s32[] get-tuple-element(input), index=0
bound = s32[] constant(100)
ROOT cmp = pred[] compare(counter, bound), direction=LT
}
ENTRY entry {
zero = s32[] constant(0)
buffer = s32[] parameter(0)
while_input = (s32[], s32[]) tuple(zero, buffer)
ROOT out = (s32[], s32[]) while(while_input), condition=while_cond, body=while_body
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(
*module->entry_computation()->root_instruction()->while_body()),
0);
EXPECT_EQ(CountControlEdges(
*module->entry_computation()->root_instruction()->while_body()),
1);
}
} // namespace
} // namespace xla