STT-tensorflow/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
George Karpenkov 9771765f41 [TF/XLA] Force all tensors which need to be constant during the XLA compilation to be located on the host
Otherwise, this leads to strange crashes during the compilation.

PiperOrigin-RevId: 304226917
Change-Id: Ia2f1e77b13a25c7e15f009787af81f93b90e8bca
2020-04-01 11:34:32 -07:00

89 lines
4.0 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// PRE_PLACEMENT passes:
// EncapsulateXlaComputationsPass rewrites computations generated by the
// xla.compile() Python code into XlaLaunch nodes.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 36,
EncapsulateXlaComputationsPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 35,
IntroduceFloatingPointJitterPass);
// from
// tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
// FunctionalizeControlFlowPass: 27
//
// This pass looks at the graph and all associated FunctionDefs, and turns
// traditional control flow structure (Switch/Merge/etc.) into functional
// control flow structure (XlaIf/XlaWhile). Following passes must
// handle those FunctionDef correctly.
// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA:
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 5,
CloneConstantsForBetterClusteringPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 9,
ClusterScopingPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 12,
ForceXlaConstantsOnHostPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
IncreaseDynamismForAutoJitPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
PartiallyDeclusterPass);
// ReportClusteringInfoPass pass needs to run after all of the auto-clustering
// passes have run but before encapsulation has run. This way it can easily
// compute a summary of the clustering decisions we made and broadcast it via
// xla_activity_listener.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
ReportClusteringInfoPass);
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
EncapsulateSubgraphsPass);
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 60,
BuildXlaOpsPass);
} // namespace tensorflow