Add an XLA "activity listener" mechanism.
This allows various components to listen to auto-clustering and JIT compilation events in TensorFlow. PiperOrigin-RevId: 253265614
This commit is contained in:
parent
4b7cbc8508
commit
5f2291877d
tensorflow/compiler/jit
@ -1,6 +1,7 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -286,6 +287,7 @@ cc_library(
|
||||
srcs = ["xla_compilation_cache.cc"],
|
||||
hdrs = ["xla_compilation_cache.h"],
|
||||
deps = [
|
||||
":xla_activity_listener",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -512,6 +514,7 @@ cc_library(
|
||||
"mark_for_compilation_pass.cc",
|
||||
"mark_for_compilation_pass_test_helper.cc",
|
||||
"partially_decluster_pass.cc",
|
||||
"report_clustering_info_pass.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"build_xla_ops_pass.h",
|
||||
@ -525,6 +528,7 @@ cc_library(
|
||||
"mark_for_compilation_pass.h",
|
||||
"mark_for_compilation_pass_test_helper.h",
|
||||
"partially_decluster_pass.h",
|
||||
"report_clustering_info_pass.h",
|
||||
],
|
||||
deps = [
|
||||
"compilability_check_util",
|
||||
@ -535,6 +539,7 @@ cc_library(
|
||||
":resource_operation_safety_analysis",
|
||||
":shape_inference_helpers",
|
||||
":union_find",
|
||||
":xla_activity_listener",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
@ -577,6 +582,7 @@ cc_library(
|
||||
hdrs = ["xla_cluster_util.h"],
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -843,6 +849,27 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_activity_listener_test",
|
||||
srcs = ["xla_activity_listener_test.cc"],
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_activity_listener",
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:direct_session_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:matmul_op",
|
||||
"//tensorflow/core/kernels:partitioned_function_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "xla_ops_py",
|
||||
kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
|
||||
@ -855,6 +882,27 @@ tf_custom_op_py_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_activity_listener",
|
||||
srcs = ["xla_activity_listener.cc"],
|
||||
hdrs = ["xla_activity_listener.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "xla_activity_proto",
|
||||
srcs = ["xla_activity.proto"],
|
||||
cc_api_version = 2,
|
||||
default_header = True,
|
||||
protodeps = tf_additional_all_protos(),
|
||||
provide_cc_alias = True,
|
||||
)
|
||||
|
||||
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
||||
cc_header_only_library(
|
||||
name = "xla_jit_headers_lib",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -58,15 +59,22 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
|
||||
PartiallyDeclusterPass);
|
||||
|
||||
// ReportClusteringInfoPass pass needs to run after all of the auto-clustering
|
||||
// passes have run but before encapsulation has run. This way it can easily
|
||||
// compute a summary of the clustering decisions we made and broadcast it via
|
||||
// xla_activity_listener.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
|
||||
ReportClusteringInfoPass);
|
||||
|
||||
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
|
||||
// also need to run it after the graph been rewritten to have _Send nodes added
|
||||
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
|
||||
// name, and encapsulation might remove that node from the graph.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
|
||||
EncapsulateSubgraphsPass);
|
||||
|
||||
// Must run after EncapsulateSubgraphsPass.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 60,
|
||||
BuildXlaOpsPass);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1364,46 +1364,36 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<absl::string_view, int> cluster_name_to_size;
|
||||
std::map<absl::string_view, std::map<absl::string_view, int>>
|
||||
cluster_name_to_op_histogram;
|
||||
std::map<absl::string_view, int> unclustered_op_histogram;
|
||||
int clustered_node_count = 0;
|
||||
|
||||
for (Node* n : graph_->nodes()) {
|
||||
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
|
||||
if (cluster_name) {
|
||||
clustered_node_count++;
|
||||
cluster_name_to_size[*cluster_name]++;
|
||||
cluster_name_to_op_histogram[*cluster_name][n->type_string()]++;
|
||||
} else {
|
||||
unclustered_op_histogram[n->type_string()]++;
|
||||
}
|
||||
}
|
||||
|
||||
int unclustered_node_count = graph_->num_nodes() - clustered_node_count;
|
||||
XlaAutoClusteringSummary auto_clustering_info =
|
||||
GetXlaAutoClusteringSummary(*graph_);
|
||||
|
||||
VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
|
||||
VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size "
|
||||
<< RatioToString(clustered_node_count, graph_->num_nodes());
|
||||
VLOG(2) << " Built " << auto_clustering_info.clusters_size()
|
||||
<< " clusters, size "
|
||||
<< RatioToString(auto_clustering_info.clustered_node_count(),
|
||||
graph_->num_nodes());
|
||||
|
||||
for (const auto& cluster_name_size_pair : cluster_name_to_size) {
|
||||
absl::string_view cluster_name = cluster_name_size_pair.first;
|
||||
int size = cluster_name_size_pair.second;
|
||||
for (XlaAutoClusteringSummary::Cluster cluster :
|
||||
auto_clustering_info.clusters()) {
|
||||
absl::string_view cluster_name = cluster.name();
|
||||
int size = cluster.size();
|
||||
VLOG(2) << " " << cluster_name << " "
|
||||
<< RatioToString(size, graph_->num_nodes());
|
||||
for (const auto& op_count_pair :
|
||||
cluster_name_to_op_histogram[cluster_name]) {
|
||||
VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second
|
||||
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
|
||||
cluster.op_histogram()) {
|
||||
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
|
||||
<< " instances";
|
||||
}
|
||||
}
|
||||
|
||||
if (!unclustered_op_histogram.empty()) {
|
||||
if (!auto_clustering_info.unclustered_op_histogram().empty()) {
|
||||
VLOG(2) << " Unclustered nodes: "
|
||||
<< RatioToString(unclustered_node_count, graph_->num_nodes());
|
||||
for (const auto& pair : unclustered_op_histogram) {
|
||||
VLOG(3) << " " << pair.first << ": " << pair.second << " instances";
|
||||
<< RatioToString(auto_clustering_info.unclustered_node_count(),
|
||||
graph_->num_nodes());
|
||||
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
|
||||
auto_clustering_info.unclustered_op_histogram()) {
|
||||
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
|
||||
<< " instances";
|
||||
}
|
||||
}
|
||||
|
||||
|
32
tensorflow/compiler/jit/report_clustering_info_pass.cc
Normal file
32
tensorflow/compiler/jit/report_clustering_info_pass.cc
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
Status ReportClusteringInfoPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
XlaAutoClusteringActivity activity;
|
||||
*activity.mutable_summary() = GetXlaAutoClusteringSummary(**options.graph);
|
||||
activity.set_global_jit_level(GetGlobalJitLevelForGraph(options));
|
||||
activity.set_cpu_global_jit_enabled(
|
||||
GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit);
|
||||
return BroadcastXlaActivity(std::move(activity));
|
||||
}
|
||||
} // namespace tensorflow
|
32
tensorflow/compiler/jit/report_clustering_info_pass.h
Normal file
32
tensorflow/compiler/jit/report_clustering_info_pass.h
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This is not really an optimization pass. It does not change the graph in any
|
||||
// way; instead it computes a summary of the XLA clusters in the graph and
|
||||
// broadcasts it via xla_activity_listener.
|
||||
class ReportClusteringInfoPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
|
104
tensorflow/compiler/jit/xla_activity.proto
Normal file
104
tensorflow/compiler/jit/xla_activity.proto
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/protobuf/config.proto";
|
||||
|
||||
// Summarizes the results of auto-clustering a TensorFlow graph.
|
||||
//
|
||||
// Next ID: 5
|
||||
message XlaAutoClusteringSummary {
|
||||
// Represents a single element in a histogram of ops ("op" as in "TensorFlow
|
||||
// operation").
|
||||
//
|
||||
// Next ID: 3
|
||||
message OpAndCount {
|
||||
// The TensorFlow operation (like MatMult, Add etc.)
|
||||
string op = 1;
|
||||
|
||||
// The number of times this occurs.
|
||||
int32 count = 2;
|
||||
}
|
||||
|
||||
// Describes a single XLA cluster.
|
||||
//
|
||||
// Next ID: 4
|
||||
message Cluster {
|
||||
string name = 1;
|
||||
|
||||
// The number of nodes in the cluster.
|
||||
int32 size = 2;
|
||||
|
||||
// A histogram of the TF operations in this cluster.
|
||||
repeated OpAndCount op_histogram = 3;
|
||||
};
|
||||
|
||||
// The number of nodes in the graph that are not inside an XLA cluster.
|
||||
int32 unclustered_node_count = 1;
|
||||
|
||||
// The number of nodes in the graph that are in an XLA cluster.
|
||||
int32 clustered_node_count = 2;
|
||||
|
||||
// All of the XLA clusters in the TF graph.
|
||||
repeated Cluster clusters = 3;
|
||||
|
||||
// A histogram of the TF operations that were not clustered.
|
||||
repeated OpAndCount unclustered_op_histogram = 4;
|
||||
}
|
||||
|
||||
// Listeners listening for auto clustering events get messages of this type.
|
||||
//
|
||||
// Next ID: 5
|
||||
message XlaAutoClusteringActivity {
|
||||
// xla_activity_listener set this to the global process id, as decided by the
|
||||
// callback registered via `SetGlobalProcessIdMaker.
|
||||
string global_process_id = 1;
|
||||
|
||||
// The value of GlobalJitLevel, as determined by `GetGlobalJitLevelForGraph`.
|
||||
// This determines if global auto-clustering is enabled.
|
||||
OptimizerOptions.GlobalJitLevel global_jit_level = 2;
|
||||
|
||||
// Whether --tf_xla_cpu_global_jit is enabled in TF_XLA_FLAGS.
|
||||
bool cpu_global_jit_enabled = 3;
|
||||
|
||||
XlaAutoClusteringSummary summary = 4;
|
||||
}
|
||||
|
||||
// Listeners listening for JIT compilation events get messages of this type.
|
||||
// Each instance of XlaJitCompilationActivity corresponds to a single
|
||||
// compilation of a single XLA cluster. E.g. if a graph has two clusters, A and
|
||||
// B, and A is compiled 5 times and B is compiled 2 times then we will generate
|
||||
// 7 instances of XlaJitCompilationActivity.
|
||||
//
|
||||
// Next ID: 6
|
||||
message XlaJitCompilationActivity {
|
||||
// xla_activity_listener set this to the global process id, as decided by the
|
||||
// callback registered via `SetGlobalProcessIdMaker.
|
||||
string global_process_id = 1;
|
||||
|
||||
string cluster_name = 2;
|
||||
|
||||
// The number of time this cluster has been compiled.
|
||||
int32 compile_count = 3;
|
||||
|
||||
// Microseconds spent in the individual compilation being reported.
|
||||
int64 compile_time_us = 4;
|
||||
|
||||
// Total microseconds spent in (re-)compiling this cluster so far.
|
||||
int64 cumulative_compile_time_us = 5;
|
||||
}
|
132
tensorflow/compiler/jit/xla_activity_listener.cc
Normal file
132
tensorflow/compiler/jit/xla_activity_listener.cc
Normal file
@ -0,0 +1,132 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// The list of all registered `XlaActivityListener`s.
|
||||
struct XlaActivityListenerList {
|
||||
absl::Mutex mutex;
|
||||
std::vector<std::unique_ptr<XlaActivityListener>> listeners GUARDED_BY(mutex);
|
||||
};
|
||||
|
||||
XlaActivityListenerList* GetXlaActivityListenerList() {
|
||||
static XlaActivityListenerList* listener_list = new XlaActivityListenerList;
|
||||
return listener_list;
|
||||
}
|
||||
|
||||
template <typename FnTy>
|
||||
Status ForEachListener(FnTy fn) {
|
||||
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
|
||||
absl::ReaderMutexLock reader_lock(&listener_list->mutex);
|
||||
|
||||
for (const std::unique_ptr<XlaActivityListener>& listener :
|
||||
listener_list->listeners) {
|
||||
TF_RETURN_IF_ERROR(fn(listener.get()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
struct GlobalProcessIdMakerStorage {
|
||||
absl::Mutex mutex;
|
||||
GlobalProcessIdMaker global_process_id_maker GUARDED_BY(mutex);
|
||||
|
||||
// True if we have used the process ID generated by `global_process_id_maker`.
|
||||
// We disallow setting the global process id maker once we have broadcasted
|
||||
// messages with global_process_id set to "unknown".
|
||||
bool has_been_used GUARDED_BY(mutex) = false;
|
||||
};
|
||||
|
||||
GlobalProcessIdMakerStorage* GetGlobalProcessIdMakerStorage() {
|
||||
static GlobalProcessIdMakerStorage* global_process_id_maker_storage =
|
||||
new GlobalProcessIdMakerStorage;
|
||||
return global_process_id_maker_storage;
|
||||
}
|
||||
|
||||
GlobalProcessIdMaker GetGlobalProcessIdMaker() {
|
||||
GlobalProcessIdMakerStorage* global_process_id_maker_storage =
|
||||
GetGlobalProcessIdMakerStorage();
|
||||
{
|
||||
absl::ReaderMutexLock reader_lock(&global_process_id_maker_storage->mutex);
|
||||
if (global_process_id_maker_storage->has_been_used) {
|
||||
return global_process_id_maker_storage->global_process_id_maker;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
absl::WriterMutexLock writer_lock(&global_process_id_maker_storage->mutex);
|
||||
global_process_id_maker_storage->has_been_used = true;
|
||||
}
|
||||
|
||||
{
|
||||
absl::ReaderMutexLock reader_lock(&global_process_id_maker_storage->mutex);
|
||||
return global_process_id_maker_storage->global_process_id_maker;
|
||||
}
|
||||
}
|
||||
|
||||
absl::string_view GetGlobalProcessId() {
|
||||
static std::string* cached_process_id = [&] {
|
||||
std::string* result = new std::string;
|
||||
GlobalProcessIdMaker maker = GetGlobalProcessIdMaker();
|
||||
*result = maker ? maker() : "unknown";
|
||||
return result;
|
||||
}();
|
||||
return *cached_process_id;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status BroadcastXlaActivity(
|
||||
XlaAutoClusteringActivity auto_clustering_activity) {
|
||||
auto_clustering_activity.set_global_process_id(
|
||||
std::string(GetGlobalProcessId()));
|
||||
return ForEachListener([&](XlaActivityListener* listener) {
|
||||
return listener->Listen(auto_clustering_activity);
|
||||
});
|
||||
}
|
||||
|
||||
Status BroadcastXlaActivity(
|
||||
XlaJitCompilationActivity jit_compilation_activity) {
|
||||
jit_compilation_activity.set_global_process_id(
|
||||
std::string(GetGlobalProcessId()));
|
||||
return ForEachListener([&](XlaActivityListener* listener) {
|
||||
return listener->Listen(jit_compilation_activity);
|
||||
});
|
||||
}
|
||||
|
||||
void RegisterXlaActivityListener(
|
||||
std::unique_ptr<XlaActivityListener> listener) {
|
||||
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
|
||||
absl::WriterMutexLock writer_lock(&listener_list->mutex);
|
||||
|
||||
listener_list->listeners.push_back(std::move(listener));
|
||||
}
|
||||
|
||||
void SetGlobalProcessIdMaker(GlobalProcessIdMaker global_process_id_maker) {
|
||||
GlobalProcessIdMakerStorage* global_process_id_maker_storage =
|
||||
GetGlobalProcessIdMakerStorage();
|
||||
absl::WriterMutexLock writer_lock(&global_process_id_maker_storage->mutex);
|
||||
CHECK(!global_process_id_maker_storage->has_been_used);
|
||||
global_process_id_maker_storage->global_process_id_maker =
|
||||
std::move(global_process_id_maker);
|
||||
}
|
||||
|
||||
XlaActivityListener::~XlaActivityListener() {}
|
||||
|
||||
} // namespace tensorflow
|
72
tensorflow/compiler/jit/xla_activity_listener.h
Normal file
72
tensorflow/compiler/jit/xla_activity_listener.h
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Broadcast `auto_clustering_activity` to all the registered listeners.
|
||||
Status BroadcastXlaActivity(XlaAutoClusteringActivity auto_clustering_activity);
|
||||
|
||||
// Broadcast `jit_compilation_activity` to all the registered listeners.
|
||||
Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity);
|
||||
|
||||
// Various components of the system can subclass XlaActivityListener to
|
||||
// notifications on auto-clustering and JIT compilation events.
|
||||
//
|
||||
// Subclasses of XlaActivityListener must be thread safe.
|
||||
class XlaActivityListener {
|
||||
public:
|
||||
// Called after TensorFlow auto-clusters a graph.
|
||||
virtual Status Listen(
|
||||
const XlaAutoClusteringActivity& auto_clustering_activity) = 0;
|
||||
|
||||
// Called after TensorFlow JIT compiles an XLA cluster.
|
||||
virtual Status Listen(
|
||||
const XlaJitCompilationActivity& jit_compilation_activity) = 0;
|
||||
|
||||
virtual ~XlaActivityListener();
|
||||
};
|
||||
|
||||
// Registers an `XlaActivityListener`, which will be invoked on all subsequent
|
||||
// `BroadcastXlaActivity` calls.
|
||||
void RegisterXlaActivityListener(std::unique_ptr<XlaActivityListener> listener);
|
||||
|
||||
using GlobalProcessIdMaker = std::function<std::string()>;
|
||||
|
||||
// Installs `global_process_id_maker` as a "global process id" maker.
|
||||
//
|
||||
// The value returned by the global process ID maker, if one is installed, is
|
||||
// stored in the global_process_id field of the Xla*Activity messages before
|
||||
// they're fed to the registered activity listeners. If no ID maker is
|
||||
// installed then global_process_id is set to "unknown".
|
||||
//
|
||||
// `global_process_id_maker` must be thread safe.
|
||||
//
|
||||
// The global process id maker is used to tag *Activity messages to so that the
|
||||
// broadcasting process can be uniquely identified. Therefore the global
|
||||
// process id maker
|
||||
//
|
||||
// - Must always return the same value within the same process.
|
||||
// - Cannot be set or changed after we have broadcasted any XLA activity.
|
||||
void SetGlobalProcessIdMaker(GlobalProcessIdMaker global_process_id_maker);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
|
195
tensorflow/compiler/jit/xla_activity_listener_test.cc
Normal file
195
tensorflow/compiler/jit/xla_activity_listener_test.cc
Normal file
@ -0,0 +1,195 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/list_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/direct_session.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class TestListener : public XlaActivityListener {
|
||||
public:
|
||||
Status Listen(
|
||||
const XlaAutoClusteringActivity& auto_clustering_activity) override {
|
||||
auto_clustering_activity_ = auto_clustering_activity;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Listen(
|
||||
const XlaJitCompilationActivity& jit_compilation_activity) override {
|
||||
jit_compilation_activity_ = jit_compilation_activity;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
~TestListener() override {}
|
||||
|
||||
const XlaAutoClusteringActivity& auto_clustering_activity() const {
|
||||
return auto_clustering_activity_;
|
||||
}
|
||||
const XlaJitCompilationActivity& jit_compilation_activity() const {
|
||||
return jit_compilation_activity_;
|
||||
}
|
||||
|
||||
private:
|
||||
XlaAutoClusteringActivity auto_clustering_activity_;
|
||||
XlaJitCompilationActivity jit_compilation_activity_;
|
||||
};
|
||||
|
||||
class XlaActivityListenerTest : public ::testing::Test {
|
||||
protected:
|
||||
XlaActivityListenerTest() {
|
||||
auto listener = absl::make_unique<TestListener>();
|
||||
listener_ = listener.get();
|
||||
RegisterXlaActivityListener(std::move(listener));
|
||||
SetGlobalProcessIdMaker([]() { return "42-xyz"; });
|
||||
}
|
||||
|
||||
TestListener* listener() const { return listener_; }
|
||||
|
||||
private:
|
||||
TestListener* listener_;
|
||||
};
|
||||
|
||||
GraphDef CreateGraphDef() {
|
||||
Scope root = Scope::NewRootScope().ExitOnError().WithAssignedDevice(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
a = ops::MatMul(root.WithOpName(absl::StrCat("matmul_", i)), a, a);
|
||||
a = ops::Add(root.WithOpName(absl::StrCat("add_", i)), a, a);
|
||||
}
|
||||
|
||||
GraphDef graph_def;
|
||||
root.graph()->ToGraphDef(&graph_def);
|
||||
return graph_def;
|
||||
}
|
||||
|
||||
TEST_F(XlaActivityListenerTest, Test) {
|
||||
GraphDef graph_def = CreateGraphDef();
|
||||
SessionOptions options;
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_global_jit_level(OptimizerOptions::ON_2);
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
|
||||
TF_ASSERT_OK(session->Create(graph_def));
|
||||
|
||||
std::vector<std::string> output_names = {std::string("add_4:0")};
|
||||
|
||||
Tensor tensor_2x2(DT_FLOAT, TensorShape({2, 2}));
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tensor_2x2.matrix<float>()(i / 2, i % 2) = 5 * i;
|
||||
}
|
||||
|
||||
Tensor tensor_3x3(DT_FLOAT, TensorShape({3, 3}));
|
||||
for (int i = 0; i < 9; i++) {
|
||||
tensor_3x3.matrix<float>()(i / 3, i % 3) = 5 * i;
|
||||
}
|
||||
|
||||
std::vector<std::pair<string, Tensor>> inputs_2x2 = {{"A", tensor_2x2}};
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run(inputs_2x2, output_names, /*target_node_names=*/{},
|
||||
&outputs));
|
||||
|
||||
absl::string_view expected_auto_clustering_activity =
|
||||
R"(global_process_id: "42-xyz"
|
||||
global_jit_level: ON_2
|
||||
cpu_global_jit_enabled: true
|
||||
summary {
|
||||
unclustered_node_count: 4
|
||||
clustered_node_count: 14
|
||||
clusters {
|
||||
name: "cluster_0"
|
||||
size: 14
|
||||
op_histogram {
|
||||
op: "Add"
|
||||
count: 1
|
||||
}
|
||||
op_histogram {
|
||||
op: "Const"
|
||||
count: 4
|
||||
}
|
||||
op_histogram {
|
||||
op: "MatMul"
|
||||
count: 5
|
||||
}
|
||||
op_histogram {
|
||||
op: "Mul"
|
||||
count: 4
|
||||
}
|
||||
}
|
||||
unclustered_op_histogram {
|
||||
op: "NoOp"
|
||||
count: 2
|
||||
}
|
||||
unclustered_op_histogram {
|
||||
op: "_Arg"
|
||||
count: 1
|
||||
}
|
||||
unclustered_op_histogram {
|
||||
op: "_Retval"
|
||||
count: 1
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_EQ(listener()->auto_clustering_activity().DebugString(),
|
||||
expected_auto_clustering_activity);
|
||||
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 1);
|
||||
|
||||
int64 first_compile_time =
|
||||
listener()->jit_compilation_activity().compile_time_us();
|
||||
EXPECT_GT(first_compile_time, 0);
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
|
||||
first_compile_time);
|
||||
|
||||
std::vector<std::pair<string, Tensor>> inputs_3x3 = {{"A", tensor_3x3}};
|
||||
|
||||
outputs.clear();
|
||||
for (int i = 0; i < 3; i++) {
|
||||
TF_ASSERT_OK(session->Run(inputs_3x3, output_names,
|
||||
/*target_node_names=*/{}, &outputs));
|
||||
}
|
||||
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 2);
|
||||
|
||||
EXPECT_GT(listener()->jit_compilation_activity().compile_time_us(), 0);
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
|
||||
first_compile_time +
|
||||
listener()->jit_compilation_activity().compile_time_us());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit = true;
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -318,4 +318,72 @@ bool IsShapeConsumerOp(const Node& node) {
|
||||
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
|
||||
node.type_string() == "Size";
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ClusterInfo {
|
||||
int size;
|
||||
|
||||
// Maps op names to the number of times they appear in the cluster.
|
||||
absl::flat_hash_map<absl::string_view, int> op_histogram;
|
||||
};
|
||||
|
||||
void HistogramMapToRepeatedOpAndCount(
|
||||
protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount>* result,
|
||||
const absl::flat_hash_map<absl::string_view, int>& histogram) {
|
||||
for (const auto& pair : histogram) {
|
||||
XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add();
|
||||
new_entry->set_op(std::string(pair.first));
|
||||
new_entry->set_count(pair.second);
|
||||
}
|
||||
|
||||
absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a,
|
||||
const XlaAutoClusteringSummary::OpAndCount& b) {
|
||||
return a.op() < b.op();
|
||||
});
|
||||
}
|
||||
|
||||
void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result,
|
||||
absl::string_view name, const ClusterInfo& info) {
|
||||
result->set_name(std::string(name));
|
||||
result->set_size(info.size);
|
||||
HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(),
|
||||
info.op_histogram);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
|
||||
absl::flat_hash_map<absl::string_view, ClusterInfo> cluster_name_to_info;
|
||||
XlaAutoClusteringSummary result;
|
||||
|
||||
absl::flat_hash_map<absl::string_view, int> unclustered_op_histogram;
|
||||
|
||||
for (Node* n : graph.nodes()) {
|
||||
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
|
||||
if (cluster_name) {
|
||||
result.set_clustered_node_count(result.clustered_node_count() + 1);
|
||||
ClusterInfo* info = &cluster_name_to_info[*cluster_name];
|
||||
info->size++;
|
||||
info->op_histogram[n->type_string()]++;
|
||||
} else {
|
||||
result.set_unclustered_node_count(result.unclustered_node_count() + 1);
|
||||
unclustered_op_histogram[n->type_string()]++;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& pair : cluster_name_to_info) {
|
||||
XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters();
|
||||
ClusterInfoToProtobuf(new_cluster, pair.first, pair.second);
|
||||
}
|
||||
|
||||
absl::c_sort(*result.mutable_clusters(),
|
||||
[&](const XlaAutoClusteringSummary::Cluster& a,
|
||||
const XlaAutoClusteringSummary::Cluster& b) {
|
||||
return a.name() < b.name();
|
||||
});
|
||||
|
||||
HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(),
|
||||
unclustered_op_histogram);
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
@ -87,6 +89,11 @@ bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
|
||||
// Returns true if `node` an operator that consumes only the shape of its input,
|
||||
// not the data itself.
|
||||
bool IsShapeConsumerOp(const Node& node);
|
||||
|
||||
// Computes a clustering summary for `graph`. See documentation on
|
||||
// `XlaAutoClusteringSummary` for details.
|
||||
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
@ -361,6 +362,16 @@ Status XlaCompilationCache::CompileImpl(
|
||||
<< tensorflow::strings::HumanReadableElapsedTime(
|
||||
it->second.cumulative_compile_time_us / 1.0e6)
|
||||
<< ")";
|
||||
|
||||
XlaJitCompilationActivity jit_compilation_activity;
|
||||
jit_compilation_activity.set_cluster_name(function.name());
|
||||
jit_compilation_activity.set_compile_count(it->second.compile_count);
|
||||
jit_compilation_activity.set_compile_time_us(compile_time_us);
|
||||
jit_compilation_activity.set_cumulative_compile_time_us(
|
||||
it->second.cumulative_compile_time_us);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
BroadcastXlaActivity(std::move(jit_compilation_activity)));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(entry->compilation_status);
|
||||
|
Loading…
Reference in New Issue
Block a user