From cae04d6d4dab50ff06c4b3b665b19fe1f5e82e64 Mon Sep 17 00:00:00 2001 From: Mingsheng Hong Date: Fri, 11 Dec 2020 11:10:44 -0800 Subject: [PATCH] Integrated graph optimization passes with TFRT for TF2 training support. These GraphDef based passes are performed before we import the graph into MLIR TF dialect. A few notes: 1. Refactored and reused code from the existing stack when applicable, to ensure consistent behavior (e.g. see CreateOptOptionsForEager()). 2. Used existing grappler passes as in TF2 eager, but instructed grappler not to "lower V2 control flow to V1". 3. One challenge is that the graph optimization (specifically the grappler ModelPruner pass) might remove unused input parameters to the function (aka the "Args" nodes in GraphDef), which affects the function calling protocol and thus requires corresonding changes to how we feed it the input. To do so, we track the arg index mapping between the original function and the optimized/pruned one, in FunctionState::arg_indices_, similar to ComponentFunctionData::arg_indices in the current runtime. 4. The graph optimizations added by this CL do not involve graph/device partitioning. Here we assume the input graph function is on a single host, but might involve multiple devices. 5. Made FunctionCache::GetOrAddFunction() return tensorflow::Status instead of llvm::Expected, to improve e2e error propagation (that function calls some other functions that return Status, and previously drops the error info when translating to llvm::Expected). This also made some code simpler (e.g. can use TF_RETURN_IF_ERROR). 6. PRE_PLACEMENT optimization passes are not yet added. Calling them caused more test failures, and need more investigation. PiperOrigin-RevId: 347035680 Change-Id: I969154f5888fe665eea336ac06fffe50a897c13e --- tensorflow/core/common_runtime/eager/BUILD | 1 + .../common_runtime/eager/kernel_and_device.cc | 24 +++------- .../core/common_runtime/partitioning_utils.cc | 44 ++++++++++--------- .../core/common_runtime/partitioning_utils.h | 39 ++++++++-------- .../process_function_library_runtime.cc | 4 +- .../core/framework/graph_to_functiondef.h | 7 +++ tensorflow/core/grappler/grappler_item.cc | 18 ++++++++ tensorflow/core/grappler/grappler_item.h | 5 +++ tensorflow/python/eager/function_test.py | 2 - 9 files changed, 85 insertions(+), 59 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 2ae079be53c..cde26d76611 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [ "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:meta_optimizer", ] diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 3b73f018c0d..79b9179de6b 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h" #if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #endif // !IS_MOBILE_PLATFORM @@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx, "Failed to parse config_proto attribute as tensorflow::ConfigProto " "proto."); } - grappler::GrapplerItem::OptimizationOptions optimization_options; - - // Tensorflow 2.0 in eager mode with automatic control dependencies will - // prune all nodes that are not in the transitive fanin of the fetch nodes. - // However because the function will be executed via FunctionLibraryRuntime, - // and current function implementation does not prune stateful and dataset - // ops, we rely on Grappler to do the correct graph pruning. - optimization_options.allow_pruning_stateful_and_dataset_ops = true; - - optimization_options.is_eager_mode = true; - - // All the nested function calls will be executed and optimized via - // PartitionedCallOp, there is no need to optimize functions now. - optimization_options.optimize_function_library = false; + grappler::GrapplerItem::OptimizationOptions optimization_options = + grappler::CreateOptOptionsForEager(); options.optimize_graph_fn = std::bind( grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2, @@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx, // In Eager mode we always inline all functions into the top-level // function body graph, to get a single executable graph, that could be - // optimized across function boundaries (e.g. prune unused inputs and outputs - // in a function call chain). This is required to mimic graph mode execution, - // with aggressive pruning of nodes not in the transitive fanin of fetches. + // optimized across function boundaries (e.g. prune unused inputs and + // outputs in a function call chain). This is required to mimic graph mode + // execution, with aggressive pruning of nodes not in the transitive fanin + // of fetches. options.config_proto.mutable_graph_options() ->mutable_optimizer_options() ->set_do_function_inlining(true); diff --git a/tensorflow/core/common_runtime/partitioning_utils.cc b/tensorflow/core/common_runtime/partitioning_utils.cc index 6fb7526c512..6cdc9704eb1 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.cc +++ b/tensorflow/core/common_runtime/partitioning_utils.cc @@ -74,7 +74,7 @@ Status PartitionFunctionGraph( } Status UpdateArgAndRetvalMetadata( - Graph* subgraph, const string& device_type, + Graph* graph, const string& device_type, std::vector* arg_indices, std::vector* ret_indices, std::vector* arg_alloc_attrs, std::vector* ret_alloc_attrs) { @@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata( // Find the Arg and Retval nodes, along with their corresponding indices // in the original function. - for (Node* node : subgraph->op_nodes()) { + for (Node* node : graph->op_nodes()) { if (node->IsArg()) { TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); int index = static_cast(attr_value->i()); @@ -124,31 +124,35 @@ Status UpdateArgAndRetvalMetadata( Node* arg = arg_nodes[i].first; arg->AddAttr("index", i); TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value)); - AllocatorAttributes alloc_attr; - DataType type = attr_value->type(); - MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || - device_type == "XLA_GPU") - ? MTypeFromDTypeIntsOnDevice(type) - : MTypeFromDType(type); - if (mtype == HOST_MEMORY) { - alloc_attr.set_on_host(true); + if (arg_alloc_attrs != nullptr) { + AllocatorAttributes alloc_attr; + DataType type = attr_value->type(); + MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || + device_type == "XLA_GPU") + ? MTypeFromDTypeIntsOnDevice(type) + : MTypeFromDType(type); + if (mtype == HOST_MEMORY) { + alloc_attr.set_on_host(true); + } + arg_alloc_attrs->push_back(alloc_attr); } - arg_alloc_attrs->push_back(alloc_attr); } for (int i = 0; i < ret_nodes.size(); ++i) { Node* ret = ret_nodes[i].first; ret->AddAttr("index", i); TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value)); - AllocatorAttributes alloc_attr; - DataType type = attr_value->type(); - MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || - device_type == "XLA_GPU") - ? MTypeFromDTypeIntsOnDevice(type) - : MTypeFromDType(type); - if (mtype == HOST_MEMORY) { - alloc_attr.set_on_host(true); + if (ret_alloc_attrs) { + AllocatorAttributes alloc_attr; + DataType type = attr_value->type(); + MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || + device_type == "XLA_GPU") + ? MTypeFromDTypeIntsOnDevice(type) + : MTypeFromDType(type); + if (mtype == HOST_MEMORY) { + alloc_attr.set_on_host(true); + } + ret_alloc_attrs->push_back(alloc_attr); } - ret_alloc_attrs->push_back(alloc_attr); } return Status::OK(); diff --git a/tensorflow/core/common_runtime/partitioning_utils.h b/tensorflow/core/common_runtime/partitioning_utils.h index 1eb17423de0..32bc36bcdae 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.h +++ b/tensorflow/core/common_runtime/partitioning_utils.h @@ -34,31 +34,34 @@ Status PartitionFunctionGraph( const DeviceSet& device_set, std::unique_ptr graph, std::unordered_map>* subgraphs); -// Each subgraph produced by partitioning the function body contains a subset -// of the original `Arg` and `Retval` nodes. This function performs -// bookkeeping to track which `Arg` and `Retval` nodes were placed on a -// particular device / subgraph. +// This function performs bookkeeping to track which `Arg` and `Retval` nodes +// were placed on a particular device / graph. // // More specifically, this function -// (1) rewrites the indices of the `Arg` and `Retval` nodes placed -// on a particular device. When a function is partitioned, each -// partition `subgraph` gets a subset of the arguments and -// return values. The `index` attributes of these _Arg and _Retval -// nodes reflect the indices of these parameters in the original -// function. To convert `subgraph` to a function, we need to replace -// there original indices with 0, 1, 2, ... . // -// The argument and return value order in the partitioned function is -// determined by the argument and return value order in the original -// function. This stability is important because it enables us to treat -// a single-partition function as having the same signature as the -// subgraph. +// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be +// consecutive. +// +// These indices might not be consecutive after grappler's pruning +// optimization (e.g. removing redundant Args), or graph partitioning. In +// the latter case, the nodes in `graph` are placed on `device_type`, and +// each such graph partition gets a subset of the arguments and return +// values. The `index` attributes of these _Arg and _Retval nodes reflect +// the indices of these parameters in the original function. To convert +// `subgraph` to a function, we need to replace there original indices with +// 0, 1, 2, ... . +// +// The argument and return value order in `graph` is determined by the +// argument and return value order in the original function. This stability +// is important because it enables us to treat a single-partition function +// as having the same signature as the subgraph. +// // (2) records the subsets of `Arg` and `Retval` nodes assigned to the // device in `*_indices`, and // (3) records which `Arg` and `Retval` nodes live in host memory in -// `*_alloc_attrs`. +// `*_alloc_attrs`. If these vectors are NULL, do nothing here. Status UpdateArgAndRetvalMetadata( - Graph* subgraph, const string& device_type, + Graph* graph, const string& device_type, std::vector* arg_indices, std::vector* ret_indices, std::vector* arg_alloc_attrs, std::vector* ret_alloc_attrs); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 50f3b52e4c6..e8b5678afe0 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions( return Status::OK(); } +} // anonymous namespace + Status GetGraphAndArgRets( const string& function_name, AttrSlice attrs, const FunctionDef* fdef, const FunctionLibraryDefinition* lib_def, std::unique_ptr* graph, @@ -644,8 +646,6 @@ Status GetGraphAndArgRets( return Status::OK(); } -} // anonymous namespace - Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( const string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, diff --git a/tensorflow/core/framework/graph_to_functiondef.h b/tensorflow/core/framework/graph_to_functiondef.h index 834bf50accc..83e56caed77 100644 --- a/tensorflow/core/framework/graph_to_functiondef.h +++ b/tensorflow/core/framework/graph_to_functiondef.h @@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, const std::vector& output_names, FunctionDef* fdef); +Status GetGraphAndArgRets( + const string& function_name, AttrSlice attrs, const FunctionDef* fdef, + const FunctionLibraryDefinition* lib_def, std::unique_ptr* graph, + std::vector* arg_nodes, std::vector* ret_nodes, + std::vector* ret_node_names, DataTypeVector* ret_types, + std::vector* control_ret_node_names); + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_ diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 4b5845698d8..89892453475 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -31,6 +31,24 @@ limitations under the License. namespace tensorflow { namespace grappler { +GrapplerItem::OptimizationOptions CreateOptOptionsForEager() { + GrapplerItem::OptimizationOptions optimization_options; + // Tensorflow 2.0 in eager mode with automatic control dependencies will + // prune all nodes that are not in the transitive fanin of the fetch nodes. + // However because the function will be executed via FunctionLibraryRuntime, + // and current function implementation does not prune stateful and dataset + // ops, we rely on Grappler to do the correct graph pruning. + optimization_options.allow_pruning_stateful_and_dataset_ops = true; + + optimization_options.is_eager_mode = true; + + // All the nested function calls will be executed and optimized via + // PartitionedCallOp, there is no need to optimize functions now. + optimization_options.optimize_function_library = false; + + return optimization_options; +} + GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const { GrapplerItem item; item.id = id; diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index 99d6d2c4566..bab435cbfac 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -29,6 +29,9 @@ limitations under the License. #include "tensorflow/core/protobuf/queue_runner.pb.h" namespace tensorflow { + +class ConfigProto; + namespace grappler { // A TensorFlow model to optimize. @@ -133,6 +136,8 @@ struct GrapplerItem { OptimizationOptions optimization_options_; }; +GrapplerItem::OptimizationOptions CreateOptOptionsForEager(); + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index c2bdc397006..2027434b2b1 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -2804,8 +2804,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase): # Grappler fallback to use the CPU impl even called with GPU function. self.assertEqual(y_value, 3.0) - @test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior ' - 'equivalent to implementation_selector for function') def testSwapImplementationInEager(self): if not context.executing_eagerly(): self.skipTest('eager only')