Add an optimization pass that clones Constant nodes to make larger clusters

I don't particularly love this approach since IMO it is papering over a problem
in mark_for_compilation_pass -- mark_for_compilation_pass should instead
rematerialize constants as necessary to create larger clusters.  But this is
what fits in best with scheme we have today.

PiperOrigin-RevId: 236729916
This commit is contained in:
Sanjoy Das 2019-03-04 14:51:15 -08:00 committed by TensorFlower Gardener
parent 020b382095
commit 0c9cb95315
10 changed files with 484 additions and 29 deletions

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -153,6 +152,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(device),
assigned_device_(other.impl()->assigned_device_),
xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@ -313,11 +314,10 @@ Status Scope::ToGraphDef(GraphDef* gdef) const {
return Status::OK();
}
Status Scope::ToGraph(Graph* g) const {
Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const {
if (ok()) {
GraphDef graph_def;
graph()->ToGraphDef(&graph_def);
GraphConstructorOptions opts;
UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
}
return *impl()->status_;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@ -205,14 +206,15 @@ class Scope {
// START_SKIP_DOXYGEN
/// If status() is Status::OK(), construct a Graph object using the default
/// If status() is Status::OK(), construct a Graph object using `opts` as the
/// GraphConstructorOptions, and return Status::OK if graph construction was
/// successful. Otherwise, return the error status.
// TODO(josh11b, keveman): Make this faster; right now it converts
// Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
// edges from the source and to the sink node, resolves back edges
// by name), and makes sure the resulting graph is valid.
Status ToGraph(Graph* g) const;
Status ToGraph(
Graph* g, GraphConstructorOptions opts = GraphConstructorOptions{}) const;
// Calls AddNode() using this scope's ShapeRefiner. This exists in the public
// API to prevent custom op wrappers from needing access to shape_refiner.h or

View File

@ -491,6 +491,7 @@ cc_library(
name = "compilation_passes",
srcs = [
"build_xla_ops_pass.cc",
"clone_constants_for_better_clustering.cc",
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
@ -503,6 +504,7 @@ cc_library(
],
hdrs = [
"build_xla_ops_pass.h",
"clone_constants_for_better_clustering.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"encapsulate_xla_computations_pass.h",
@ -542,6 +544,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@ -557,9 +560,11 @@ cc_library(
srcs = ["xla_cluster_util.cc"],
hdrs = ["xla_cluster_util.h"],
deps = [
":flags",
":resource_operation_safety_analysis",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core:graph",
@ -625,6 +630,7 @@ tf_cc_test(
size = "small",
srcs = [
"build_xla_ops_pass_test.cc",
"clone_constants_for_better_clustering_test.cc",
"compilation_passes_test_main.cc",
"encapsulate_subgraphs_pass_test.cc",
"encapsulate_xla_computations_pass_test.cc",

View File

@ -0,0 +1,187 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
using se::port::StatusOr;
string CloneConstantsForBetterClusteringPass::GenerateUniqueName(
const absl::flat_hash_set<string>& name_set, absl::string_view prefix) {
string candidate;
do {
candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++);
} while (name_set.contains(candidate));
return candidate;
}
StatusOr<Node*> CloneConstantsForBetterClusteringPass::CloneNode(
Graph* g, const absl::flat_hash_set<string>& name_set, Node* n) {
NodeDef new_in_def = n->def();
new_in_def.clear_input();
new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name()));
Status s;
Node* new_in = g->AddNode(new_in_def, &s);
TF_RETURN_IF_ERROR(s);
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
g->AddControlEdge(e->src(), new_in);
} else {
g->AddEdge(e->src(), e->src_output(), new_in, e->dst_input());
}
}
new_in->set_assigned_device_name(n->assigned_device_name());
return new_in;
}
namespace {
// We only clone host constants for now since we want to avoid increasing memory
// pressure on GPUs.
StatusOr<bool> IsSmallHostConstant(Node* n) {
if (!n->IsConstant()) {
return false;
}
DeviceNameUtils::ParsedName parsed;
TF_RET_CHECK(
DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed));
if (parsed.type != DEVICE_CPU) {
return false;
}
const TensorProto* proto = nullptr;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
// TODO(sanjoy): It may make sense to combine this threshold with XLA's "large
// constant" threshold, if there is one.
const int kSmallTensorThreshold = 16;
int64 total_elements = 1;
for (const auto& dim : proto->tensor_shape().dim()) {
if (dim.size() < 0) {
return errors::Internal("Unknown dimension size in constant tensor ",
n->name());
}
total_elements *= dim.size();
}
return total_elements < kSmallTensorThreshold;
}
bool IsInPlaceOp(absl::string_view op_name) {
return op_name == "InplaceUpdate" || op_name == "InplaceAdd" ||
op_name == "InplaceSub";
}
} // namespace
Status CloneConstantsForBetterClusteringPass::CloneSmallHostConstantInputs(
Graph* g, const absl::flat_hash_set<string>& name_set, Node* n) {
std::vector<const Edge*> in_edges;
absl::c_copy(n->in_edges(), std::back_inserter(in_edges));
for (const Edge* e : in_edges) {
Node* input = e->src();
TF_ASSIGN_OR_RETURN(bool is_small_host_constant,
IsSmallHostConstant(input));
if (is_small_host_constant && input->out_edges().size() != 1) {
VLOG(2) << "Cloning small host constant " << input->name();
TF_ASSIGN_OR_RETURN(Node* const input_cloned,
CloneNode(g, name_set, input));
if (e->IsControlEdge()) {
g->AddControlEdge(input_cloned, e->dst());
} else {
int dst_input = e->dst_input();
TF_RET_CHECK(e->src_output() == 0)
<< "expected constant to have exactly one non-control output, but "
"found output index = "
<< e->src_output();
g->RemoveEdge(e);
g->AddEdge(input_cloned, 0, n, dst_input);
}
}
}
return Status::OK();
}
Status CloneConstantsForBetterClusteringPass::Run(
const GraphOptimizationPassOptions& options) {
if (GetGlobalJitLevel(options) == OptimizerOptions::OFF) {
return Status::OK();
}
Graph* g = options.graph->get();
absl::flat_hash_set<string> name_set;
absl::c_transform(g->nodes(), std::inserter(name_set, name_set.begin()),
[](Node* n) { return n->name(); });
std::vector<Node*> nodes;
for (Node* n : g->nodes()) {
// We rely on the immutability of Tensors to safely clone Const operations.
// However, "in place" ops do not respect the immutability of Tensors so we
// avoid this transformation when such ops are present in the graph.
//
// In-place operations are problematic because they break the semantic
// illusion that tensorflow::Tensor instances are immutable. For instance
// if we have the following graph:
//
// digraph {
// SRC -> Const
// SRC -> I
// SRC -> V
// Const -> Identity
// Const -> InplaceAdd [label="x"]
// I -> InplaceAdd [label="i"]
// V -> InplaceAdd [label="v"]
// InplaceAdd -> Identity [style=dotted]
// }
//
// then the value produced by `Identity` is Const+I*V since InplaceAdd
// modifies the tensor in place. However, if we clone `Const` and turn the
// graph into:
//
// digraph {
// SRC -> "Const/clone_1"
// SRC -> "Const/clone_2"
// SRC -> I
// SRC -> V
// "Const/clone_1" -> Identity
// "Const/clone_2" -> InplaceAdd [label="x"]
// I -> InplaceAdd [label="i"]
// V -> InplaceAdd [label="v"]
// InplaceAdd -> Identity [style=dotted]
// }
//
// then `Identity` no longer produces Const+I*V because the InplaceAdd
// operation only modifies Const/clone_2 in place.
if (IsInPlaceOp(n->type_string())) {
return Status::OK();
}
nodes.push_back(n);
}
// Iterate over a copy of the nodes to avoid iterating over g->nodes() while
// creating more nodes.
for (Node* n : nodes) {
TF_RETURN_IF_ERROR(CloneSmallHostConstantInputs(g, name_set, n));
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_
#define TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
// Clones small host constants in the graph to make it easier to form larger
// clusters.
//
// This helps us in two ways:
//
// - It reduces dependencies between clusters. Let's say a constant C is used
// by nodes X and Y. If X and Y are put in different clusters (for whatever
// reason) Y's cluster now has to wait for all the operations in X's cluster
// to finish before it starts running.
//
// - It lets us create bigger clusters in multi-GPU benchmarks. Consider the
// following graph:
//
// digraph {
// Const -> GPU_1
// Const -> GPU_0_Y
// GPU_0_X -> GPU_0_Y
// }
//
// We'd cluster Const and GPU_1 together (and place it on GPU_1), and this
// will block us from clustering GPU_0_X and GPU_0_Y together since that
// would increase the amount of work on GPU 0 waiting on work on GPU 1.
// However, cloning Const into two copies, one for GPU_0_Y and one for GPU_1
// will let us create one cluster containing {Const/copy_0, GPU_1} and
// another containing {Const/copy_1, GPU_0_X, GPU_0_Y}.
//
// We only clone small host constants now to avoid increasing memory consumption
// too much. Moreover, in practice the constants we have to duplicate are
// things like the `perm` input to `Transpose` and the `size` input to `Slice`
// which tend to be small anyway.
class CloneConstantsForBetterClusteringPass : public GraphOptimizationPass {
public:
CloneConstantsForBetterClusteringPass() = default;
Status Run(const GraphOptimizationPassOptions& options) override;
private:
Status CloneSmallHostConstantInputs(
Graph* g, const absl::flat_hash_set<string>& name_set, Node* n);
string GenerateUniqueName(const absl::flat_hash_set<string>& name_set,
absl::string_view prefix);
se::port::StatusOr<Node*> CloneNode(
Graph* g, const absl::flat_hash_set<string>& name_set, Node* n);
int unique_name_counter_ = 0;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_

View File

@ -0,0 +1,176 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/compiler/jit/node_matchers.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace {
using ::tensorflow::testing::FindNodeByName;
Status CloneConstantsForBetterClustering(const Scope& s,
std::unique_ptr<Graph>* result) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
SessionOptions session_options;
session_options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_global_jit_level(OptimizerOptions::ON_2);
GraphOptimizationPassOptions options;
options.graph = &graph;
options.session_options = &session_options;
// Scope::ToGraph seems to drop assigned devices, probably because it goes
// through a GraphDef. So explicitly maintain the device assignment.
// std::unordered_map<string, string> assigned_device_names;
// for (Node* n : s.graph()->nodes()) {
// assigned_device_names[n->name()] = n->assigned_device_name();
// }
GraphConstructorOptions opts;
opts.expect_device_spec = true;
TF_RETURN_IF_ERROR(s.ToGraph(graph.get(), opts));
// for (Node* n : graph->nodes()) {
// n->set_assigned_device_name(assigned_device_names[n->name()]);
// }
CloneConstantsForBetterClusteringPass rewriter;
TF_RETURN_IF_ERROR(rewriter.Run(options));
*result = std::move(graph);
return Status::OK();
}
const char* kCPU = "/job:localhost/replica:0/task:0/device:CPU:0";
const char* kGPU = "/job:localhost/replica:0/task:0/device:GPU:0";
TEST(CloneConstantsForBetterClusteringTest, Basic) {
Scope root = Scope::NewRootScope().ExitOnError();
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU);
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
Output perm = ops::Const(on_cpu.WithOpName("perm"), {3, 1, 2, 0});
{
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm);
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm);
}
std::unique_ptr<Graph> result;
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
OutputTensor tr0_perm;
TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm));
OutputTensor tr1_perm;
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
EXPECT_NE(tr0_perm.node, tr1_perm.node);
}
TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) {
Scope root = Scope::NewRootScope().ExitOnError();
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
Output perm = ops::Const(on_gpu.WithOpName("perm"), {3, 1, 2, 0});
{
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm);
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm);
}
std::unique_ptr<Graph> result;
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
OutputTensor tr0_perm;
TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm));
OutputTensor tr1_perm;
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
EXPECT_EQ(tr0_perm.node, tr1_perm.node);
}
TEST(CloneConstantsForBetterClusteringTest, DontCloneLargeConstants) {
Scope root = Scope::NewRootScope().ExitOnError();
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU);
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
Output perm = ops::Const(
on_cpu.WithOpName("perm"),
{17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
{
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm);
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm);
}
std::unique_ptr<Graph> result;
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
OutputTensor tr0_perm;
TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm));
OutputTensor tr1_perm;
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
EXPECT_EQ(tr0_perm.node, tr1_perm.node);
}
TEST(CloneConstantsForBetterClusteringTest, InplaceOps) {
Scope root = Scope::NewRootScope().ExitOnError();
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU);
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
Output perm = ops::Const(on_cpu.WithOpName("perm"), {3, 1, 2, 0});
{
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm);
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm);
}
Output in_place_add =
ops::InplaceAdd(on_cpu.WithOpName("tr0"), perm,
ops::Placeholder(on_cpu.WithOpName("i"), DT_INT32), perm);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
OutputTensor tr0_perm;
TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm));
OutputTensor tr1_perm;
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
EXPECT_EQ(tr0_perm.node, tr1_perm.node);
}
} // namespace
} // namespace tensorflow

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
@ -41,6 +42,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA:
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 5,
CloneConstantsForBetterClusteringPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);

