NFC: Refactored existing runtime code to prepare for TFRT/TF2 training's integration points.
PiperOrigin-RevId: 347678597 Change-Id: I852639a438b33618f0e6c4a5991099ff11fcd33c
This commit is contained in:
parent
849bcce3b0
commit
2584650b59
tensorflow/core
@ -392,6 +392,7 @@ KERNEL_AND_DEVICE_DEPS = [
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/profiler/lib:annotated_traceme",
|
"//tensorflow/core/profiler/lib:annotated_traceme",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
@ -189,20 +190,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
|||||||
"Failed to parse config_proto attribute as tensorflow::ConfigProto "
|
"Failed to parse config_proto attribute as tensorflow::ConfigProto "
|
||||||
"proto.");
|
"proto.");
|
||||||
}
|
}
|
||||||
grappler::GrapplerItem::OptimizationOptions optimization_options;
|
grappler::GrapplerItem::OptimizationOptions optimization_options =
|
||||||
|
grappler::CreateOptOptionsForEager();
|
||||||
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
|
||||||
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
|
||||||
// However because the function will be executed via FunctionLibraryRuntime,
|
|
||||||
// and current function implementation does not prune stateful and dataset
|
|
||||||
// ops, we rely on Grappler to do the correct graph pruning.
|
|
||||||
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
|
||||||
|
|
||||||
optimization_options.is_eager_mode = true;
|
|
||||||
|
|
||||||
// All the nested function calls will be executed and optimized via
|
|
||||||
// PartitionedCallOp, there is no need to optimize functions now.
|
|
||||||
optimization_options.optimize_function_library = false;
|
|
||||||
|
|
||||||
options.optimize_graph_fn = std::bind(
|
options.optimize_graph_fn = std::bind(
|
||||||
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
|
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
|
||||||
@ -215,9 +204,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
|||||||
|
|
||||||
// In Eager mode we always inline all functions into the top-level
|
// In Eager mode we always inline all functions into the top-level
|
||||||
// function body graph, to get a single executable graph, that could be
|
// function body graph, to get a single executable graph, that could be
|
||||||
// optimized across function boundaries (e.g. prune unused inputs and outputs
|
// optimized across function boundaries (e.g. prune unused inputs and
|
||||||
// in a function call chain). This is required to mimic graph mode execution,
|
// outputs in a function call chain). This is required to mimic graph mode
|
||||||
// with aggressive pruning of nodes not in the transitive fanin of fetches.
|
// execution, with aggressive pruning of nodes not in the transitive fanin
|
||||||
|
// of fetches.
|
||||||
options.config_proto.mutable_graph_options()
|
options.config_proto.mutable_graph_options()
|
||||||
->mutable_optimizer_options()
|
->mutable_optimizer_options()
|
||||||
->set_do_function_inlining(true);
|
->set_do_function_inlining(true);
|
||||||
|
@ -74,7 +74,7 @@ Status PartitionFunctionGraph(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status UpdateArgAndRetvalMetadata(
|
Status UpdateArgAndRetvalMetadata(
|
||||||
Graph* subgraph, const string& device_type,
|
Graph* graph, const string& device_type,
|
||||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||||
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
||||||
@ -84,7 +84,7 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
|
|
||||||
// Find the Arg and Retval nodes, along with their corresponding indices
|
// Find the Arg and Retval nodes, along with their corresponding indices
|
||||||
// in the original function.
|
// in the original function.
|
||||||
for (Node* node : subgraph->op_nodes()) {
|
for (Node* node : graph->op_nodes()) {
|
||||||
if (node->IsArg()) {
|
if (node->IsArg()) {
|
||||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||||
int index = static_cast<int>(attr_value->i());
|
int index = static_cast<int>(attr_value->i());
|
||||||
@ -124,31 +124,35 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
Node* arg = arg_nodes[i].first;
|
Node* arg = arg_nodes[i].first;
|
||||||
arg->AddAttr("index", i);
|
arg->AddAttr("index", i);
|
||||||
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
|
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
|
||||||
AllocatorAttributes alloc_attr;
|
if (arg_alloc_attrs != nullptr) {
|
||||||
DataType type = attr_value->type();
|
AllocatorAttributes alloc_attr;
|
||||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
DataType type = attr_value->type();
|
||||||
device_type == "XLA_GPU")
|
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||||
? MTypeFromDTypeIntsOnDevice(type)
|
device_type == "XLA_GPU")
|
||||||
: MTypeFromDType(type);
|
? MTypeFromDTypeIntsOnDevice(type)
|
||||||
if (mtype == HOST_MEMORY) {
|
: MTypeFromDType(type);
|
||||||
alloc_attr.set_on_host(true);
|
if (mtype == HOST_MEMORY) {
|
||||||
|
alloc_attr.set_on_host(true);
|
||||||
|
}
|
||||||
|
arg_alloc_attrs->push_back(alloc_attr);
|
||||||
}
|
}
|
||||||
arg_alloc_attrs->push_back(alloc_attr);
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < ret_nodes.size(); ++i) {
|
for (int i = 0; i < ret_nodes.size(); ++i) {
|
||||||
Node* ret = ret_nodes[i].first;
|
Node* ret = ret_nodes[i].first;
|
||||||
ret->AddAttr("index", i);
|
ret->AddAttr("index", i);
|
||||||
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
|
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
|
||||||
AllocatorAttributes alloc_attr;
|
if (ret_alloc_attrs) {
|
||||||
DataType type = attr_value->type();
|
AllocatorAttributes alloc_attr;
|
||||||
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
DataType type = attr_value->type();
|
||||||
device_type == "XLA_GPU")
|
MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
|
||||||
? MTypeFromDTypeIntsOnDevice(type)
|
device_type == "XLA_GPU")
|
||||||
: MTypeFromDType(type);
|
? MTypeFromDTypeIntsOnDevice(type)
|
||||||
if (mtype == HOST_MEMORY) {
|
: MTypeFromDType(type);
|
||||||
alloc_attr.set_on_host(true);
|
if (mtype == HOST_MEMORY) {
|
||||||
|
alloc_attr.set_on_host(true);
|
||||||
|
}
|
||||||
|
ret_alloc_attrs->push_back(alloc_attr);
|
||||||
}
|
}
|
||||||
ret_alloc_attrs->push_back(alloc_attr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -34,31 +34,34 @@ Status PartitionFunctionGraph(
|
|||||||
const DeviceSet& device_set, std::unique_ptr<Graph> graph,
|
const DeviceSet& device_set, std::unique_ptr<Graph> graph,
|
||||||
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
|
std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
|
||||||
|
|
||||||
// Each subgraph produced by partitioning the function body contains a subset
|
// This function performs bookkeeping to track which `Arg` and `Retval` nodes
|
||||||
// of the original `Arg` and `Retval` nodes. This function performs
|
// were placed on a particular device / graph.
|
||||||
// bookkeeping to track which `Arg` and `Retval` nodes were placed on a
|
|
||||||
// particular device / subgraph.
|
|
||||||
//
|
//
|
||||||
// More specifically, this function
|
// More specifically, this function
|
||||||
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
|
|
||||||
// on a particular device. When a function is partitioned, each
|
|
||||||
// partition `subgraph` gets a subset of the arguments and
|
|
||||||
// return values. The `index` attributes of these _Arg and _Retval
|
|
||||||
// nodes reflect the indices of these parameters in the original
|
|
||||||
// function. To convert `subgraph` to a function, we need to replace
|
|
||||||
// there original indices with 0, 1, 2, ... .
|
|
||||||
//
|
//
|
||||||
// The argument and return value order in the partitioned function is
|
// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be
|
||||||
// determined by the argument and return value order in the original
|
// consecutive.
|
||||||
// function. This stability is important because it enables us to treat
|
//
|
||||||
// a single-partition function as having the same signature as the
|
// These indices might not be consecutive after grappler's pruning
|
||||||
// subgraph.
|
// optimization (e.g. removing redundant Args), or graph partitioning. In
|
||||||
|
// the latter case, the nodes in `graph` are placed on `device_type`, and
|
||||||
|
// each such graph partition gets a subset of the arguments and return
|
||||||
|
// values. The `index` attributes of these _Arg and _Retval nodes reflect
|
||||||
|
// the indices of these parameters in the original function. To convert
|
||||||
|
// `subgraph` to a function, we need to replace there original indices with
|
||||||
|
// 0, 1, 2, ... .
|
||||||
|
//
|
||||||
|
// The argument and return value order in `graph` is determined by the
|
||||||
|
// argument and return value order in the original function. This stability
|
||||||
|
// is important because it enables us to treat a single-partition function
|
||||||
|
// as having the same signature as the subgraph.
|
||||||
|
//
|
||||||
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
|
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
|
||||||
// device in `*_indices`, and
|
// device in `*_indices`, and
|
||||||
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
||||||
// `*_alloc_attrs`.
|
// `*_alloc_attrs`. If these vectors are NULL, do nothing here.
|
||||||
Status UpdateArgAndRetvalMetadata(
|
Status UpdateArgAndRetvalMetadata(
|
||||||
Graph* subgraph, const string& device_type,
|
Graph* graph, const string& device_type,
|
||||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||||
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
||||||
|
@ -608,6 +608,8 @@ Status ValidateMultiDeviceOptions(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
Status GetGraphAndArgRets(
|
Status GetGraphAndArgRets(
|
||||||
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
||||||
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
||||||
@ -644,8 +646,6 @@ Status GetGraphAndArgRets(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||||
const string& function_name, AttrSlice attrs,
|
const string& function_name, AttrSlice attrs,
|
||||||
const FunctionLibraryRuntime::InstantiateOptions& options,
|
const FunctionLibraryRuntime::InstantiateOptions& options,
|
||||||
|
@ -60,6 +60,13 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
|||||||
const std::vector<std::string>& output_names,
|
const std::vector<std::string>& output_names,
|
||||||
FunctionDef* fdef);
|
FunctionDef* fdef);
|
||||||
|
|
||||||
|
Status GetGraphAndArgRets(
|
||||||
|
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
|
||||||
|
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
|
||||||
|
std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
|
||||||
|
std::vector<string>* ret_node_names, DataTypeVector* ret_types,
|
||||||
|
std::vector<string>* control_ret_node_names);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
|
#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
|
||||||
|
@ -31,6 +31,24 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
|
||||||
|
GrapplerItem::OptimizationOptions optimization_options;
|
||||||
|
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
||||||
|
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
||||||
|
// However because the function will be executed via FunctionLibraryRuntime,
|
||||||
|
// and current function implementation does not prune stateful and dataset
|
||||||
|
// ops, we rely on Grappler to do the correct graph pruning.
|
||||||
|
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
||||||
|
|
||||||
|
optimization_options.is_eager_mode = true;
|
||||||
|
|
||||||
|
// All the nested function calls will be executed and optimized via
|
||||||
|
// PartitionedCallOp, there is no need to optimize functions now.
|
||||||
|
optimization_options.optimize_function_library = false;
|
||||||
|
|
||||||
|
return optimization_options;
|
||||||
|
}
|
||||||
|
|
||||||
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
|
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.id = id;
|
item.id = id;
|
||||||
|
@ -133,6 +133,8 @@ struct GrapplerItem {
|
|||||||
OptimizationOptions optimization_options_;
|
OptimizationOptions optimization_options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
GrapplerItem::OptimizationOptions CreateOptOptionsForEager();
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user