diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7bd7b5964f9..5c84eecd976 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -546,8 +546,8 @@ cc_library( hdrs = ["resource_operation_safety_analysis.h"], deps = [ ":xla_cluster_util", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -718,7 +718,6 @@ cc_library( "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/cc:scope_internal", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:side_effect_util", @@ -731,6 +730,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -758,9 +758,9 @@ cc_library( deps = [ ":flags", ":xla_activity_proto_cc", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -994,13 +994,13 @@ cc_library( ":xla_activity_listener", ":xla_activity_proto_cc", ":xla_cluster_util", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:graph", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 62e121420c3..87b06c2ab36 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" @@ -42,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 65da072483b..224bedabd3b 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -24,11 +24,11 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index d482642b44c..fd55cab637c 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -27,11 +27,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 967cfa3fca1..62a4c5372b0 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -30,12 +30,12 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h index c652e5fe216..3931ae6c7cc 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.h +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ #define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index e2a1d159336..bf6dd5ab9f4 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -21,8 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e16575bebd4..851eca4f7be 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3460,6 +3460,33 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/graphcycles", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "loop_schedule_linearizer", + srcs = ["loop_schedule_linearizer.cc"], + hdrs = ["loop_schedule_linearizer.h"], + deps = [ + ":dump", + ":hlo", + ":hlo_alias_analysis", + ":hlo_dce", + ":hlo_graph_dumper", + ":hlo_ordering", + ":hlo_pass", + ":logical_buffer", + ":tuple_simplifier", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -3488,6 +3515,28 @@ tf_cc_test( ], ) +tf_cc_test( + name = "loop_schedule_linearizer_test", + srcs = ["loop_schedule_linearizer_test.cc"], + deps = [ + ":copy_insertion", + ":hlo", + ":hlo_graph_dumper", + ":hlo_matchers", + ":hlo_runner", + ":loop_schedule_linearizer", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "memory_space_assignment_utils", srcs = ["memory_space_assignment_utils.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d2c462e0957..05535c0dbe9 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1239,6 +1239,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:logistic_expander", + "//tensorflow/compiler/xla/service:loop_schedule_linearizer", "//tensorflow/compiler/xla/service:qr_expander", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 25fbc0e05cb..6e81dc0d5e2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -92,6 +92,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" +#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h" #include "tensorflow/compiler/xla/service/qr_expander.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" @@ -362,6 +363,7 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { if (hlo_module->config().alias_passthrough_params()) { pipeline.AddPass(); } + pipeline.AddPass(GetCanShareBuffer()); pipeline.AddPass(GetCanShareBuffer()); pipeline.AddPass(); return pipeline.Run(hlo_module).status(); diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/xla/service/graphcycles/BUILD similarity index 96% rename from tensorflow/compiler/jit/graphcycles/BUILD rename to tensorflow/compiler/xla/service/graphcycles/BUILD index 23d994c27c5..0c1ba803ccf 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/xla/service/graphcycles/BUILD @@ -4,6 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( default_visibility = [ "//tensorflow/compiler/tf2xla:internal", + "//tensorflow/compiler/xla:internal", ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/xla/service/graphcycles/graphcycles.cc similarity index 99% rename from tensorflow/compiler/jit/graphcycles/graphcycles.cc rename to tensorflow/compiler/xla/service/graphcycles/graphcycles.cc index 416e101a025..69e10871815 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/xla/service/graphcycles/graphcycles.cc @@ -29,7 +29,7 @@ limitations under the License. // (2) When a new edge (x->y) is inserted, do nothing if rank[x] < rank[y]. // (3) Otherwise: adjust ranks in the neighborhood of x and y. -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include #include @@ -38,7 +38,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/jit/graphcycles/ordered_set.h" +#include "tensorflow/compiler/xla/service/graphcycles/ordered_set.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/xla/service/graphcycles/graphcycles.h similarity index 96% rename from tensorflow/compiler/jit/graphcycles/graphcycles.h rename to tensorflow/compiler/xla/service/graphcycles/graphcycles.h index 3e20c4e641c..5028091c928 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.h +++ b/tensorflow/compiler/xla/service/graphcycles/graphcycles.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ -#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_GRAPHCYCLES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_GRAPHCYCLES_H_ #include @@ -149,4 +149,4 @@ class GraphCycles { }; } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_GRAPHCYCLES_H_ diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/xla/service/graphcycles/graphcycles_test.cc similarity index 99% rename from tensorflow/compiler/jit/graphcycles/graphcycles_test.cc rename to tensorflow/compiler/xla/service/graphcycles/graphcycles_test.cc index 5b7eec19e27..f44a36c677f 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc +++ b/tensorflow/compiler/xla/service/graphcycles/graphcycles_test.cc @@ -15,7 +15,7 @@ limitations under the License. // A test for the GraphCycles interface. -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include #include diff --git a/tensorflow/compiler/jit/graphcycles/ordered_set.h b/tensorflow/compiler/xla/service/graphcycles/ordered_set.h similarity index 93% rename from tensorflow/compiler/jit/graphcycles/ordered_set.h rename to tensorflow/compiler/xla/service/graphcycles/ordered_set.h index 0417782b984..622c5d3afb9 100644 --- a/tensorflow/compiler/jit/graphcycles/ordered_set.h +++ b/tensorflow/compiler/xla/service/graphcycles/ordered_set.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_ -#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_ORDERED_SET_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_ORDERED_SET_H_ #include @@ -82,4 +82,4 @@ class OrderedSet { }; } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GRAPHCYCLES_ORDERED_SET_H_ diff --git a/tensorflow/compiler/jit/graphcycles/ordered_set_test.cc b/tensorflow/compiler/xla/service/graphcycles/ordered_set_test.cc similarity index 97% rename from tensorflow/compiler/jit/graphcycles/ordered_set_test.cc rename to tensorflow/compiler/xla/service/graphcycles/ordered_set_test.cc index 38ac1cfe9b6..eec433b979b 100644 --- a/tensorflow/compiler/jit/graphcycles/ordered_set_test.cc +++ b/tensorflow/compiler/xla/service/graphcycles/ordered_set_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/graphcycles/ordered_set.h" +#include "tensorflow/compiler/xla/service/graphcycles/ordered_set.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/service/loop_schedule_linearizer.cc b/tensorflow/compiler/xla/service/loop_schedule_linearizer.cc new file mode 100644 index 00000000000..0da457c829c --- /dev/null +++ b/tensorflow/compiler/xla/service/loop_schedule_linearizer.cc @@ -0,0 +1,166 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h" + +#include "tensorflow/compiler/xla/service/dump.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" + +namespace xla { + +namespace { + +// Calculate ordering for HLO, for fast online checking of whether adding +// additional dependencies would create cycles. +struct ComputationInstructionOrdering { + explicit ComputationInstructionOrdering(const HloComputation& computation) { + for (const HloInstruction* instr : computation.instructions()) { + for (const HloInstruction* control_pred : instr->control_predecessors()) { + CHECK(this->InsertEdge(*control_pred, *instr)) + << "Graph already contained a cycle"; + } + + for (int op_id = 0; op_id < instr->operand_count(); op_id++) { + const HloInstruction* op = instr->operand(op_id); + CHECK(this->InsertEdge(*op, *instr)) + << "Graph already contained a cycle"; + } + } + } + + int32 NodeIdForInstruction(const HloInstruction& instr) { + int32 instruction_id = instr.unique_id(); + auto it = node_id_to_graph_id.find(instruction_id); + + if (it != node_id_to_graph_id.end()) { + return it->second; + } + int32 node_id = graph_cycles.NewNode(); + node_id_to_graph_id[instruction_id] = node_id; + return node_id; + } + + // Returns `false` if adding an edge would have introduced a cycle. Does not + // add an edge in that case. Returns `true` otherwise. + bool InsertEdge(const HloInstruction& source, const HloInstruction& dest) { + int32 source_id = NodeIdForInstruction(source); + int32 dest_id = NodeIdForInstruction(dest); + return graph_cycles.InsertEdge(source_id, dest_id); + } + + absl::flat_hash_map node_id_to_graph_id; + + tensorflow::GraphCycles graph_cycles; +}; + +} // namespace + +static StatusOr AddControlEdgesForLoopWrites( + HloInstruction* xla_while, HloAliasAnalysis& alias_analysis) { + HloDataflowAnalysis& dataflow = alias_analysis.dataflow_analysis(); + HloComputation* body = xla_while->while_body(); + HloInstruction* root = body->root_instruction(); + HloInstruction* input = body->parameter_instruction(0); + + bool changed = false; + + // Compute dependency ordering ourselves. The reason we don't reuse other + // computations is because it is hard to extract the underlying graph from + // those abstractions. + ComputationInstructionOrdering ordering(*body); + ShapeTree indices_to_copy(xla_while->shape()); + + for (auto& p : indices_to_copy) { + const ShapeIndex& index = p.first; + + if (index.empty()) { + continue; + } + + if (dataflow.GetValueSet(root, index).values().size() > 1 || + dataflow.GetValueSet(input, index).values().size() > 1) { + VLOG(2) << "Index " << index.ToString() << " is associated with multiple " + << "values, not attempting to introduce stricter dependencies"; + } else { + HloValue& value_at_root = dataflow.GetUniqueValueAt(root, index); + HloValue& value_at_input = dataflow.GetUniqueValueAt(input, index); + + if (value_at_root.shape().IsTuple()) { + // TODO(cheshire): For simplicity we currently do not handle nested + // tuples, as we haven't seen them in the examples we care about. + continue; + } + + // TODO(cheshire): This is too conservative and does not take aliasing + // into account. + HloInstruction* write = value_at_root.defining_instruction(); + + for (const HloUse& use : value_at_input.uses()) { + HloInstruction* read = use.instruction; + + if (read != write && + value_at_root != value_at_input + + // TODO(cheshire): Parents sometimes differ in case of e.g. nested + // loops, where the value is read/written into in the inner loop. + // For now we skip this case for simplicity (as the inner loop + // performance is more important in any case) + && read->parent() == write->parent()) { + VLOG(2) << "Inside " << body->name() << ", index " + << index.ToString(); + if (!ordering.InsertEdge(*read, *write)) { + VLOG(2) << "Not adding a control dependency from " + << read->ToShortString() << " to " << write->ToShortString() + << " as it would introduce a cycle"; + continue; + } + + changed |= absl::c_linear_search(read->control_successors(), write); + + // Unless we want a copy, read should happen before write. + TF_RETURN_IF_ERROR(read->AddControlDependencyTo(write)); + VLOG(2) << "Adding dependency: " << read->ToShortString() + << " before " << write->ToShortString(); + } + } + } + } + return changed; +} + +StatusOr LoopScheduleLinearizer::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, can_share_buffer_)); + + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kWhile) { + StatusOr updated_loop = + AddControlEdgesForLoopWrites(instruction, *alias_analysis); + TF_RETURN_IF_ERROR(updated_loop.status()); + changed |= *updated_loop; + } + } + } + DumpHloModuleDuringPassIfEnabled( + name(), "after inserting control edges inside while loop bodies", + *module); + + return changed; +} + +} // end namespace xla diff --git a/tensorflow/compiler/xla/service/loop_schedule_linearizer.h b/tensorflow/compiler/xla/service/loop_schedule_linearizer.h new file mode 100644 index 00000000000..67ef37bcc5b --- /dev/null +++ b/tensorflow/compiler/xla/service/loop_schedule_linearizer.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOOP_SCHEDULE_LINEARIZER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LOOP_SCHEDULE_LINEARIZER_H_ + +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Adds control dependency edges from instructions which "write" values inside +// the loop, to instructions which "read" those same values, in order to avoid +// extraneous copies. This is not always possible with our buffer layout +// constraints (that is, assuming that every element of the tuple the while loop +// operates upon gets the same buffer) as it may create cycles (an easiest +// example of a dependency cycle is a loop doing `(a, b) = (b, a)`). Thus we +// take a best-effort approach instead: add dependency edges only if we can show +// they don't create a cycle. +class LoopScheduleLinearizer : public HloModulePass { + public: + absl::string_view name() const override { return "loop-schedule-linearizer"; } + + explicit LoopScheduleLinearizer( + const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) + : can_share_buffer_(can_share_buffer) {} + + StatusOr Run(HloModule* module) override; + + private: + // Backend specific function that decides whether an instruction can share + // buffer with its operand. + HloDataflowAnalysis::CanShareBuffer can_share_buffer_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOOP_SCHEDULE_LINEARIZER_H_ diff --git a/tensorflow/compiler/xla/service/loop_schedule_linearizer_test.cc b/tensorflow/compiler/xla/service/loop_schedule_linearizer_test.cc new file mode 100644 index 00000000000..d3f6d8b01a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/loop_schedule_linearizer_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h" + +#include + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace xla { +namespace { + +int64 CountCopies(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + count++; + } + } + return count; +} + +int64 CountCopies(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountCopies(*computation); + } + return count; +} + +int64 CountControlEdges(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + count += instruction->control_successors().size(); + } + return count; +} + +int64 CountControlEdges(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountControlEdges(*computation); + } + return count; +} + +class LoopScheduleLinearizerTest : public HloTestBase { + protected: + void InsertCopies(HloModule* module) { + LoopScheduleLinearizer loop_schedule_linearizer; + ASSERT_IS_OK(loop_schedule_linearizer.Run(module).status()); + + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module).status()); + } +}; + +TEST_F(LoopScheduleLinearizerTest, NoExtraCopiesRequired) { + absl::string_view hlo_string = R"( +HloModule module + +while_body { + input = (s32[], s32[]) parameter(0) + counter = s32[] get-tuple-element(input), index=0 + buffer = s32[] get-tuple-element(input), index=1 + + one = s32[] constant(1) + + updated_counter = s32[] add(counter, one) + + updated_buffer = s32[] add(buffer, counter) + ROOT out = (s32[], s32[]) tuple(updated_counter, updated_buffer) +} + +while_cond { + input = (s32[], s32[]) parameter(0) + counter = s32[] get-tuple-element(input), index=0 + bound = s32[] constant(100) + ROOT cmp = pred[] compare(counter, bound), direction=LT +} + +ENTRY entry { + zero = s32[] constant(0) + buffer = s32[] parameter(0) + while_input = (s32[], s32[]) tuple(zero, buffer) + ROOT out = (s32[], s32[]) while(while_input), condition=while_cond, body=while_body +} + + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies( + *module->entry_computation()->root_instruction()->while_body()), + 0); + EXPECT_EQ(CountControlEdges( + *module->entry_computation()->root_instruction()->while_body()), + 1); +} + +} // namespace +} // namespace xla