Integrated graph optimization passes with TFRT for TF2 training support.

These GraphDef based passes are performed before we import the graph into MLIR
TF dialect.

A few notes:

1. Refactored and reused code from the existing stack when applicable, to ensure
consistent behavior (e.g. see CreateOptOptionsForEager()).

2. Used existing grappler passes as in TF2 eager, but instructed grappler not to
"lower V2 control flow to V1".

3. One challenge is that the graph optimization (specifically the grappler
ModelPruner pass) might remove unused input parameters to the function (aka the
"Args" nodes in GraphDef), which affects the function calling protocol and thus
requires corresonding changes to how we feed it the input.

To do so, we track the arg index mapping between the original function and the
optimized/pruned one, in FunctionState::arg_indices_, similar to
ComponentFunctionData::arg_indices in the current runtime.

4. The graph optimizations added by this CL do not involve graph/device
partitioning. Here we assume the input graph function is on a single host, but
might involve multiple devices.

5. Made FunctionCache::GetOrAddFunction() return tensorflow::Status instead of
llvm::Expected, to improve e2e error propagation (that function calls some other
functions that return Status, and previously drops the error info when
translating to llvm::Expected). This also made some code simpler (e.g. can use
TF_RETURN_IF_ERROR).

6. PRE_PLACEMENT optimization passes are not yet added. Calling them caused more
test failures, and need more investigation.

PiperOrigin-RevId: 347035680
Change-Id: I969154f5888fe665eea336ac06fffe50a897c13e
This commit is contained in:
Mingsheng Hong 2020-12-11 11:10:44 -08:00 committed by TensorFlower Gardener
parent 4279f36afb
commit cae04d6d4d
9 changed files with 85 additions and 59 deletions

View File

@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:annotated_traceme",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:meta_optimizer", "//tensorflow/core/grappler/optimizers:meta_optimizer",
] ]

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h"
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
"Failed to parse config_proto attribute as tensorflow::ConfigProto " "Failed to parse config_proto attribute as tensorflow::ConfigProto "
"proto."); "proto.");
} }
grappler::GrapplerItem::OptimizationOptions optimization_options; grappler::GrapplerItem::OptimizationOptions optimization_options =
grappler::CreateOptOptionsForEager();
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
optimization_options.is_eager_mode = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
options.optimize_graph_fn = std::bind( options.optimize_graph_fn = std::bind(
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2, grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
// In Eager mode we always inline all functions into the top-level // In Eager mode we always inline all functions into the top-level
// function body graph, to get a single executable graph, that could be // function body graph, to get a single executable graph, that could be
// optimized across function boundaries (e.g. prune unused inputs and outputs // optimized across function boundaries (e.g. prune unused inputs and
// in a function call chain). This is required to mimic graph mode execution, // outputs in a function call chain). This is required to mimic graph mode
// with aggressive pruning of nodes not in the transitive fanin of fetches. // execution, with aggressive pruning of nodes not in the transitive fanin
// of fetches.
options.config_proto.mutable_graph_options() options.config_proto.mutable_graph_options()
->mutable_optimizer_options() ->mutable_optimizer_options()
->set_do_function_inlining(true); ->set_do_function_inlining(true);

View File

@ -74,7 +74,7 @@ Status PartitionFunctionGraph(
} }
Status UpdateArgAndRetvalMetadata( Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type, Graph* graph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices, std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs, std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs) { std::vector<AllocatorAttributes>* ret_alloc_attrs) {
@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata(
// Find the Arg and Retval nodes, along with their corresponding indices // Find the Arg and Retval nodes, along with their corresponding indices
// in the original function. // in the original function.
for (Node* node : subgraph->op_nodes()) { for (Node* node : graph->op_nodes()) {
if (node->IsArg()) { if (node->IsArg()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i()); int index = static_cast<int>(attr_value->i());
@ -124,6 +124,7 @@ Status UpdateArgAndRetvalMetadata(
Node* arg = arg_nodes[i].first; Node* arg = arg_nodes[i].first;
arg->AddAttr("index", i); arg->AddAttr("index", i);
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value)); TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
if (arg_alloc_attrs != nullptr) {
AllocatorAttributes alloc_attr; AllocatorAttributes alloc_attr;
DataType type = attr_value->type(); DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
@ -135,10 +136,12 @@ Status UpdateArgAndRetvalMetadata(
} }
arg_alloc_attrs->push_back(alloc_attr); arg_alloc_attrs->push_back(alloc_attr);
} }
}
for (int i = 0; i < ret_nodes.size(); ++i) { for (int i = 0; i < ret_nodes.size(); ++i) {
Node* ret = ret_nodes[i].first; Node* ret = ret_nodes[i].first;
ret->AddAttr("index", i); ret->AddAttr("index", i);
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value)); TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
if (ret_alloc_attrs) {
AllocatorAttributes alloc_attr; AllocatorAttributes alloc_attr;
DataType type = attr_value->type(); DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
@ -150,6 +153,7 @@ Status UpdateArgAndRetvalMetadata(
} }
ret_alloc_attrs->push_back(alloc_attr); ret_alloc_attrs->push_back(alloc_attr);
} }
}
return Status::OK(); return Status::OK();
} }

