diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5902d7f659c..3a7c774f3fc 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,6 +1,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library") package( default_visibility = [ @@ -281,6 +282,7 @@ cc_library( srcs = ["xla_compilation_cache.cc"], hdrs = ["xla_compilation_cache.h"], deps = [ + ":xla_activity_listener", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", @@ -507,6 +509,7 @@ cc_library( "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", + "report_clustering_info_pass.cc", ], hdrs = [ "build_xla_ops_pass.h", @@ -520,6 +523,7 @@ cc_library( "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", + "report_clustering_info_pass.h", ], deps = [ "compilability_check_util", @@ -530,6 +534,7 @@ cc_library( ":resource_operation_safety_analysis", ":shape_inference_helpers", ":union_find", + ":xla_activity_listener", ":xla_cluster_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:functional_ops", @@ -572,6 +577,7 @@ cc_library( hdrs = ["xla_cluster_util.h"], deps = [ ":flags", + ":xla_activity_proto_cc", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -838,6 +844,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "xla_activity_listener_test", + srcs = ["xla_activity_listener_test.cc"], + deps = [ + ":flags", + ":xla_activity_listener", + ":xla_cpu_device", + ":xla_cpu_jit", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:partitioned_function_ops", + ], +) + tf_custom_op_py_library( name = "xla_ops_py", kernels = ["//tensorflow/compiler/jit/ops:xla_ops"], @@ -850,6 +877,26 @@ tf_custom_op_py_library( ], ) +cc_library( + name = "xla_activity_listener", + srcs = ["xla_activity_listener.cc"], + hdrs = ["xla_activity_listener.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_activity_proto_cc", + "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", + ], +) + +tf_proto_library( + name = "xla_activity_proto", + srcs = ["xla_activity.proto"], + cc_api_version = 2, + protodeps = tf_additional_all_protos(), + provide_cc_alias = True, +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 33f92ea391a..127f0d4a82e 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "tensorflow/compiler/jit/report_clustering_info_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { @@ -58,15 +59,22 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, PartiallyDeclusterPass); +// ReportClusteringInfoPass pass needs to run after all of the auto-clustering +// passes have run but before encapsulation has run. This way it can easily +// compute a summary of the clustering decisions we made and broadcast it via +// xla_activity_listener. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, + ReportClusteringInfoPass); + // The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We // also need to run it after the graph been rewritten to have _Send nodes added // for fetches. Before the _Send nodes are added, fetch nodes are identified by // name, and encapsulation might remove that node from the graph. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50, EncapsulateSubgraphsPass); // Must run after EncapsulateSubgraphsPass. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 60, BuildXlaOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 374615a5909..4a853c6cafc 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1364,46 +1364,36 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() { return; } - std::map cluster_name_to_size; - std::map> - cluster_name_to_op_histogram; - std::map unclustered_op_histogram; - int clustered_node_count = 0; - - for (Node* n : graph_->nodes()) { - absl::optional cluster_name = GetXlaClusterForNode(*n); - if (cluster_name) { - clustered_node_count++; - cluster_name_to_size[*cluster_name]++; - cluster_name_to_op_histogram[*cluster_name][n->type_string()]++; - } else { - unclustered_op_histogram[n->type_string()]++; - } - } - - int unclustered_node_count = graph_->num_nodes() - clustered_node_count; + XlaAutoClusteringSummary auto_clustering_info = + GetXlaAutoClusteringSummary(*graph_); VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes(); - VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size " - << RatioToString(clustered_node_count, graph_->num_nodes()); + VLOG(2) << " Built " << auto_clustering_info.clusters_size() + << " clusters, size " + << RatioToString(auto_clustering_info.clustered_node_count(), + graph_->num_nodes()); - for (const auto& cluster_name_size_pair : cluster_name_to_size) { - absl::string_view cluster_name = cluster_name_size_pair.first; - int size = cluster_name_size_pair.second; + for (XlaAutoClusteringSummary::Cluster cluster : + auto_clustering_info.clusters()) { + absl::string_view cluster_name = cluster.name(); + int size = cluster.size(); VLOG(2) << " " << cluster_name << " " << RatioToString(size, graph_->num_nodes()); - for (const auto& op_count_pair : - cluster_name_to_op_histogram[cluster_name]) { - VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second + for (const XlaAutoClusteringSummary::OpAndCount& op_count : + cluster.op_histogram()) { + VLOG(3) << " " << op_count.op() << ": " << op_count.count() << " instances"; } } - if (!unclustered_op_histogram.empty()) { + if (!auto_clustering_info.unclustered_op_histogram().empty()) { VLOG(2) << " Unclustered nodes: " - << RatioToString(unclustered_node_count, graph_->num_nodes()); - for (const auto& pair : unclustered_op_histogram) { - VLOG(3) << " " << pair.first << ": " << pair.second << " instances"; + << RatioToString(auto_clustering_info.unclustered_node_count(), + graph_->num_nodes()); + for (const XlaAutoClusteringSummary::OpAndCount& op_count : + auto_clustering_info.unclustered_op_histogram()) { + VLOG(3) << " " << op_count.op() << ": " << op_count.count() + << " instances"; } } diff --git a/tensorflow/compiler/jit/report_clustering_info_pass.cc b/tensorflow/compiler/jit/report_clustering_info_pass.cc new file mode 100644 index 00000000000..b2b71b47c79 --- /dev/null +++ b/tensorflow/compiler/jit/report_clustering_info_pass.cc @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/report_clustering_info_pass.h" + +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/xla_activity_listener.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +namespace tensorflow { +Status ReportClusteringInfoPass::Run( + const GraphOptimizationPassOptions& options) { + XlaAutoClusteringActivity activity; + *activity.mutable_summary() = GetXlaAutoClusteringSummary(**options.graph); + activity.set_global_jit_level(GetGlobalJitLevelForGraph(options)); + activity.set_cpu_global_jit_enabled( + GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit); + return BroadcastXlaActivity(std::move(activity)); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/report_clustering_info_pass.h b/tensorflow/compiler/jit/report_clustering_info_pass.h new file mode 100644 index 00000000000..97471cff134 --- /dev/null +++ b/tensorflow/compiler/jit/report_clustering_info_pass.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// This is not really an optimization pass. It does not change the graph in any +// way; instead it computes a summary of the XLA clusters in the graph and +// broadcasts it via xla_activity_listener. +class ReportClusteringInfoPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_ diff --git a/tensorflow/compiler/jit/xla_activity.proto b/tensorflow/compiler/jit/xla_activity.proto new file mode 100644 index 00000000000..2d78a266f56 --- /dev/null +++ b/tensorflow/compiler/jit/xla_activity.proto @@ -0,0 +1,104 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; + +import "tensorflow/core/protobuf/config.proto"; + +// Summarizes the results of auto-clustering a TensorFlow graph. +// +// Next ID: 5 +message XlaAutoClusteringSummary { + // Represents a single element in a histogram of ops ("op" as in "TensorFlow + // operation"). + // + // Next ID: 3 + message OpAndCount { + // The TensorFlow operation (like MatMult, Add etc.) + string op = 1; + + // The number of times this occurs. + int32 count = 2; + } + + // Describes a single XLA cluster. + // + // Next ID: 4 + message Cluster { + string name = 1; + + // The number of nodes in the cluster. + int32 size = 2; + + // A histogram of the TF operations in this cluster. + repeated OpAndCount op_histogram = 3; + }; + + // The number of nodes in the graph that are not inside an XLA cluster. + int32 unclustered_node_count = 1; + + // The number of nodes in the graph that are in an XLA cluster. + int32 clustered_node_count = 2; + + // All of the XLA clusters in the TF graph. + repeated Cluster clusters = 3; + + // A histogram of the TF operations that were not clustered. + repeated OpAndCount unclustered_op_histogram = 4; +} + +// Listeners listening for auto clustering events get messages of this type. +// +// Next ID: 5 +message XlaAutoClusteringActivity { + // xla_activity_listener set this to the global process id, as decided by the + // callback registered via `SetGlobalProcessIdMaker. + string global_process_id = 1; + + // The value of GlobalJitLevel, as determined by `GetGlobalJitLevelForGraph`. + // This determines if global auto-clustering is enabled. + OptimizerOptions.GlobalJitLevel global_jit_level = 2; + + // Whether --tf_xla_cpu_global_jit is enabled in TF_XLA_FLAGS. + bool cpu_global_jit_enabled = 3; + + XlaAutoClusteringSummary summary = 4; +} + +// Listeners listening for JIT compilation events get messages of this type. +// Each instance of XlaJitCompilationActivity corresponds to a single +// compilation of a single XLA cluster. E.g. if a graph has two clusters, A and +// B, and A is compiled 5 times and B is compiled 2 times then we will generate +// 7 instances of XlaJitCompilationActivity. +// +// Next ID: 6 +message XlaJitCompilationActivity { + // xla_activity_listener set this to the global process id, as decided by the + // callback registered via `SetGlobalProcessIdMaker. + string global_process_id = 1; + + string cluster_name = 2; + + // The number of time this cluster has been compiled. + int32 compile_count = 3; + + // Microseconds spent in the individual compilation being reported. + int64 compile_time_us = 4; + + // Total microseconds spent in (re-)compiling this cluster so far. + int64 cumulative_compile_time_us = 5; +} diff --git a/tensorflow/compiler/jit/xla_activity_listener.cc b/tensorflow/compiler/jit/xla_activity_listener.cc new file mode 100644 index 00000000000..f0dac0c806d --- /dev/null +++ b/tensorflow/compiler/jit/xla_activity_listener.cc @@ -0,0 +1,132 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_activity_listener.h" + +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { +// The list of all registered `XlaActivityListener`s. +struct XlaActivityListenerList { + absl::Mutex mutex; + std::vector> listeners GUARDED_BY(mutex); +}; + +XlaActivityListenerList* GetXlaActivityListenerList() { + static XlaActivityListenerList* listener_list = new XlaActivityListenerList; + return listener_list; +} + +template +Status ForEachListener(FnTy fn) { + XlaActivityListenerList* listener_list = GetXlaActivityListenerList(); + absl::ReaderMutexLock reader_lock(&listener_list->mutex); + + for (const std::unique_ptr& listener : + listener_list->listeners) { + TF_RETURN_IF_ERROR(fn(listener.get())); + } + + return Status::OK(); +} + +struct GlobalProcessIdMakerStorage { + absl::Mutex mutex; + GlobalProcessIdMaker global_process_id_maker GUARDED_BY(mutex); + + // True if we have used the process ID generated by `global_process_id_maker`. + // We disallow setting the global process id maker once we have broadcasted + // messages with global_process_id set to "unknown". + bool has_been_used GUARDED_BY(mutex) = false; +}; + +GlobalProcessIdMakerStorage* GetGlobalProcessIdMakerStorage() { + static GlobalProcessIdMakerStorage* global_process_id_maker_storage = + new GlobalProcessIdMakerStorage; + return global_process_id_maker_storage; +} + +GlobalProcessIdMaker GetGlobalProcessIdMaker() { + GlobalProcessIdMakerStorage* global_process_id_maker_storage = + GetGlobalProcessIdMakerStorage(); + { + absl::ReaderMutexLock reader_lock(&global_process_id_maker_storage->mutex); + if (global_process_id_maker_storage->has_been_used) { + return global_process_id_maker_storage->global_process_id_maker; + } + } + + { + absl::WriterMutexLock writer_lock(&global_process_id_maker_storage->mutex); + global_process_id_maker_storage->has_been_used = true; + } + + { + absl::ReaderMutexLock reader_lock(&global_process_id_maker_storage->mutex); + return global_process_id_maker_storage->global_process_id_maker; + } +} + +absl::string_view GetGlobalProcessId() { + static std::string* cached_process_id = [&] { + std::string* result = new std::string; + GlobalProcessIdMaker maker = GetGlobalProcessIdMaker(); + *result = maker ? maker() : "unknown"; + return result; + }(); + return *cached_process_id; +} +} // namespace + +Status BroadcastXlaActivity( + XlaAutoClusteringActivity auto_clustering_activity) { + auto_clustering_activity.set_global_process_id( + std::string(GetGlobalProcessId())); + return ForEachListener([&](XlaActivityListener* listener) { + return listener->Listen(auto_clustering_activity); + }); +} + +Status BroadcastXlaActivity( + XlaJitCompilationActivity jit_compilation_activity) { + jit_compilation_activity.set_global_process_id( + std::string(GetGlobalProcessId())); + return ForEachListener([&](XlaActivityListener* listener) { + return listener->Listen(jit_compilation_activity); + }); +} + +void RegisterXlaActivityListener( + std::unique_ptr listener) { + XlaActivityListenerList* listener_list = GetXlaActivityListenerList(); + absl::WriterMutexLock writer_lock(&listener_list->mutex); + + listener_list->listeners.push_back(std::move(listener)); +} + +void SetGlobalProcessIdMaker(GlobalProcessIdMaker global_process_id_maker) { + GlobalProcessIdMakerStorage* global_process_id_maker_storage = + GetGlobalProcessIdMakerStorage(); + absl::WriterMutexLock writer_lock(&global_process_id_maker_storage->mutex); + CHECK(!global_process_id_maker_storage->has_been_used); + global_process_id_maker_storage->global_process_id_maker = + std::move(global_process_id_maker); +} + +XlaActivityListener::~XlaActivityListener() {} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_activity_listener.h b/tensorflow/compiler/jit/xla_activity_listener.h new file mode 100644 index 00000000000..0c0082ab523 --- /dev/null +++ b/tensorflow/compiler/jit/xla_activity_listener.h @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_ + +#include + +#include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +// Broadcast `auto_clustering_activity` to all the registered listeners. +Status BroadcastXlaActivity(XlaAutoClusteringActivity auto_clustering_activity); + +// Broadcast `jit_compilation_activity` to all the registered listeners. +Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity); + +// Various components of the system can subclass XlaActivityListener to +// notifications on auto-clustering and JIT compilation events. +// +// Subclasses of XlaActivityListener must be thread safe. +class XlaActivityListener { + public: + // Called after TensorFlow auto-clusters a graph. + virtual Status Listen( + const XlaAutoClusteringActivity& auto_clustering_activity) = 0; + + // Called after TensorFlow JIT compiles an XLA cluster. + virtual Status Listen( + const XlaJitCompilationActivity& jit_compilation_activity) = 0; + + virtual ~XlaActivityListener(); +}; + +// Registers an `XlaActivityListener`, which will be invoked on all subsequent +// `BroadcastXlaActivity` calls. +void RegisterXlaActivityListener(std::unique_ptr listener); + +using GlobalProcessIdMaker = std::function; + +// Installs `global_process_id_maker` as a "global process id" maker. +// +// The value returned by the global process ID maker, if one is installed, is +// stored in the global_process_id field of the Xla*Activity messages before +// they're fed to the registered activity listeners. If no ID maker is +// installed then global_process_id is set to "unknown". +// +// `global_process_id_maker` must be thread safe. +// +// The global process id maker is used to tag *Activity messages to so that the +// broadcasting process can be uniquely identified. Therefore the global +// process id maker +// +// - Must always return the same value within the same process. +// - Cannot be set or changed after we have broadcasted any XLA activity. +void SetGlobalProcessIdMaker(GlobalProcessIdMaker global_process_id_maker); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_ diff --git a/tensorflow/compiler/jit/xla_activity_listener_test.cc b/tensorflow/compiler/jit/xla_activity_listener_test.cc new file mode 100644 index 00000000000..624299f5341 --- /dev/null +++ b/tensorflow/compiler/jit/xla_activity_listener_test.cc @@ -0,0 +1,195 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_activity_listener.h" + +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/list_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/core/common_runtime/direct_session.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class TestListener : public XlaActivityListener { + public: + Status Listen( + const XlaAutoClusteringActivity& auto_clustering_activity) override { + auto_clustering_activity_ = auto_clustering_activity; + return Status::OK(); + } + + Status Listen( + const XlaJitCompilationActivity& jit_compilation_activity) override { + jit_compilation_activity_ = jit_compilation_activity; + return Status::OK(); + } + + ~TestListener() override {} + + const XlaAutoClusteringActivity& auto_clustering_activity() const { + return auto_clustering_activity_; + } + const XlaJitCompilationActivity& jit_compilation_activity() const { + return jit_compilation_activity_; + } + + private: + XlaAutoClusteringActivity auto_clustering_activity_; + XlaJitCompilationActivity jit_compilation_activity_; +}; + +class XlaActivityListenerTest : public ::testing::Test { + protected: + XlaActivityListenerTest() { + auto listener = absl::make_unique(); + listener_ = listener.get(); + RegisterXlaActivityListener(std::move(listener)); + SetGlobalProcessIdMaker([]() { return "42-xyz"; }); + } + + TestListener* listener() const { return listener_; } + + private: + TestListener* listener_; +}; + +GraphDef CreateGraphDef() { + Scope root = Scope::NewRootScope().ExitOnError().WithAssignedDevice( + "/job:localhost/replica:0/task:0/device:CPU:0"); + Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT); + for (int i = 0; i < 5; i++) { + a = ops::MatMul(root.WithOpName(absl::StrCat("matmul_", i)), a, a); + a = ops::Add(root.WithOpName(absl::StrCat("add_", i)), a, a); + } + + GraphDef graph_def; + root.graph()->ToGraphDef(&graph_def); + return graph_def; +} + +TEST_F(XlaActivityListenerTest, Test) { + GraphDef graph_def = CreateGraphDef(); + SessionOptions options; + options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(OptimizerOptions::ON_2); + std::unique_ptr session(NewSession(options)); + + TF_ASSERT_OK(session->Create(graph_def)); + + std::vector output_names = {std::string("add_4:0")}; + + Tensor tensor_2x2(DT_FLOAT, TensorShape({2, 2})); + for (int i = 0; i < 4; i++) { + tensor_2x2.matrix()(i / 2, i % 2) = 5 * i; + } + + Tensor tensor_3x3(DT_FLOAT, TensorShape({3, 3})); + for (int i = 0; i < 9; i++) { + tensor_3x3.matrix()(i / 3, i % 3) = 5 * i; + } + + std::vector> inputs_2x2 = {{"A", tensor_2x2}}; + + std::vector outputs; + TF_ASSERT_OK(session->Run(inputs_2x2, output_names, /*target_node_names=*/{}, + &outputs)); + + absl::string_view expected_auto_clustering_activity = + R"(global_process_id: "42-xyz" +global_jit_level: ON_2 +cpu_global_jit_enabled: true +summary { + unclustered_node_count: 4 + clustered_node_count: 14 + clusters { + name: "cluster_0" + size: 14 + op_histogram { + op: "Add" + count: 1 + } + op_histogram { + op: "Const" + count: 4 + } + op_histogram { + op: "MatMul" + count: 5 + } + op_histogram { + op: "Mul" + count: 4 + } + } + unclustered_op_histogram { + op: "NoOp" + count: 2 + } + unclustered_op_histogram { + op: "_Arg" + count: 1 + } + unclustered_op_histogram { + op: "_Retval" + count: 1 + } +} +)"; + + EXPECT_EQ(listener()->auto_clustering_activity().DebugString(), + expected_auto_clustering_activity); + + EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0"); + EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 1); + + int64 first_compile_time = + listener()->jit_compilation_activity().compile_time_us(); + EXPECT_GT(first_compile_time, 0); + EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(), + first_compile_time); + + std::vector> inputs_3x3 = {{"A", tensor_3x3}}; + + outputs.clear(); + for (int i = 0; i < 3; i++) { + TF_ASSERT_OK(session->Run(inputs_3x3, output_names, + /*target_node_names=*/{}, &outputs)); + } + + EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0"); + EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 2); + + EXPECT_GT(listener()->jit_compilation_activity().compile_time_us(), 0); + EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(), + first_compile_time + + listener()->jit_compilation_activity().compile_time_us()); +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { + tensorflow::GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit = true; + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 063bb9c26a3..97737c5ee9c 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -318,4 +318,72 @@ bool IsShapeConsumerOp(const Node& node) { return node.type_string() == "Shape" || node.type_string() == "Rank" || node.type_string() == "Size"; } + +namespace { +struct ClusterInfo { + int size; + + // Maps op names to the number of times they appear in the cluster. + absl::flat_hash_map op_histogram; +}; + +void HistogramMapToRepeatedOpAndCount( + protobuf::RepeatedPtrField* result, + const absl::flat_hash_map& histogram) { + for (const auto& pair : histogram) { + XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add(); + new_entry->set_op(std::string(pair.first)); + new_entry->set_count(pair.second); + } + + absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a, + const XlaAutoClusteringSummary::OpAndCount& b) { + return a.op() < b.op(); + }); +} + +void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result, + absl::string_view name, const ClusterInfo& info) { + result->set_name(std::string(name)); + result->set_size(info.size); + HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(), + info.op_histogram); +} +} // namespace + +XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) { + absl::flat_hash_map cluster_name_to_info; + XlaAutoClusteringSummary result; + + absl::flat_hash_map unclustered_op_histogram; + + for (Node* n : graph.nodes()) { + absl::optional cluster_name = GetXlaClusterForNode(*n); + if (cluster_name) { + result.set_clustered_node_count(result.clustered_node_count() + 1); + ClusterInfo* info = &cluster_name_to_info[*cluster_name]; + info->size++; + info->op_histogram[n->type_string()]++; + } else { + result.set_unclustered_node_count(result.unclustered_node_count() + 1); + unclustered_op_histogram[n->type_string()]++; + } + } + + for (const auto& pair : cluster_name_to_info) { + XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters(); + ClusterInfoToProtobuf(new_cluster, pair.first, pair.second); + } + + absl::c_sort(*result.mutable_clusters(), + [&](const XlaAutoClusteringSummary::Cluster& a, + const XlaAutoClusteringSummary::Cluster& b) { + return a.name() < b.name(); + }); + + HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(), + unclustered_op_histogram); + + return result; +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 657075caf4d..97fe80258a1 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -18,8 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" @@ -87,6 +89,11 @@ bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def); // Returns true if `node` an operator that consumes only the shape of its input, // not the data itself. bool IsShapeConsumerOp(const Node& node); + +// Computes a clustering summary for `graph`. See documentation on +// `XlaAutoClusteringSummary` for details. +XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index f53a1e5d403..035a50e1852 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -361,6 +362,16 @@ Status XlaCompilationCache::CompileImpl( << tensorflow::strings::HumanReadableElapsedTime( it->second.cumulative_compile_time_us / 1.0e6) << ")"; + + XlaJitCompilationActivity jit_compilation_activity; + jit_compilation_activity.set_cluster_name(function.name()); + jit_compilation_activity.set_compile_count(it->second.compile_count); + jit_compilation_activity.set_compile_time_us(compile_time_us); + jit_compilation_activity.set_cumulative_compile_time_us( + it->second.cumulative_compile_time_us); + + TF_RETURN_IF_ERROR( + BroadcastXlaActivity(std::move(jit_compilation_activity))); } } TF_RETURN_IF_ERROR(entry->compilation_status);