Integrated graph optimization passes with TFRT for TF2 training support.
These GraphDef based passes are performed before we import the graph into MLIR TF dialect. A few notes: 1. Refactored and reused code from the existing stack when applicable, to ensure consistent behavior (e.g. see CreateOptOptionsForEager()). 2. Used existing grappler passes as in TF2 eager, but instructed grappler not to "lower V2 control flow to V1". 3. One challenge is that the graph optimization (specifically the grappler ModelPruner pass) might remove unused input parameters to the function (aka the "Args" nodes in GraphDef), which affects the function calling protocol and thus requires corresonding changes to how we feed it the input. To do so, we track the arg index mapping between the original function and the optimized/pruned one, in FunctionState::arg_indices_, similar to ComponentFunctionData::arg_indices in the current runtime. 4. The graph optimizations added by this CL do not involve graph/device partitioning. Here we assume the input graph function is on a single host, but might involve multiple devices. 5. Made FunctionCache::GetOrAddFunction() return tensorflow::Status instead of llvm::Expected, to improve e2e error propagation (that function calls some other functions that return Status, and previously drops the error info when translating to llvm::Expected). This also made some code simpler (e.g. can use TF_RETURN_IF_ERROR). 6. PRE_PLACEMENT optimization passes are not yet added. Calling them caused more test failures, and need more investigation. PiperOrigin-RevId: 347035680 Change-Id: I969154f5888fe665eea336ac06fffe50a897c13e
This commit is contained in:
parent
4279f36afb
commit
cae04d6d4d
@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/profiler/lib:annotated_traceme",
|
"//tensorflow/core/profiler/lib:annotated_traceme",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
|||||||
"Failed to parse config_proto attribute as tensorflow::ConfigProto "
|
"Failed to parse config_proto attribute as tensorflow::ConfigProto "
|
||||||
"proto.");
|
"proto.");
|
||||||
}
|
}
|
||||||
grappler::GrapplerItem::OptimizationOptions optimization_options;
|
grappler::GrapplerItem::OptimizationOptions optimization_options =
|
||||||
|
grappler::CreateOptOptionsForEager();
|
||||||
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
|
||||||
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
|
||||||
// However because the function will be executed via FunctionLibraryRuntime,
|
|
||||||
// and current function implementation does not prune stateful and dataset
|
|
||||||
// ops, we rely on Grappler to do the correct graph pruning.
|
|
||||||
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
|
||||||
|
|
||||||
optimization_options.is_eager_mode = true;
|
|
||||||
|
|
||||||
// All the nested function calls will be executed and optimized via
|
|
||||||
// PartitionedCallOp, there is no need to optimize functions now.
|
|
||||||
optimization_options.optimize_function_library = false;
|
|
||||||
|
|
||||||
options.optimize_graph_fn = std::bind(
|
options.optimize_graph_fn = std::bind(
|
||||||
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
|
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
|
||||||
@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
|||||||
|
|
||||||
// In Eager mode we always inline all functions into the top-level
|
// In Eager mode we always inline all functions into the top-level
|
||||||
// function body graph, to get a single executable graph, that could be
|
// function body graph, to get a single executable graph, that could be
|
||||||
// optimized across function boundaries (e.g. prune unused inputs and outputs
|
// optimized across function boundaries (e.g. prune unused inputs and
|
||||||
// in a function call chain). This is required to mimic graph mode execution,
|
// outputs in a function call chain). This is required to mimic graph mode
|
||||||
// with aggressive pruning of nodes not in the transitive fanin of fetches.
|
// execution, with aggressive pruning of nodes not in the transitive fanin
|
||||||
|
// of fetches.
|
||||||
options.config_proto.mutable_graph_options()
|
options.config_proto.mutable_graph_options()
|
||||||
->mutable_optimizer_options()
|
->mutable_optimizer_options()
|
||||||
->set_do_function_inlining(true);
|
->set_do_function_inlining(true);
|
||||||
|
@ -74,7 +74,7 @@ Status PartitionFunctionGraph(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status UpdateArgAndRetvalMetadata(
|
Status UpdateArgAndRetvalMetadata(
|
||||||
Graph* subgraph, const string& device_type,
|
Graph* graph, const string& device_type,
|
||||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||||
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
||||||
@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
|
|
||||||
// Find the Arg and Retval nodes, along with their corresponding indices
|
// Find the Arg and Retval nodes, along with their corresponding indices
|
||||||
// in the original function.
|
// in the original function.
|
||||||
for (Node* node : subgraph->op_nodes()) {
|
for (Node* node : graph->op_nodes()) {
|
||||||
if (node->IsArg()) {
|
if (node->IsArg()) {
|
||||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||||
int index = static_cast<int>(attr_value->i());
|
int index = static_cast<int>(attr_value->i());
|
||||||
@ -124,6 +124,7 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
Node* arg = arg_nodes[i].first;
|
Node* arg = arg_nodes[i].first;
|
||||||
arg->AddAttr("index", i);
|
arg->AddAttr("index", i);
|
||||||
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
|
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
|
||||||
|
if (arg_alloc_attrs != nullptr) {
|
||||||
AllocatorAttributes alloc_attr;
|
AllocatorAttributes alloc_attr;
|
||||||
DataType type = attr_value->type();
|
DataType type = attr_value->type();
|
||||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||||
@ -135,10 +136,12 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
}
|
}
|
||||||
arg_alloc_attrs->push_back(alloc_attr);
|
arg_alloc_attrs->push_back(alloc_attr);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
for (int i = 0; i < ret_nodes.size(); ++i) {
|
for (int i = 0; i < ret_nodes.size(); ++i) {
|
||||||
Node* ret = ret_nodes[i].first;
|
Node* ret = ret_nodes[i].first;
|
||||||
ret->AddAttr("index", i);
|
ret->AddAttr("index", i);
|
||||||
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
|
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
|
||||||
|
if (ret_alloc_attrs) {
|
||||||
AllocatorAttributes alloc_attr;
|
AllocatorAttributes alloc_attr;
|
||||||
DataType type = attr_value->type();
|
DataType type = attr_value->type();
|
||||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||||
@ -150,6 +153,7 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
}
|
}
|
||||||
ret_alloc_attrs->push_back(alloc_attr);
|
ret_alloc_attrs->push_back(alloc_attr);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -34,31 +34,34 @@ Status PartitionFunctionGraph(
|
|||||||
const DeviceSet& device_set, std::unique_ptr<Graph> graph,
|
const DeviceSet& device_set, std::unique_ptr<Graph> graph,
|
||||||
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
|
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
|
||||||
|
|
||||||
// Each subgraph produced by partitioning the function body contains a subset
|
// This function performs bookkeeping to track which `Arg` and `Retval` nodes
|
||||||
// of the original `Arg` and `Retval` nodes. This function performs
|
// were placed on a particular device / graph.
|
||||||
// bookkeeping to track which `Arg` and `Retval` nodes were placed on a
|
|
||||||
// particular device / subgraph.
|
|
||||||
//
|
//
|
||||||
// More specifically, this function
|
// More specifically, this function
|
||||||
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
|
|
||||||
// on a particular device. When a function is partitioned, each
|
|
||||||
// partition `subgraph` gets a subset of the arguments and
|
|
||||||
// return values. The `index` attributes of these _Arg and _Retval
|
|
||||||
// nodes reflect the indices of these parameters in the original
|
|
||||||
// function. To convert `subgraph` to a function, we need to replace
|
|
||||||
// there original indices with 0, 1, 2, ... .
|
|
||||||
//
|
//
|
||||||
// The argument and return value order in the partitioned function is
|
// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
|
||||||
// determined by the argument and return value order in the original
|
// consecutive.
|
||||||
// function. This stability is important because it enables us to treat
|
//
|
||||||
// a single-partition function as having the same signature as the
|
// These indices might not be consecutive after grappler's pruning
|
||||||
// subgraph.
|
// optimization (e.g. removing redundant Args), or graph partitioning. In
|
||||||
|
// the latter case, the nodes in `graph` are placed on `device_type`, and
|
||||||
|
// each such graph partition gets a subset of the arguments and return
|
||||||
|
// values. The `index` attributes of these _Arg and _Retval nodes reflect
|
||||||
|
// the indices of these parameters in the original function. To convert
|
||||||
|
// `subgraph` to a function, we need to replace there original indices with
|
||||||
|
// 0, 1, 2, ... .
|
||||||
|
//
|
||||||
|
// The argument and return value order in `graph` is determined by the
|
||||||
|
// argument and return value order in the original function. This stability
|
||||||
|
// is important because it enables us to treat a single-partition function
|
||||||
|
// as having the same signature as the subgraph.
|
||||||
|
//
|
||||||
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
|
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
|
||||||
// device in `*_indices`, and
|
// device in `*_indices`, and
|
||||||
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
||||||
// `*_alloc_attrs`.
|
// `*_alloc_attrs`. If these vectors are NULL, do nothing here.
|
||||||
Status UpdateArgAndRetvalMetadata(
|
Status UpdateArgAndRetvalMetadata(
|
||||||
Graph* subgraph, const string& device_type,
|
Graph* graph, const string& device_type,
|
||||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||||
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
||||||
|
@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
Status GetGraphAndArgRets(
|
Status GetGraphAndArgRets(
|
||||||
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
||||||
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
||||||
@ -644,8 +646,6 @@ Status GetGraphAndArgRets(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||||
const string& function_name, AttrSlice attrs,
|
const string& function_name, AttrSlice attrs,
|
||||||
const FunctionLibraryRuntime::InstantiateOptions& options,
|
const FunctionLibraryRuntime::InstantiateOptions& options,
|
||||||
|
@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
|||||||
const std::vector<std::string>& output_names,
|
const std::vector<std::string>& output_names,
|
||||||
FunctionDef* fdef);
|
FunctionDef* fdef);
|
||||||
|
|
||||||
|
Status GetGraphAndArgRets(
|
||||||
|
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
||||||
|
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
||||||
|
std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
|
||||||
|
std::vector<string>* ret_node_names, DataTypeVector* ret_types,
|
||||||
|
std::vector<string>* control_ret_node_names);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
|
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
|
||||||
|
@ -31,6 +31,24 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
|
||||||
|
GrapplerItem::OptimizationOptions optimization_options;
|
||||||
|
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
||||||
|
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
||||||
|
// However because the function will be executed via FunctionLibraryRuntime,
|
||||||
|
// and current function implementation does not prune stateful and dataset
|
||||||
|
// ops, we rely on Grappler to do the correct graph pruning.
|
||||||
|
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
||||||
|
|
||||||
|
optimization_options.is_eager_mode = true;
|
||||||
|
|
||||||
|
// All the nested function calls will be executed and optimized via
|
||||||
|
// PartitionedCallOp, there is no need to optimize functions now.
|
||||||
|
optimization_options.optimize_function_library = false;
|
||||||
|
|
||||||
|
return optimization_options;
|
||||||
|
}
|
||||||
|
|
||||||
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
|
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.id = id;
|
item.id = id;
|
||||||
|
@ -29,6 +29,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/protobuf/queue_runner.pb.h"
|
#include "tensorflow/core/protobuf/queue_runner.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class ConfigProto;
|
||||||
|
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
// A TensorFlow model to optimize.
|
// A TensorFlow model to optimize.
|
||||||
@ -133,6 +136,8 @@ struct GrapplerItem {
|
|||||||
OptimizationOptions optimization_options_;
|
OptimizationOptions optimization_options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
@ -2804,8 +2804,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
# Grappler fallback to use the CPU impl even called with GPU function.
|
# Grappler fallback to use the CPU impl even called with GPU function.
|
||||||
self.assertEqual(y_value, 3.0)
|
self.assertEqual(y_value, 3.0)
|
||||||
|
|
||||||
@test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior '
|
|
||||||
'equivalent to implementation_selector for function')
|
|
||||||
def testSwapImplementationInEager(self):
|
def testSwapImplementationInEager(self):
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
self.skipTest('eager only')
|
self.skipTest('eager only')
|
||||||
|
Loading…
Reference in New Issue
Block a user