Add an optimization pass that clones Constant nodes to make larger clusters
I don't particularly love this approach since IMO it is papering over a problem in mark_for_compilation_pass -- mark_for_compilation_pass should instead rematerialize constants as necessary to create larger clusters. But this is what fits in best with scheme we have today. PiperOrigin-RevId: 236729916
This commit is contained in:
parent
020b382095
commit
0c9cb95315
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
@ -153,6 +152,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
|
||||
exit_on_error_(other.impl()->exit_on_error_),
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(device),
|
||||
assigned_device_(other.impl()->assigned_device_),
|
||||
xla_cluster_(other.impl()->xla_cluster_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
@ -313,11 +314,10 @@ Status Scope::ToGraphDef(GraphDef* gdef) const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Scope::ToGraph(Graph* g) const {
|
||||
Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const {
|
||||
if (ok()) {
|
||||
GraphDef graph_def;
|
||||
graph()->ToGraphDef(&graph_def);
|
||||
GraphConstructorOptions opts;
|
||||
UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
|
||||
}
|
||||
return *impl()->status_;
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
@ -205,14 +206,15 @@ class Scope {
|
||||
|
||||
// START_SKIP_DOXYGEN
|
||||
|
||||
/// If status() is Status::OK(), construct a Graph object using the default
|
||||
/// If status() is Status::OK(), construct a Graph object using `opts` as the
|
||||
/// GraphConstructorOptions, and return Status::OK if graph construction was
|
||||
/// successful. Otherwise, return the error status.
|
||||
// TODO(josh11b, keveman): Make this faster; right now it converts
|
||||
// Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
|
||||
// edges from the source and to the sink node, resolves back edges
|
||||
// by name), and makes sure the resulting graph is valid.
|
||||
Status ToGraph(Graph* g) const;
|
||||
Status ToGraph(
|
||||
Graph* g, GraphConstructorOptions opts = GraphConstructorOptions{}) const;
|
||||
|
||||
// Calls AddNode() using this scope's ShapeRefiner. This exists in the public
|
||||
// API to prevent custom op wrappers from needing access to shape_refiner.h or
|
||||
|
@ -491,6 +491,7 @@ cc_library(
|
||||
name = "compilation_passes",
|
||||
srcs = [
|
||||
"build_xla_ops_pass.cc",
|
||||
"clone_constants_for_better_clustering.cc",
|
||||
"deadness_analysis.cc",
|
||||
"deadness_analysis_internal.h",
|
||||
"encapsulate_subgraphs_pass.cc",
|
||||
@ -503,6 +504,7 @@ cc_library(
|
||||
],
|
||||
hdrs = [
|
||||
"build_xla_ops_pass.h",
|
||||
"clone_constants_for_better_clustering.h",
|
||||
"deadness_analysis.h",
|
||||
"encapsulate_subgraphs_pass.h",
|
||||
"encapsulate_xla_computations_pass.h",
|
||||
@ -542,6 +544,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
@ -557,9 +560,11 @@ cc_library(
|
||||
srcs = ["xla_cluster_util.cc"],
|
||||
hdrs = ["xla_cluster_util.h"],
|
||||
deps = [
|
||||
":flags",
|
||||
":resource_operation_safety_analysis",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_bounds_check",
|
||||
"//tensorflow/core:graph",
|
||||
@ -625,6 +630,7 @@ tf_cc_test(
|
||||
size = "small",
|
||||
srcs = [
|
||||
"build_xla_ops_pass_test.cc",
|
||||
"clone_constants_for_better_clustering_test.cc",
|
||||
"compilation_passes_test_main.cc",
|
||||
"encapsulate_subgraphs_pass_test.cc",
|
||||
"encapsulate_xla_computations_pass_test.cc",
|
||||
|
187
tensorflow/compiler/jit/clone_constants_for_better_clustering.cc
Normal file
187
tensorflow/compiler/jit/clone_constants_for_better_clustering.cc
Normal file
@ -0,0 +1,187 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using se::port::StatusOr;
|
||||
|
||||
string CloneConstantsForBetterClusteringPass::GenerateUniqueName(
|
||||
const absl::flat_hash_set<string>& name_set, absl::string_view prefix) {
|
||||
string candidate;
|
||||
do {
|
||||
candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++);
|
||||
} while (name_set.contains(candidate));
|
||||
return candidate;
|
||||
}
|
||||
|
||||
StatusOr<Node*> CloneConstantsForBetterClusteringPass::CloneNode(
|
||||
Graph* g, const absl::flat_hash_set<string>& name_set, Node* n) {
|
||||
NodeDef new_in_def = n->def();
|
||||
new_in_def.clear_input();
|
||||
new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name()));
|
||||
Status s;
|
||||
Node* new_in = g->AddNode(new_in_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(e->src(), new_in);
|
||||
} else {
|
||||
g->AddEdge(e->src(), e->src_output(), new_in, e->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
new_in->set_assigned_device_name(n->assigned_device_name());
|
||||
return new_in;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// We only clone host constants for now since we want to avoid increasing memory
|
||||
// pressure on GPUs.
|
||||
StatusOr<bool> IsSmallHostConstant(Node* n) {
|
||||
if (!n->IsConstant()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
TF_RET_CHECK(
|
||||
DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed));
|
||||
if (parsed.type != DEVICE_CPU) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const TensorProto* proto = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
|
||||
|
||||
// TODO(sanjoy): It may make sense to combine this threshold with XLA's "large
|
||||
// constant" threshold, if there is one.
|
||||
const int kSmallTensorThreshold = 16;
|
||||
int64 total_elements = 1;
|
||||
for (const auto& dim : proto->tensor_shape().dim()) {
|
||||
if (dim.size() < 0) {
|
||||
return errors::Internal("Unknown dimension size in constant tensor ",
|
||||
n->name());
|
||||
}
|
||||
total_elements *= dim.size();
|
||||
}
|
||||
return total_elements < kSmallTensorThreshold;
|
||||
}
|
||||
|
||||
bool IsInPlaceOp(absl::string_view op_name) {
|
||||
return op_name == "InplaceUpdate" || op_name == "InplaceAdd" ||
|
||||
op_name == "InplaceSub";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status CloneConstantsForBetterClusteringPass::CloneSmallHostConstantInputs(
|
||||
Graph* g, const absl::flat_hash_set<string>& name_set, Node* n) {
|
||||
std::vector<const Edge*> in_edges;
|
||||
absl::c_copy(n->in_edges(), std::back_inserter(in_edges));
|
||||
for (const Edge* e : in_edges) {
|
||||
Node* input = e->src();
|
||||
TF_ASSIGN_OR_RETURN(bool is_small_host_constant,
|
||||
IsSmallHostConstant(input));
|
||||
if (is_small_host_constant && input->out_edges().size() != 1) {
|
||||
VLOG(2) << "Cloning small host constant " << input->name();
|
||||
TF_ASSIGN_OR_RETURN(Node* const input_cloned,
|
||||
CloneNode(g, name_set, input));
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(input_cloned, e->dst());
|
||||
} else {
|
||||
int dst_input = e->dst_input();
|
||||
TF_RET_CHECK(e->src_output() == 0)
|
||||
<< "expected constant to have exactly one non-control output, but "
|
||||
"found output index = "
|
||||
<< e->src_output();
|
||||
g->RemoveEdge(e);
|
||||
g->AddEdge(input_cloned, 0, n, dst_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CloneConstantsForBetterClusteringPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
if (GetGlobalJitLevel(options) == OptimizerOptions::OFF) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Graph* g = options.graph->get();
|
||||
absl::flat_hash_set<string> name_set;
|
||||
absl::c_transform(g->nodes(), std::inserter(name_set, name_set.begin()),
|
||||
[](Node* n) { return n->name(); });
|
||||
std::vector<Node*> nodes;
|
||||
for (Node* n : g->nodes()) {
|
||||
// We rely on the immutability of Tensors to safely clone Const operations.
|
||||
// However, "in place" ops do not respect the immutability of Tensors so we
|
||||
// avoid this transformation when such ops are present in the graph.
|
||||
//
|
||||
// In-place operations are problematic because they break the semantic
|
||||
// illusion that tensorflow::Tensor instances are immutable. For instance
|
||||
// if we have the following graph:
|
||||
//
|
||||
// digraph {
|
||||
// SRC -> Const
|
||||
// SRC -> I
|
||||
// SRC -> V
|
||||
// Const -> Identity
|
||||
// Const -> InplaceAdd [label="x"]
|
||||
// I -> InplaceAdd [label="i"]
|
||||
// V -> InplaceAdd [label="v"]
|
||||
// InplaceAdd -> Identity [style=dotted]
|
||||
// }
|
||||
//
|
||||
// then the value produced by `Identity` is Const+I*V since InplaceAdd
|
||||
// modifies the tensor in place. However, if we clone `Const` and turn the
|
||||
// graph into:
|
||||
//
|
||||
// digraph {
|
||||
// SRC -> "Const/clone_1"
|
||||
// SRC -> "Const/clone_2"
|
||||
// SRC -> I
|
||||
// SRC -> V
|
||||
// "Const/clone_1" -> Identity
|
||||
// "Const/clone_2" -> InplaceAdd [label="x"]
|
||||
// I -> InplaceAdd [label="i"]
|
||||
// V -> InplaceAdd [label="v"]
|
||||
// InplaceAdd -> Identity [style=dotted]
|
||||
// }
|
||||
//
|
||||
// then `Identity` no longer produces Const+I*V because the InplaceAdd
|
||||
// operation only modifies Const/clone_2 in place.
|
||||
|
||||
if (IsInPlaceOp(n->type_string())) {
|
||||
return Status::OK();
|
||||
}
|
||||
nodes.push_back(n);
|
||||
}
|
||||
|
||||
// Iterate over a copy of the nodes to avoid iterating over g->nodes() while
|
||||
// creating more nodes.
|
||||
for (Node* n : nodes) {
|
||||
TF_RETURN_IF_ERROR(CloneSmallHostConstantInputs(g, name_set, n));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,74 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Clones small host constants in the graph to make it easier to form larger
|
||||
// clusters.
|
||||
//
|
||||
// This helps us in two ways:
|
||||
//
|
||||
// - It reduces dependencies between clusters. Let's say a constant C is used
|
||||
// by nodes X and Y. If X and Y are put in different clusters (for whatever
|
||||
// reason) Y's cluster now has to wait for all the operations in X's cluster
|
||||
// to finish before it starts running.
|
||||
//
|
||||
// - It lets us create bigger clusters in multi-GPU benchmarks. Consider the
|
||||
// following graph:
|
||||
//
|
||||
// digraph {
|
||||
// Const -> GPU_1
|
||||
// Const -> GPU_0_Y
|
||||
// GPU_0_X -> GPU_0_Y
|
||||
// }
|
||||
//
|
||||
// We'd cluster Const and GPU_1 together (and place it on GPU_1), and this
|
||||
// will block us from clustering GPU_0_X and GPU_0_Y together since that
|
||||
// would increase the amount of work on GPU 0 waiting on work on GPU 1.
|
||||
// However, cloning Const into two copies, one for GPU_0_Y and one for GPU_1
|
||||
// will let us create one cluster containing {Const/copy_0, GPU_1} and
|
||||
// another containing {Const/copy_1, GPU_0_X, GPU_0_Y}.
|
||||
//
|
||||
// We only clone small host constants now to avoid increasing memory consumption
|
||||
// too much. Moreover, in practice the constants we have to duplicate are
|
||||
// things like the `perm` input to `Transpose` and the `size` input to `Slice`
|
||||
// which tend to be small anyway.
|
||||
|
||||
class CloneConstantsForBetterClusteringPass : public GraphOptimizationPass {
|
||||
public:
|
||||
CloneConstantsForBetterClusteringPass() = default;
|
||||
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
private:
|
||||
Status CloneSmallHostConstantInputs(
|
||||
Graph* g, const absl::flat_hash_set<string>& name_set, Node* n);
|
||||
string GenerateUniqueName(const absl::flat_hash_set<string>& name_set,
|
||||
absl::string_view prefix);
|
||||
se::port::StatusOr<Node*> CloneNode(
|
||||
Graph* g, const absl::flat_hash_set<string>& name_set, Node* n);
|
||||
|
||||
int unique_name_counter_ = 0;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_
|
@ -0,0 +1,176 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/compiler/jit/node_matchers.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
using ::tensorflow::testing::FindNodeByName;
|
||||
|
||||
Status CloneConstantsForBetterClustering(const Scope& s,
|
||||
std::unique_ptr<Graph>* result) {
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
SessionOptions session_options;
|
||||
session_options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_global_jit_level(OptimizerOptions::ON_2);
|
||||
GraphOptimizationPassOptions options;
|
||||
options.graph = &graph;
|
||||
options.session_options = &session_options;
|
||||
|
||||
// Scope::ToGraph seems to drop assigned devices, probably because it goes
|
||||
// through a GraphDef. So explicitly maintain the device assignment.
|
||||
// std::unordered_map<string, string> assigned_device_names;
|
||||
// for (Node* n : s.graph()->nodes()) {
|
||||
// assigned_device_names[n->name()] = n->assigned_device_name();
|
||||
// }
|
||||
GraphConstructorOptions opts;
|
||||
opts.expect_device_spec = true;
|
||||
TF_RETURN_IF_ERROR(s.ToGraph(graph.get(), opts));
|
||||
// for (Node* n : graph->nodes()) {
|
||||
// n->set_assigned_device_name(assigned_device_names[n->name()]);
|
||||
// }
|
||||
|
||||
CloneConstantsForBetterClusteringPass rewriter;
|
||||
TF_RETURN_IF_ERROR(rewriter.Run(options));
|
||||
*result = std::move(graph);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const char* kCPU = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char* kGPU = "/job:localhost/replica:0/task:0/device:GPU:0";
|
||||
|
||||
TEST(CloneConstantsForBetterClusteringTest, Basic) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
|
||||
Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU);
|
||||
|
||||
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
|
||||
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
|
||||
|
||||
Output perm = ops::Const(on_cpu.WithOpName("perm"), {3, 1, 2, 0});
|
||||
|
||||
{
|
||||
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm);
|
||||
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm);
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
|
||||
|
||||
OutputTensor tr0_perm;
|
||||
TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm));
|
||||
|
||||
OutputTensor tr1_perm;
|
||||
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
|
||||
|
||||
EXPECT_NE(tr0_perm.node, tr1_perm.node);
|
||||
}
|
||||
|
||||
TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
|
||||
|
||||
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
|
||||
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
|
||||
|
||||
Output perm = ops::Const(on_gpu.WithOpName("perm"), {3, 1, 2, 0});
|
||||
|
||||
{
|
||||
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm);
|
||||
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm);
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
|
||||
|
||||
OutputTensor tr0_perm;
|
||||
TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm));
|
||||
|
||||
OutputTensor tr1_perm;
|
||||
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
|
||||
|
||||
EXPECT_EQ(tr0_perm.node, tr1_perm.node);
|
||||
}
|
||||
|
||||
TEST(CloneConstantsForBetterClusteringTest, DontCloneLargeConstants) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
|
||||
Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU);
|
||||
|
||||
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
|
||||
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
|
||||
|
||||
Output perm = ops::Const(
|
||||
on_cpu.WithOpName("perm"),
|
||||
{17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
||||
|
||||
{
|
||||
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm);
|
||||
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm);
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
|
||||
|
||||
OutputTensor tr0_perm;
|
||||
TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm));
|
||||
|
||||
OutputTensor tr1_perm;
|
||||
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
|
||||
|
||||
EXPECT_EQ(tr0_perm.node, tr1_perm.node);
|
||||
}
|
||||
|
||||
TEST(CloneConstantsForBetterClusteringTest, InplaceOps) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU);
|
||||
Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU);
|
||||
|
||||
Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT);
|
||||
Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT);
|
||||
|
||||
Output perm = ops::Const(on_cpu.WithOpName("perm"), {3, 1, 2, 0});
|
||||
|
||||
{
|
||||
Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm);
|
||||
Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm);
|
||||
}
|
||||
|
||||
Output in_place_add =
|
||||
ops::InplaceAdd(on_cpu.WithOpName("tr0"), perm,
|
||||
ops::Placeholder(on_cpu.WithOpName("i"), DT_INT32), perm);
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result));
|
||||
|
||||
OutputTensor tr0_perm;
|
||||
TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm));
|
||||
|
||||
OutputTensor tr1_perm;
|
||||
TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm));
|
||||
|
||||
EXPECT_EQ(tr0_perm.node, tr1_perm.node);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
|
||||
#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
|
||||
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
|
||||
@ -41,6 +42,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
|
||||
|
||||
// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA:
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 5,
|
||||
CloneConstantsForBetterClusteringPass);
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
|
||||
MarkForCompilationPass);
|
||||
|
||||
|
@ -632,30 +632,6 @@ Status FindCompilationCandidates(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Determine the global jit level which is ON if either the
|
||||
// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag
|
||||
// is true.
|
||||
OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
OptimizerOptions::GlobalJitLevel global_jit_level =
|
||||
options.session_options->config.graph_options()
|
||||
.optimizer_options()
|
||||
.global_jit_level();
|
||||
if (global_jit_level == OptimizerOptions::DEFAULT) {
|
||||
// To set compilation to be on by default, change the following line.
|
||||
global_jit_level = OptimizerOptions::OFF;
|
||||
}
|
||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||
if (flags->tf_xla_auto_jit == -1 ||
|
||||
(1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
|
||||
// If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
|
||||
// the setting in ConfigProto.
|
||||
global_jit_level =
|
||||
static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
|
||||
}
|
||||
return global_jit_level;
|
||||
}
|
||||
|
||||
struct Cluster {
|
||||
// Identifies the node that represents this cluster in the cycle detection
|
||||
// graph.
|
||||
|
@ -23,11 +23,13 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -344,4 +346,24 @@ Status CanPickDeviceForXla(absl::Span<const string> device_names,
|
||||
/*out_device_picked=*/nullptr);
|
||||
}
|
||||
|
||||
OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
OptimizerOptions::GlobalJitLevel global_jit_level =
|
||||
options.session_options->config.graph_options()
|
||||
.optimizer_options()
|
||||
.global_jit_level();
|
||||
if (global_jit_level == OptimizerOptions::DEFAULT) {
|
||||
// To set compilation to be on by default, change the following line.
|
||||
global_jit_level = OptimizerOptions::OFF;
|
||||
}
|
||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||
if (flags->tf_xla_auto_jit != OptimizerOptions::DEFAULT) {
|
||||
// If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides
|
||||
// the setting in ConfigProto.
|
||||
global_jit_level =
|
||||
static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
|
||||
}
|
||||
return global_jit_level;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -118,6 +119,13 @@ Status PickDeviceForXla(absl::Span<const string> device_names,
|
||||
Status CanPickDeviceForXla(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
bool* out_can_pick_device);
|
||||
|
||||
// Determine the global jit level which is ON if either the
|
||||
// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag
|
||||
// is true.
|
||||
OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
|
||||
const GraphOptimizationPassOptions& options);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
|
Loading…
Reference in New Issue
Block a user