View File

@ -632,30 +632,6 @@ Status FindCompilationCandidates(
return Status::OK();
}
// Determine the global jit level which is ON if either the
// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag
// is true.
OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
const GraphOptimizationPassOptions& options) {
OptimizerOptions::GlobalJitLevel global_jit_level =
options.session_options->config.graph_options()
.optimizer_options()
.global_jit_level();
if (global_jit_level == OptimizerOptions::DEFAULT) {
// To set compilation to be on by default, change the following line.
global_jit_level = OptimizerOptions::OFF;
}
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
if (flags->tf_xla_auto_jit == -1 ||
(1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
// If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
// the setting in ConfigProto.
global_jit_level =
static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
}
return global_jit_level;
}
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.

View File

@ -23,11 +23,13 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@ -344,4 +346,24 @@ Status CanPickDeviceForXla(absl::Span<const string> device_names,
/*out_device_picked=*/nullptr);
}
OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
const GraphOptimizationPassOptions& options) {
OptimizerOptions::GlobalJitLevel global_jit_level =
options.session_options->config.graph_options()
.optimizer_options()
.global_jit_level();
if (global_jit_level == OptimizerOptions::DEFAULT) {
// To set compilation to be on by default, change the following line.
global_jit_level = OptimizerOptions::OFF;
}
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
if (flags->tf_xla_auto_jit != OptimizerOptions::DEFAULT) {
// If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides
// the setting in ConfigProto.
global_jit_level =
static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
}
return global_jit_level;
}
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/algorithm.h"
namespace tensorflow {
@ -118,6 +119,13 @@ Status PickDeviceForXla(absl::Span<const string> device_names,
Status CanPickDeviceForXla(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device);
// Determine the global jit level which is ON if either the
// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag
// is true.
OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
const GraphOptimizationPassOptions& options);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_