Integrated graph optimization passes with TFRT for TF2 training support.
These GraphDef based passes are performed before we import the graph into MLIR TF dialect. A few notes: 1. Refactored and reused code from the existing stack when applicable, to ensure consistent behavior (e.g. see CreateOptOptionsForEager()). 2. Used existing grappler passes as in TF2 eager, but instructed grappler not to "lower V2 control flow to V1". 3. One challenge is that the graph optimization (specifically the grappler ModelPruner pass) might remove unused input parameters to the function (aka the "Args" nodes in GraphDef), which affects the function calling protocol and thus requires corresonding changes to how we feed it the input. To do so, we track the arg index mapping between the original function and the optimized/pruned one, in FunctionState::arg_indices_, similar to ComponentFunctionData::arg_indices in the current runtime. 4. The graph optimizations added by this CL do not involve graph/device partitioning. Here we assume the input graph function is on a single host, but might involve multiple devices. 5. Made FunctionCache::GetOrAddFunction() return tensorflow::Status instead of llvm::Expected, to improve e2e error propagation (that function calls some other functions that return Status, and previously drops the error info when translating to llvm::Expected). This also made some code simpler (e.g. can use TF_RETURN_IF_ERROR). 6. PRE_PLACEMENT optimization passes are not yet added. Calling them caused more test failures, and need more investigation. PiperOrigin-RevId: 347035680 Change-Id: I969154f5888fe665eea336ac06fffe50a897c13e
This commit is contained in:
parent
4279f36afb
commit
cae04d6d4d
@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:annotated_traceme",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||
]
|
||||
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
||||
"Failed to parse config_proto attribute as tensorflow::ConfigProto "
|
||||
"proto.");
|
||||
}
|
||||
grappler::GrapplerItem::OptimizationOptions optimization_options;
|
||||
|
||||
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
||||
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
||||
// However because the function will be executed via FunctionLibraryRuntime,
|
||||
// and current function implementation does not prune stateful and dataset
|
||||
// ops, we rely on Grappler to do the correct graph pruning.
|
||||
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
||||
|
||||
optimization_options.is_eager_mode = true;
|
||||
|
||||
// All the nested function calls will be executed and optimized via
|
||||
// PartitionedCallOp, there is no need to optimize functions now.
|
||||
optimization_options.optimize_function_library = false;
|
||||
grappler::GrapplerItem::OptimizationOptions optimization_options =
|
||||
grappler::CreateOptOptionsForEager();
|
||||
|
||||
options.optimize_graph_fn = std::bind(
|
||||
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
|
||||
@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
||||
|
||||
// In Eager mode we always inline all functions into the top-level
|
||||
// function body graph, to get a single executable graph, that could be
|
||||
// optimized across function boundaries (e.g. prune unused inputs and outputs
|
||||
// in a function call chain). This is required to mimic graph mode execution,
|
||||
// with aggressive pruning of nodes not in the transitive fanin of fetches.
|
||||
// optimized across function boundaries (e.g. prune unused inputs and
|
||||
// outputs in a function call chain). This is required to mimic graph mode
|
||||
// execution, with aggressive pruning of nodes not in the transitive fanin
|
||||
// of fetches.
|
||||
options.config_proto.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_do_function_inlining(true);
|
||||
|
@ -74,7 +74,7 @@ Status PartitionFunctionGraph(
|
||||
}
|
||||
|
||||
Status UpdateArgAndRetvalMetadata(
|
||||
Graph* subgraph, const string& device_type,
|
||||
Graph* graph, const string& device_type,
|
||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
||||
@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata(
|
||||
|
||||
// Find the Arg and Retval nodes, along with their corresponding indices
|
||||
// in the original function.
|
||||
for (Node* node : subgraph->op_nodes()) {
|
||||
for (Node* node : graph->op_nodes()) {
|
||||
if (node->IsArg()) {
|
||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||
int index = static_cast<int>(attr_value->i());
|
||||
@ -124,31 +124,35 @@ Status UpdateArgAndRetvalMetadata(
|
||||
Node* arg = arg_nodes[i].first;
|
||||
arg->AddAttr("index", i);
|
||||
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
|
||||
AllocatorAttributes alloc_attr;
|
||||
DataType type = attr_value->type();
|
||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||
device_type == "XLA_GPU")
|
||||
? MTypeFromDTypeIntsOnDevice(type)
|
||||
: MTypeFromDType(type);
|
||||
if (mtype == HOST_MEMORY) {
|
||||
alloc_attr.set_on_host(true);
|
||||
if (arg_alloc_attrs != nullptr) {
|
||||
AllocatorAttributes alloc_attr;
|
||||
DataType type = attr_value->type();
|
||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||
device_type == "XLA_GPU")
|
||||
? MTypeFromDTypeIntsOnDevice(type)
|
||||
: MTypeFromDType(type);
|
||||
if (mtype == HOST_MEMORY) {
|
||||
alloc_attr.set_on_host(true);
|
||||
}
|
||||
arg_alloc_attrs->push_back(alloc_attr);
|
||||
}
|
||||
arg_alloc_attrs->push_back(alloc_attr);
|
||||
}
|
||||
for (int i = 0; i < ret_nodes.size(); ++i) {
|
||||
Node* ret = ret_nodes[i].first;
|
||||
ret->AddAttr("index", i);
|
||||
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
|
||||
AllocatorAttributes alloc_attr;
|
||||
DataType type = attr_value->type();
|
||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||
device_type == "XLA_GPU")
|
||||
? MTypeFromDTypeIntsOnDevice(type)
|
||||
: MTypeFromDType(type);
|
||||
if (mtype == HOST_MEMORY) {
|
||||
alloc_attr.set_on_host(true);
|
||||
if (ret_alloc_attrs) {
|
||||
AllocatorAttributes alloc_attr;
|
||||
DataType type = attr_value->type();
|
||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||
device_type == "XLA_GPU")
|
||||
? MTypeFromDTypeIntsOnDevice(type)
|
||||
: MTypeFromDType(type);
|
||||
if (mtype == HOST_MEMORY) {
|
||||
alloc_attr.set_on_host(true);
|
||||
}
|
||||
ret_alloc_attrs->push_back(alloc_attr);
|
||||
}
|
||||
ret_alloc_attrs->push_back(alloc_attr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -34,31 +34,34 @@ Status PartitionFunctionGraph(
|
||||
const DeviceSet& device_set, std::unique_ptr<Graph> graph,
|
||||
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
|
||||
|
||||
// Each subgraph produced by partitioning the function body contains a subset
|
||||
// of the original `Arg` and `Retval` nodes. This function performs
|
||||
// bookkeeping to track which `Arg` and `Retval` nodes were placed on a
|
||||
// particular device / subgraph.
|
||||
// This function performs bookkeeping to track which `Arg` and `Retval` nodes
|
||||
// were placed on a particular device / graph.
|
||||
//
|
||||
// More specifically, this function
|
||||
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
|
||||
// on a particular device. When a function is partitioned, each
|
||||
// partition `subgraph` gets a subset of the arguments and
|
||||
// return values. The `index` attributes of these _Arg and _Retval
|
||||
// nodes reflect the indices of these parameters in the original
|
||||
// function. To convert `subgraph` to a function, we need to replace
|
||||
// there original indices with 0, 1, 2, ... .
|
||||
//
|
||||
// The argument and return value order in the partitioned function is
|
||||
// determined by the argument and return value order in the original
|
||||
// function. This stability is important because it enables us to treat
|
||||
// a single-partition function as having the same signature as the
|
||||
// subgraph.
|
||||
// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
|
||||
// consecutive.
|
||||
//
|
||||
// These indices might not be consecutive after grappler's pruning
|
||||
// optimization (e.g. removing redundant Args), or graph partitioning. In
|
||||
// the latter case, the nodes in `graph` are placed on `device_type`, and
|
||||
// each such graph partition gets a subset of the arguments and return
|
||||
// values. The `index` attributes of these _Arg and _Retval nodes reflect
|
||||
// the indices of these parameters in the original function. To convert
|
||||
// `subgraph` to a function, we need to replace there original indices with
|
||||
// 0, 1, 2, ... .
|
||||
//
|
||||
// The argument and return value order in `graph` is determined by the
|
||||
// argument and return value order in the original function. This stability
|
||||
// is important because it enables us to treat a single-partition function
|
||||
// as having the same signature as the subgraph.
|
||||
//
|
||||
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
|
||||
// device in `*_indices`, and
|
||||
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
||||
// `*_alloc_attrs`.
|
||||
// `*_alloc_attrs`. If these vectors are NULL, do nothing here.
|
||||
Status UpdateArgAndRetvalMetadata(
|
||||
Graph* subgraph, const string& device_type,
|
||||
Graph* graph, const string& device_type,
|
||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
||||
|
@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status GetGraphAndArgRets(
|
||||
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
||||
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
||||
@ -644,8 +646,6 @@ Status GetGraphAndArgRets(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
const string& function_name, AttrSlice attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& options,
|
||||
|
@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||
const std::vector<std::string>& output_names,
|
||||
FunctionDef* fdef);
|
||||
|
||||
Status GetGraphAndArgRets(
|
||||
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
||||
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
||||
std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
|
||||
std::vector<string>* ret_node_names, DataTypeVector* ret_types,
|
||||
std::vector<string>* control_ret_node_names);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
|
||||
|
@ -31,6 +31,24 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
|
||||
GrapplerItem::OptimizationOptions optimization_options;
|
||||
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
||||
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
||||
// However because the function will be executed via FunctionLibraryRuntime,
|
||||
// and current function implementation does not prune stateful and dataset
|
||||
// ops, we rely on Grappler to do the correct graph pruning.
|
||||
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
||||
|
||||
optimization_options.is_eager_mode = true;
|
||||
|
||||
// All the nested function calls will be executed and optimized via
|
||||
// PartitionedCallOp, there is no need to optimize functions now.
|
||||
optimization_options.optimize_function_library = false;
|
||||
|
||||
return optimization_options;
|
||||
}
|
||||
|
||||
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
|
||||
GrapplerItem item;
|
||||
item.id = id;
|
||||
|
@ -29,6 +29,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/protobuf/queue_runner.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ConfigProto;
|
||||
|
||||
namespace grappler {
|
||||
|
||||
// A TensorFlow model to optimize.
|
||||
@ -133,6 +136,8 @@ struct GrapplerItem {
|
||||
OptimizationOptions optimization_options_;
|
||||
};
|
||||
|
||||
GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -2804,8 +2804,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
# Grappler fallback to use the CPU impl even called with GPU function.
|
||||
self.assertEqual(y_value, 3.0)
|
||||
|
||||
@test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior '
|
||||
'equivalent to implementation_selector for function')
|
||||
def testSwapImplementationInEager(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('eager only')
|
||||
|
Loading…
Reference in New Issue
Block a user