From 0c9cb95315c498eefde79ecfd33de68e1104ab23 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 4 Mar 2019 14:51:15 -0800 Subject: [PATCH] Add an optimization pass that clones Constant nodes to make larger clusters I don't particularly love this approach since IMO it is papering over a problem in mark_for_compilation_pass -- mark_for_compilation_pass should instead rematerialize constants as necessary to create larger clusters. But this is what fits in best with scheme we have today. PiperOrigin-RevId: 236729916 --- tensorflow/cc/framework/scope.cc | 6 +- tensorflow/cc/framework/scope.h | 6 +- tensorflow/compiler/jit/BUILD | 6 + .../clone_constants_for_better_clustering.cc | 187 ++++++++++++++++++ .../clone_constants_for_better_clustering.h | 74 +++++++ ...ne_constants_for_better_clustering_test.cc | 176 +++++++++++++++++ .../jit/jit_compilation_pass_registration.cc | 4 + .../compiler/jit/mark_for_compilation_pass.cc | 24 --- tensorflow/compiler/jit/xla_cluster_util.cc | 22 +++ tensorflow/compiler/jit/xla_cluster_util.h | 8 + 10 files changed, 484 insertions(+), 29 deletions(-) create mode 100644 tensorflow/compiler/jit/clone_constants_for_better_clustering.cc create mode 100644 tensorflow/compiler/jit/clone_constants_for_better_clustering.h create mode 100644 tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 81785b2d89b..134d64af140 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -153,6 +152,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device) exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(device), + assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -313,11 +314,10 @@ Status Scope::ToGraphDef(GraphDef* gdef) const { return Status::OK(); } -Status Scope::ToGraph(Graph* g) const { +Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const { if (ok()) { GraphDef graph_def; graph()->ToGraphDef(&graph_def); - GraphConstructorOptions opts; UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g)); } return *impl()->status_; diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 0a75f23725c..1e17b74bc8f 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -205,14 +206,15 @@ class Scope { // START_SKIP_DOXYGEN - /// If status() is Status::OK(), construct a Graph object using the default + /// If status() is Status::OK(), construct a Graph object using `opts` as the /// GraphConstructorOptions, and return Status::OK if graph construction was /// successful. Otherwise, return the error status. // TODO(josh11b, keveman): Make this faster; right now it converts // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds // edges from the source and to the sink node, resolves back edges // by name), and makes sure the resulting graph is valid. - Status ToGraph(Graph* g) const; + Status ToGraph( + Graph* g, GraphConstructorOptions opts = GraphConstructorOptions{}) const; // Calls AddNode() using this scope's ShapeRefiner. This exists in the public // API to prevent custom op wrappers from needing access to shape_refiner.h or diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index dc840ef6305..b846ad789e5 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -491,6 +491,7 @@ cc_library( name = "compilation_passes", srcs = [ "build_xla_ops_pass.cc", + "clone_constants_for_better_clustering.cc", "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", @@ -503,6 +504,7 @@ cc_library( ], hdrs = [ "build_xla_ops_pass.h", + "clone_constants_for_better_clustering.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "encapsulate_xla_computations_pass.h", @@ -542,6 +544,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -557,9 +560,11 @@ cc_library( srcs = ["xla_cluster_util.cc"], hdrs = ["xla_cluster_util.h"], deps = [ + ":flags", ":resource_operation_safety_analysis", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_bounds_check", "//tensorflow/core:graph", @@ -625,6 +630,7 @@ tf_cc_test( size = "small", srcs = [ "build_xla_ops_pass_test.cc", + "clone_constants_for_better_clustering_test.cc", "compilation_passes_test_main.cc", "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc new file mode 100644 index 00000000000..848a6362a4a --- /dev/null +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc @@ -0,0 +1,187 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h" + +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { + +using se::port::StatusOr; + +string CloneConstantsForBetterClusteringPass::GenerateUniqueName( + const absl::flat_hash_set& name_set, absl::string_view prefix) { + string candidate; + do { + candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++); + } while (name_set.contains(candidate)); + return candidate; +} + +StatusOr CloneConstantsForBetterClusteringPass::CloneNode( + Graph* g, const absl::flat_hash_set& name_set, Node* n) { + NodeDef new_in_def = n->def(); + new_in_def.clear_input(); + new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name())); + Status s; + Node* new_in = g->AddNode(new_in_def, &s); + TF_RETURN_IF_ERROR(s); + + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), new_in); + } else { + g->AddEdge(e->src(), e->src_output(), new_in, e->dst_input()); + } + } + + new_in->set_assigned_device_name(n->assigned_device_name()); + return new_in; +} + +namespace { +// We only clone host constants for now since we want to avoid increasing memory +// pressure on GPUs. +StatusOr IsSmallHostConstant(Node* n) { + if (!n->IsConstant()) { + return false; + } + + DeviceNameUtils::ParsedName parsed; + TF_RET_CHECK( + DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed)); + if (parsed.type != DEVICE_CPU) { + return false; + } + + const TensorProto* proto = nullptr; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); + + // TODO(sanjoy): It may make sense to combine this threshold with XLA's "large + // constant" threshold, if there is one. + const int kSmallTensorThreshold = 16; + int64 total_elements = 1; + for (const auto& dim : proto->tensor_shape().dim()) { + if (dim.size() < 0) { + return errors::Internal("Unknown dimension size in constant tensor ", + n->name()); + } + total_elements *= dim.size(); + } + return total_elements < kSmallTensorThreshold; +} + +bool IsInPlaceOp(absl::string_view op_name) { + return op_name == "InplaceUpdate" || op_name == "InplaceAdd" || + op_name == "InplaceSub"; +} +} // namespace + +Status CloneConstantsForBetterClusteringPass::CloneSmallHostConstantInputs( + Graph* g, const absl::flat_hash_set& name_set, Node* n) { + std::vector in_edges; + absl::c_copy(n->in_edges(), std::back_inserter(in_edges)); + for (const Edge* e : in_edges) { + Node* input = e->src(); + TF_ASSIGN_OR_RETURN(bool is_small_host_constant, + IsSmallHostConstant(input)); + if (is_small_host_constant && input->out_edges().size() != 1) { + VLOG(2) << "Cloning small host constant " << input->name(); + TF_ASSIGN_OR_RETURN(Node* const input_cloned, + CloneNode(g, name_set, input)); + if (e->IsControlEdge()) { + g->AddControlEdge(input_cloned, e->dst()); + } else { + int dst_input = e->dst_input(); + TF_RET_CHECK(e->src_output() == 0) + << "expected constant to have exactly one non-control output, but " + "found output index = " + << e->src_output(); + g->RemoveEdge(e); + g->AddEdge(input_cloned, 0, n, dst_input); + } + } + } + return Status::OK(); +} + +Status CloneConstantsForBetterClusteringPass::Run( + const GraphOptimizationPassOptions& options) { + if (GetGlobalJitLevel(options) == OptimizerOptions::OFF) { + return Status::OK(); + } + + Graph* g = options.graph->get(); + absl::flat_hash_set name_set; + absl::c_transform(g->nodes(), std::inserter(name_set, name_set.begin()), + [](Node* n) { return n->name(); }); + std::vector nodes; + for (Node* n : g->nodes()) { + // We rely on the immutability of Tensors to safely clone Const operations. + // However, "in place" ops do not respect the immutability of Tensors so we + // avoid this transformation when such ops are present in the graph. + // + // In-place operations are problematic because they break the semantic + // illusion that tensorflow::Tensor instances are immutable. For instance + // if we have the following graph: + // + // digraph { + // SRC -> Const + // SRC -> I + // SRC -> V + // Const -> Identity + // Const -> InplaceAdd [label="x"] + // I -> InplaceAdd [label="i"] + // V -> InplaceAdd [label="v"] + // InplaceAdd -> Identity [style=dotted] + // } + // + // then the value produced by `Identity` is Const+I*V since InplaceAdd + // modifies the tensor in place. However, if we clone `Const` and turn the + // graph into: + // + // digraph { + // SRC -> "Const/clone_1" + // SRC -> "Const/clone_2" + // SRC -> I + // SRC -> V + // "Const/clone_1" -> Identity + // "Const/clone_2" -> InplaceAdd [label="x"] + // I -> InplaceAdd [label="i"] + // V -> InplaceAdd [label="v"] + // InplaceAdd -> Identity [style=dotted] + // } + // + // then `Identity` no longer produces Const+I*V because the InplaceAdd + // operation only modifies Const/clone_2 in place. + + if (IsInPlaceOp(n->type_string())) { + return Status::OK(); + } + nodes.push_back(n); + } + + // Iterate over a copy of the nodes to avoid iterating over g->nodes() while + // creating more nodes. + for (Node* n : nodes) { + TF_RETURN_IF_ERROR(CloneSmallHostConstantInputs(g, name_set, n)); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.h b/tensorflow/compiler/jit/clone_constants_for_better_clustering.h new file mode 100644 index 00000000000..f67da75b34f --- /dev/null +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.h @@ -0,0 +1,74 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ +#define TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +// Clones small host constants in the graph to make it easier to form larger +// clusters. +// +// This helps us in two ways: +// +// - It reduces dependencies between clusters. Let's say a constant C is used +// by nodes X and Y. If X and Y are put in different clusters (for whatever +// reason) Y's cluster now has to wait for all the operations in X's cluster +// to finish before it starts running. +// +// - It lets us create bigger clusters in multi-GPU benchmarks. Consider the +// following graph: +// +// digraph { +// Const -> GPU_1 +// Const -> GPU_0_Y +// GPU_0_X -> GPU_0_Y +// } +// +// We'd cluster Const and GPU_1 together (and place it on GPU_1), and this +// will block us from clustering GPU_0_X and GPU_0_Y together since that +// would increase the amount of work on GPU 0 waiting on work on GPU 1. +// However, cloning Const into two copies, one for GPU_0_Y and one for GPU_1 +// will let us create one cluster containing {Const/copy_0, GPU_1} and +// another containing {Const/copy_1, GPU_0_X, GPU_0_Y}. +// +// We only clone small host constants now to avoid increasing memory consumption +// too much. Moreover, in practice the constants we have to duplicate are +// things like the `perm` input to `Transpose` and the `size` input to `Slice` +// which tend to be small anyway. + +class CloneConstantsForBetterClusteringPass : public GraphOptimizationPass { + public: + CloneConstantsForBetterClusteringPass() = default; + + Status Run(const GraphOptimizationPassOptions& options) override; + + private: + Status CloneSmallHostConstantInputs( + Graph* g, const absl::flat_hash_set& name_set, Node* n); + string GenerateUniqueName(const absl::flat_hash_set& name_set, + absl::string_view prefix); + se::port::StatusOr CloneNode( + Graph* g, const absl::flat_hash_set& name_set, Node* n); + + int unique_name_counter_ = 0; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc new file mode 100644 index 00000000000..31543d1c3f8 --- /dev/null +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc @@ -0,0 +1,176 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { +using ::tensorflow::testing::FindNodeByName; + +Status CloneConstantsForBetterClustering(const Scope& s, + std::unique_ptr* result) { + auto graph = absl::make_unique(OpRegistry::Global()); + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(OptimizerOptions::ON_2); + GraphOptimizationPassOptions options; + options.graph = &graph; + options.session_options = &session_options; + + // Scope::ToGraph seems to drop assigned devices, probably because it goes + // through a GraphDef. So explicitly maintain the device assignment. + // std::unordered_map assigned_device_names; + // for (Node* n : s.graph()->nodes()) { + // assigned_device_names[n->name()] = n->assigned_device_name(); + // } + GraphConstructorOptions opts; + opts.expect_device_spec = true; + TF_RETURN_IF_ERROR(s.ToGraph(graph.get(), opts)); + // for (Node* n : graph->nodes()) { + // n->set_assigned_device_name(assigned_device_names[n->name()]); + // } + + CloneConstantsForBetterClusteringPass rewriter; + TF_RETURN_IF_ERROR(rewriter.Run(options)); + *result = std::move(graph); + return Status::OK(); +} + +const char* kCPU = "/job:localhost/replica:0/task:0/device:CPU:0"; +const char* kGPU = "/job:localhost/replica:0/task:0/device:GPU:0"; + +TEST(CloneConstantsForBetterClusteringTest, Basic) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU); + + Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT); + Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT); + + Output perm = ops::Const(on_cpu.WithOpName("perm"), {3, 1, 2, 0}); + + { + Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm); + Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm); + } + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor tr0_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm)); + + OutputTensor tr1_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm)); + + EXPECT_NE(tr0_perm.node, tr1_perm.node); +} + +TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + + Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT); + Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT); + + Output perm = ops::Const(on_gpu.WithOpName("perm"), {3, 1, 2, 0}); + + { + Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm); + Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm); + } + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor tr0_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm)); + + OutputTensor tr1_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm)); + + EXPECT_EQ(tr0_perm.node, tr1_perm.node); +} + +TEST(CloneConstantsForBetterClusteringTest, DontCloneLargeConstants) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU); + + Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT); + Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT); + + Output perm = ops::Const( + on_cpu.WithOpName("perm"), + {17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + + { + Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm); + Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm); + } + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor tr0_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm)); + + OutputTensor tr1_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm)); + + EXPECT_EQ(tr0_perm.node, tr1_perm.node); +} + +TEST(CloneConstantsForBetterClusteringTest, InplaceOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU); + + Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT); + Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT); + + Output perm = ops::Const(on_cpu.WithOpName("perm"), {3, 1, 2, 0}); + + { + Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm); + Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm); + } + + Output in_place_add = + ops::InplaceAdd(on_cpu.WithOpName("tr0"), perm, + ops::Placeholder(on_cpu.WithOpName("i"), DT_INT32), perm); + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor tr0_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm)); + + OutputTensor tr1_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm)); + + EXPECT_EQ(tr0_perm.node, tr1_perm.node); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index f79bdc1e2e8..7326b6c222b 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/build_xla_ops_pass.h" +#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" @@ -41,6 +42,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, // POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA: +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 5, + CloneConstantsForBetterClusteringPass); + REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index fed8af3465f..11a710b2a4e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -632,30 +632,6 @@ Status FindCompilationCandidates( return Status::OK(); } -// Determine the global jit level which is ON if either the -// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag -// is true. -OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( - const GraphOptimizationPassOptions& options) { - OptimizerOptions::GlobalJitLevel global_jit_level = - options.session_options->config.graph_options() - .optimizer_options() - .global_jit_level(); - if (global_jit_level == OptimizerOptions::DEFAULT) { - // To set compilation to be on by default, change the following line. - global_jit_level = OptimizerOptions::OFF; - } - MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - if (flags->tf_xla_auto_jit == -1 || - (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { - // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides - // the setting in ConfigProto. - global_jit_level = - static_cast(flags->tf_xla_auto_jit); - } - return global_jit_level; -} - struct Cluster { // Identifies the node that represents this cluster in the cycle detection // graph. diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 797672c8483..eaa7015768c 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -23,11 +23,13 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -344,4 +346,24 @@ Status CanPickDeviceForXla(absl::Span device_names, /*out_device_picked=*/nullptr); } +OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( + const GraphOptimizationPassOptions& options) { + OptimizerOptions::GlobalJitLevel global_jit_level = + options.session_options->config.graph_options() + .optimizer_options() + .global_jit_level(); + if (global_jit_level == OptimizerOptions::DEFAULT) { + // To set compilation to be on by default, change the following line. + global_jit_level = OptimizerOptions::OFF; + } + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_auto_jit != OptimizerOptions::DEFAULT) { + // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides + // the setting in ConfigProto. + global_jit_level = + static_cast(flags->tf_xla_auto_jit); + } + return global_jit_level; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 806fd939bde..ddca0aaeabb 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" namespace tensorflow { @@ -118,6 +119,13 @@ Status PickDeviceForXla(absl::Span device_names, Status CanPickDeviceForXla(absl::Span device_names, bool allow_mixing_unknown_and_cpu, bool* out_can_pick_device); + +// Determine the global jit level which is ON if either the +// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag +// is true. +OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( + const GraphOptimizationPassOptions& options); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_