NFC: Refactored existing runtime code to prepare for TFRT/TF2 training's integration points.

PiperOrigin-RevId: 347678597
Change-Id: I852639a438b33618f0e6c4a5991099ff11fcd33c
This commit is contained in:
Mingsheng Hong 2020-12-15 13:16:43 -08:00 committed by TensorFlower Gardener
parent 849bcce3b0
commit 2584650b59
8 changed files with 82 additions and 57 deletions

View File

@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:annotated_traceme",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
]

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#endif // !IS_MOBILE_PLATFORM
@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
"Failed to parse config_proto attribute as tensorflow::ConfigProto "
"proto.");
}
grappler::GrapplerItem::OptimizationOptions optimization_options;
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
optimization_options.is_eager_mode = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
grappler::GrapplerItem::OptimizationOptions optimization_options =
grappler::CreateOptOptionsForEager();
options.optimize_graph_fn = std::bind(
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
// In Eager mode we always inline all functions into the top-level
// function body graph, to get a single executable graph, that could be
// optimized across function boundaries (e.g. prune unused inputs and outputs
// in a function call chain). This is required to mimic graph mode execution,
// with aggressive pruning of nodes not in the transitive fanin of fetches.
// optimized across function boundaries (e.g. prune unused inputs and
// outputs in a function call chain). This is required to mimic graph mode
// execution, with aggressive pruning of nodes not in the transitive fanin
// of fetches.
options.config_proto.mutable_graph_options()
->mutable_optimizer_options()
->set_do_function_inlining(true);

View File

@ -74,7 +74,7 @@ Status PartitionFunctionGraph(
}
Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type,
Graph* graph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata(
// Find the Arg and Retval nodes, along with their corresponding indices
// in the original function.
for (Node* node : subgraph->op_nodes()) {
for (Node* node : graph->op_nodes()) {
if (node->IsArg()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i());
@ -124,31 +124,35 @@ Status UpdateArgAndRetvalMetadata(
Node* arg = arg_nodes[i].first;
arg->AddAttr("index", i);
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
device_type == "XLA_GPU")
? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
if (arg_alloc_attrs != nullptr) {
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
device_type == "XLA_GPU")
? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
}
arg_alloc_attrs->push_back(alloc_attr);
}
arg_alloc_attrs->push_back(alloc_attr);
}
for (int i = 0; i < ret_nodes.size(); ++i) {
Node* ret = ret_nodes[i].first;
ret->AddAttr("index", i);
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
device_type == "XLA_GPU")
? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
if (ret_alloc_attrs) {
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
device_type == "XLA_GPU")
? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
}
ret_alloc_attrs->push_back(alloc_attr);
}
ret_alloc_attrs->push_back(alloc_attr);
}
return Status::OK();

View File

@ -34,31 +34,34 @@ Status PartitionFunctionGraph(
const DeviceSet& device_set, std::unique_ptr<Graph> graph,
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
// Each subgraph produced by partitioning the function body contains a subset
// of the original `Arg` and `Retval` nodes. This function performs
// bookkeeping to track which `Arg` and `Retval` nodes were placed on a
// particular device / subgraph.
// This function performs bookkeeping to track which `Arg` and `Retval` nodes
// were placed on a particular device / graph.
//
// More specifically, this function
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
// on a particular device. When a function is partitioned, each
// partition `subgraph` gets a subset of the arguments and
// return values. The `index` attributes of these _Arg and _Retval
// nodes reflect the indices of these parameters in the original
// function. To convert `subgraph` to a function, we need to replace
// there original indices with 0, 1, 2, ... .
//
// The argument and return value order in the partitioned function is
// determined by the argument and return value order in the original
// function. This stability is important because it enables us to treat
// a single-partition function as having the same signature as the
// subgraph.
// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
// consecutive.
//
// These indices might not be consecutive after grappler's pruning
// optimization (e.g. removing redundant Args), or graph partitioning. In
// the latter case, the nodes in `graph` are placed on `device_type`, and
// each such graph partition gets a subset of the arguments and return
// values. The `index` attributes of these _Arg and _Retval nodes reflect
// the indices of these parameters in the original function. To convert
// `subgraph` to a function, we need to replace there original indices with
// 0, 1, 2, ... .
//
// The argument and return value order in `graph` is determined by the
// argument and return value order in the original function. This stability
// is important because it enables us to treat a single-partition function
// as having the same signature as the subgraph.
//
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
// device in `*_indices`, and
// (3) records which `Arg` and `Retval` nodes live in host memory in
// `*_alloc_attrs`.
// `*_alloc_attrs`. If these vectors are NULL, do nothing here.
Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type,
Graph* graph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs);

View File

@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions(
return Status::OK();
}
} // anonymous namespace
Status GetGraphAndArgRets(
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
@ -644,8 +646,6 @@ Status GetGraphAndArgRets(
return Status::OK();
}
} // anonymous namespace
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
const string& function_name, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,

View File

@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
const std::vector<std::string>& output_names,
FunctionDef* fdef);
Status GetGraphAndArgRets(
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
std::vector<string>* ret_node_names, DataTypeVector* ret_types,
std::vector<string>* control_ret_node_names);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_

View File

@ -31,6 +31,24 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
GrapplerItem::OptimizationOptions optimization_options;
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
optimization_options.is_eager_mode = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
return optimization_options;
}
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
GrapplerItem item;
item.id = id;

View File

@ -133,6 +133,8 @@ struct GrapplerItem {
OptimizationOptions optimization_options_;
};
GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
} // end namespace grappler
} // end namespace tensorflow