NFC: Refactored existing runtime code to prepare for TFRT/TF2 training's integration points.

PiperOrigin-RevId: 347678597
Change-Id: I852639a438b33618f0e6c4a5991099ff11fcd33c
This commit is contained in:
Mingsheng Hong 2020-12-15 13:16:43 -08:00 committed by TensorFlower Gardener
parent 849bcce3b0
commit 2584650b59
8 changed files with 82 additions and 57 deletions

View File

@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:annotated_traceme",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:meta_optimizer", "//tensorflow/core/grappler/optimizers:meta_optimizer",
] ]

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h"
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
"Failed to parse config_proto attribute as tensorflow::ConfigProto " "Failed to parse config_proto attribute as tensorflow::ConfigProto "
"proto."); "proto.");
} }
grappler::GrapplerItem::OptimizationOptions optimization_options; grappler::GrapplerItem::OptimizationOptions optimization_options =
grappler::CreateOptOptionsForEager();
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
optimization_options.is_eager_mode = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
options.optimize_graph_fn = std::bind( options.optimize_graph_fn = std::bind(
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2, grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
// In Eager mode we always inline all functions into the top-level // In Eager mode we always inline all functions into the top-level
// function body graph, to get a single executable graph, that could be // function body graph, to get a single executable graph, that could be
// optimized across function boundaries (e.g. prune unused inputs and outputs // optimized across function boundaries (e.g. prune unused inputs and
// in a function call chain). This is required to mimic graph mode execution, // outputs in a function call chain). This is required to mimic graph mode
// with aggressive pruning of nodes not in the transitive fanin of fetches. // execution, with aggressive pruning of nodes not in the transitive fanin
// of fetches.
options.config_proto.mutable_graph_options() options.config_proto.mutable_graph_options()
->mutable_optimizer_options() ->mutable_optimizer_options()
->set_do_function_inlining(true); ->set_do_function_inlining(true);

View File

@ -74,7 +74,7 @@ Status PartitionFunctionGraph(
} }
Status UpdateArgAndRetvalMetadata( Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type, Graph* graph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices, std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs, std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs) { std::vector<AllocatorAttributes>* ret_alloc_attrs) {
@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata(
// Find the Arg and Retval nodes, along with their corresponding indices // Find the Arg and Retval nodes, along with their corresponding indices
// in the original function. // in the original function.
for (Node* node : subgraph->op_nodes()) { for (Node* node : graph->op_nodes()) {
if (node->IsArg()) { if (node->IsArg()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i()); int index = static_cast<int>(attr_value->i());
@ -124,31 +124,35 @@ Status UpdateArgAndRetvalMetadata(
Node* arg = arg_nodes[i].first; Node* arg = arg_nodes[i].first;
arg->AddAttr("index", i); arg->AddAttr("index", i);
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value)); TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
AllocatorAttributes alloc_attr; if (arg_alloc_attrs != nullptr) {
DataType type = attr_value->type(); AllocatorAttributes alloc_attr;
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || DataType type = attr_value->type();
device_type == "XLA_GPU") MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
? MTypeFromDTypeIntsOnDevice(type) device_type == "XLA_GPU")
: MTypeFromDType(type); ? MTypeFromDTypeIntsOnDevice(type)
if (mtype == HOST_MEMORY) { : MTypeFromDType(type);
alloc_attr.set_on_host(true); if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
}
arg_alloc_attrs->push_back(alloc_attr);
} }
arg_alloc_attrs->push_back(alloc_attr);
} }
for (int i = 0; i < ret_nodes.size(); ++i) { for (int i = 0; i < ret_nodes.size(); ++i) {
Node* ret = ret_nodes[i].first; Node* ret = ret_nodes[i].first;
ret->AddAttr("index", i); ret->AddAttr("index", i);
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value)); TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
AllocatorAttributes alloc_attr; if (ret_alloc_attrs) {
DataType type = attr_value->type(); AllocatorAttributes alloc_attr;
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" || DataType type = attr_value->type();
device_type == "XLA_GPU") MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
? MTypeFromDTypeIntsOnDevice(type) device_type == "XLA_GPU")
: MTypeFromDType(type); ? MTypeFromDTypeIntsOnDevice(type)
if (mtype == HOST_MEMORY) { : MTypeFromDType(type);
alloc_attr.set_on_host(true); if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
}
ret_alloc_attrs->push_back(alloc_attr);
} }
ret_alloc_attrs->push_back(alloc_attr);
} }
return Status::OK(); return Status::OK();