View File

@ -34,31 +34,34 @@ Status PartitionFunctionGraph(
const DeviceSet& device_set, std::unique_ptr<Graph> graph, const DeviceSet& device_set, std::unique_ptr<Graph> graph,
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs); std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
// Each subgraph produced by partitioning the function body contains a subset // This function performs bookkeeping to track which `Arg` and `Retval` nodes
// of the original `Arg` and `Retval` nodes. This function performs // were placed on a particular device / graph.
// bookkeeping to track which `Arg` and `Retval` nodes were placed on a
// particular device / subgraph.
// //
// More specifically, this function // More specifically, this function
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
// on a particular device. When a function is partitioned, each
// partition `subgraph` gets a subset of the arguments and
// return values. The `index` attributes of these _Arg and _Retval
// nodes reflect the indices of these parameters in the original
// function. To convert `subgraph` to a function, we need to replace
// there original indices with 0, 1, 2, ... .
// //
// The argument and return value order in the partitioned function is // (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
// determined by the argument and return value order in the original // consecutive.
// function. This stability is important because it enables us to treat //
// a single-partition function as having the same signature as the // These indices might not be consecutive after grappler's pruning
// subgraph. // optimization (e.g. removing redundant Args), or graph partitioning. In
// the latter case, the nodes in `graph` are placed on `device_type`, and
// each such graph partition gets a subset of the arguments and return
// values. The `index` attributes of these _Arg and _Retval nodes reflect
// the indices of these parameters in the original function. To convert
// `subgraph` to a function, we need to replace there original indices with
// 0, 1, 2, ... .
//
// The argument and return value order in `graph` is determined by the
// argument and return value order in the original function. This stability
// is important because it enables us to treat a single-partition function
// as having the same signature as the subgraph.
//
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the // (2) records the subsets of `Arg` and `Retval` nodes assigned to the
// device in `*_indices`, and // device in `*_indices`, and
// (3) records which `Arg` and `Retval` nodes live in host memory in // (3) records which `Arg` and `Retval` nodes live in host memory in
// `*_alloc_attrs`. // `*_alloc_attrs`. If these vectors are NULL, do nothing here.
Status UpdateArgAndRetvalMetadata( Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type, Graph* graph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices, std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs, std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs); std::vector<AllocatorAttributes>* ret_alloc_attrs);

View File

@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions(
return Status::OK(); return Status::OK();
} }
} // anonymous namespace
Status GetGraphAndArgRets( Status GetGraphAndArgRets(
const string& function_name, AttrSlice attrs, const FunctionDef* fdef, const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph, const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
@ -644,8 +646,6 @@ Status GetGraphAndArgRets(
return Status::OK(); return Status::OK();
} }
} // anonymous namespace
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
const string& function_name, AttrSlice attrs, const string& function_name, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options, const FunctionLibraryRuntime::InstantiateOptions& options,

View File

@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
FunctionDef* fdef); FunctionDef* fdef);
Status GetGraphAndArgRets(
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
std::vector<string>* ret_node_names, DataTypeVector* ret_types,
std::vector<string>* control_ret_node_names);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_ #endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_

View File

@ -31,6 +31,24 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
GrapplerItem::OptimizationOptions optimization_options;
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
optimization_options.is_eager_mode = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
return optimization_options;
}
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const { GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
GrapplerItem item; GrapplerItem item;
item.id = id; item.id = id;

View File

@ -29,6 +29,9 @@ limitations under the License.
#include "tensorflow/core/protobuf/queue_runner.pb.h" #include "tensorflow/core/protobuf/queue_runner.pb.h"
namespace tensorflow { namespace tensorflow {
class ConfigProto;
namespace grappler { namespace grappler {
// A TensorFlow model to optimize. // A TensorFlow model to optimize.
@ -133,6 +136,8 @@ struct GrapplerItem {
OptimizationOptions optimization_options_; OptimizationOptions optimization_options_;
}; };
GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
} // end namespace grappler } // end namespace grappler
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -2804,8 +2804,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# Grappler fallback to use the CPU impl even called with GPU function. # Grappler fallback to use the CPU impl even called with GPU function.
self.assertEqual(y_value, 3.0) self.assertEqual(y_value, 3.0)
@test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior '
'equivalent to implementation_selector for function')
def testSwapImplementationInEager(self): def testSwapImplementationInEager(self):
if not context.executing_eagerly(): if not context.executing_eagerly():
self.skipTest('eager only') self.skipTest('eager only')