diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 4424c29e395..70e26001903 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -533,6 +533,7 @@ cc_library( "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 756377bd950..71abee245d5 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -394,11 +394,11 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) { return true; } -std::unordered_set GraphCycles::Successors(int32 node) { +std::unordered_set GraphCycles::Successors(int32 node) const { return rep_->nodes_[node]->out; } -std::unordered_set GraphCycles::Predecessors(int32 node) { +std::unordered_set GraphCycles::Predecessors(int32 node) const { return rep_->nodes_[node]->in; } diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h index 44448fa3d78..8e7801d622b 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.h +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h @@ -117,8 +117,8 @@ class GraphCycles { // Expensive: should only be called from graphcycles_test.cc. bool CheckInvariants() const; - std::unordered_set Successors(int32 node); - std::unordered_set Predecessors(int32 node); + std::unordered_set Successors(int32 node) const; + std::unordered_set Predecessors(int32 node) const; // ---------------------------------------------------- struct Rep; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 35fc2616ab1..22dc6f67a11 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/bounds_check.h" @@ -52,6 +53,8 @@ limitations under the License. namespace tensorflow { namespace { +using xla::StatusOr; + bool HasResourceOutput(const Node& node) { return absl::c_count(node.output_types(), DT_RESOURCE) != 0; } @@ -389,46 +392,230 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( return op_filter; } -// Nodes that XLA can compile are put in `candidates`. Nodes put in -// `isolated_nodes` must either be unclustered or be put in trivial single-node -// clusters. -Status FindCompilationCandidates( - const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, - const std::function& is_compilable_fn, +class MarkForCompilationPassImpl { + public: + struct DebugOptions { + // If true, do not respect the results of deadness analysis. + bool ignore_deadness_checks; + + // If true, do not respect the _XlaCompile=false attribute. + bool ignore_xla_compile_attr; + + int max_cluster_size; + int min_cluster_size; + + // Compiler fuel for the auto-clustering algorithm. + // + // We decrement this value by one on every time we choose a compilation + // candidate and we stop clustering when it hits zero. This means the + // initial value for this variable (via --tf_xla_clustering_fuel=N) + // effectively acts as a "cap" for how much we cluster and we can bisect + // over this initial value to discover clustering decisions that cause a + // miscompile or a performance regression. + std::atomic* fuel; + + bool dump_graphs; + }; + + MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph, + FunctionLibraryDefinition* flib_def, Env* env, + OptimizerOptions::GlobalJitLevel global_jit_level) + : debug_options_(debug_options), + graph_(graph), + flib_def_(flib_def), + env_(env), + global_jit_level_(global_jit_level) {} + + Status Run(); + + private: + struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; + + // The set of devices the nodes in this cluster are placed on. + absl::flat_hash_set devices; + + // If there are resource operation in the cluster then this is the device + // that resource operations are placed on. All resource operations in a + // cluster must be placed on the same device. + string resource_op_device; + + // True if any node in the cluster has an _XlaCompile attribute set to true. + bool has_xla_compile_attr; + }; + + // Nodes that XLA can compile are put in `candidates`. Nodes put in + // `isolated_nodes` must either be unclustered or be put in trivial + // single-node clusters. + Status FindCompilationCandidates(OrderedNodeSet* candidates, + absl::flat_hash_set* isolated_nodes); + + bool CompilationDisallowedByXlaCompileAttr(Node* node, + const DeviceType& jit_device_type); + + void BuildInitialClusterSet(const OrderedNodeSet& compilation_candidates, + std::vector>* clusters, + std::deque*>* worklist); + + Status ShouldCompileClusterImpl(const Cluster& cluster, bool* should_compile, + string* device); + + Status ShouldCompileCluster(const Cluster& cluster, bool* should_compile, + string* device); + + bool HasMismatchingXlaScope(Node* node_from, Node* node_to); + + StatusOr ClusteringWillIntroduceInterDeviceDependency( + int to_node_id, const OrderedNodeSet& compilation_candidates, + absl::Span> clusters, const GraphCycles& cycles); + + // Returns true if the devices in `cluster_a` and `cluster_b` are compatible + // and therefore not a hindrance for combining the two clusters into a larger + // cluster. + Status AreDevicesCompatible(const Cluster& cluster_a, + const Cluster& cluster_b, bool* result); + + void DumpPostClusteringGraphs(); + void VLogClusteringSummary(); + + DebugOptions debug_options_; + Graph* graph_; + FunctionLibraryDefinition* flib_def_; + Env* env_; + OptimizerOptions::GlobalJitLevel global_jit_level_; + absl::flat_hash_map> + should_compile_cluster_cache_; +}; + +StatusOr +MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency( + int to_node_id, const OrderedNodeSet& compilation_candidates, + absl::Span> clusters, const GraphCycles& cycles) { + const Cluster& cluster_to = clusters[to_node_id].Get(); + + // If any of the consumer's producers are on a different device, do not + // cluster these nodes. This prevents other work on this device from being + // delayed by work on other devices. We consider predecessors of the entire + // cluster rather than just the inputs to the node to prevent the cluster + // still being combined in cases where the 'to' cluster has multiple + // dependencies on the 'from' cluster and another dependency leads to a + // merging of the clusters. + // + // TODO(b/117085735): We probably want to handle the reciprocal of this case + // where a cluster is producing data for multiple devices. + for (const auto& in_id : cycles.Predecessors(to_node_id)) { + if (in_id >= graph_->num_node_ids()) { + continue; + } + + Node* in = graph_->FindNodeId(in_id); + const Cluster& cluster_in = clusters[in_id].Get(); + if (compilation_candidates.find(in) != compilation_candidates.cend()) { + bool devices_compatible; + TF_RETURN_IF_ERROR( + AreDevicesCompatible(cluster_to, cluster_in, &devices_compatible)); + if (!devices_compatible) { + return true; + } + } + } + + return false; +} + +bool MarkForCompilationPassImpl::HasMismatchingXlaScope(Node* node_from, + Node* node_to) { + // Look for an _XlaScope on both nodes. If both nodes have a scope and the + // scopes do not match, do not cluster along this edge. This restriction is + // overridden if the global_jit_level_ is ON. If even one of the nodes lacks + // an _XlaScope attribute, then it is treated as a "bridge" and a cluster may + // be created along it. We may want to restrict this behavior to require all + // nodes marked with _XlaCompile=true to also have a _XlaScope property set + // (and raise an error otherwise); but for now we don't do this. + if (global_jit_level_ != OptimizerOptions::OFF) { + return false; + } + + string from_scope, to_scope; + return GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && + GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && + from_scope != to_scope; +} + +void MarkForCompilationPassImpl::BuildInitialClusterSet( + const OrderedNodeSet& compilation_candidates, + std::vector>* clusters, + std::deque*>* worklist) { + clusters->resize(graph_->num_node_ids()); + for (Node* node : compilation_candidates) { + Cluster* cluster = &(*clusters)[node->id()].Get(); + cluster->representative = node->id(); + const string& device = !node->assigned_device_name().empty() + ? node->assigned_device_name() + : node->requested_device(); + if (HasResourceInput(*node) || HasResourceOutput(*node)) { + cluster->resource_op_device = device; + } + + cluster->has_xla_compile_attr = false; + + bool xla_compile_attr; + if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) { + cluster->has_xla_compile_attr |= xla_compile_attr; + } + + if (flib_def_->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr).ok()) { + cluster->has_xla_compile_attr |= xla_compile_attr; + } + + cluster->devices.insert(device); + worklist->push_back(&(*clusters)[node->id()]); + } +} + +Status MarkForCompilationPassImpl::FindCompilationCandidates( OrderedNodeSet* candidates, absl::flat_hash_set* isolated_nodes) { OptimizerOptions opts; std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, - flib_def, opts)); + new ProcessFunctionLibraryRuntime(nullptr, env_, TF_GRAPH_DEF_VERSION, + flib_def_, opts)); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - std::vector compile_time_const_nodes(graph.num_node_ids(), false); - TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, - &compile_time_const_nodes, lib_runtime)); - - int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; + std::vector compile_time_const_nodes(graph_->num_node_ids(), false); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *graph_, /*compile_time_const_arg_indices=*/nullptr, + &compile_time_const_nodes, lib_runtime)); // Iterate over nodes in sorted order so that compiler fuel is deterministic. // We can't simply pass op_nodes().begin() and op_nodes().end to the // std::vector constructor because they're not proper iterators, with // iterator_traits defined and so on. std::vector sorted_nodes; - for (Node* node : graph.op_nodes()) { + for (Node* node : graph_->op_nodes()) { sorted_nodes.push_back(node); } std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); - if (fuel >= std::numeric_limits::max() / 2) { + if (*debug_options_.fuel >= std::numeric_limits::max() / 2) { // The assumption is that if fuel started out as INT64_MAX, it will forever // stay greater than INT64_MAX / 2. VLOG(2) << "Starting fuel: infinity"; } else { - VLOG(2) << "Starting fuel: " << fuel; + VLOG(2) << "Starting fuel: " << *debug_options_.fuel; } + std::unique_ptr deadness_analysis; + if (!debug_options_.ignore_deadness_checks) { + XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); + TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis)); + } + + VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size(); + for (Node* node : sorted_nodes) { - if (fuel <= 0) { + if (*debug_options_.fuel <= 0) { VLOG(1) << "Hit fuel limit; not marking any remaining ops as clusterable."; break; @@ -440,14 +627,28 @@ Status FindCompilationCandidates( VLOG(4) << "Device type for " << node->name() << ": " << device_type.type_string(); - if (is_compilable_fn && !is_compilable_fn(node, device_type)) { - // is_compilable_fn has already logged the reason if it returned false. + if (deadness_analysis) { + if (node->IsMerge() || + deadness_analysis->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; + continue; + } + } + + if (CompilationDisallowedByXlaCompileAttr(node, device_type)) { + VLOG(2) << "Not clustering " << node->name() + << ": disallowed by _XlaCompile attribute"; continue; } const XlaOpRegistry::DeviceRegistration* registration; - CHECK( - XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), + ®istration)) { + VLOG(2) << "Rejecting " << node->name() + << ": could not find JIT device for " << device_type.type(); + continue; + } + DeviceType jit_device_type(registration->compilation_device_name); RecursiveCompilabilityChecker::OperationFilter op_filter = @@ -461,7 +662,7 @@ Status FindCompilationCandidates( if (compile_time_const_nodes[node->id()]) { const OpDef* op_def; TF_RETURN_IF_ERROR( - graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); + graph_->op_registry()->LookUpOpDef(node->type_string(), &op_def)); if (op_def->is_stateful()) { // It is easiest to demonstrate the problem we're trying to solve with // an example. Say we have this graph: @@ -506,149 +707,342 @@ Status FindCompilationCandidates( VLOG(2) << "Isolating " << node->name() << ": must-be-constant stateful op"; isolated_nodes->insert(node); - // Keep going and execute all the other checks. } } } candidates->insert(node); - --fuel; + --(*debug_options_.fuel); } + VLOG(2) << "candidates->size() = " << candidates->size(); return Status::OK(); } -struct Cluster { - // Identifies the node that represents this cluster in the cycle detection - // graph. - int representative = -1; +bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( + Node* node, const DeviceType& device_type) { + if (debug_options_.ignore_xla_compile_attr) { + return false; + } - // The set of devices the nodes in this cluster are placed on. - absl::flat_hash_set devices; - - // If there are resource operation in the cluster then this is the device that - // resource operations are placed on. All resource operations in a cluster - // must be placed on the same device. - string resource_op_device; - - // True if any node in the cluster has an _XlaCompile attribute set to true. - bool has_xla_compile_attr; -}; - -} // anonymous namespace - -bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { - Device* device = flr->device(); const XlaOpRegistry::DeviceRegistration* registration; - CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), - ®istration)); - DeviceType jit_device_type(registration->compilation_device_name); + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device."; + return false; + } - // We can always *compile* resource operations, stateful RNGs and dummy ops, - // even if we are sometimes unable to auto-cluster them. - RecursiveCompilabilityChecker::OperationFilter op_filter; - op_filter.allow_resource_ops_in_called_functions = true; - op_filter.allow_non_resource_var_resource_ops = true; - op_filter.allow_resource_producing_ops = true; - op_filter.allow_stateful_rng_ops = true; - op_filter.allow_control_trigger = true; - op_filter.allow_dummy_ops = true; - op_filter.allow_ops_producing_or_consuming_variant = true; + // If there is a _XlaCompile annotation, use its value. + bool compile = false; + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") is false."; + } + return !compile; + } - return RecursiveCompilabilityChecker{&op_filter, &jit_device_type} - .IsCompilableCall(ndef, flr); + status = flib_def_->GetAttr(*node, kXlaCompileAttr, &compile); + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") on callee is false."; + } + return !compile; + } + + return false; } -Status MarkForCompilationPass::Run( - const GraphOptimizationPassOptions& options) { - // TODO(phawkins): precompute the "GetCompilationDevice" properties of each - // device ahead of time. - MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; - - const FunctionLibraryDefinition* fld = options.flib_def; - - // Deadness analysis expects a graph with source and sink edges properly - // connected but sometimes the incoming graph does not follow this invariant. - // So fix up the source and sink edges before calling into deadness analysis. - FixupSourceAndSinkEdges(options.graph->get()); - - // See explanation on `kXlaAlreadyClustered`. - for (Node* n : options.graph->get()->nodes()) { - if (n->attrs().Find(kXlaAlreadyClustered)) { - return Status::OK(); - } - } - - std::unique_ptr deadness; - { - XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); - TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness)); - } - - bool deadness_analysis_disabled = - GetMarkForCompilationPassFlags() - ->tf_xla_disable_deadness_safety_checks_for_debugging; - - if (deadness_analysis_disabled) { - LOG(WARNING) << "Deadness analysis was manually disabled via " - "--tf_xla_disable_deadness_safety_checks_for_debugging; " - "auto-clustering " - "is unsound!"; - } - - auto is_compilable = [&](const Node* node, const DeviceType& device_type) { - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), - ®istration)) { - VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device."; - return false; - } - - // If there is a _XlaCompile annotation, use its value. - bool compile = false; - Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); - if (status.ok()) { - if (!compile) { - VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" - << kXlaCompileAttr << ") is false."; - } - return compile; - } - - status = fld->GetAttr(*node, kXlaCompileAttr, &compile); - if (status.ok()) { - if (!compile) { - VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" - << kXlaCompileAttr << ") on callee is false."; - } - return compile; - } - - // If inputs to `node` can have conflicting deadness (i.e. some are alive - // and some are dead) then don't compile it. XLA cannot represent the - // deadness semantics of these nodes correctly and auto-clustering these - // nodes can cause deadness to propagate to nodes that should be live. - if (!deadness_analysis_disabled) { - if (node->IsMerge() || - deadness->HasInputsWithMismatchingDeadness(*node)) { - VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; - return false; - } - } - - return true; - }; - - return RunImpl(options, is_compilable); +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +bool IsShapeConsumerOp(const Node& node) { + return node.type_string() == "Shape" || node.type_string() == "Rank" || + node.type_string() == "Size"; } -static string RatioToString(int numerator, int denominator) { +Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) { + // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then + // ignore it during resource operation safety analysis. We need this hack + // because of two reasons: + // + // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled. + // 2. We don't support live-out values of type DT_RESOURCE and live-in values + // of type DT_RESOURCE that are not resource variables. + // + // Together these imply we cannot let resource variable safety analysis + // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different + // clusters: both of them will have to be clustered because of (1) and we + // won't be able to keep the edge between the two as neither the input to the + // second XLA cluster nor the output from the first XLA cluster are supported + // because of (2). + // + // TODO(b/113100872): This can be fixed if the TensorFlow representation for + // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then + // (2) would no longer hold. + + if (n.assigned_device_name().empty()) { + *ignore = false; + return Status::OK(); + } + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n.assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *ignore = true; + } else { + *ignore = registration->compile_all_resource_ops; + } + return Status::OK(); +} + +Status MarkForCompilationPassImpl::Run() { + static std::atomic cluster_sequence_num; + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + + OrderedNodeSet compilation_candidates; + absl::flat_hash_set isolated_nodes; + TF_RETURN_IF_ERROR( + FindCompilationCandidates(&compilation_candidates, &isolated_nodes)); + + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + return Status::OK(); + } + + GraphCycles cycles; + TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, + CreateCycleDetectionGraph(graph_, &cycles)); + if (!cycle_detection_graph_ok) { + return Status::OK(); + } + + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( + graph_, flib_def_, IgnoreResourceOpForSafetyAnalysis, &cycles)); + + // Each compilation candidate belongs to a cluster. The cluster's + // representative names the node in the 'cycles' graph that represents the + // cluster. + std::vector> clusters; + std::deque*> worklist; + BuildInitialClusterSet(compilation_candidates, &clusters, &worklist); + + // Repeatedly contract edges between clusters that are on the same device, + // provided the contraction would not create a cycle. + // + // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for + // example, from the Grappler fusion pass). + while (!worklist.empty()) { + Cluster* cluster_from = &worklist.front()->Get(); + int from = cluster_from->representative; + worklist.pop_front(); + + Node* node_from = graph_->FindNodeId(from); + if (node_from->IsControlFlow()) { + // Control flow nodes aren't compilation candidates and should never + // appear. + return errors::Internal( + "Found control flow node in clustering worklist: ", + node_from->type_string()); + } + + if (isolated_nodes.count(node_from)) { + continue; + } + + for (int to : cycles.Successors(from)) { + if (to >= graph_->num_node_ids()) { + // Node is a fictitious node that is present only in the cycle detection + // graph. No clustering is possible. + continue; + } + + const Cluster& cluster_to = clusters[to].Get(); + Node* node_to = graph_->FindNodeId(to); + if (compilation_candidates.find(node_to) == + compilation_candidates.cend()) { + continue; + } + + bool devices_compatible; + TF_RETURN_IF_ERROR( + AreDevicesCompatible(*cluster_from, cluster_to, &devices_compatible)); + if (!devices_compatible) { + continue; + } + + if (isolated_nodes.count(node_to)) { + continue; + } + + if (HasMismatchingXlaScope(node_from, node_to)) { + continue; + } + + // Ops that consume shapes cannot be the root of a cluster. This is an + // optimization. + if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { + continue; + } + + // Don't exceed the maximum cluster size. + if (clusters[from].Size() + clusters[to].Size() > + debug_options_.max_cluster_size) { + continue; + } + + TF_ASSIGN_OR_RETURN( + bool will_introduce_cross_device_dependency, + ClusteringWillIntroduceInterDeviceDependency( + to, compilation_candidates, absl::MakeSpan(clusters), cycles)); + + if (will_introduce_cross_device_dependency) { + continue; + } + + // If contracting the edge would create a cycle, bail out. However, just + // because we can't merge the clusters now does not mean we won't be able + // to merge them in the future. e.g., if we have edges 1->2, 2->3 and + // 1->3, we cannot contract edge 1->3. But if we first contract 1->2 then + // we can later contract 1->3. + if (!cycles.ContractEdge(from, to)) { + continue; + } + + // Merge the clusters. ContractEdge uses 'from' as the number of the + // merged node, so make sure 'from' is the chosen representative. + cluster_from->devices.insert(cluster_to.devices.begin(), + cluster_to.devices.end()); + if (!cluster_to.resource_op_device.empty()) { + cluster_from->resource_op_device = cluster_to.resource_op_device; + } + cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr; + clusters[from].Merge(&clusters[to]); + + worklist.push_back(&clusters[from]); + break; + } + } + + // Count the number of non-trivial elements in each cluster. + std::vector effective_cluster_sizes(graph_->num_node_ids()); + + // has_functional_control_flow remembers if a cluster contains a functional + // control flow node. + std::vector has_functional_control_flow(graph_->num_node_ids()); + + for (const Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + // We want clusters to be big enough that the benefit from XLA's + // optimizations offsets XLA related overhead (for instance we add some + // Switch/Merge nodes into the graph to implement lazy compilation). To + // this end, we don't count Identity and Constant nodes because they do not + // enable interesting optimizations by themselves. + if (!n->IsIdentity() && !n->IsConstant()) { + effective_cluster_sizes[cluster]++; + } + if (n->type_string() == "While" || n->type_string() == "If") { + has_functional_control_flow[cluster] = true; + } + } + + // Names for each cluster. + std::unordered_map cluster_names; + + if (debug_options_.dump_graphs) { + DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_); + } + + // Mark clusters for compilation that: + // * are placed on a device that requires compilation (an XlaDevice), + // * are explicitly marked for compilation (_XlaCompile=true), or + // * have more than debug_options_.xla_min_cluster_size elements (applicable + // only if compilation is enabled, otherwise there will be no such + // candidates). + for (Node* n : compilation_candidates) { + const Cluster& cluster = clusters[n->id()].Get(); + bool should_compile; + string device; + TF_RETURN_IF_ERROR(ShouldCompileCluster(cluster, &should_compile, &device)); + if (!should_compile) { + continue; + } + + int cluster_repr = cluster.representative; + + // Compile if the user marked this node _XlaCompile=true + bool compile_attr = false; + bool marked_for_compilation = false; + if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) { + marked_for_compilation = compile_attr; + } else if (flib_def_->GetAttr(*n, kXlaCompileAttr, &compile_attr).ok()) { + marked_for_compilation = compile_attr; + } + + // We assume that functional If and While nodes have at least + // min_cluster_size non-trivial nodes in them. It would be more principled + // to (recursively) verify this fact, but that's probably not worth the + // trouble. + + if (effective_cluster_sizes[cluster_repr] >= + debug_options_.min_cluster_size || + has_functional_control_flow[cluster_repr] || marked_for_compilation) { + string& name = cluster_names[cluster_repr]; + + if (name.empty()) { + name = absl::StrCat("cluster_", cluster_sequence_num++); + } + n->AddAttr(kXlaClusterAttr, name); + n->AddAttr(kXlaAlreadyClustered, true); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; + } + } + + if (debug_options_.dump_graphs) { + DumpPostClusteringGraphs(); + } + + VLogClusteringSummary(); + + return Status::OK(); +} + +void MarkForCompilationPassImpl::DumpPostClusteringGraphs() { + DumpGraphToFile("mark_for_compilation", *graph_, flib_def_); + + // We also dump out an annoated version of the TF graph where the nodes + // names are prefixed with the cluster names. This can help visualizing the + // clustering decisions on TensorBoard. + Graph new_graph(graph_->op_registry()); + CopyGraph(*graph_, &new_graph); + + for (Node* n : new_graph.nodes()) { + if (absl::optional cluster_name = + GetXlaClusterForNode(*n)) { + n->set_name(absl::StrCat(*cluster_name, "/", n->name())); + } else if (n->type_string() == "VarHandleOp") { + n->set_name(absl::StrCat("varhandle/", n->name())); + } else { + // There is room for improvement here. In particular, it may help to + // split these unclustered nodes into classes where every node in a + // specific class has edges to and from the same set of clusters. + n->set_name(absl::StrCat("unclustered/", n->name())); + } + } + + DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_); +} + +string RatioToString(int numerator, int denominator) { return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, (100.0 * numerator) / denominator); } -static void VLogClusteringSummary(const Graph& g) { +void MarkForCompilationPassImpl::VLogClusteringSummary() { if (!VLOG_IS_ON(2)) { return; } @@ -659,7 +1053,7 @@ static void VLogClusteringSummary(const Graph& g) { std::map unclustered_op_histogram; int clustered_node_count = 0; - for (Node* n : g.nodes()) { + for (Node* n : graph_->nodes()) { absl::optional cluster_name = GetXlaClusterForNode(*n); if (cluster_name) { clustered_node_count++; @@ -670,17 +1064,17 @@ static void VLogClusteringSummary(const Graph& g) { } } - int unclustered_node_count = g.num_nodes() - clustered_node_count; + int unclustered_node_count = graph_->num_nodes() - clustered_node_count; - VLOG(2) << "*** Clustering info for graph of size " << g.num_nodes(); + VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes(); VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size " - << RatioToString(clustered_node_count, g.num_nodes()); + << RatioToString(clustered_node_count, graph_->num_nodes()); for (const auto& cluster_name_size_pair : cluster_name_to_size) { absl::string_view cluster_name = cluster_name_size_pair.first; int size = cluster_name_size_pair.second; VLOG(2) << " " << cluster_name << " " - << RatioToString(size, g.num_nodes()); + << RatioToString(size, graph_->num_nodes()); for (const auto& op_count_pair : cluster_name_to_op_histogram[cluster_name]) { VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second @@ -690,7 +1084,7 @@ static void VLogClusteringSummary(const Graph& g) { if (!unclustered_op_histogram.empty()) { VLOG(2) << " Unclustered nodes: " - << RatioToString(unclustered_node_count, g.num_nodes()); + << RatioToString(unclustered_node_count, graph_->num_nodes()); for (const auto& pair : unclustered_op_histogram) { VLOG(3) << " " << pair.first << ": " << pair.second << " instances"; } @@ -721,7 +1115,7 @@ static void VLogClusteringSummary(const Graph& g) { std::set cluster_names_to_print; - for (const Edge* e : g.edges()) { + for (const Edge* e : graph_->edges()) { const Node* from = e->src(); absl::optional from_cluster_name = GetXlaClusterForNode(*from); @@ -776,59 +1170,8 @@ static void VLogClusteringSummary(const Graph& g) { } } -// Is 'node' an operator that consumes only the shape of its input, not the -// data itself? -static bool IsShapeConsumerOp(const Node& node) { - return node.type_string() == "Shape" || node.type_string() == "Rank" || - node.type_string() == "Size"; -} - -static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) { - // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then - // ignore it during resource operation safety analysis. We need this hack - // because of two reasons: - // - // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled. - // 2. We don't support live-out values of type DT_RESOURCE and live-in values - // of type DT_RESOURCE that are not resource variables. - // - // Together these imply we cannot let resource variable safety analysis - // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different - // clusters: both of them will have to be clustered because of (1) and we - // won't be able to keep the edge between the two as neither the input to the - // second XLA cluster nor the output from the first XLA cluster are supported - // because of (2). - // - // TODO(b/113100872): This can be fixed if the TensorFlow representation for - // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then - // (2) would no longer hold. - - if (n.assigned_device_name().empty()) { - *ignore = false; - return Status::OK(); - } - DeviceType device_type(""); - TF_RETURN_IF_ERROR( - DeviceToDeviceType(n.assigned_device_name(), &device_type)); - - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - *ignore = true; - } else { - *ignore = registration->compile_all_resource_ops; - } - return Status::OK(); -} - -// Sequence number generator to ensure clusters have unique names. -static std::atomic cluster_sequence_num; - -// Returns true if the devices in `cluster_a` and `cluster_b` are compatible and -// therefore not a hindrance for combining the two clusters into a larger -// cluster. -static Status AreDevicesCompatible( - const Cluster& cluster_a, const Cluster& cluster_b, - OptimizerOptions::GlobalJitLevel global_jit_level, bool* result) { +Status MarkForCompilationPassImpl::AreDevicesCompatible( + const Cluster& cluster_a, const Cluster& cluster_b, bool* result) { std::vector devices; absl::c_remove_copy(cluster_a.devices, std::back_inserter(devices), ""); absl::c_remove_copy(cluster_b.devices, std::back_inserter(devices), ""); @@ -871,9 +1214,9 @@ static Status AreDevicesCompatible( // We will check this again later, but here we prune out clusters that would // never have been sent to XLA to save compile time. Without this change we - // will e.g. create a CPU cluster only to later notice that the user did not - // enable the CPU JIT via --tf_xla_cpu_global_jit. With this change we avoid - // creating the cluster to begin with. + // will e.graph_-> create a CPU cluster only to later notice that the user did + // not enable the CPU JIT via --tf_xla_cpu_global_jit. With this change we + // avoid creating the cluster to begin with. // // TODO(b/126629785): It is possible that this is just papering over O(n^2) // behavior in our clustering algorithm. @@ -891,15 +1234,14 @@ static Status AreDevicesCompatible( XlaOpRegistry::AutoclusteringPolicy::kAlways || (registration->autoclustering_policy == XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && - global_jit_level != OptimizerOptions::OFF); + global_jit_level_ != OptimizerOptions::OFF); return Status::OK(); } // Returns `true` iff we should compile `cluster`. -static Status ShouldCompileClusterImpl( - const Cluster& cluster, OptimizerOptions::GlobalJitLevel global_jit_level, - bool* should_compile, string* device) { +Status MarkForCompilationPassImpl::ShouldCompileClusterImpl( + const Cluster& cluster, bool* should_compile, string* device) { std::vector devices; absl::c_remove_copy(cluster.devices, std::back_inserter(devices), ""); absl::c_sort(devices); @@ -923,7 +1265,7 @@ static Status ShouldCompileClusterImpl( XlaOpRegistry::AutoclusteringPolicy::kAlways || (registration->autoclustering_policy == XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && - global_jit_level != OptimizerOptions::OFF); + global_jit_level_ != OptimizerOptions::OFF); if (!*should_compile && registration->autoclustering_policy == @@ -955,324 +1297,112 @@ static Status ShouldCompileClusterImpl( return Status::OK(); } -static Status ShouldCompileCluster( - absl::flat_hash_map>* cache, - OptimizerOptions::GlobalJitLevel global_jit_level, const Cluster& cluster, - bool* should_compile, string* device) { - auto it = cache->find(cluster.representative); - if (it != cache->end()) { +Status MarkForCompilationPassImpl::ShouldCompileCluster(const Cluster& cluster, + bool* should_compile, + string* device) { + auto it = should_compile_cluster_cache_.find(cluster.representative); + if (it != should_compile_cluster_cache_.end()) { *should_compile = it->second.first; *device = it->second.second; return Status::OK(); } string device_s; - TF_RETURN_IF_ERROR(ShouldCompileClusterImpl(cluster, global_jit_level, - should_compile, &device_s)); - cache->insert({cluster.representative, {*should_compile, device_s}}); + TF_RETURN_IF_ERROR( + ShouldCompileClusterImpl(cluster, should_compile, &device_s)); + should_compile_cluster_cache_.insert( + {cluster.representative, {*should_compile, device_s}}); *device = std::move(device_s); return Status::OK(); } -Status MarkForCompilationPass::RunImpl( +Status MarkForCompilation( const GraphOptimizationPassOptions& options, - const std::function& - is_compilable_fn) { - VLOG(1) << "MarkForCompilationPass::Run"; - - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - + const MarkForCompilationPassImpl::DebugOptions& debug_options) { Graph* graph = options.graph->get(); + FunctionLibraryDefinition* flib_def = options.flib_def; - OrderedNodeSet compilation_candidates; - absl::flat_hash_set isolated_nodes; - TF_RETURN_IF_ERROR(FindCompilationCandidates( - *graph, options.flib_def, - (options.session_options != nullptr) ? options.session_options->env - : Env::Default(), - is_compilable_fn, &compilation_candidates, &isolated_nodes)); + // Deadness analysis expects a graph with source and sink edges properly + // connected but sometimes the incoming graph does not follow this invariant. + // So fix up the source and sink edges before calling into deadness analysis. + FixupSourceAndSinkEdges(graph); - if (compilation_candidates.empty()) { - VLOG(2) << "No compilable candidates"; - return Status::OK(); - } - - GraphCycles cycles; - TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, - CreateCycleDetectionGraph(graph, &cycles)); - if (!cycle_detection_graph_ok) { - return Status::OK(); - } - TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( - graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); - - // Each compilation candidate belongs to a cluster. The cluster's - // representative - // names the node in the 'cycles' graph that represents the cluster. - std::vector> clusters(graph->num_node_ids()); - std::deque*> worklist; - for (Node* node : compilation_candidates) { - Cluster& cluster = clusters[node->id()].Get(); - cluster.representative = node->id(); - const string& device = !node->assigned_device_name().empty() - ? node->assigned_device_name() - : node->requested_device(); - if (HasResourceInput(*node) || HasResourceOutput(*node)) { - cluster.resource_op_device = device; - } - cluster.has_xla_compile_attr = false; - bool xla_compile_attr; - if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) { - cluster.has_xla_compile_attr |= xla_compile_attr; - } - if (options.flib_def->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr) - .ok()) { - cluster.has_xla_compile_attr |= xla_compile_attr; - } - - cluster.devices.insert(device); - worklist.push_back(&clusters[node->id()]); - } - - OptimizerOptions::GlobalJitLevel global_jit_level = - GetGlobalJitLevel(options); - MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - - // Repeatedly contract edges between clusters that are on the same device, - // provided the contraction would not create a cycle. - // - // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for - // example, from the Grappler fusion pass). - while (!worklist.empty()) { - Cluster* cluster_from = &worklist.front()->Get(); - int from = cluster_from->representative; - worklist.pop_front(); - - Node* node_from = graph->FindNodeId(from); - if (node_from->IsControlFlow()) { - // Control flow nodes aren't compilation candidates and should never - // appear. - return errors::Internal( - "Found control flow node in clustering worklist: ", - node_from->type_string()); - } - - if (isolated_nodes.count(node_from)) { - continue; - } - - string from_scope; - string to_scope; - for (int to : cycles.Successors(from)) { - if (to >= graph->num_node_ids()) { - // Node is a fictitious node that is present only in the cycle detection - // graph. No clustering is possible. - continue; - } - - const Cluster& cluster_to = clusters[to].Get(); - Node* node_to = graph->FindNodeId(to); - if (compilation_candidates.find(node_to) == - compilation_candidates.cend()) { - continue; - } - bool devices_compatible; - TF_RETURN_IF_ERROR(AreDevicesCompatible( - *cluster_from, cluster_to, global_jit_level, &devices_compatible)); - if (!devices_compatible) { - continue; - } - if (isolated_nodes.count(node_to)) { - continue; - } - // Look for an _XlaScope on both nodes. If both nodes have a - // scope and the scopes do not match, do not cluster along this - // edge. This restriction is overridden if the global_jit_level is ON. If - // even one of the nodes lacks an _XlaScope attribute, - // then it is treated as a "bridge" and a cluster may be created - // along it. We may want to restrict this behavior to require - // all nodes marked with _XlaCompile=true to also have a - // _XlaScope property set (and raise an error otherwise); but - // for now we don't do this. - if (global_jit_level == OptimizerOptions::OFF && - GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && - GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && - from_scope != to_scope) { - continue; - } - - // Ops that consume shapes cannot be the root of a cluster. This is an - // optimization. - if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { - continue; - } - - // Don't exceed the maximum cluster size. - if (clusters[from].Size() + clusters[to].Size() > - flags->tf_xla_max_cluster_size) { - continue; - } - - // If any of the consumer's producers are on a different device, do not - // cluster these nodes. This prevents other work on this device from being - // delayed by work on other devices. We consider predecessors of the - // entire cluster rather than just the inputs to the node to prevent the - // cluster still being combined in cases where the 'to' cluster has - // multiple dependencies on the 'from' cluster and another dependency - // leads to a merging of the clusters. - // - // TODO(b/117085735): We probably want to handle the reciprocal of this - // case where a cluster is producing data for multiple devices. - bool found_split = false; - for (const auto& in_id : cycles.Predecessors(to)) { - if (in_id >= graph->num_node_ids()) continue; - - Node* in = graph->FindNodeId(in_id); - const Cluster& cluster_in = clusters[in_id].Get(); - if (compilation_candidates.find(in) != compilation_candidates.cend()) { - bool devices_compatible; - TF_RETURN_IF_ERROR(AreDevicesCompatible( - cluster_to, cluster_in, global_jit_level, &devices_compatible)); - if (!devices_compatible) { - found_split = true; - } - } - } - if (found_split) continue; - - // If contracting the edge would create a cycle, bail out. - // However, just because we can't merge the clusters now does not mean - // we won't be able to merge them in the future. - // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge - // 1->3. But if we first contract 1->2 then we can later contract 1->3. - if (!cycles.ContractEdge(from, to)) continue; - - // Merge the clusters. ContractEdge uses 'from' as the number of the - // merged node, so make sure 'from' is the chosen representative. - cluster_from->devices.insert(cluster_to.devices.begin(), - cluster_to.devices.end()); - if (!cluster_to.resource_op_device.empty()) { - cluster_from->resource_op_device = cluster_to.resource_op_device; - } - cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr; - clusters[from].Merge(&clusters[to]); - - worklist.push_back(&clusters[from]); - break; + // See explanation on `kXlaAlreadyClustered`. + for (Node* n : graph->nodes()) { + if (n->attrs().Find(kXlaAlreadyClustered)) { + return Status::OK(); } } - // Count the number of non-trivial elements in each cluster. - std::vector effective_cluster_sizes(graph->num_node_ids()); - - // has_functional_control_flow remembers if a cluster contains a functional - // control flow node. - std::vector has_functional_control_flow(graph->num_node_ids()); - - for (const Node* n : compilation_candidates) { - int cluster = clusters[n->id()].Get().representative; - // We want clusters to be big enough that the benefit from XLA's - // optimizations offsets XLA related overhead (for instance we add some - // Switch/Merge nodes into the graph to implement lazy compilation). To - // this end, we don't count Identity and Constant nodes because they do not - // enable interesting optimizations by themselves. - if (!n->IsIdentity() && !n->IsConstant()) { - effective_cluster_sizes[cluster]++; - } - if (n->type_string() == "While" || n->type_string() == "If") { - has_functional_control_flow[cluster] = true; - } - } - - // Names for each cluster. - std::unordered_map cluster_names; - - if (flags->tf_xla_clustering_debug) { - DumpGraphToFile("before_mark_for_compilation", **options.graph, - options.flib_def); - } - - absl::flat_hash_map> - should_compile_cluster_cache; - - // Mark clusters for compilation that: - // * are placed on a device that requires compilation (an XlaDevice), - // * are explicitly marked for compilation (_XlaCompile=true), or - // * have more than flags->tf_xla_min_cluster_size elements (applicable only - // if compilation is enabled, otherwise there will be no such candidates). - const int min_cluster_size = flags->tf_xla_min_cluster_size; - for (Node* n : compilation_candidates) { - const Cluster& cluster = clusters[n->id()].Get(); - bool should_compile; - string device; - TF_RETURN_IF_ERROR(ShouldCompileCluster(&should_compile_cluster_cache, - global_jit_level, cluster, - &should_compile, &device)); - if (!should_compile) { - continue; - } - - int cluster_repr = cluster.representative; - - // Compile if the user marked this node _XlaCompile=true - bool compile_attr = false; - bool marked_for_compilation = false; - if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) { - marked_for_compilation = compile_attr; - } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr) - .ok()) { - marked_for_compilation = compile_attr; - } - - // We assume that functional If and While nodes have at least - // min_cluster_size non-trivial nodes in them. It would be more principled - // to (recursively) verify this fact, but that's probably not worth the - // trouble. - - if (effective_cluster_sizes[cluster_repr] >= min_cluster_size || - has_functional_control_flow[cluster_repr] || marked_for_compilation) { - string& name = cluster_names[cluster_repr]; - - if (name.empty()) { - name = absl::StrCat("cluster_", cluster_sequence_num++); - } - n->AddAttr(kXlaClusterAttr, name); - n->AddAttr(kXlaAlreadyClustered, true); - VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; - } - } - - if (flags->tf_xla_clustering_debug) { - DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def); - - // We also dump out an annoated version of the TF graph where the nodes - // names are prefixed with the cluster names. This can help visualizing the - // clustering decisions on TensorBoard. - Graph new_graph((*options.graph)->op_registry()); - CopyGraph(**options.graph, &new_graph); - - for (Node* n : new_graph.nodes()) { - if (absl::optional cluster_name = - GetXlaClusterForNode(*n)) { - n->set_name(absl::StrCat(*cluster_name, "/", n->name())); - } else if (n->type_string() == "VarHandleOp") { - n->set_name(absl::StrCat("varhandle/", n->name())); - } else { - // There is room for improvement here. In particular, it may help to - // split these unclustered nodes into classes where every node in a - // specific class has edges to and from the same set of clusters. - n->set_name(absl::StrCat("unclustered/", n->name())); - } - } - - DumpGraphToFile("mark_for_compilation_annotated", new_graph, - options.flib_def); - } - - VLogClusteringSummary(*graph); - - return Status::OK(); + return MarkForCompilationPassImpl{debug_options, graph, flib_def, + options.session_options != nullptr + ? options.session_options->env + : Env::Default(), + GetGlobalJitLevel(options)} + .Run(); } +std::atomic* GetPointerToFuel(int64 initial_value) { + static std::atomic* fuel = [&]() { + std::atomic* fuel = new std::atomic; + *fuel = initial_value; + return fuel; + }(); + + return fuel; +} +} // anonymous namespace + +bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { + Device* device = flr->device(); + const XlaOpRegistry::DeviceRegistration* registration; + CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), + ®istration)); + DeviceType jit_device_type(registration->compilation_device_name); + + // We can always *compile* resource operations, stateful RNGs and dummy ops, + // even if we are sometimes unable to auto-cluster them. + RecursiveCompilabilityChecker::OperationFilter op_filter; + op_filter.allow_resource_ops_in_called_functions = true; + op_filter.allow_non_resource_var_resource_ops = true; + op_filter.allow_resource_producing_ops = true; + op_filter.allow_stateful_rng_ops = true; + op_filter.allow_control_trigger = true; + op_filter.allow_dummy_ops = true; + op_filter.allow_ops_producing_or_consuming_variant = true; + + return RecursiveCompilabilityChecker{&op_filter, &jit_device_type} + .IsCompilableCall(ndef, flr); +} + +Status MarkForCompilationPass::Run( + const GraphOptimizationPassOptions& options) { + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + + MarkForCompilationPassImpl::DebugOptions debug_options; + debug_options.ignore_deadness_checks = + flags->tf_xla_disable_deadness_safety_checks_for_debugging; + debug_options.ignore_xla_compile_attr = false; + debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; + debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; + debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); + debug_options.dump_graphs = flags->tf_xla_clustering_debug; + + return MarkForCompilation(options, debug_options); +} + +Status MarkForCompilationPass::RunForTest( + const GraphOptimizationPassOptions& options) { + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + + MarkForCompilationPassImpl::DebugOptions debug_options; + debug_options.ignore_deadness_checks = true; + debug_options.ignore_xla_compile_attr = true; + debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; + debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; + debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); + debug_options.dump_graphs = flags->tf_xla_clustering_debug; + + return MarkForCompilation(options, debug_options); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index f1137af3c1e..f0c46c91459 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -41,9 +41,7 @@ class MarkForCompilationPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; private: - Status RunImpl(const GraphOptimizationPassOptions& options, - const std::function& - is_compilable_fn = {}); + Status RunForTest(const GraphOptimizationPassOptions& options); friend class MarkForCompilationPassTestHelper; }; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index 5f0ebe150fa..9f767b50082 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -49,7 +49,7 @@ namespace tensorflow { opt_options.session_options = &session_options; opt_options.flib_def = flib_def; MarkForCompilationPass pass; - return pass.RunImpl(opt_options); + return pass.RunForTest(opt_options); } /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(