Add an XLA "activity listener" mechanism.

This allows various components to listen to auto-clustering and JIT compilation
events in TensorFlow.

PiperOrigin-RevId: 253265614
This commit is contained in:
Sanjoy Das 2019-06-14 11:33:09 -07:00 committed by TensorFlower Gardener
parent 4b7cbc8508
commit 5f2291877d
12 changed files with 731 additions and 32 deletions

View File

@ -1,6 +1,7 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
package(
default_visibility = [
@ -286,6 +287,7 @@ cc_library(
srcs = ["xla_compilation_cache.cc"],
hdrs = ["xla_compilation_cache.h"],
deps = [
":xla_activity_listener",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
@ -512,6 +514,7 @@ cc_library(
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
"report_clustering_info_pass.cc",
],
hdrs = [
"build_xla_ops_pass.h",
@ -525,6 +528,7 @@ cc_library(
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
"report_clustering_info_pass.h",
],
deps = [
"compilability_check_util",
@ -535,6 +539,7 @@ cc_library(
":resource_operation_safety_analysis",
":shape_inference_helpers",
":union_find",
":xla_activity_listener",
":xla_cluster_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:functional_ops",
@ -577,6 +582,7 @@ cc_library(
hdrs = ["xla_cluster_util.h"],
deps = [
":flags",
":xla_activity_proto_cc",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -843,6 +849,27 @@ tf_cc_test(
],
)
tf_cc_test(
name = "xla_activity_listener_test",
srcs = ["xla_activity_listener_test.cc"],
deps = [
":flags",
":xla_activity_listener",
":xla_cpu_device",
":xla_cpu_jit",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session_internal",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:partitioned_function_ops",
],
)
tf_custom_op_py_library(
name = "xla_ops_py",
kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
@ -855,6 +882,27 @@ tf_custom_op_py_library(
],
)
cc_library(
name = "xla_activity_listener",
srcs = ["xla_activity_listener.cc"],
hdrs = ["xla_activity_listener.h"],
visibility = ["//visibility:public"],
deps = [
":xla_activity_proto_cc",
"//tensorflow/core:lib",
"@com_google_absl//absl/synchronization",
],
)
tf_proto_library(
name = "xla_activity_proto",
srcs = ["xla_activity.proto"],
cc_api_version = 2,
default_header = True,
protodeps = tf_additional_all_protos(),
provide_cc_alias = True,
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
@ -58,15 +59,22 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
PartiallyDeclusterPass);
// ReportClusteringInfoPass pass needs to run after all of the auto-clustering
// passes have run but before encapsulation has run. This way it can easily
// compute a summary of the clustering decisions we made and broadcast it via
// xla_activity_listener.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
ReportClusteringInfoPass);
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
EncapsulateSubgraphsPass);
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 60,
BuildXlaOpsPass);
} // namespace tensorflow

View File

@ -1364,46 +1364,36 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
return;
}
std::map<absl::string_view, int> cluster_name_to_size;
std::map<absl::string_view, std::map<absl::string_view, int>>
cluster_name_to_op_histogram;
std::map<absl::string_view, int> unclustered_op_histogram;
int clustered_node_count = 0;
for (Node* n : graph_->nodes()) {
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
if (cluster_name) {
clustered_node_count++;
cluster_name_to_size[*cluster_name]++;
cluster_name_to_op_histogram[*cluster_name][n->type_string()]++;
} else {
unclustered_op_histogram[n->type_string()]++;
}
}
int unclustered_node_count = graph_->num_nodes() - clustered_node_count;
XlaAutoClusteringSummary auto_clustering_info =
GetXlaAutoClusteringSummary(*graph_);
VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size "
<< RatioToString(clustered_node_count, graph_->num_nodes());
VLOG(2) << " Built " << auto_clustering_info.clusters_size()
<< " clusters, size "
<< RatioToString(auto_clustering_info.clustered_node_count(),
graph_->num_nodes());
for (const auto& cluster_name_size_pair : cluster_name_to_size) {
absl::string_view cluster_name = cluster_name_size_pair.first;
int size = cluster_name_size_pair.second;
for (XlaAutoClusteringSummary::Cluster cluster :
auto_clustering_info.clusters()) {
absl::string_view cluster_name = cluster.name();
int size = cluster.size();
VLOG(2) << " " << cluster_name << " "
<< RatioToString(size, graph_->num_nodes());
for (const auto& op_count_pair :
cluster_name_to_op_histogram[cluster_name]) {
VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
cluster.op_histogram()) {
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
<< " instances";
}
}
if (!unclustered_op_histogram.empty()) {
if (!auto_clustering_info.unclustered_op_histogram().empty()) {
VLOG(2) << " Unclustered nodes: "
<< RatioToString(unclustered_node_count, graph_->num_nodes());
for (const auto& pair : unclustered_op_histogram) {
VLOG(3) << " " << pair.first << ": " << pair.second << " instances";
<< RatioToString(auto_clustering_info.unclustered_node_count(),
graph_->num_nodes());
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
auto_clustering_info.unclustered_op_histogram()) {
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
<< " instances";
}
}

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
namespace tensorflow {
Status ReportClusteringInfoPass::Run(
const GraphOptimizationPassOptions& options) {
XlaAutoClusteringActivity activity;
*activity.mutable_summary() = GetXlaAutoClusteringSummary(**options.graph);
activity.set_global_jit_level(GetGlobalJitLevelForGraph(options));
activity.set_cpu_global_jit_enabled(
GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit);
return BroadcastXlaActivity(std::move(activity));
}
} // namespace tensorflow

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
#define TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// This is not really an optimization pass. It does not change the graph in any
// way; instead it computes a summary of the XLA clusters in the graph and
// broadcasts it via xla_activity_listener.
class ReportClusteringInfoPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_

View File

@ -0,0 +1,104 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow;
import "tensorflow/core/protobuf/config.proto";
// Summarizes the results of auto-clustering a TensorFlow graph.
//
// Next ID: 5
message XlaAutoClusteringSummary {
// Represents a single element in a histogram of ops ("op" as in "TensorFlow
// operation").
//
// Next ID: 3
message OpAndCount {
// The TensorFlow operation (like MatMult, Add etc.)
string op = 1;
// The number of times this occurs.
int32 count = 2;
}
// Describes a single XLA cluster.
//
// Next ID: 4
message Cluster {
string name = 1;
// The number of nodes in the cluster.
int32 size = 2;
// A histogram of the TF operations in this cluster.
repeated OpAndCount op_histogram = 3;
};
// The number of nodes in the graph that are not inside an XLA cluster.
int32 unclustered_node_count = 1;
// The number of nodes in the graph that are in an XLA cluster.
int32 clustered_node_count = 2;
// All of the XLA clusters in the TF graph.
repeated Cluster clusters = 3;
// A histogram of the TF operations that were not clustered.
repeated OpAndCount unclustered_op_histogram = 4;
}
// Listeners listening for auto clustering events get messages of this type.
//
// Next ID: 5
message XlaAutoClusteringActivity {
// xla_activity_listener set this to the global process id, as decided by the
// callback registered via `SetGlobalProcessIdMaker.
string global_process_id = 1;
// The value of GlobalJitLevel, as determined by `GetGlobalJitLevelForGraph`.
// This determines if global auto-clustering is enabled.
OptimizerOptions.GlobalJitLevel global_jit_level = 2;
// Whether --tf_xla_cpu_global_jit is enabled in TF_XLA_FLAGS.
bool cpu_global_jit_enabled = 3;
XlaAutoClusteringSummary summary = 4;
}
// Listeners listening for JIT compilation events get messages of this type.
// Each instance of XlaJitCompilationActivity corresponds to a single
// compilation of a single XLA cluster. E.g. if a graph has two clusters, A and
// B, and A is compiled 5 times and B is compiled 2 times then we will generate
// 7 instances of XlaJitCompilationActivity.
//
// Next ID: 6
message XlaJitCompilationActivity {
// xla_activity_listener set this to the global process id, as decided by the
// callback registered via `SetGlobalProcessIdMaker.
string global_process_id = 1;
string cluster_name = 2;
// The number of time this cluster has been compiled.
int32 compile_count = 3;
// Microseconds spent in the individual compilation being reported.
int64 compile_time_us = 4;
// Total microseconds spent in (re-)compiling this cluster so far.
int64 cumulative_compile_time_us = 5;
}

View File

@ -0,0 +1,132 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace {
// The list of all registered `XlaActivityListener`s.
struct XlaActivityListenerList {
absl::Mutex mutex;
std::vector<std::unique_ptr<XlaActivityListener>> listeners GUARDED_BY(mutex);
};
XlaActivityListenerList* GetXlaActivityListenerList() {
static XlaActivityListenerList* listener_list = new XlaActivityListenerList;
return listener_list;
}
template <typename FnTy>
Status ForEachListener(FnTy fn) {
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
absl::ReaderMutexLock reader_lock(&listener_list->mutex);
for (const std::unique_ptr<XlaActivityListener>& listener :
listener_list->listeners) {
TF_RETURN_IF_ERROR(fn(listener.get()));
}
return Status::OK();
}
struct GlobalProcessIdMakerStorage {
absl::Mutex mutex;
GlobalProcessIdMaker global_process_id_maker GUARDED_BY(mutex);
// True if we have used the process ID generated by `global_process_id_maker`.
// We disallow setting the global process id maker once we have broadcasted
// messages with global_process_id set to "unknown".
bool has_been_used GUARDED_BY(mutex) = false;
};
GlobalProcessIdMakerStorage* GetGlobalProcessIdMakerStorage() {
static GlobalProcessIdMakerStorage* global_process_id_maker_storage =
new GlobalProcessIdMakerStorage;
return global_process_id_maker_storage;
}
GlobalProcessIdMaker GetGlobalProcessIdMaker() {
GlobalProcessIdMakerStorage* global_process_id_maker_storage =
GetGlobalProcessIdMakerStorage();
{
absl::ReaderMutexLock reader_lock(&global_process_id_maker_storage->mutex);
if (global_process_id_maker_storage->has_been_used) {
return global_process_id_maker_storage->global_process_id_maker;
}
}
{
absl::WriterMutexLock writer_lock(&global_process_id_maker_storage->mutex);
global_process_id_maker_storage->has_been_used = true;
}
{
absl::ReaderMutexLock reader_lock(&global_process_id_maker_storage->mutex);
return global_process_id_maker_storage->global_process_id_maker;
}
}
absl::string_view GetGlobalProcessId() {
static std::string* cached_process_id = [&] {
std::string* result = new std::string;
GlobalProcessIdMaker maker = GetGlobalProcessIdMaker();
*result = maker ? maker() : "unknown";
return result;
}();
return *cached_process_id;
}
} // namespace
Status BroadcastXlaActivity(
XlaAutoClusteringActivity auto_clustering_activity) {
auto_clustering_activity.set_global_process_id(
std::string(GetGlobalProcessId()));
return ForEachListener([&](XlaActivityListener* listener) {
return listener->Listen(auto_clustering_activity);
});
}
Status BroadcastXlaActivity(
XlaJitCompilationActivity jit_compilation_activity) {
jit_compilation_activity.set_global_process_id(
std::string(GetGlobalProcessId()));
return ForEachListener([&](XlaActivityListener* listener) {
return listener->Listen(jit_compilation_activity);
});
}
void RegisterXlaActivityListener(
std::unique_ptr<XlaActivityListener> listener) {
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
absl::WriterMutexLock writer_lock(&listener_list->mutex);
listener_list->listeners.push_back(std::move(listener));
}
void SetGlobalProcessIdMaker(GlobalProcessIdMaker global_process_id_maker) {
GlobalProcessIdMakerStorage* global_process_id_maker_storage =
GetGlobalProcessIdMakerStorage();
absl::WriterMutexLock writer_lock(&global_process_id_maker_storage->mutex);
CHECK(!global_process_id_maker_storage->has_been_used);
global_process_id_maker_storage->global_process_id_maker =
std::move(global_process_id_maker);
}
XlaActivityListener::~XlaActivityListener() {}
} // namespace tensorflow

View File

@ -0,0 +1,72 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
#define TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
#include <memory>
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Broadcast `auto_clustering_activity` to all the registered listeners.
Status BroadcastXlaActivity(XlaAutoClusteringActivity auto_clustering_activity);
// Broadcast `jit_compilation_activity` to all the registered listeners.
Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity);
// Various components of the system can subclass XlaActivityListener to
// notifications on auto-clustering and JIT compilation events.
//
// Subclasses of XlaActivityListener must be thread safe.
class XlaActivityListener {
public:
// Called after TensorFlow auto-clusters a graph.
virtual Status Listen(
const XlaAutoClusteringActivity& auto_clustering_activity) = 0;
// Called after TensorFlow JIT compiles an XLA cluster.
virtual Status Listen(
const XlaJitCompilationActivity& jit_compilation_activity) = 0;
virtual ~XlaActivityListener();
};
// Registers an `XlaActivityListener`, which will be invoked on all subsequent
// `BroadcastXlaActivity` calls.
void RegisterXlaActivityListener(std::unique_ptr<XlaActivityListener> listener);
using GlobalProcessIdMaker = std::function<std::string()>;
// Installs `global_process_id_maker` as a "global process id" maker.
//
// The value returned by the global process ID maker, if one is installed, is
// stored in the global_process_id field of the Xla*Activity messages before
// they're fed to the registered activity listeners. If no ID maker is
// installed then global_process_id is set to "unknown".
//
// `global_process_id_maker` must be thread safe.
//
// The global process id maker is used to tag *Activity messages to so that the
// broadcasting process can be uniquely identified. Therefore the global
// process id maker
//
// - Must always return the same value within the same process.
// - Cannot be set or changed after we have broadcasted any XLA activity.
void SetGlobalProcessIdMaker(GlobalProcessIdMaker global_process_id_maker);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_

View File

@ -0,0 +1,195 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include <cstdlib>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/direct_session.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class TestListener : public XlaActivityListener {
public:
Status Listen(
const XlaAutoClusteringActivity& auto_clustering_activity) override {
auto_clustering_activity_ = auto_clustering_activity;
return Status::OK();
}
Status Listen(
const XlaJitCompilationActivity& jit_compilation_activity) override {
jit_compilation_activity_ = jit_compilation_activity;
return Status::OK();
}
~TestListener() override {}
const XlaAutoClusteringActivity& auto_clustering_activity() const {
return auto_clustering_activity_;
}
const XlaJitCompilationActivity& jit_compilation_activity() const {
return jit_compilation_activity_;
}
private:
XlaAutoClusteringActivity auto_clustering_activity_;
XlaJitCompilationActivity jit_compilation_activity_;
};
class XlaActivityListenerTest : public ::testing::Test {
protected:
XlaActivityListenerTest() {
auto listener = absl::make_unique<TestListener>();
listener_ = listener.get();
RegisterXlaActivityListener(std::move(listener));
SetGlobalProcessIdMaker([]() { return "42-xyz"; });
}
TestListener* listener() const { return listener_; }
private:
TestListener* listener_;
};
GraphDef CreateGraphDef() {
Scope root = Scope::NewRootScope().ExitOnError().WithAssignedDevice(
"/job:localhost/replica:0/task:0/device:CPU:0");
Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
for (int i = 0; i < 5; i++) {
a = ops::MatMul(root.WithOpName(absl::StrCat("matmul_", i)), a, a);
a = ops::Add(root.WithOpName(absl::StrCat("add_", i)), a, a);
}
GraphDef graph_def;
root.graph()->ToGraphDef(&graph_def);
return graph_def;
}
TEST_F(XlaActivityListenerTest, Test) {
GraphDef graph_def = CreateGraphDef();
SessionOptions options;
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_global_jit_level(OptimizerOptions::ON_2);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(graph_def));
std::vector<std::string> output_names = {std::string("add_4:0")};
Tensor tensor_2x2(DT_FLOAT, TensorShape({2, 2}));
for (int i = 0; i < 4; i++) {
tensor_2x2.matrix<float>()(i / 2, i % 2) = 5 * i;
}
Tensor tensor_3x3(DT_FLOAT, TensorShape({3, 3}));
for (int i = 0; i < 9; i++) {
tensor_3x3.matrix<float>()(i / 3, i % 3) = 5 * i;
}
std::vector<std::pair<string, Tensor>> inputs_2x2 = {{"A", tensor_2x2}};
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->Run(inputs_2x2, output_names, /*target_node_names=*/{},
&outputs));
absl::string_view expected_auto_clustering_activity =
R"(global_process_id: "42-xyz"
global_jit_level: ON_2
cpu_global_jit_enabled: true
summary {
unclustered_node_count: 4
clustered_node_count: 14
clusters {
name: "cluster_0"
size: 14
op_histogram {
op: "Add"
count: 1
}
op_histogram {
op: "Const"
count: 4
}
op_histogram {
op: "MatMul"
count: 5
}
op_histogram {
op: "Mul"
count: 4
}
}
unclustered_op_histogram {
op: "NoOp"
count: 2
}
unclustered_op_histogram {
op: "_Arg"
count: 1
}
unclustered_op_histogram {
op: "_Retval"
count: 1
}
}
)";
EXPECT_EQ(listener()->auto_clustering_activity().DebugString(),
expected_auto_clustering_activity);
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 1);
int64 first_compile_time =
listener()->jit_compilation_activity().compile_time_us();
EXPECT_GT(first_compile_time, 0);
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
first_compile_time);
std::vector<std::pair<string, Tensor>> inputs_3x3 = {{"A", tensor_3x3}};
outputs.clear();
for (int i = 0; i < 3; i++) {
TF_ASSERT_OK(session->Run(inputs_3x3, output_names,
/*target_node_names=*/{}, &outputs));
}
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 2);
EXPECT_GT(listener()->jit_compilation_activity().compile_time_us(), 0);
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
first_compile_time +
listener()->jit_compilation_activity().compile_time_us());
}
} // namespace
} // namespace tensorflow
int main(int argc, char** argv) {
tensorflow::GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit = true;
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -318,4 +318,72 @@ bool IsShapeConsumerOp(const Node& node) {
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
node.type_string() == "Size";
}
namespace {
struct ClusterInfo {
int size;
// Maps op names to the number of times they appear in the cluster.
absl::flat_hash_map<absl::string_view, int> op_histogram;
};
void HistogramMapToRepeatedOpAndCount(
protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount>* result,
const absl::flat_hash_map<absl::string_view, int>& histogram) {
for (const auto& pair : histogram) {
XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add();
new_entry->set_op(std::string(pair.first));
new_entry->set_count(pair.second);
}
absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a,
const XlaAutoClusteringSummary::OpAndCount& b) {
return a.op() < b.op();
});
}
void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result,
absl::string_view name, const ClusterInfo& info) {
result->set_name(std::string(name));
result->set_size(info.size);
HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(),
info.op_histogram);
}
} // namespace
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
absl::flat_hash_map<absl::string_view, ClusterInfo> cluster_name_to_info;
XlaAutoClusteringSummary result;
absl::flat_hash_map<absl::string_view, int> unclustered_op_histogram;
for (Node* n : graph.nodes()) {
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
if (cluster_name) {
result.set_clustered_node_count(result.clustered_node_count() + 1);
ClusterInfo* info = &cluster_name_to_info[*cluster_name];
info->size++;
info->op_histogram[n->type_string()]++;
} else {
result.set_unclustered_node_count(result.unclustered_node_count() + 1);
unclustered_op_histogram[n->type_string()]++;
}
}
for (const auto& pair : cluster_name_to_info) {
XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters();
ClusterInfoToProtobuf(new_cluster, pair.first, pair.second);
}
absl::c_sort(*result.mutable_clusters(),
[&](const XlaAutoClusteringSummary::Cluster& a,
const XlaAutoClusteringSummary::Cluster& b) {
return a.name() < b.name();
});
HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(),
unclustered_op_histogram);
return result;
}
} // namespace tensorflow

View File

@ -18,8 +18,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/algorithm.h"
@ -87,6 +89,11 @@ bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
// Returns true if `node` an operator that consumes only the shape of its input,
// not the data itself.
bool IsShapeConsumerOp(const Node& node);
// Computes a clustering summary for `graph`. See documentation on
// `XlaAutoClusteringSummary` for details.
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
@ -361,6 +362,16 @@ Status XlaCompilationCache::CompileImpl(
<< tensorflow::strings::HumanReadableElapsedTime(
it->second.cumulative_compile_time_us / 1.0e6)
<< ")";
XlaJitCompilationActivity jit_compilation_activity;
jit_compilation_activity.set_cluster_name(function.name());
jit_compilation_activity.set_compile_count(it->second.compile_count);
jit_compilation_activity.set_compile_time_us(compile_time_us);
jit_compilation_activity.set_cumulative_compile_time_us(
it->second.cumulative_compile_time_us);
TF_RETURN_IF_ERROR(
BroadcastXlaActivity(std::move(jit_compilation_activity)));
}
}
TF_RETURN_IF_ERROR(entry->compilation_status);