Add an XLA "activity listener" mechanism.
This allows various components to listen to auto-clustering and JIT compilation events in TensorFlow. PiperOrigin-RevId: 253265614
This commit is contained in:
parent
4b7cbc8508
commit
5f2291877d
@ -1,6 +1,7 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||||
|
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
@ -286,6 +287,7 @@ cc_library(
|
|||||||
srcs = ["xla_compilation_cache.cc"],
|
srcs = ["xla_compilation_cache.cc"],
|
||||||
hdrs = ["xla_compilation_cache.h"],
|
hdrs = ["xla_compilation_cache.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":xla_activity_listener",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -512,6 +514,7 @@ cc_library(
|
|||||||
"mark_for_compilation_pass.cc",
|
"mark_for_compilation_pass.cc",
|
||||||
"mark_for_compilation_pass_test_helper.cc",
|
"mark_for_compilation_pass_test_helper.cc",
|
||||||
"partially_decluster_pass.cc",
|
"partially_decluster_pass.cc",
|
||||||
|
"report_clustering_info_pass.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"build_xla_ops_pass.h",
|
"build_xla_ops_pass.h",
|
||||||
@ -525,6 +528,7 @@ cc_library(
|
|||||||
"mark_for_compilation_pass.h",
|
"mark_for_compilation_pass.h",
|
||||||
"mark_for_compilation_pass_test_helper.h",
|
"mark_for_compilation_pass_test_helper.h",
|
||||||
"partially_decluster_pass.h",
|
"partially_decluster_pass.h",
|
||||||
|
"report_clustering_info_pass.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"compilability_check_util",
|
"compilability_check_util",
|
||||||
@ -535,6 +539,7 @@ cc_library(
|
|||||||
":resource_operation_safety_analysis",
|
":resource_operation_safety_analysis",
|
||||||
":shape_inference_helpers",
|
":shape_inference_helpers",
|
||||||
":union_find",
|
":union_find",
|
||||||
|
":xla_activity_listener",
|
||||||
":xla_cluster_util",
|
":xla_cluster_util",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/cc:functional_ops",
|
"//tensorflow/cc:functional_ops",
|
||||||
@ -577,6 +582,7 @@ cc_library(
|
|||||||
hdrs = ["xla_cluster_util.h"],
|
hdrs = ["xla_cluster_util.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":flags",
|
":flags",
|
||||||
|
":xla_activity_proto_cc",
|
||||||
"//tensorflow/compiler/jit/graphcycles",
|
"//tensorflow/compiler/jit/graphcycles",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -843,6 +849,27 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "xla_activity_listener_test",
|
||||||
|
srcs = ["xla_activity_listener_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":flags",
|
||||||
|
":xla_activity_listener",
|
||||||
|
":xla_cpu_device",
|
||||||
|
":xla_cpu_jit",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:ops",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:direct_session_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
|
"//tensorflow/core/kernels:matmul_op",
|
||||||
|
"//tensorflow/core/kernels:partitioned_function_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_custom_op_py_library(
|
tf_custom_op_py_library(
|
||||||
name = "xla_ops_py",
|
name = "xla_ops_py",
|
||||||
kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
|
kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
|
||||||
@ -855,6 +882,27 @@ tf_custom_op_py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "xla_activity_listener",
|
||||||
|
srcs = ["xla_activity_listener.cc"],
|
||||||
|
hdrs = ["xla_activity_listener.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":xla_activity_proto_cc",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_proto_library(
|
||||||
|
name = "xla_activity_proto",
|
||||||
|
srcs = ["xla_activity.proto"],
|
||||||
|
cc_api_version = 2,
|
||||||
|
default_header = True,
|
||||||
|
protodeps = tf_additional_all_protos(),
|
||||||
|
provide_cc_alias = True,
|
||||||
|
)
|
||||||
|
|
||||||
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
||||||
cc_header_only_library(
|
cc_header_only_library(
|
||||||
name = "xla_jit_headers_lib",
|
name = "xla_jit_headers_lib",
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
|
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
|
||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||||
|
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -58,15 +59,22 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
|||||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
|
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
|
||||||
PartiallyDeclusterPass);
|
PartiallyDeclusterPass);
|
||||||
|
|
||||||
|
// ReportClusteringInfoPass pass needs to run after all of the auto-clustering
|
||||||
|
// passes have run but before encapsulation has run. This way it can easily
|
||||||
|
// compute a summary of the clustering decisions we made and broadcast it via
|
||||||
|
// xla_activity_listener.
|
||||||
|
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
|
||||||
|
ReportClusteringInfoPass);
|
||||||
|
|
||||||
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
|
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
|
||||||
// also need to run it after the graph been rewritten to have _Send nodes added
|
// also need to run it after the graph been rewritten to have _Send nodes added
|
||||||
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
|
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
|
||||||
// name, and encapsulation might remove that node from the graph.
|
// name, and encapsulation might remove that node from the graph.
|
||||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
|
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
|
||||||
EncapsulateSubgraphsPass);
|
EncapsulateSubgraphsPass);
|
||||||
|
|
||||||
// Must run after EncapsulateSubgraphsPass.
|
// Must run after EncapsulateSubgraphsPass.
|
||||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
|
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 60,
|
||||||
BuildXlaOpsPass);
|
BuildXlaOpsPass);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1364,46 +1364,36 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<absl::string_view, int> cluster_name_to_size;
|
XlaAutoClusteringSummary auto_clustering_info =
|
||||||
std::map<absl::string_view, std::map<absl::string_view, int>>
|
GetXlaAutoClusteringSummary(*graph_);
|
||||||
cluster_name_to_op_histogram;
|
|
||||||
std::map<absl::string_view, int> unclustered_op_histogram;
|
|
||||||
int clustered_node_count = 0;
|
|
||||||
|
|
||||||
for (Node* n : graph_->nodes()) {
|
|
||||||
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
|
|
||||||
if (cluster_name) {
|
|
||||||
clustered_node_count++;
|
|
||||||
cluster_name_to_size[*cluster_name]++;
|
|
||||||
cluster_name_to_op_histogram[*cluster_name][n->type_string()]++;
|
|
||||||
} else {
|
|
||||||
unclustered_op_histogram[n->type_string()]++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int unclustered_node_count = graph_->num_nodes() - clustered_node_count;
|
|
||||||
|
|
||||||
VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
|
VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
|
||||||
VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size "
|
VLOG(2) << " Built " << auto_clustering_info.clusters_size()
|
||||||
<< RatioToString(clustered_node_count, graph_->num_nodes());
|
<< " clusters, size "
|
||||||
|
<< RatioToString(auto_clustering_info.clustered_node_count(),
|
||||||
|
graph_->num_nodes());
|
||||||
|
|
||||||
for (const auto& cluster_name_size_pair : cluster_name_to_size) {
|
for (XlaAutoClusteringSummary::Cluster cluster :
|
||||||
absl::string_view cluster_name = cluster_name_size_pair.first;
|
auto_clustering_info.clusters()) {
|
||||||
int size = cluster_name_size_pair.second;
|
absl::string_view cluster_name = cluster.name();
|
||||||
|
int size = cluster.size();
|
||||||
VLOG(2) << " " << cluster_name << " "
|
VLOG(2) << " " << cluster_name << " "
|
||||||
<< RatioToString(size, graph_->num_nodes());
|
<< RatioToString(size, graph_->num_nodes());
|
||||||
for (const auto& op_count_pair :
|
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
|
||||||
cluster_name_to_op_histogram[cluster_name]) {
|
cluster.op_histogram()) {
|
||||||
VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second
|
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
|
||||||
<< " instances";
|
<< " instances";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!unclustered_op_histogram.empty()) {
|
if (!auto_clustering_info.unclustered_op_histogram().empty()) {
|
||||||
VLOG(2) << " Unclustered nodes: "
|
VLOG(2) << " Unclustered nodes: "
|
||||||
<< RatioToString(unclustered_node_count, graph_->num_nodes());
|
<< RatioToString(auto_clustering_info.unclustered_node_count(),
|
||||||
for (const auto& pair : unclustered_op_histogram) {
|
graph_->num_nodes());
|
||||||
VLOG(3) << " " << pair.first << ": " << pair.second << " instances";
|
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
|
||||||
|
auto_clustering_info.unclustered_op_histogram()) {
|
||||||
|
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
|
||||||
|
<< " instances";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
32
tensorflow/compiler/jit/report_clustering_info_pass.cc
Normal file
32
tensorflow/compiler/jit/report_clustering_info_pass.cc
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
Status ReportClusteringInfoPass::Run(
|
||||||
|
const GraphOptimizationPassOptions& options) {
|
||||||
|
XlaAutoClusteringActivity activity;
|
||||||
|
*activity.mutable_summary() = GetXlaAutoClusteringSummary(**options.graph);
|
||||||
|
activity.set_global_jit_level(GetGlobalJitLevelForGraph(options));
|
||||||
|
activity.set_cpu_global_jit_enabled(
|
||||||
|
GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit);
|
||||||
|
return BroadcastXlaActivity(std::move(activity));
|
||||||
|
}
|
||||||
|
} // namespace tensorflow
|
32
tensorflow/compiler/jit/report_clustering_info_pass.h
Normal file
32
tensorflow/compiler/jit/report_clustering_info_pass.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// This is not really an optimization pass. It does not change the graph in any
|
||||||
|
// way; instead it computes a summary of the XLA clusters in the graph and
|
||||||
|
// broadcasts it via xla_activity_listener.
|
||||||
|
class ReportClusteringInfoPass : public GraphOptimizationPass {
|
||||||
|
public:
|
||||||
|
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||||
|
};
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
|
104
tensorflow/compiler/jit/xla_activity.proto
Normal file
104
tensorflow/compiler/jit/xla_activity.proto
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package tensorflow;
|
||||||
|
|
||||||
|
import "tensorflow/core/protobuf/config.proto";
|
||||||
|
|
||||||
|
// Summarizes the results of auto-clustering a TensorFlow graph.
|
||||||
|
//
|
||||||
|
// Next ID: 5
|
||||||
|
message XlaAutoClusteringSummary {
|
||||||
|
// Represents a single element in a histogram of ops ("op" as in "TensorFlow
|
||||||
|
// operation").
|
||||||
|
//
|
||||||
|
// Next ID: 3
|
||||||
|
message OpAndCount {
|
||||||
|
// The TensorFlow operation (like MatMult, Add etc.)
|
||||||
|
string op = 1;
|
||||||
|
|
||||||
|
// The number of times this occurs.
|
||||||
|
int32 count = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes a single XLA cluster.
|
||||||
|
//
|
||||||
|
// Next ID: 4
|
||||||
|
message Cluster {
|
||||||
|
string name = 1;
|
||||||
|
|
||||||
|
// The number of nodes in the cluster.
|
||||||
|
int32 size = 2;
|
||||||
|
|
||||||
|
// A histogram of the TF operations in this cluster.
|
||||||
|
repeated OpAndCount op_histogram = 3;
|
||||||
|
};
|
||||||
|
|
||||||
|
// The number of nodes in the graph that are not inside an XLA cluster.
|
||||||
|
int32 unclustered_node_count = 1;
|
||||||
|
|
||||||
|
// The number of nodes in the graph that are in an XLA cluster.
|
||||||
|
int32 clustered_node_count = 2;
|
||||||
|
|
||||||
|
// All of the XLA clusters in the TF graph.
|
||||||
|
repeated Cluster clusters = 3;
|
||||||
|
|
||||||
|
// A histogram of the TF operations that were not clustered.
|
||||||
|
repeated OpAndCount unclustered_op_histogram = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listeners listening for auto clustering events get messages of this type.
|
||||||
|
//
|
||||||
|
// Next ID: 5
|
||||||
|
message XlaAutoClusteringActivity {
|
||||||
|
// xla_activity_listener set this to the global process id, as decided by the
|
||||||
|
// callback registered via `SetGlobalProcessIdMaker.
|
||||||
|
string global_process_id = 1;
|
||||||
|
|
||||||
|
// The value of GlobalJitLevel, as determined by `GetGlobalJitLevelForGraph`.
|
||||||
|
// This determines if global auto-clustering is enabled.
|
||||||
|
OptimizerOptions.GlobalJitLevel global_jit_level = 2;
|
||||||
|
|
||||||
|
// Whether --tf_xla_cpu_global_jit is enabled in TF_XLA_FLAGS.
|
||||||
|
bool cpu_global_jit_enabled = 3;
|
||||||
|
|
||||||
|
XlaAutoClusteringSummary summary = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listeners listening for JIT compilation events get messages of this type.
|
||||||
|
// Each instance of XlaJitCompilationActivity corresponds to a single
|
||||||
|
// compilation of a single XLA cluster. E.g. if a graph has two clusters, A and
|
||||||
|
// B, and A is compiled 5 times and B is compiled 2 times then we will generate
|
||||||
|
// 7 instances of XlaJitCompilationActivity.
|
||||||
|
//
|
||||||
|
// Next ID: 6
|
||||||
|
message XlaJitCompilationActivity {
|
||||||
|
// xla_activity_listener set this to the global process id, as decided by the
|
||||||
|
// callback registered via `SetGlobalProcessIdMaker.
|
||||||
|
string global_process_id = 1;
|
||||||
|
|
||||||
|
string cluster_name = 2;
|
||||||
|
|
||||||
|
// The number of time this cluster has been compiled.
|
||||||
|
int32 compile_count = 3;
|
||||||
|
|
||||||
|
// Microseconds spent in the individual compilation being reported.
|
||||||
|
int64 compile_time_us = 4;
|
||||||
|
|
||||||
|
// Total microseconds spent in (re-)compiling this cluster so far.
|
||||||
|
int64 cumulative_compile_time_us = 5;
|
||||||
|
}
|
132
tensorflow/compiler/jit/xla_activity_listener.cc
Normal file
132
tensorflow/compiler/jit/xla_activity_listener.cc
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
// The list of all registered `XlaActivityListener`s.
|
||||||
|
struct XlaActivityListenerList {
|
||||||
|
absl::Mutex mutex;
|
||||||
|
std::vector<std::unique_ptr<XlaActivityListener>> listeners GUARDED_BY(mutex);
|
||||||
|
};
|
||||||
|
|
||||||
|
XlaActivityListenerList* GetXlaActivityListenerList() {
|
||||||
|
static XlaActivityListenerList* listener_list = new XlaActivityListenerList;
|
||||||
|
return listener_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FnTy>
|
||||||
|
Status ForEachListener(FnTy fn) {
|
||||||
|
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
|
||||||
|
absl::ReaderMutexLock reader_lock(&listener_list->mutex);
|
||||||
|
|
||||||
|
for (const std::unique_ptr<XlaActivityListener>& listener :
|
||||||
|
listener_list->listeners) {
|
||||||
|
TF_RETURN_IF_ERROR(fn(listener.get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GlobalProcessIdMakerStorage {
|
||||||
|
absl::Mutex mutex;
|
||||||
|
GlobalProcessIdMaker global_process_id_maker GUARDED_BY(mutex);
|
||||||
|
|
||||||
|
// True if we have used the process ID generated by `global_process_id_maker`.
|
||||||
|
// We disallow setting the global process id maker once we have broadcasted
|
||||||
|
// messages with global_process_id set to "unknown".
|
||||||
|
bool has_been_used GUARDED_BY(mutex) = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
GlobalProcessIdMakerStorage* GetGlobalProcessIdMakerStorage() {
|
||||||
|
static GlobalProcessIdMakerStorage* global_process_id_maker_storage =
|
||||||
|
new GlobalProcessIdMakerStorage;
|
||||||
|
return global_process_id_maker_storage;
|
||||||
|
}
|
||||||
|
|
||||||
|
GlobalProcessIdMaker GetGlobalProcessIdMaker() {
|
||||||
|
GlobalProcessIdMakerStorage* global_process_id_maker_storage =
|
||||||
|
GetGlobalProcessIdMakerStorage();
|
||||||
|
{
|
||||||
|
absl::ReaderMutexLock reader_lock(&global_process_id_maker_storage->mutex);
|
||||||
|
if (global_process_id_maker_storage->has_been_used) {
|
||||||
|
return global_process_id_maker_storage->global_process_id_maker;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
absl::WriterMutexLock writer_lock(&global_process_id_maker_storage->mutex);
|
||||||
|
global_process_id_maker_storage->has_been_used = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
absl::ReaderMutexLock reader_lock(&global_process_id_maker_storage->mutex);
|
||||||
|
return global_process_id_maker_storage->global_process_id_maker;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::string_view GetGlobalProcessId() {
|
||||||
|
static std::string* cached_process_id = [&] {
|
||||||
|
std::string* result = new std::string;
|
||||||
|
GlobalProcessIdMaker maker = GetGlobalProcessIdMaker();
|
||||||
|
*result = maker ? maker() : "unknown";
|
||||||
|
return result;
|
||||||
|
}();
|
||||||
|
return *cached_process_id;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status BroadcastXlaActivity(
|
||||||
|
XlaAutoClusteringActivity auto_clustering_activity) {
|
||||||
|
auto_clustering_activity.set_global_process_id(
|
||||||
|
std::string(GetGlobalProcessId()));
|
||||||
|
return ForEachListener([&](XlaActivityListener* listener) {
|
||||||
|
return listener->Listen(auto_clustering_activity);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BroadcastXlaActivity(
|
||||||
|
XlaJitCompilationActivity jit_compilation_activity) {
|
||||||
|
jit_compilation_activity.set_global_process_id(
|
||||||
|
std::string(GetGlobalProcessId()));
|
||||||
|
return ForEachListener([&](XlaActivityListener* listener) {
|
||||||
|
return listener->Listen(jit_compilation_activity);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegisterXlaActivityListener(
|
||||||
|
std::unique_ptr<XlaActivityListener> listener) {
|
||||||
|
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
|
||||||
|
absl::WriterMutexLock writer_lock(&listener_list->mutex);
|
||||||
|
|
||||||
|
listener_list->listeners.push_back(std::move(listener));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetGlobalProcessIdMaker(GlobalProcessIdMaker global_process_id_maker) {
|
||||||
|
GlobalProcessIdMakerStorage* global_process_id_maker_storage =
|
||||||
|
GetGlobalProcessIdMakerStorage();
|
||||||
|
absl::WriterMutexLock writer_lock(&global_process_id_maker_storage->mutex);
|
||||||
|
CHECK(!global_process_id_maker_storage->has_been_used);
|
||||||
|
global_process_id_maker_storage->global_process_id_maker =
|
||||||
|
std::move(global_process_id_maker);
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaActivityListener::~XlaActivityListener() {}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
72
tensorflow/compiler/jit/xla_activity_listener.h
Normal file
72
tensorflow/compiler/jit/xla_activity_listener.h
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
|
||||||
|
#define TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
// Broadcast `auto_clustering_activity` to all the registered listeners.
|
||||||
|
Status BroadcastXlaActivity(XlaAutoClusteringActivity auto_clustering_activity);
|
||||||
|
|
||||||
|
// Broadcast `jit_compilation_activity` to all the registered listeners.
|
||||||
|
Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity);
|
||||||
|
|
||||||
|
// Various components of the system can subclass XlaActivityListener to
|
||||||
|
// notifications on auto-clustering and JIT compilation events.
|
||||||
|
//
|
||||||
|
// Subclasses of XlaActivityListener must be thread safe.
|
||||||
|
class XlaActivityListener {
|
||||||
|
public:
|
||||||
|
// Called after TensorFlow auto-clusters a graph.
|
||||||
|
virtual Status Listen(
|
||||||
|
const XlaAutoClusteringActivity& auto_clustering_activity) = 0;
|
||||||
|
|
||||||
|
// Called after TensorFlow JIT compiles an XLA cluster.
|
||||||
|
virtual Status Listen(
|
||||||
|
const XlaJitCompilationActivity& jit_compilation_activity) = 0;
|
||||||
|
|
||||||
|
virtual ~XlaActivityListener();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Registers an `XlaActivityListener`, which will be invoked on all subsequent
|
||||||
|
// `BroadcastXlaActivity` calls.
|
||||||
|
void RegisterXlaActivityListener(std::unique_ptr<XlaActivityListener> listener);
|
||||||
|
|
||||||
|
using GlobalProcessIdMaker = std::function<std::string()>;
|
||||||
|
|
||||||
|
// Installs `global_process_id_maker` as a "global process id" maker.
|
||||||
|
//
|
||||||
|
// The value returned by the global process ID maker, if one is installed, is
|
||||||
|
// stored in the global_process_id field of the Xla*Activity messages before
|
||||||
|
// they're fed to the registered activity listeners. If no ID maker is
|
||||||
|
// installed then global_process_id is set to "unknown".
|
||||||
|
//
|
||||||
|
// `global_process_id_maker` must be thread safe.
|
||||||
|
//
|
||||||
|
// The global process id maker is used to tag *Activity messages to so that the
|
||||||
|
// broadcasting process can be uniquely identified. Therefore the global
|
||||||
|
// process id maker
|
||||||
|
//
|
||||||
|
// - Must always return the same value within the same process.
|
||||||
|
// - Cannot be set or changed after we have broadcasted any XLA activity.
|
||||||
|
void SetGlobalProcessIdMaker(GlobalProcessIdMaker global_process_id_maker);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
|
195
tensorflow/compiler/jit/xla_activity_listener_test.cc
Normal file
195
tensorflow/compiler/jit/xla_activity_listener_test.cc
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
|
#include "tensorflow/cc/ops/array_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/list_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
|
#include "tensorflow/core/common_runtime/direct_session.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class TestListener : public XlaActivityListener {
|
||||||
|
public:
|
||||||
|
Status Listen(
|
||||||
|
const XlaAutoClusteringActivity& auto_clustering_activity) override {
|
||||||
|
auto_clustering_activity_ = auto_clustering_activity;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Listen(
|
||||||
|
const XlaJitCompilationActivity& jit_compilation_activity) override {
|
||||||
|
jit_compilation_activity_ = jit_compilation_activity;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
~TestListener() override {}
|
||||||
|
|
||||||
|
const XlaAutoClusteringActivity& auto_clustering_activity() const {
|
||||||
|
return auto_clustering_activity_;
|
||||||
|
}
|
||||||
|
const XlaJitCompilationActivity& jit_compilation_activity() const {
|
||||||
|
return jit_compilation_activity_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
XlaAutoClusteringActivity auto_clustering_activity_;
|
||||||
|
XlaJitCompilationActivity jit_compilation_activity_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class XlaActivityListenerTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
XlaActivityListenerTest() {
|
||||||
|
auto listener = absl::make_unique<TestListener>();
|
||||||
|
listener_ = listener.get();
|
||||||
|
RegisterXlaActivityListener(std::move(listener));
|
||||||
|
SetGlobalProcessIdMaker([]() { return "42-xyz"; });
|
||||||
|
}
|
||||||
|
|
||||||
|
TestListener* listener() const { return listener_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
TestListener* listener_;
|
||||||
|
};
|
||||||
|
|
||||||
|
GraphDef CreateGraphDef() {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError().WithAssignedDevice(
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||||
|
Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
a = ops::MatMul(root.WithOpName(absl::StrCat("matmul_", i)), a, a);
|
||||||
|
a = ops::Add(root.WithOpName(absl::StrCat("add_", i)), a, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
root.graph()->ToGraphDef(&graph_def);
|
||||||
|
return graph_def;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaActivityListenerTest, Test) {
|
||||||
|
GraphDef graph_def = CreateGraphDef();
|
||||||
|
SessionOptions options;
|
||||||
|
options.config.mutable_graph_options()
|
||||||
|
->mutable_optimizer_options()
|
||||||
|
->set_global_jit_level(OptimizerOptions::ON_2);
|
||||||
|
std::unique_ptr<Session> session(NewSession(options));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(session->Create(graph_def));
|
||||||
|
|
||||||
|
std::vector<std::string> output_names = {std::string("add_4:0")};
|
||||||
|
|
||||||
|
Tensor tensor_2x2(DT_FLOAT, TensorShape({2, 2}));
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
tensor_2x2.matrix<float>()(i / 2, i % 2) = 5 * i;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor tensor_3x3(DT_FLOAT, TensorShape({3, 3}));
|
||||||
|
for (int i = 0; i < 9; i++) {
|
||||||
|
tensor_3x3.matrix<float>()(i / 3, i % 3) = 5 * i;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<string, Tensor>> inputs_2x2 = {{"A", tensor_2x2}};
|
||||||
|
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
TF_ASSERT_OK(session->Run(inputs_2x2, output_names, /*target_node_names=*/{},
|
||||||
|
&outputs));
|
||||||
|
|
||||||
|
absl::string_view expected_auto_clustering_activity =
|
||||||
|
R"(global_process_id: "42-xyz"
|
||||||
|
global_jit_level: ON_2
|
||||||
|
cpu_global_jit_enabled: true
|
||||||
|
summary {
|
||||||
|
unclustered_node_count: 4
|
||||||
|
clustered_node_count: 14
|
||||||
|
clusters {
|
||||||
|
name: "cluster_0"
|
||||||
|
size: 14
|
||||||
|
op_histogram {
|
||||||
|
op: "Add"
|
||||||
|
count: 1
|
||||||
|
}
|
||||||
|
op_histogram {
|
||||||
|
op: "Const"
|
||||||
|
count: 4
|
||||||
|
}
|
||||||
|
op_histogram {
|
||||||
|
op: "MatMul"
|
||||||
|
count: 5
|
||||||
|
}
|
||||||
|
op_histogram {
|
||||||
|
op: "Mul"
|
||||||
|
count: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unclustered_op_histogram {
|
||||||
|
op: "NoOp"
|
||||||
|
count: 2
|
||||||
|
}
|
||||||
|
unclustered_op_histogram {
|
||||||
|
op: "_Arg"
|
||||||
|
count: 1
|
||||||
|
}
|
||||||
|
unclustered_op_histogram {
|
||||||
|
op: "_Retval"
|
||||||
|
count: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
EXPECT_EQ(listener()->auto_clustering_activity().DebugString(),
|
||||||
|
expected_auto_clustering_activity);
|
||||||
|
|
||||||
|
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
|
||||||
|
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 1);
|
||||||
|
|
||||||
|
int64 first_compile_time =
|
||||||
|
listener()->jit_compilation_activity().compile_time_us();
|
||||||
|
EXPECT_GT(first_compile_time, 0);
|
||||||
|
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
|
||||||
|
first_compile_time);
|
||||||
|
|
||||||
|
std::vector<std::pair<string, Tensor>> inputs_3x3 = {{"A", tensor_3x3}};
|
||||||
|
|
||||||
|
outputs.clear();
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
TF_ASSERT_OK(session->Run(inputs_3x3, output_names,
|
||||||
|
/*target_node_names=*/{}, &outputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
|
||||||
|
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 2);
|
||||||
|
|
||||||
|
EXPECT_GT(listener()->jit_compilation_activity().compile_time_us(), 0);
|
||||||
|
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
|
||||||
|
first_compile_time +
|
||||||
|
listener()->jit_compilation_activity().compile_time_us());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
tensorflow::GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit = true;
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
@ -318,4 +318,72 @@ bool IsShapeConsumerOp(const Node& node) {
|
|||||||
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
|
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
|
||||||
node.type_string() == "Size";
|
node.type_string() == "Size";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct ClusterInfo {
|
||||||
|
int size;
|
||||||
|
|
||||||
|
// Maps op names to the number of times they appear in the cluster.
|
||||||
|
absl::flat_hash_map<absl::string_view, int> op_histogram;
|
||||||
|
};
|
||||||
|
|
||||||
|
void HistogramMapToRepeatedOpAndCount(
|
||||||
|
protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount>* result,
|
||||||
|
const absl::flat_hash_map<absl::string_view, int>& histogram) {
|
||||||
|
for (const auto& pair : histogram) {
|
||||||
|
XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add();
|
||||||
|
new_entry->set_op(std::string(pair.first));
|
||||||
|
new_entry->set_count(pair.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a,
|
||||||
|
const XlaAutoClusteringSummary::OpAndCount& b) {
|
||||||
|
return a.op() < b.op();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result,
|
||||||
|
absl::string_view name, const ClusterInfo& info) {
|
||||||
|
result->set_name(std::string(name));
|
||||||
|
result->set_size(info.size);
|
||||||
|
HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(),
|
||||||
|
info.op_histogram);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
|
||||||
|
absl::flat_hash_map<absl::string_view, ClusterInfo> cluster_name_to_info;
|
||||||
|
XlaAutoClusteringSummary result;
|
||||||
|
|
||||||
|
absl::flat_hash_map<absl::string_view, int> unclustered_op_histogram;
|
||||||
|
|
||||||
|
for (Node* n : graph.nodes()) {
|
||||||
|
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
|
||||||
|
if (cluster_name) {
|
||||||
|
result.set_clustered_node_count(result.clustered_node_count() + 1);
|
||||||
|
ClusterInfo* info = &cluster_name_to_info[*cluster_name];
|
||||||
|
info->size++;
|
||||||
|
info->op_histogram[n->type_string()]++;
|
||||||
|
} else {
|
||||||
|
result.set_unclustered_node_count(result.unclustered_node_count() + 1);
|
||||||
|
unclustered_op_histogram[n->type_string()]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& pair : cluster_name_to_info) {
|
||||||
|
XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters();
|
||||||
|
ClusterInfoToProtobuf(new_cluster, pair.first, pair.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::c_sort(*result.mutable_clusters(),
|
||||||
|
[&](const XlaAutoClusteringSummary::Cluster& a,
|
||||||
|
const XlaAutoClusteringSummary::Cluster& b) {
|
||||||
|
return a.name() < b.name();
|
||||||
|
});
|
||||||
|
|
||||||
|
HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(),
|
||||||
|
unclustered_op_histogram);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||||
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
@ -87,6 +89,11 @@ bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
|
|||||||
// Returns true if `node` an operator that consumes only the shape of its input,
|
// Returns true if `node` an operator that consumes only the shape of its input,
|
||||||
// not the data itself.
|
// not the data itself.
|
||||||
bool IsShapeConsumerOp(const Node& node);
|
bool IsShapeConsumerOp(const Node& node);
|
||||||
|
|
||||||
|
// Computes a clustering summary for `graph`. See documentation on
|
||||||
|
// `XlaAutoClusteringSummary` for details.
|
||||||
|
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
@ -361,6 +362,16 @@ Status XlaCompilationCache::CompileImpl(
|
|||||||
<< tensorflow::strings::HumanReadableElapsedTime(
|
<< tensorflow::strings::HumanReadableElapsedTime(
|
||||||
it->second.cumulative_compile_time_us / 1.0e6)
|
it->second.cumulative_compile_time_us / 1.0e6)
|
||||||
<< ")";
|
<< ")";
|
||||||
|
|
||||||
|
XlaJitCompilationActivity jit_compilation_activity;
|
||||||
|
jit_compilation_activity.set_cluster_name(function.name());
|
||||||
|
jit_compilation_activity.set_compile_count(it->second.compile_count);
|
||||||
|
jit_compilation_activity.set_compile_time_us(compile_time_us);
|
||||||
|
jit_compilation_activity.set_cumulative_compile_time_us(
|
||||||
|
it->second.cumulative_compile_time_us);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
BroadcastXlaActivity(std::move(jit_compilation_activity)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(entry->compilation_status);
|
TF_RETURN_IF_ERROR(entry->compilation_status);
|
||||||
|
Loading…
Reference in New Issue
Block a user