STT-tensorflow/tensorflow/compiler/jit/encapsulate_util.h

157 lines
7.4 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains some utility functions for encapsulating XLA computation
// in host graph and encapsulating outside compilation in XLA computation.
#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
// Attribute marking output tensor shapes inferred by XLA. Attribute value is
// a list of PartialTensorShape objects.
extern const char kXlaInferredShapesAttrName[];
// Infers output shapes for all nodes in graph `g`. The output shapes will be
// stored in node attribute `kXlaInferredShapesAttrName`.
//
// We have to perform shape inference before encapsulation because after
// encapsulation, some nodes will be encapsulated into function call, and shape
// inference does not handle function call at the moment.
Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g);
// Attribute indicating that some ops in this node's XLA computation has control
// dependency on this node. Attribute value will always be "true".
extern const char kXlaConnectedToXlaComputationAttrName[];
// Attribute indicating that this node has control dependency on some ops in
// this node's XLA computation. Attribute value will always be "true".
extern const char kXlaConnectedFromXlaComputationAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// string (original input node name).
extern const char kOutsideCompilationOriginalNodeAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// int (src_output for original edge).
extern const char kOutsideCompilationSrcOutputAttrName[];
// Attribute indicating that this node has control dependencies on some other
// nodes within the same XLA cluster. Attribute value will be a list of string
// (node names).
extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
// Attribute indicating that this node is an outside compilation node which is
// lifted out of If/While/function node. Attribute value will always be boolean
// value "true".
extern const char kXlaIsLiftedArgAttrName[];
// Attribute indicating that this node is a Placeholder node for an _Arg node
// lifted out of If/While/function node. Attribute value will be a string, which
// is the outside compilation cluster name sending the lifted arg node to host.
extern const char kXlaLiftedArgOutsideCompilationAttrName[];
// Attribute indicating that this is an IdentityN node receiving inputs for a
// outside compilation Placeholder node (the original outside compilation node
// is moved out of TPU computation, and we left a Placeholder node there).
// Attribute value will be a string, which is the outside compilation cluster
// name for the outside compilation Placeholder node.
extern const char kXlaOutsideCompilationInputsAttrName[];
// Attribute indicating that this is a Placeholder node for an _Arg node used in
// outside compilation. We should not move this node out of XLA computation.
// Attribute value will always be boolean value "true".
extern const char kXlaIsPlaceholderForArg[];
// Information for XLA computation.
struct XlaClusterInfo {
// Add an explicitly-defined default constructor for this class.
//
// The compiler may delete the default constructor here because
// host_compute_core is a const member whose type (std::map) doesn't
// necessarily have a user provided constructor -- while libc++ and
// libstdc++ 4.8 provide a user defined default constructor, libstdc++ at
// least >= 7.3 does not. See also c++11 [class.ctor] p5.
//
// TODO(klimek): In c++17 we'll be able to initialize host_compute_core
// without losing aggregate initialization, which allows us to get rid of
// the constructor definitions again.
XlaClusterInfo() {}
XlaClusterInfo(const string& cluster_name,
const NameAttrList& func_name_attrs, Node* node,
const std::map<string, int>& host_compute_core)
: cluster_name(cluster_name),
func_name_attrs(func_name_attrs),
node(node),
host_compute_core(host_compute_core) {}
// XLA cluster name. It might be different from `func_name`.
const string cluster_name;
// Name and attributes of XLA computation function.
const NameAttrList func_name_attrs;
// The XLA computation node in the graph.
Node* node;
// A mapping from outside compilation cluster name to its device assignment.
const std::map<string, int> host_compute_core;
};
// Finds dependencies between outside compilation clusters, including both data
// dependencies and control dependencies. cluster_deps maps the name name of an
// outside compilation cluster to a set of names of outside compilation clusters
// that it depends on.
stream_executor::port::StatusOr<
std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(
const Graph* g, const string& outside_compilation_attr_name);
// Preprocesses edges within the same XLA cluster. It will perform the following
// operations in order:
//
// 0. Remove edges from source node to outside compilation nodes, and edges
// from outside compilation nodes to sink node.
// 1a. For edges between different outside compilation clusters, remove the edge
// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node
// name" to dst node.
// 1b. For control edges between outside compilation and its XLA computation,
// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
// outside compilation node.
// 2. For data edges between different outside compilations, remove the edge
// and create a Placeholder node as dst node's input.
Status PreprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name);
// Postprocesses edges within the same XLA cluster. This function reverts what
// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the
// following operations in order:
//
// 1. Remove Placeholder nodes between different outside compilations (created
// in `PreprocessEdgesBetweenOutsideCompilations` step 2).
// 2a. Reconnect control edges between different outside compilations (marked by
// `PreprocessEdgesBetweenOutsideCompilations` step 1a).
// Notice that control edges marked by
// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here.
// They are handled in `RewriteOutsideCompilationSubgraphFn`.
Status PostprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_