View File

@ -34,31 +34,34 @@ Status PartitionFunctionGraph(
const DeviceSet& device_set, std::unique_ptr<Graph> graph, const DeviceSet& device_set, std::unique_ptr<Graph> graph,
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs); std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
// Each subgraph produced by partitioning the function body contains a subset // This function performs bookkeeping to track which `Arg` and `Retval` nodes
// of the original `Arg` and `Retval` nodes. This function performs // were placed on a particular device / graph.
// bookkeeping to track which `Arg` and `Retval` nodes were placed on a
// particular device / subgraph.
// //
// More specifically, this function // More specifically, this function
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
// on a particular device. When a function is partitioned, each
// partition `subgraph` gets a subset of the arguments and
// return values. The `index` attributes of these _Arg and _Retval
// nodes reflect the indices of these parameters in the original
// function. To convert `subgraph` to a function, we need to replace
// there original indices with 0, 1, 2, ... .
// //
// The argument and return value order in the partitioned function is // (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
// determined by the argument and return value order in the original // consecutive.
// function. This stability is important because it enables us to treat //
// a single-partition function as having the same signature as the // These indices might not be consecutive after grappler's pruning
// subgraph. // optimization (e.g. removing redundant Args), or graph partitioning. In
// the latter case, the nodes in `graph` are placed on `device_type`, and
// each such graph partition gets a subset of the arguments and return
// values. The `index` attributes of these _Arg and _Retval nodes reflect
// the indices of these parameters in the original function. To convert
// `subgraph` to a function, we need to replace there original indices with
// 0, 1, 2, ... .
//
// The argument and return value order in `graph` is determined by the
// argument and return value order in the original function. This stability
// is important because it enables us to treat a single-partition function
// as having the same signature as the subgraph.
//
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the // (2) records the subsets of `Arg` and `Retval` nodes assigned to the
// device in `*_indices`, and // device in `*_indices`, and
// (3) records which `Arg` and `Retval` nodes live in host memory in // (3) records which `Arg` and `Retval` nodes live in host memory in
// `*_alloc_attrs`. // `*_alloc_attrs`. If these vectors are NULL, do nothing here.
Status UpdateArgAndRetvalMetadata( Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type, Graph* graph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices, std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs, std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs); std::vector<AllocatorAttributes>* ret_alloc_attrs);

View File

@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions(
return Status::OK(); return Status::OK();
} }
} // anonymous namespace
Status GetGraphAndArgRets( Status GetGraphAndArgRets(
const string& function_name, AttrSlice attrs, const FunctionDef* fdef, const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph, const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
@ -644,8 +646,6 @@ Status GetGraphAndArgRets(
return Status::OK(); return Status::OK();
} }
} // anonymous namespace
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
const string& function_name, AttrSlice attrs, const string& function_name, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options, const FunctionLibraryRuntime::InstantiateOptions& options,

View File

@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
FunctionDef* fdef); FunctionDef* fdef);
Status GetGraphAndArgRets(
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
std::vector<string>* ret_node_names, DataTypeVector* ret_types,
std::vector<string>* control_ret_node_names);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_ #endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_

View File

@ -31,6 +31,24 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
GrapplerItem::OptimizationOptions optimization_options;
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
optimization_options.is_eager_mode = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
return optimization_options;
}
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const { GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
GrapplerItem item; GrapplerItem item;
item.id = id; item.id = id;

View File

@ -133,6 +133,8 @@ struct GrapplerItem {
OptimizationOptions optimization_options_; OptimizationOptions optimization_options_;
}; };
GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
} // end namespace grappler } // end namespace grappler
} // end namespace tensorflow } // end namespace tensorflow