[NFC] Clarify documentation of JIT passes
PiperOrigin-RevId: 286427776 Change-Id: I424ffc406bf9a963f2c52c6d6f61408273f0230f
This commit is contained in:
parent
9685d6a218
commit
2ae84120e7
|
@ -22,8 +22,9 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
|
||||
// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
|
||||
// Replaces TF function calls marked with `_XlaCompiledKernel` with _XlaCompile
|
||||
// and _XlaRun nodes (which compile and launch, respectively, the corresponding
|
||||
// HLO module).
|
||||
class BuildXlaOpsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
// If enable_lazy_compilation is not nullopt then *enable_lazy_compilation
|
||||
|
|
|
@ -27,6 +27,15 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
// EncapsulateSubgraphs pass takes all the nodes with the same cluster ID
|
||||
// (derived from kXlaClusterAttr=ID (kXlaClusterAttr) attribute), puts them into
|
||||
// a TF function, and replaces the subgraph in the main graph with a call to
|
||||
// that TF function annotated with kXlaCompiledKernelAttr (_XlaCompiledKernel).
|
||||
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
// A rewriting function to apply to each subgraph during encapsulation.
|
||||
// 'arg_source_tensors' are the tensors corresponding to the arguments in the
|
||||
// original source graph (*not* 'graph').
|
||||
|
@ -100,11 +109,6 @@ extern const char* const kXlaHasReferenceVarsAttr;
|
|||
// TODO(hpucha): Move the utilities to a more appropriate place.
|
||||
void SortControlInputs(GraphDef* gdef);
|
||||
|
||||
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorflow {
|
||||
|
||||
// Encapsulates nodes marked with the _xla_compile_id attribute into
|
||||
// XlaLaunch operators.
|
||||
|
|
|
@ -34,8 +34,9 @@ extern const char* const kXlaClusterAttr;
|
|||
// compilation by the encapsulate subgraphs pass.
|
||||
extern const char* const kXlaOutsideCompilationAttr;
|
||||
|
||||
// Pass that marks a subset of operators in the graph with attribute
|
||||
// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
|
||||
// Marks a subset of nodes in the graph which are to be clustered
|
||||
// with an attribute _XlaCluster=<cluster id> so they are picked up by the
|
||||
// EncapsulateSubgraphsPass.
|
||||
class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
public:
|
||||
MarkForCompilationPass() = default;
|
||||
|
|
Loading…
Reference in New Issue