Integrated graph optimization passes with TFRT for TF2 training support.

These GraphDef based passes are performed before we import the graph into MLIR
TF dialect.

A few notes:

1. Refactored and reused code from the existing stack when applicable, to ensure
consistent behavior (e.g. see CreateOptOptionsForEager()).

2. Used existing grappler passes as in TF2 eager, but instructed grappler not to
"lower V2 control flow to V1".

3. One challenge is that the graph optimization (specifically the grappler
ModelPruner pass) might remove unused input parameters to the function (aka the
"Args" nodes in GraphDef), which affects the function calling protocol and thus
requires corresonding changes to how we feed it the input.

To do so, we track the arg index mapping between the original function and the
optimized/pruned one, in FunctionState::arg_indices_, similar to
ComponentFunctionData::arg_indices in the current runtime.

4. The graph optimizations added by this CL do not involve graph/device
partitioning. Here we assume the input graph function is on a single host, but
might involve multiple devices.

5. Made FunctionCache::GetOrAddFunction() return tensorflow::Status instead of
llvm::Expected, to improve e2e error propagation (that function calls some other
functions that return Status, and previously drops the error info when
translating to llvm::Expected). This also made some code simpler (e.g. can use
TF_RETURN_IF_ERROR).

6. PRE_PLACEMENT optimization passes are not yet added. Calling them caused more
test failures, and need more investigation.

PiperOrigin-RevId: 347035680
Change-Id: I969154f5888fe665eea336ac06fffe50a897c13e
This commit is contained in:
Mingsheng Hong 2020-12-11 11:10:44 -08:00 committed by TensorFlower Gardener
parent 4279f36afb
commit cae04d6d4d
9 changed files with 85 additions and 59 deletions

View File

@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:annotated_traceme",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
]

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#endif // !IS_MOBILE_PLATFORM
@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
"Failed to parse config_proto attribute as tensorflow::ConfigProto "
"proto.");
}
grappler::GrapplerItem::OptimizationOptions optimization_options;
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
optimization_options.is_eager_mode = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
grappler::GrapplerItem::OptimizationOptions optimization_options =
grappler::CreateOptOptionsForEager();
options.optimize_graph_fn = std::bind(
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
// In Eager mode we always inline all functions into the top-level
// function body graph, to get a single executable graph, that could be
// optimized across function boundaries (e.g. prune unused inputs and outputs
// in a function call chain). This is required to mimic graph mode execution,
// with aggressive pruning of nodes not in the transitive fanin of fetches.
// optimized across function boundaries (e.g. prune unused inputs and
// outputs in a function call chain). This is required to mimic graph mode
// execution, with aggressive pruning of nodes not in the transitive fanin
// of fetches.
options.config_proto.mutable_graph_options()
->mutable_optimizer_options()
->set_do_function_inlining(true);

View File

@ -74,7 +74,7 @@ Status PartitionFunctionGraph(
}
Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type,
Graph* graph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata(
// Find the Arg and Retval nodes, along with their corresponding indices
// in the original function.
for (Node* node : subgraph->op_nodes()) {
for (Node* node : graph->op_nodes()) {
if (node->IsArg()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i());
@ -124,31 +124,35 @@ Status UpdateArgAndRetvalMetadata(
Node* arg = arg_nodes[i].first;
arg->AddAttr("index", i);
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
device_type == "XLA_GPU")
? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
if (arg_alloc_attrs != nullptr) {
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
device_type == "XLA_GPU")
? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
}
arg_alloc_attrs->push_back(alloc_attr);
}
arg_alloc_attrs->push_back(alloc_attr);
}
for (int i = 0; i < ret_nodes.size(); ++i) {
Node* ret = ret_nodes[i].first;
ret->AddAttr("index", i);
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
device_type == "XLA_GPU")
? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
if (ret_alloc_attrs) {
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
device_type == "XLA_GPU")
? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
}
ret_alloc_attrs->push_back(alloc_attr);
}
ret_alloc_attrs->push_back(alloc_attr);
}
return Status::OK();

View File

@ -34,31 +34,34 @@ Status PartitionFunctionGraph(
const DeviceSet& device_set, std::unique_ptr<Graph> graph,
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
// Each subgraph produced by partitioning the function body contains a subset
// of the original `Arg` and `Retval` nodes. This function performs
// bookkeeping to track which `Arg` and `Retval` nodes were placed on a
// particular device / subgraph.
// This function performs bookkeeping to track which `Arg` and `Retval` nodes
// were placed on a particular device / graph.
//
// More specifically, this function
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
// on a particular device. When a function is partitioned, each
// partition `subgraph` gets a subset of the arguments and
// return values. The `index` attributes of these _Arg and _Retval
// nodes reflect the indices of these parameters in the original
// function. To convert `subgraph` to a function, we need to replace
// there original indices with 0, 1, 2, ... .
//
// The argument and return value order in the partitioned function is
// determined by the argument and return value order in the original
// function. This stability is important because it enables us to treat
// a single-partition function as having the same signature as the
// subgraph.
// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
// consecutive.
//
// These indices might not be consecutive after grappler's pruning
// optimization (e.g. removing redundant Args), or graph partitioning. In
// the latter case, the nodes in `graph` are placed on `device_type`, and
// each such graph partition gets a subset of the arguments and return
// values. The `index` attributes of these _Arg and _Retval nodes reflect
// the indices of these parameters in the original function. To convert
// `subgraph` to a function, we need to replace there original indices with
// 0, 1, 2, ... .
//
// The argument and return value order in `graph` is determined by the
// argument and return value order in the original function. This stability
// is important because it enables us to treat a single-partition function
// as having the same signature as the subgraph.
//
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
// device in `*_indices`, and
// (3) records which `Arg` and `Retval` nodes live in host memory in
// `*_alloc_attrs`.
// `*_alloc_attrs`. If these vectors are NULL, do nothing here.
Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type,
Graph* graph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs);

View File

@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions(
return Status::OK();
}
} // anonymous namespace
Status GetGraphAndArgRets(
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
@ -644,8 +646,6 @@ Status GetGraphAndArgRets(
return Status::OK();
}
} // anonymous namespace
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
const string& function_name, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,

View File

@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
const std::vector<std::string>& output_names,
FunctionDef* fdef);
Status GetGraphAndArgRets(
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
std::vector<string>* ret_node_names, DataTypeVector* ret_types,
std::vector<string>* control_ret_node_names);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_

View File

@ -31,6 +31,24 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
GrapplerItem::OptimizationOptions optimization_options;
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
optimization_options.is_eager_mode = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
return optimization_options;
}
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
GrapplerItem item;
item.id = id;

View File

@ -29,6 +29,9 @@ limitations under the License.
#include "tensorflow/core/protobuf/queue_runner.pb.h"
namespace tensorflow {
class ConfigProto;
namespace grappler {
// A TensorFlow model to optimize.
@ -133,6 +136,8 @@ struct GrapplerItem {
OptimizationOptions optimization_options_;
};
GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
} // end namespace grappler
} // end namespace tensorflow

View File

@ -2804,8 +2804,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# Grappler fallback to use the CPU impl even called with GPU function.
self.assertEqual(y_value, 3.0)
@test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior '
'equivalent to implementation_selector for function')
def testSwapImplementationInEager(self):
if not context.executing_eagerly():
self.skipTest('eager only')