diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 2ae079be53c..cde26d76611 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [ "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:meta_optimizer", ] diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 3b73f018c0d..79b9179de6b 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h" #if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #endif // !IS_MOBILE_PLATFORM @@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx, "Failed to parse config_proto attribute as tensorflow::ConfigProto " "proto."); } - grappler::GrapplerItem::OptimizationOptions optimization_options; - - // Tensorflow 2.0 in eager mode with automatic control dependencies will - // prune all nodes that are not in the transitive fanin of the fetch nodes. - // However because the function will be executed via FunctionLibraryRuntime, - // and current function implementation does not prune stateful and dataset - // ops, we rely on Grappler to do the correct graph pruning. - optimization_options.allow_pruning_stateful_and_dataset_ops = true; - - optimization_options.is_eager_mode = true; - - // All the nested function calls will be executed and optimized via - // PartitionedCallOp, there is no need to optimize functions now. - optimization_options.optimize_function_library = false; + grappler::GrapplerItem::OptimizationOptions optimization_options = + grappler::CreateOptOptionsForEager(); options.optimize_graph_fn = std::bind( grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2, @@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx, // In Eager mode we always inline all functions into the top-level // function body graph, to get a single executable graph, that could be - // optimized across function boundaries (e.g. prune unused inputs and outputs - // in a function call chain). This is required to mimic graph mode execution, - // with aggressive pruning of nodes not in the transitive fanin of fetches. + // optimized across function boundaries (e.g. prune unused inputs and + // outputs in a function call chain). This is required to mimic graph mode + // execution, with aggressive pruning of nodes not in the transitive fanin + // of fetches. options.config_proto.mutable_graph_options() ->mutable_optimizer_options() ->set_do_function_inlining(true); diff --git a/tensorflow/core/common_runtime/partitioning_utils.cc b/tensorflow/core/common_runtime/partitioning_utils.cc index 6fb7526c512..6cdc9704eb1 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.cc +++ b/tensorflow/core/common_runtime/partitioning_utils.cc @@ -74,7 +74,7 @@ Status PartitionFunctionGraph( } Status UpdateArgAndRetvalMetadata( - Graph* subgraph, const string& device_type, + Graph* graph, const string& device_type, std::vector* arg_indices, std::vector* ret_indices, std::vector* arg_alloc_attrs, std::vector* ret_alloc_attrs) { @@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata( // Find the Arg and Retval nodes, along with their corresponding indices // in the original function. - for (Node* node : subgraph->op_nodes()) { + for (Node* node : graph->op_nodes()) { if (node->IsArg()) { TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); int index = static_cast(attr_value->i()); @@ -124,31 +124,35 @@ Status UpdateArgAndRetvalMetadata( Node* arg = arg_nodes[i].first; arg->AddAttr("index", i); TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value)); - AllocatorAttributes alloc_attr; - DataType type = attr_value->type(); - MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || - device_type == "XLA_GPU") - ? MTypeFromDTypeIntsOnDevice(type) - : MTypeFromDType(type); - if (mtype == HOST_MEMORY) { - alloc_attr.set_on_host(true); + if (arg_alloc_attrs != nullptr) { + AllocatorAttributes alloc_attr; + DataType type = attr_value->type(); + MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || + device_type == "XLA_GPU") + ? MTypeFromDTypeIntsOnDevice(type) + : MTypeFromDType(type); + if (mtype == HOST_MEMORY) { + alloc_attr.set_on_host(true); + } + arg_alloc_attrs->push_back(alloc_attr); } - arg_alloc_attrs->push_back(alloc_attr); } for (int i = 0; i < ret_nodes.size(); ++i) { Node* ret = ret_nodes[i].first; ret->AddAttr("index", i); TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value)); - AllocatorAttributes alloc_attr; - DataType type = attr_value->type(); - MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || - device_type == "XLA_GPU") - ? MTypeFromDTypeIntsOnDevice(type) - : MTypeFromDType(type); - if (mtype == HOST_MEMORY) { - alloc_attr.set_on_host(true); + if (ret_alloc_attrs) { + AllocatorAttributes alloc_attr; + DataType type = attr_value->type(); + MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || + device_type == "XLA_GPU") + ? MTypeFromDTypeIntsOnDevice(type) + : MTypeFromDType(type); + if (mtype == HOST_MEMORY) { + alloc_attr.set_on_host(true); + } + ret_alloc_attrs->push_back(alloc_attr); } - ret_alloc_attrs->push_back(alloc_attr); } return Status::OK(); diff --git a/tensorflow/core/common_runtime/partitioning_utils.h b/tensorflow/core/common_runtime/partitioning_utils.h index 1eb17423de0..32bc36bcdae 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.h +++ b/tensorflow/core/common_runtime/partitioning_utils.h @@ -34,31 +34,34 @@ Status PartitionFunctionGraph( const DeviceSet& device_set, std::unique_ptr graph, std::unordered_map>* subgraphs); -// Each subgraph produced by partitioning the function body contains a subset -// of the original `Arg` and `Retval` nodes. This function performs -// bookkeeping to track which `Arg` and `Retval` nodes were placed on a -// particular device / subgraph. +// This function performs bookkeeping to track which `Arg` and `Retval` nodes +// were placed on a particular device / graph. // // More specifically, this function -// (1) rewrites the indices of the `Arg` and `Retval` nodes placed -// on a particular device. When a function is partitioned, each -// partition `subgraph` gets a subset of the arguments and -// return values. The `index` attributes of these _Arg and _Retval -// nodes reflect the indices of these parameters in the original -// function. To convert `subgraph` to a function, we need to replace -// there original indices with 0, 1, 2, ... . // -// The argument and return value order in the partitioned function is -// determined by the argument and return value order in the original -// function. This stability is important because it enables us to treat -// a single-partition function as having the same signature as the -// subgraph. +// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be +// consecutive. +// +// These indices might not be consecutive after grappler's pruning +// optimization (e.g. removing redundant Args), or graph partitioning. In +// the latter case, the nodes in `graph` are placed on `device_type`, and +// each such graph partition gets a subset of the arguments and return +// values. The `index` attributes of these _Arg and _Retval nodes reflect +// the indices of these parameters in the original function. To convert +// `subgraph` to a function, we need to replace there original indices with +// 0, 1, 2, ... . +// +// The argument and return value order in `graph` is determined by the +// argument and return value order in the original function. This stability +// is important because it enables us to treat a single-partition function +// as having the same signature as the subgraph. +// // (2) records the subsets of `Arg` and `Retval` nodes assigned to the // device in `*_indices`, and // (3) records which `Arg` and `Retval` nodes live in host memory in -// `*_alloc_attrs`. +// `*_alloc_attrs`. If these vectors are NULL, do nothing here. Status UpdateArgAndRetvalMetadata( - Graph* subgraph, const string& device_type, + Graph* graph, const string& device_type, std::vector* arg_indices, std::vector* ret_indices, std::vector* arg_alloc_attrs, std::vector* ret_alloc_attrs); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 4a8f37eca1f..60b837995bc 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions( return Status::OK(); } +} // anonymous namespace + Status GetGraphAndArgRets( const string& function_name, AttrSlice attrs, const FunctionDef* fdef, const FunctionLibraryDefinition* lib_def, std::unique_ptr* graph, @@ -644,8 +646,6 @@ Status GetGraphAndArgRets( return Status::OK(); } -} // anonymous namespace - Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( const string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, diff --git a/tensorflow/core/framework/graph_to_functiondef.h b/tensorflow/core/framework/graph_to_functiondef.h index 834bf50accc..83e56caed77 100644 --- a/tensorflow/core/framework/graph_to_functiondef.h +++ b/tensorflow/core/framework/graph_to_functiondef.h @@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, const std::vector& output_names, FunctionDef* fdef); +Status GetGraphAndArgRets( + const string& function_name, AttrSlice attrs, const FunctionDef* fdef, + const FunctionLibraryDefinition* lib_def, std::unique_ptr* graph, + std::vector* arg_nodes, std::vector* ret_nodes, + std::vector* ret_node_names, DataTypeVector* ret_types, + std::vector* control_ret_node_names); + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_ diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 4b5845698d8..89892453475 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -31,6 +31,24 @@ limitations under the License. namespace tensorflow { namespace grappler { +GrapplerItem::OptimizationOptions CreateOptOptionsForEager() { + GrapplerItem::OptimizationOptions optimization_options; + // Tensorflow 2.0 in eager mode with automatic control dependencies will + // prune all nodes that are not in the transitive fanin of the fetch nodes. + // However because the function will be executed via FunctionLibraryRuntime, + // and current function implementation does not prune stateful and dataset + // ops, we rely on Grappler to do the correct graph pruning. + optimization_options.allow_pruning_stateful_and_dataset_ops = true; + + optimization_options.is_eager_mode = true; + + // All the nested function calls will be executed and optimized via + // PartitionedCallOp, there is no need to optimize functions now. + optimization_options.optimize_function_library = false; + + return optimization_options; +} + GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const { GrapplerItem item; item.id = id; diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index 99d6d2c4566..7a3900f4515 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -133,6 +133,8 @@ struct GrapplerItem { OptimizationOptions optimization_options_; }; +GrapplerItem::OptimizationOptions CreateOptOptionsForEager(); + } // end namespace grappler } // end namespace tensorflow