NFC: Refactored existing runtime code to prepare for TFRT/TF2 training's integration points.
PiperOrigin-RevId: 347678597 Change-Id: I852639a438b33618f0e6c4a5991099ff11fcd33c
This commit is contained in:
parent
849bcce3b0
commit
2584650b59
@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:annotated_traceme",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||
]
|
||||
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
||||
"Failed to parse config_proto attribute as tensorflow::ConfigProto "
|
||||
"proto.");
|
||||
}
|
||||
grappler::GrapplerItem::OptimizationOptions optimization_options;
|
||||
|
||||
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
||||
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
||||
// However because the function will be executed via FunctionLibraryRuntime,
|
||||
// and current function implementation does not prune stateful and dataset
|
||||
// ops, we rely on Grappler to do the correct graph pruning.
|
||||
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
||||
|
||||
optimization_options.is_eager_mode = true;
|
||||
|
||||
// All the nested function calls will be executed and optimized via
|
||||
// PartitionedCallOp, there is no need to optimize functions now.
|
||||
optimization_options.optimize_function_library = false;
|
||||
grappler::GrapplerItem::OptimizationOptions optimization_options =
|
||||
grappler::CreateOptOptionsForEager();
|
||||
|
||||
options.optimize_graph_fn = std::bind(
|
||||
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
|
||||
@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
||||
|
||||
// In Eager mode we always inline all functions into the top-level
|
||||
// function body graph, to get a single executable graph, that could be
|
||||
// optimized across function boundaries (e.g. prune unused inputs and outputs
|
||||
// in a function call chain). This is required to mimic graph mode execution,
|
||||
// with aggressive pruning of nodes not in the transitive fanin of fetches.
|
||||
// optimized across function boundaries (e.g. prune unused inputs and
|
||||
// outputs in a function call chain). This is required to mimic graph mode
|
||||
// execution, with aggressive pruning of nodes not in the transitive fanin
|
||||
// of fetches.
|
||||
options.config_proto.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_do_function_inlining(true);
|
||||
|
@ -74,7 +74,7 @@ Status PartitionFunctionGraph(
|
||||
}
|
||||
|
||||
Status UpdateArgAndRetvalMetadata(
|
||||
Graph* subgraph, const string& device_type,
|
||||
Graph* graph, const string& device_type,
|
||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
||||
@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata(
|
||||
|
||||
// Find the Arg and Retval nodes, along with their corresponding indices
|
||||
// in the original function.
|
||||
for (Node* node : subgraph->op_nodes()) {
|
||||
for (Node* node : graph->op_nodes()) {
|
||||
if (node->IsArg()) {
|
||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||
int index = static_cast<int>(attr_value->i());
|
||||
@ -124,31 +124,35 @@ Status UpdateArgAndRetvalMetadata(
|
||||
Node* arg = arg_nodes[i].first;
|
||||
arg->AddAttr("index", i);
|
||||
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
|
||||
AllocatorAttributes alloc_attr;
|
||||
DataType type = attr_value->type();
|
||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||
device_type == "XLA_GPU")
|
||||
? MTypeFromDTypeIntsOnDevice(type)
|
||||
: MTypeFromDType(type);
|
||||
if (mtype == HOST_MEMORY) {
|
||||
alloc_attr.set_on_host(true);
|
||||
if (arg_alloc_attrs != nullptr) {
|
||||
AllocatorAttributes alloc_attr;
|
||||
DataType type = attr_value->type();
|
||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||
device_type == "XLA_GPU")
|
||||
? MTypeFromDTypeIntsOnDevice(type)
|
||||
: MTypeFromDType(type);
|
||||
if (mtype == HOST_MEMORY) {
|
||||
alloc_attr.set_on_host(true);
|
||||
}
|
||||
arg_alloc_attrs->push_back(alloc_attr);
|
||||
}
|
||||
arg_alloc_attrs->push_back(alloc_attr);
|
||||
}
|
||||
for (int i = 0; i < ret_nodes.size(); ++i) {
|
||||
Node* ret = ret_nodes[i].first;
|
||||
ret->AddAttr("index", i);
|
||||
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
|
||||
AllocatorAttributes alloc_attr;
|
||||
DataType type = attr_value->type();
|
||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||
device_type == "XLA_GPU")
|
||||
? MTypeFromDTypeIntsOnDevice(type)
|
||||
: MTypeFromDType(type);
|
||||
if (mtype == HOST_MEMORY) {
|
||||
alloc_attr.set_on_host(true);
|
||||
if (ret_alloc_attrs) {
|
||||
AllocatorAttributes alloc_attr;
|
||||
DataType type = attr_value->type();
|
||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||
device_type == "XLA_GPU")
|
||||
? MTypeFromDTypeIntsOnDevice(type)
|
||||
: MTypeFromDType(type);
|
||||
if (mtype == HOST_MEMORY) {
|
||||
alloc_attr.set_on_host(true);
|
||||
}
|
||||
ret_alloc_attrs->push_back(alloc_attr);
|
||||
}
|
||||
ret_alloc_attrs->push_back(alloc_attr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -34,31 +34,34 @@ Status PartitionFunctionGraph(
|
||||
const DeviceSet& device_set, std::unique_ptr<Graph> graph,
|
||||
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
|
||||
|
||||
// Each subgraph produced by partitioning the function body contains a subset
|
||||
// of the original `Arg` and `Retval` nodes. This function performs
|
||||
// bookkeeping to track which `Arg` and `Retval` nodes were placed on a
|
||||
// particular device / subgraph.
|
||||
// This function performs bookkeeping to track which `Arg` and `Retval` nodes
|
||||
// were placed on a particular device / graph.
|
||||
//
|
||||
// More specifically, this function
|
||||
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
|
||||
// on a particular device. When a function is partitioned, each
|
||||
// partition `subgraph` gets a subset of the arguments and
|
||||
// return values. The `index` attributes of these _Arg and _Retval
|
||||
// nodes reflect the indices of these parameters in the original
|
||||
// function. To convert `subgraph` to a function, we need to replace
|
||||
// there original indices with 0, 1, 2, ... .
|
||||
//
|
||||
// The argument and return value order in the partitioned function is
|
||||
// determined by the argument and return value order in the original
|
||||
// function. This stability is important because it enables us to treat
|
||||
// a single-partition function as having the same signature as the
|
||||
// subgraph.
|
||||
// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
|
||||
// consecutive.
|
||||
//
|
||||
// These indices might not be consecutive after grappler's pruning
|
||||
// optimization (e.g. removing redundant Args), or graph partitioning. In
|
||||
// the latter case, the nodes in `graph` are placed on `device_type`, and
|
||||
// each such graph partition gets a subset of the arguments and return
|
||||
// values. The `index` attributes of these _Arg and _Retval nodes reflect
|
||||
// the indices of these parameters in the original function. To convert
|
||||
// `subgraph` to a function, we need to replace there original indices with
|
||||
// 0, 1, 2, ... .
|
||||
//
|
||||
// The argument and return value order in `graph` is determined by the
|
||||
// argument and return value order in the original function. This stability
|
||||
// is important because it enables us to treat a single-partition function
|
||||
// as having the same signature as the subgraph.
|
||||
//
|
||||
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
|
||||
// device in `*_indices`, and
|
||||
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
||||
// `*_alloc_attrs`.
|
||||
// `*_alloc_attrs`. If these vectors are NULL, do nothing here.
|
||||
Status UpdateArgAndRetvalMetadata(
|
||||
Graph* subgraph, const string& device_type,
|
||||
Graph* graph, const string& device_type,
|
||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
||||
|
@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status GetGraphAndArgRets(
|
||||
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
||||
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
||||
@ -644,8 +646,6 @@ Status GetGraphAndArgRets(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
const string& function_name, AttrSlice attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& options,
|
||||
|
@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||
const std::vector<std::string>& output_names,
|
||||
FunctionDef* fdef);
|
||||
|
||||
Status GetGraphAndArgRets(
|
||||
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
||||
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
||||
std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
|
||||
std::vector<string>* ret_node_names, DataTypeVector* ret_types,
|
||||
std::vector<string>* control_ret_node_names);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
|
||||
|
@ -31,6 +31,24 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
|
||||
GrapplerItem::OptimizationOptions optimization_options;
|
||||
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
||||
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
||||
// However because the function will be executed via FunctionLibraryRuntime,
|
||||
// and current function implementation does not prune stateful and dataset
|
||||
// ops, we rely on Grappler to do the correct graph pruning.
|
||||
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
||||
|
||||
optimization_options.is_eager_mode = true;
|
||||
|
||||
// All the nested function calls will be executed and optimized via
|
||||
// PartitionedCallOp, there is no need to optimize functions now.
|
||||
optimization_options.optimize_function_library = false;
|
||||
|
||||
return optimization_options;
|
||||
}
|
||||
|
||||
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
|
||||
GrapplerItem item;
|
||||
item.id = id;
|
||||
|
@ -133,6 +133,8 @@ struct GrapplerItem {
|
||||
OptimizationOptions optimization_options_;
|
||||
};
|
||||
|
||||
GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user