[Build cleanup] Split "core_cpu_impl" into fine-grained targets (3/n).
This change splits many (but not all) of the function-related targets into separate cc_library targets. The main changes are: * Move "graph/graph_constructor.{h,cc}" to "common_runtime/graph_constructor.{h,cc}" and leave a forwarding header. This code depends on common_runtime and is built as part of it, so it makes sense to move it across. The "graph_constructor" library includes "shape_refiner.{h,cc}", "graph_runner.{h,cc}", and "eval_const_tensor.{h,cc}" because of a circular dependency between these modules. * Split "function.{h,cc}" into "function_body.{h,cc}", "function_utils.{h,cc}", and "inline_function_utils.{h,cc}" (plus the original, slimmed-down module). This enables other targets in common_runtime to depend on just the function utilities they need, without the whole runtime, which breaks some cycles. * New fine-grained targets for "constant_folding", "function_optimization_registry", and "graph_optimizer". PiperOrigin-RevId: 308651243 Change-Id: Iac59c1db4ebdd16609f89d6caee6b7e6ba7ff0a1
This commit is contained in:
parent
7c977c938e
commit
d6027bd76a
|
@ -604,6 +604,7 @@ tf_cc_test(
|
|||
":c_api",
|
||||
":c_api_internal",
|
||||
":c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
|
|
@ -491,6 +491,7 @@ cc_library(
|
|||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime:core_cpu",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
|
|
@ -58,6 +58,7 @@ tf_cuda_library(
|
|||
"device_factory.h",
|
||||
"function.h",
|
||||
"function_optimization_registry.h",
|
||||
"graph_constructor.h",
|
||||
"optimization_registry.h",
|
||||
"shape_refiner.h",
|
||||
"//tensorflow/core/graph:core_cpu_headers",
|
||||
|
@ -153,7 +154,12 @@ filegroup(
|
|||
"device_set.h",
|
||||
"eval_const_tensor.h",
|
||||
"function.h",
|
||||
"function_body.h",
|
||||
"function_utils.h",
|
||||
"graph_constructor.h",
|
||||
"graph_optimizer.h",
|
||||
"graph_runner.h",
|
||||
"inline_function_utils.h",
|
||||
"metrics.h",
|
||||
"process_function_library_runtime.h",
|
||||
"scoped_allocator.h",
|
||||
|
@ -167,9 +173,6 @@ filegroup(
|
|||
tf_cuda_library(
|
||||
name = "core_cpu_base_no_ops",
|
||||
srcs = [
|
||||
"eval_const_tensor.cc",
|
||||
"graph_optimizer.h",
|
||||
"shape_refiner.cc",
|
||||
"//tensorflow/core/graph:core_cpu_base_no_ops_srcs",
|
||||
"//tensorflow/core/public:session_options.h",
|
||||
"//tensorflow/core/public:version.h",
|
||||
|
@ -190,6 +193,7 @@ tf_cuda_library(
|
|||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"//third_party/eigen3",
|
||||
] + if_static([
|
||||
":graph_constructor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
]),
|
||||
|
@ -221,7 +225,6 @@ filegroup(
|
|||
"executor.h",
|
||||
"executor_factory.h",
|
||||
"function_optimization_registry.h",
|
||||
"graph_optimizer.h",
|
||||
"input_colocation_exemption_registry.h",
|
||||
"isolate_placer_inspection_required_ops_pass.h",
|
||||
"local_device.h",
|
||||
|
@ -390,6 +393,27 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant_folding",
|
||||
srcs = ["constant_folding.cc"],
|
||||
hdrs = ["constant_folding.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
":device_factory",
|
||||
":executor",
|
||||
":function_utils",
|
||||
":graph_constructor",
|
||||
":memory_types",
|
||||
":rendezvous_mgr",
|
||||
":session_options",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "costmodel_manager",
|
||||
srcs = ["costmodel_manager.cc"],
|
||||
|
@ -524,19 +548,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_view",
|
||||
srcs = ["graph_view.cc"],
|
||||
hdrs = ["graph_view.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "device_set",
|
||||
srcs = ["device_set.cc"],
|
||||
|
@ -557,6 +568,112 @@ cc_library(
|
|||
deps = ["//tensorflow/core:framework"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_body",
|
||||
srcs = ["function_body.cc"],
|
||||
hdrs = ["function_body.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_optimization_registry",
|
||||
srcs = ["function_optimization_registry.cc"],
|
||||
hdrs = ["function_optimization_registry.h"],
|
||||
deps = [
|
||||
":device_set",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_utils",
|
||||
srcs = ["function_utils.cc"],
|
||||
hdrs = ["function_utils.h"],
|
||||
deps = [
|
||||
":function_body",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
# This library also includes "eval_const_tensor", "graph_runner", and
|
||||
# "shape_refiner", because there are circular dependencies between these
|
||||
# modules.
|
||||
cc_library(
|
||||
name = "graph_constructor",
|
||||
srcs = [
|
||||
"eval_const_tensor.cc",
|
||||
"graph_constructor.cc",
|
||||
"graph_runner.cc",
|
||||
"shape_refiner.cc",
|
||||
"//tensorflow/core/framework:versions.h",
|
||||
],
|
||||
hdrs = [
|
||||
"eval_const_tensor.h",
|
||||
"graph_constructor.h",
|
||||
"graph_runner.h",
|
||||
"shape_refiner.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
":device_factory",
|
||||
":executor",
|
||||
":function_utils",
|
||||
":memory_types",
|
||||
":rendezvous_mgr",
|
||||
":session_options",
|
||||
":single_threaded_cpu_device",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_optimizer",
|
||||
srcs = ["graph_optimizer.cc"],
|
||||
hdrs = ["graph_optimizer.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":constant_folding",
|
||||
":function_utils",
|
||||
":graph_constructor",
|
||||
":inline_function_utils",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_view",
|
||||
srcs = ["graph_view.cc"],
|
||||
hdrs = ["graph_view.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hierarchical_tree_broadcaster",
|
||||
srcs = ["hierarchical_tree_broadcaster.cc"],
|
||||
|
@ -592,6 +709,29 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inline_function_utils",
|
||||
srcs = ["inline_function_utils.cc"],
|
||||
hdrs = ["inline_function_utils.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device",
|
||||
":function_body",
|
||||
":function_utils",
|
||||
":graph_constructor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "input_colocation_exemption_registry",
|
||||
srcs = ["input_colocation_exemption_registry.cc"],
|
||||
|
@ -685,6 +825,7 @@ cc_library(
|
|||
copts = tf_copts(),
|
||||
deps = [
|
||||
":device_set",
|
||||
":graph_constructor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -1062,11 +1203,7 @@ tf_cuda_library(
|
|||
srcs = [
|
||||
"colocation_graph.cc",
|
||||
"composite_device.cc",
|
||||
"constant_folding.cc",
|
||||
"function.cc",
|
||||
"function_optimization_registry.cc",
|
||||
"graph_optimizer.cc",
|
||||
"graph_runner.cc",
|
||||
"inspecting_placer.cc",
|
||||
"isolate_placer_inspection_required_ops_pass.cc",
|
||||
"lower_case_op.cc",
|
||||
|
@ -1087,9 +1224,14 @@ tf_cuda_library(
|
|||
":entry",
|
||||
":executor",
|
||||
":executor_factory",
|
||||
":function_body",
|
||||
":function_optimization_registry",
|
||||
":graph_constructor",
|
||||
":graph_optimizer",
|
||||
":graph_view",
|
||||
":local_executor_params",
|
||||
":immutable_executor_state",
|
||||
":inline_function_utils",
|
||||
":input_colocation_exemption_registry",
|
||||
":pending_counts",
|
||||
":propagator_debug_utils",
|
||||
|
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/executor_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
|
@ -49,7 +50,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
|
@ -27,7 +28,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -22,7 +22,10 @@ limitations under the License.
|
|||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
@ -30,8 +33,6 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
static constexpr const char* const kNoInlineAttr = "_noinline";
|
||||
|
||||
// Get default customizable kernel creator if set
|
||||
const CustomKernelCreator* GetDefaultCustomKernelCreator();
|
||||
|
||||
|
@ -67,88 +68,6 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
|||
const SessionMetadata* session_metadata,
|
||||
ProcessFunctionLibraryRuntime* parent);
|
||||
|
||||
// FunctionLibraryRuntime::GetFunctionBody returns a description of an
|
||||
// instantiated function that is represented as a Graph with arg/ret
|
||||
// nodes annotated.
|
||||
struct FunctionBody {
|
||||
FunctionDef fdef;
|
||||
Graph* graph = nullptr; // owned.
|
||||
DataTypeVector arg_types;
|
||||
DataTypeVector ret_types;
|
||||
// arg_nodes[i] contains the i'th function input. In other words,
|
||||
// GetNodeAttr(arg_nodes[i]->attrs(), "index") == i.
|
||||
gtl::InlinedVector<Node*, 4> arg_nodes;
|
||||
// ret_nodes[i] contains the i'th function output. In other words,
|
||||
// GetNodeAttr(ret_nodes[i]->attrs(), "index") == i.
|
||||
gtl::InlinedVector<Node*, 4> ret_nodes;
|
||||
gtl::InlinedVector<Node*, 4> control_ret_nodes;
|
||||
|
||||
FunctionBody() {}
|
||||
FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
|
||||
DataTypeSlice ret_types, Graph* g);
|
||||
~FunctionBody();
|
||||
};
|
||||
|
||||
// Debugging facility. Returns a debug string for a graph
|
||||
// representing an instantiated function.
|
||||
string DebugString(const Graph* g);
|
||||
|
||||
// A few hand-crafted optimization on the instantiated function body
|
||||
// (a Graph*).
|
||||
|
||||
// Removes nodes that are
|
||||
// 1. not stateful; and
|
||||
// 2. not _Arg; and
|
||||
// 3. not reachable from _Retval.
|
||||
//
|
||||
// This function is triggered by function inlining, unlike 'PruneFunctionBody'
|
||||
// it doesn't preserve nodes that are reachable from control returns. Function
|
||||
// inlining is responsible for connecting control return nodes with the nodes
|
||||
// that have input control edges from the inlined function call node.
|
||||
//
|
||||
// Assuming that automatic control dependency tracking is correct, absence of
|
||||
// outgoing control edge from the function call node means that no one needs to
|
||||
// observe side-effect that might have been generated by the function (see
|
||||
// documentation in common_runtime/function.cc for details).
|
||||
//
|
||||
// Returns true iff any node is removed from "g".
|
||||
bool RemoveDeadNodes(Graph* g);
|
||||
|
||||
// Find a pattern:
|
||||
// src -(in)-> node -(out)-> dst, where
|
||||
// 1) node is an identity node;
|
||||
// 2) in is the only incoming data edge;
|
||||
// 3) out is the only outgoing data edge;
|
||||
//
|
||||
// Rewrites the above pattern with src->dst and relevant data
|
||||
// dependencies updated. Repeat the process until no such pattern
|
||||
// left.
|
||||
bool RemoveIdentityNodes(Graph* g);
|
||||
|
||||
// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
|
||||
bool RemoveListArrayConverter(Graph* g);
|
||||
|
||||
// Dump the contents of the "graph" to log files if the logging level is
|
||||
// sufficiently high.
|
||||
void DumpGraph(StringPiece label, const Graph* g);
|
||||
|
||||
// Applies graph rewrite optimization such as inlining, dead code
|
||||
// removal, etc.
|
||||
//
|
||||
// **g is a graph constructed based on the runtime library 'lib'.
|
||||
// OptimizeGraph mutates **g extensively and replaces '*g' with a
|
||||
// complete copy. Therefore, the caller should not keep any references
|
||||
// to nodes *g.
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
|
||||
const GraphOptimizer::Options& graph_optimizer_options);
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
|
||||
|
||||
// Convert the Graph of a function to a GraphDef.
|
||||
//
|
||||
// Handles renaming of nodes to avoid duplicate names which may
|
||||
// be present after various rewriting operations.
|
||||
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
|
||||
|
||||
// Given a numerical function "f", returns another numerical function
|
||||
// "g", such that if "f" takes N inputs and produces M outputs, "g"
|
||||
// takes N + M inputs and produces N outputs. I.e., if
|
||||
|
@ -161,221 +80,6 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
|
|||
// TODO(zhifengc): Asks math expert to say the comment again.
|
||||
std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f);
|
||||
|
||||
// Optionally override device assignment for nodes added to the graph for
|
||||
// inlined functions:
|
||||
// (1) Identity nodes added in place of function input arguments.
|
||||
// (2) Identity nodes added in place of function return values.
|
||||
// (3) Special NoOp nodes that enforce side-effects execution order.
|
||||
// (4) All nodes inside function body specified in FunctionDef.
|
||||
class InlinedFunctionBodyPlacer {
|
||||
public:
|
||||
virtual ~InlinedFunctionBodyPlacer() = default;
|
||||
|
||||
virtual absl::optional<string> InputNodeDevice(int input_index) const = 0;
|
||||
virtual absl::optional<string> OutputNodeDevice(int output_index) const = 0;
|
||||
// Returns true if the added input/output identity nodes should be colocated
|
||||
// with the corresponding input/output from the function body.
|
||||
virtual bool ColocateInputOutputIdentities() const = 0;
|
||||
virtual absl::optional<string> ControlNodeDevice() const = 0;
|
||||
virtual absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const = 0;
|
||||
|
||||
// Place input nodes on the same device as the corresponding caller input
|
||||
// node. Do not specify any placement for all other nodes.
|
||||
static std::unique_ptr<InlinedFunctionBodyPlacer> DefaultPlacer(
|
||||
const Graph& graph, const Node& caller);
|
||||
|
||||
// Place all nodes on the same device as caller node.
|
||||
static std::unique_ptr<InlinedFunctionBodyPlacer> SingleDevicePlacer(
|
||||
const Graph& graph, const Node& caller);
|
||||
|
||||
// Place input nodes on the same device as the corresponding caller input
|
||||
// node. Do not place output node. Place control nodes on the same device as
|
||||
// caller node. For all function body nodes overrides job, replica and task
|
||||
// parts of the device assignment to match function caller node.
|
||||
static std::unique_ptr<InlinedFunctionBodyPlacer> MultiDevicePlacer(
|
||||
const Graph& graph, const Node& caller);
|
||||
|
||||
using Factory = std::function<std::unique_ptr<InlinedFunctionBodyPlacer>(
|
||||
const Graph&, const Node&)>;
|
||||
|
||||
struct Config {
|
||||
string name;
|
||||
Factory get;
|
||||
};
|
||||
|
||||
static Config Default() { return {"default", DefaultPlacer}; }
|
||||
static Config SingleDevice() { return {"single_device", SingleDevicePlacer}; }
|
||||
static Config MultiDevice() { return {"multi_device", MultiDevicePlacer}; }
|
||||
};
|
||||
|
||||
struct InlineFunctionBodyOptions {
|
||||
// All nodes that have incoming control edge *from* the function call node,
|
||||
// will be forwarded to the "output control node". There are two options for
|
||||
// choosing which nodes will have a control edge *to* the "output control
|
||||
// node":
|
||||
// a) control returns (`control_ret` field in FunctionDef)
|
||||
// b) data returns (`ret` field in FunctionDef)
|
||||
enum class OutputControlSource { kDataOutputs, kControlOutputs };
|
||||
|
||||
// Keep a node in a graph with the same name as the function call node:
|
||||
//
|
||||
// a) DoNotKeep: Function call node is fully inlined, and there is no node in
|
||||
// a graph with the same name.
|
||||
//
|
||||
// b) Fetchable: Add an IdentityN node to the graph in place of the inlined
|
||||
// function call node. It will have a control edge from inlined
|
||||
// 'output_control_node' and data edges from function output nodes.
|
||||
// The IdentityN node will be placed on the same device as the caller node.
|
||||
//
|
||||
// This is mostly for compatibility with Tensorflow v1 and sessions.
|
||||
// When we prepare a graph for execution in
|
||||
// GraphExecutionState::MakeForBaseGraph we don't know what nodes will be
|
||||
// fetched, so we can't safely remove any of them. When graph executed as a
|
||||
// function it has 'Retval' nodes for all fetched tensors, and we can
|
||||
// safely inline function calls.
|
||||
//
|
||||
// c) Targetable: Add a NoOp node to the graph in place of the inlined
|
||||
// function call node. It will have a control edge from inline
|
||||
// 'output_control_node' and no data edges. NoOp node will be placed on the
|
||||
// same device as the caller node. This will keep the inlined function call
|
||||
// node a valid 'session.run' target, and also will keep it a valid control
|
||||
// output node.
|
||||
enum class KeepCallerNode { kDoNotKeep, kFetchable, kTargetable };
|
||||
|
||||
// If 'true' function inlining is completely disabled. This allows to control
|
||||
// function inlining for different types of function calls (see
|
||||
// 'ExpandInlineFunctionsOptions' below).
|
||||
bool disable_inlining = false;
|
||||
// Ignore '_noinline' function attribute.
|
||||
bool ignore_noinline = false;
|
||||
// If 'true' function inlining will inline functions in implementation
|
||||
// selection group. Normally those functions should not be inlined; they will
|
||||
// be handled by Grappler.
|
||||
bool inline_impl_selection_group_functions = false;
|
||||
// Controls if we want to keep a node with the name as the function call node
|
||||
// in a graph after function inlining.
|
||||
KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep;
|
||||
// For compatibility with Tensorflow v1 by default we will use data outputs.
|
||||
// Control returns were added to Tensorflow v2 with automatic control
|
||||
// dependencies tracking in Eager mode.
|
||||
OutputControlSource output_control_src = OutputControlSource::kDataOutputs;
|
||||
// Inlined function body placer decides what requested device assignments
|
||||
// should be added to the nodes added to the graph. See documentation above
|
||||
// for available strategies.
|
||||
InlinedFunctionBodyPlacer::Config inlined_function_body_placer =
|
||||
InlinedFunctionBodyPlacer::Default();
|
||||
// If true, frame names in the function body will be
|
||||
// made unique in the resulting graph (e.g. by prepending a unique prefix).
|
||||
// NOTE(mrry): Only set this option to false when there is a single function
|
||||
// call in the graph (e.g. when making a remote function call via
|
||||
// ClusterFunctionLibraryRuntime). This option is provided because the graph
|
||||
// partitioner generates frame names that must remain unmodified across all
|
||||
// partitions of a multi-device function.
|
||||
bool uniquify_frame_names = true;
|
||||
|
||||
// A human-readable debug string for this options.
|
||||
string DebugString() const;
|
||||
};
|
||||
|
||||
// Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node'
|
||||
// based on the type signature of 'node' and 'fbody':
|
||||
//
|
||||
// (1) Caller node has the same number of inputs and outputs as the function.
|
||||
// (2) Caller node inputs and outputs have the same data types as function
|
||||
// inputs and returns.
|
||||
// (3) Validation rules defined in InlineFunctionBodyOptions.
|
||||
//
|
||||
// If function can't be safely inlined, returns error message with details why
|
||||
// inlining is not possible or safe.
|
||||
Status ValidateInlining(const Node* node, const FunctionBody* fbody,
|
||||
const InlineFunctionBodyOptions& options);
|
||||
|
||||
// Given a "caller" in graph "g", which is a function call of a function
|
||||
// to "fbody". Replaces the "caller" with fbody->graph and connects
|
||||
// edges properly. "override_device" specifies whether inlining should replace
|
||||
// explicitly specified devices inside fbody with the callee's device.
|
||||
//
|
||||
// Returns 'Status::OK()' if function was successfully inlined into the graph.
|
||||
// If function inlining is not possible returns an error with a reason, and
|
||||
// leaves the graph in unmodified state.
|
||||
Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
Node* caller, const FunctionBody* fbody,
|
||||
const InlineFunctionBodyOptions& options);
|
||||
|
||||
// There are three types of function calls that could be invoked during
|
||||
// *Tensorflow graph execution*:
|
||||
//
|
||||
// 1) Native function call (node.type_string() is the function name). These
|
||||
// functions are always executed on a single-device, which is the device of
|
||||
// the function call node.
|
||||
//
|
||||
// 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall
|
||||
// ops) can execute on multiple devices and accept DT_RESOURCE inputs that
|
||||
// belong to different devices. This type of functions was added in
|
||||
// Tensorflow 2.0 Eager mode, and it has control outputs to represent
|
||||
// side-effects that must always execute (see `control_ret` in FunctionDef).
|
||||
//
|
||||
// 3) SymbolicGradient has been deprecated for a while, but we still keep it and
|
||||
// use `native` options for inlining for compatibility.
|
||||
//
|
||||
// We need to have distinct inlining rules for compatibility with Tensorflow v1.
|
||||
//
|
||||
// There are few other places in Tensorflow that could execute functions:
|
||||
//
|
||||
// 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level"
|
||||
// functions directly via function library runtime, without going through
|
||||
// the graph.
|
||||
// 2) tf.data pipelines - also execute functions directly via function library
|
||||
// runtime with custom executors.
|
||||
struct ExpandInlineFunctionsOptions {
|
||||
ExpandInlineFunctionsOptions() : native_options(), multi_device_options() {
|
||||
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
|
||||
multi_device_options.output_control_src = OutputControlSrc::kControlOutputs;
|
||||
}
|
||||
|
||||
InlineFunctionBodyOptions native_options;
|
||||
InlineFunctionBodyOptions multi_device_options;
|
||||
};
|
||||
|
||||
// WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary
|
||||
// workaround that will be enabled only during the function inlining unification
|
||||
// (b/126811947). Contact ezhulenev@ if you think you need it.
|
||||
// TODO(ezhulenev): Delete this function.
|
||||
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
|
||||
const ExpandInlineFunctionsOptions& options);
|
||||
|
||||
// For each node in "graph", if "lib" indicates that the node is a
|
||||
// function call, inline the function body. Returns true if at least
|
||||
// one node is inlined.
|
||||
//
|
||||
// This routine goes through "graph" nodes once and applies the
|
||||
// inlining. The caller may decide to apply the inlining on "graph"
|
||||
// multiple times by calling ExpandInlineFunctions a few times.
|
||||
//
|
||||
// Function calls that can't be safely inlined into the graph (ValidateInlining
|
||||
// returns error), are ignored.
|
||||
//
|
||||
// TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the
|
||||
// FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see
|
||||
// lower_function_call.cc).
|
||||
inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
|
||||
return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
|
||||
}
|
||||
|
||||
// Extracts function name and attributes from `call_def`
|
||||
// `call_def` can be a native function call (where the op type is the function
|
||||
// name) or a call through PartitionedCall/StatefulPartitionedCall.
|
||||
Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
|
||||
NameAttrList* function);
|
||||
|
||||
// Extracts function name and attributes from `call_def` and invokes
|
||||
// flr->Instantiate(name, attrs, handle).
|
||||
// `call_def` can be a native function call (where the op type is the function
|
||||
// name) or a call through PartitionedCall/StatefulPartitionedCall.
|
||||
Status InstantiateFunctionCall(const NodeDef& call_def,
|
||||
FunctionLibraryRuntime* flr,
|
||||
FunctionLibraryRuntime::Handle* handle);
|
||||
|
||||
// Returns true iff `n` represents a function call. `n` can be a native
|
||||
// function call (n.type_string() is the function name),
|
||||
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
|
||||
DataTypeSlice ret_t, Graph* g)
|
||||
: fdef(f),
|
||||
graph(g),
|
||||
arg_types(arg_t.begin(), arg_t.end()),
|
||||
ret_types(ret_t.begin(), ret_t.end()) {
|
||||
// 1. Find regular Arg/Ret nodes.
|
||||
this->arg_nodes.resize(arg_types.size());
|
||||
this->ret_nodes.resize(ret_types.size());
|
||||
for (Node* n : this->graph->op_nodes()) {
|
||||
gtl::InlinedVector<Node*, 4>* node_vec;
|
||||
if (n->type_string() == FunctionLibraryDefinition::kRetOp ||
|
||||
n->type_string() == FunctionLibraryDefinition::kDeviceRetOp) {
|
||||
node_vec = &this->ret_nodes;
|
||||
} else if (n->type_string() == FunctionLibraryDefinition::kArgOp ||
|
||||
n->type_string() == FunctionLibraryDefinition::kDeviceArgOp) {
|
||||
node_vec = &this->arg_nodes;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
int index;
|
||||
TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index));
|
||||
CHECK_LE(0, index);
|
||||
CHECK_LT(index, node_vec->size());
|
||||
(*node_vec)[index] = n;
|
||||
}
|
||||
// 2. Find ControlRet nodes that must be always executed.
|
||||
std::unordered_set<StringPiece, StringPieceHasher> control_ret_node_names;
|
||||
for (const auto& control_ret : fdef.control_ret()) {
|
||||
control_ret_node_names.insert(control_ret.second);
|
||||
}
|
||||
this->control_ret_nodes.reserve(control_ret_node_names.size());
|
||||
for (Node* n : this->graph->op_nodes()) {
|
||||
if (control_ret_node_names.count(n->name()) > 0) {
|
||||
this->control_ret_nodes.push_back(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FunctionBody::~FunctionBody() { delete this->graph; }
|
||||
|
||||
} // end namespace tensorflow
|
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Graph;
|
||||
class Node;
|
||||
|
||||
// FunctionLibraryRuntime::GetFunctionBody returns a description of an
|
||||
// instantiated function that is represented as a Graph with arg/ret
|
||||
// nodes annotated.
|
||||
struct FunctionBody {
|
||||
FunctionDef fdef;
|
||||
Graph* graph = nullptr; // owned.
|
||||
DataTypeVector arg_types;
|
||||
DataTypeVector ret_types;
|
||||
// arg_nodes[i] contains the i'th function input. In other words,
|
||||
// GetNodeAttr(arg_nodes[i]->attrs(), "index") == i.
|
||||
gtl::InlinedVector<Node*, 4> arg_nodes;
|
||||
// ret_nodes[i] contains the i'th function output. In other words,
|
||||
// GetNodeAttr(ret_nodes[i]->attrs(), "index") == i.
|
||||
gtl::InlinedVector<Node*, 4> ret_nodes;
|
||||
gtl::InlinedVector<Node*, 4> control_ret_nodes;
|
||||
|
||||
FunctionBody() {}
|
||||
FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
|
||||
DataTypeSlice ret_types, Graph* g);
|
||||
~FunctionBody();
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/executor_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
|
@ -40,7 +41,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
|
|
@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <utility>
|
||||
|
||||
|
@ -25,7 +23,9 @@ limitations under the License.
|
|||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
|
@ -33,7 +33,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
|
|
@ -0,0 +1,368 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static constexpr const char* const kNodeLabel = "Func";
|
||||
|
||||
// Represents the index-th output of a node.
|
||||
struct Endpoint {
|
||||
Node* node;
|
||||
int index;
|
||||
|
||||
// Returns the string name represents this endpoint.
|
||||
string name() const {
|
||||
if (index == 0) {
|
||||
return node->name();
|
||||
} else {
|
||||
return strings::StrCat(node->name(), ":", index);
|
||||
}
|
||||
}
|
||||
|
||||
DataType dtype() const { return node->output_type(index); }
|
||||
};
|
||||
|
||||
// The following Add* routines are used to add a few graph nodes while
|
||||
// functions are transformed.
|
||||
static Node* AddNoOp(StringPiece name, Graph* g) {
|
||||
NodeDef ndef;
|
||||
ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
|
||||
ndef.set_op("NoOp");
|
||||
Status s;
|
||||
Node* ret = g->AddNode(ndef, &s);
|
||||
TF_CHECK_OK(s);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
|
||||
DCHECK_LT(0, input.dtype());
|
||||
NodeDef ndef;
|
||||
ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
|
||||
ndef.set_op("Identity");
|
||||
ndef.add_input(input.name());
|
||||
AddNodeAttr("T", BaseType(input.dtype()), &ndef);
|
||||
Status s;
|
||||
Node* ret = g->AddNode(ndef, &s);
|
||||
TF_CHECK_OK(s);
|
||||
g->AddEdge(input.node, input.index, ret, 0);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void DumpGraph(StringPiece label, const Graph* g) {
|
||||
// TODO(zhifengc): Change Graph to record #nodes.
|
||||
VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
|
||||
<< g->num_edges();
|
||||
if (VLOG_IS_ON(5)) {
|
||||
for (const auto& line : str_util::Split(DebugString(g), '\n')) {
|
||||
VLOG(5) << "|| " << line;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool RemoveDeadNodes(Graph* g) {
|
||||
VLOG(2) << "Removing dead nodes";
|
||||
std::unordered_set<const Node*> nodes;
|
||||
for (auto n : g->nodes()) {
|
||||
if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
|
||||
n->op_def().is_stateful()) {
|
||||
nodes.insert(n);
|
||||
}
|
||||
}
|
||||
return PruneForReverseReachability(g, std::move(nodes));
|
||||
}
|
||||
|
||||
namespace {
|
||||
// If 'edges' contains only 1 non-control edge, returns it. Otherwise,
|
||||
// returns a nullptr.
|
||||
const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
|
||||
const Edge* ret = nullptr;
|
||||
for (const Edge* e : edges) {
|
||||
if (e->IsControlEdge() || ret) {
|
||||
// Don't touch it if there is a control edge.
|
||||
return nullptr;
|
||||
}
|
||||
if (IsRefType(e->src()->output_type(e->src_output()))) {
|
||||
// Don't touch it if the identity node is effectively de-reffing
|
||||
// a ref.
|
||||
return nullptr;
|
||||
}
|
||||
if (IsRecv(e->src()) || IsSwitch(e->src())) {
|
||||
// Don't touch it if the identity is introduced for control flow.
|
||||
// Recv disables all its successors if it receives a dead signal.
|
||||
// When Recv has an outgoing control edge, the current executor
|
||||
// would not disable the destination. The current solution (see
|
||||
// graph_partition.cc) is to add an identity after Recv and change
|
||||
// the control edge to be from this identity node. So the identity
|
||||
// can't be removed.
|
||||
return nullptr;
|
||||
}
|
||||
ret = e;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
} // end namespace
|
||||
|
||||
bool RemoveIdentityNodes(Graph* g) {
|
||||
VLOG(2) << "Removing identity nodes";
|
||||
bool removed_any = false;
|
||||
gtl::InlinedVector<Node*, 8> matches;
|
||||
for (Node* n : g->nodes()) {
|
||||
if (!n->IsIdentity()) continue;
|
||||
if (!GetTheOnlyDataEdge(n->in_edges())) continue;
|
||||
|
||||
// Some identity nodes are used as sink nodes to give names to output
|
||||
// tensors. These nodes are not going to be executed unless they are in the
|
||||
// fetch set. But if they are in the fetch set we don't want to remove them.
|
||||
if (n->out_edges().empty()) continue;
|
||||
|
||||
matches.push_back(n);
|
||||
}
|
||||
if (!matches.empty()) {
|
||||
for (Node* n : matches) {
|
||||
const Edge* in = GetTheOnlyDataEdge(n->in_edges());
|
||||
for (const Edge* out : n->out_edges()) {
|
||||
if (out->IsControlEdge()) {
|
||||
g->AddControlEdge(in->src(), out->dst());
|
||||
} else {
|
||||
g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
|
||||
}
|
||||
}
|
||||
VLOG(2) << "Remove Identity: " << n->DebugString();
|
||||
g->RemoveNode(n);
|
||||
removed_any = true;
|
||||
}
|
||||
}
|
||||
return removed_any;
|
||||
}
|
||||
|
||||
bool RemoveListArrayConverter(Graph* g) {
|
||||
VLOG(2) << "Removing list array converter";
|
||||
gtl::InlinedVector<Node*, 8> matches;
|
||||
for (Node* n : g->nodes()) {
|
||||
if ((n->type_string() == "_ListToArray") ||
|
||||
(n->type_string() == "_ArrayToList")) {
|
||||
matches.push_back(n);
|
||||
}
|
||||
}
|
||||
bool removed_any = false;
|
||||
if (!matches.empty()) {
|
||||
for (Node* n : matches) {
|
||||
if (n->num_inputs() != n->num_outputs()) {
|
||||
continue; // Not expected. Skip.
|
||||
}
|
||||
gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
|
||||
|
||||
const auto no_op = [&](StringPiece name) -> Node* {
|
||||
return AddNoOp(absl::StrCat(n->name(), "/", name), g);
|
||||
};
|
||||
|
||||
const auto identity = [&](StringPiece name, Endpoint input) -> Node* {
|
||||
Node* node = AddIdentity(absl::StrCat(n->name(), "/", name), g, input);
|
||||
node->set_requested_device(input.node->def().device());
|
||||
return node;
|
||||
};
|
||||
|
||||
// Process input edges first.
|
||||
Node* input_control_node = nullptr;
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
if (input_control_node == nullptr) {
|
||||
// If node "n" has any control dependencies, adds a no-op
|
||||
// node (input_control_node) which the additional Identity
|
||||
// nodes depends on and the input_control_node depends on
|
||||
// the node "n"s control dependencies.
|
||||
input_control_node = no_op("input_control_node");
|
||||
}
|
||||
g->AddControlEdge(e->src(), input_control_node);
|
||||
} else {
|
||||
const int index = e->dst_input();
|
||||
Node** id_node = &identity_nodes[index];
|
||||
if (*id_node != nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "RemoveListArrayConverter unexpected duplicated input: "
|
||||
<< e->dst_input();
|
||||
return removed_any;
|
||||
}
|
||||
*id_node = identity("input", {e->src(), e->src_output()});
|
||||
}
|
||||
}
|
||||
|
||||
// If node "n" has any control dependencies, the added identity
|
||||
// nodes should have control dependencies on input_control_node.
|
||||
if (input_control_node != nullptr) {
|
||||
for (Node* id : identity_nodes) {
|
||||
g->AddControlEdge(input_control_node, id);
|
||||
}
|
||||
}
|
||||
|
||||
Node* output_control_node = nullptr;
|
||||
for (const Edge* e : n->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
if (output_control_node == nullptr) {
|
||||
// If node "n" is control-depended upon by other nodes,
|
||||
// adds a no-op node (output_control_node) which those
|
||||
// nodes will depend on and output_control_node depends on
|
||||
// all Identity nodes.
|
||||
output_control_node = no_op("output_control_node");
|
||||
}
|
||||
g->AddControlEdge(output_control_node, e->dst());
|
||||
} else {
|
||||
Node* id_node = identity_nodes[e->src_output()];
|
||||
if (id_node == nullptr) {
|
||||
LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
|
||||
<< e->src_output();
|
||||
return removed_any;
|
||||
}
|
||||
CHECK(id_node);
|
||||
g->AddEdge(id_node, 0, e->dst(), e->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
// If any nodes have control dependencies on node "n", those
|
||||
// nodes should have control dependencies on
|
||||
// output_control_node.
|
||||
if (output_control_node != nullptr) {
|
||||
for (Node* id : identity_nodes) {
|
||||
g->AddControlEdge(id, output_control_node);
|
||||
}
|
||||
}
|
||||
|
||||
g->RemoveNode(n);
|
||||
removed_any = true;
|
||||
}
|
||||
}
|
||||
return removed_any;
|
||||
}
|
||||
|
||||
Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
|
||||
NameAttrList* function) {
|
||||
if (call_def.op() == "PartitionedCall" ||
|
||||
call_def.op() == "StatefulPartitionedCall") {
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", function));
|
||||
} else {
|
||||
function->set_name(call_def.op());
|
||||
*function->mutable_attr() = call_def.attr();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InstantiateFunctionCall(const NodeDef& call_def,
|
||||
FunctionLibraryRuntime* flr,
|
||||
FunctionLibraryRuntime::Handle* handle) {
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(call_def, &function));
|
||||
return flr->Instantiate(function.name(), AttrSlice(&function.attr()), handle);
|
||||
}
|
||||
|
||||
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def,
|
||||
const Node& node) {
|
||||
return node.IsFunctionCall();
|
||||
}
|
||||
|
||||
string NewName(const Node* n, bool pretty) {
|
||||
if (pretty) {
|
||||
return strings::StrCat(n->type_string(), n->id());
|
||||
} else {
|
||||
return strings::StrCat("n", n->id());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
|
||||
// and stash the original NodeDef name as an attr for documentation
|
||||
// purpose.
|
||||
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
|
||||
// We visit nodes in forward topological sort order, which is a
|
||||
// possible execution order of the graph.
|
||||
gtl::InlinedVector<const Edge*, 4> inputs;
|
||||
gdef->Clear();
|
||||
*gdef->mutable_versions() = g->versions();
|
||||
|
||||
std::vector<Node*> start_nodes;
|
||||
for (Node* n : g->nodes()) {
|
||||
if (n->out_edges().empty()) {
|
||||
start_nodes.push_back(n);
|
||||
}
|
||||
}
|
||||
|
||||
ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
|
||||
if (!n->IsOp()) return;
|
||||
NodeDef* ndef = gdef->add_node();
|
||||
ndef->set_name(NewName(n, pretty));
|
||||
ndef->set_op(n->type_string());
|
||||
for (const auto& attr : n->attrs()) {
|
||||
(*ndef->mutable_attr())[attr.first] = attr.second;
|
||||
}
|
||||
|
||||
if (!n->assigned_device_name().empty()) {
|
||||
ndef->set_device(n->assigned_device_name());
|
||||
} else {
|
||||
ndef->set_device(n->requested_device());
|
||||
}
|
||||
|
||||
inputs.clear();
|
||||
inputs.resize(n->num_inputs());
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
inputs.push_back(e);
|
||||
} else {
|
||||
if (inputs[e->dst_input()] == nullptr) {
|
||||
inputs[e->dst_input()] = e;
|
||||
} else {
|
||||
LOG(WARNING) << "Malformed graph node. multiple input edges: "
|
||||
<< n->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
// node->name() is merely NodeDef::name, which are not guaranteed
|
||||
// to be unique and stable after optimization rewrites. Therefore,
|
||||
// we use "n<node id>" instead.
|
||||
for (const Edge* e : inputs) {
|
||||
if (e == nullptr) {
|
||||
ndef->add_input("unknown");
|
||||
continue;
|
||||
}
|
||||
const string srcname = NewName(e->src(), pretty);
|
||||
if (!e->src()->IsOp()) {
|
||||
} else if (e->IsControlEdge()) {
|
||||
ndef->add_input(strings::StrCat("^", srcname));
|
||||
} else if (e->src_output() == 0) {
|
||||
ndef->add_input(srcname);
|
||||
} else {
|
||||
ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
string DebugString(const Graph* g) {
|
||||
GraphDef gdef;
|
||||
ToGraphDef(g, &gdef);
|
||||
return DebugString(gdef);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
|
@ -0,0 +1,105 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class AttrSlice;
|
||||
class Graph;
|
||||
class GraphDef;
|
||||
class NameAttrList;
|
||||
class Node;
|
||||
class NodeDef;
|
||||
class OpDef;
|
||||
|
||||
// Debugging facility. Returns a debug string for a graph
|
||||
// representing an instantiated function.
|
||||
string DebugString(const Graph* g);
|
||||
|
||||
// Dump the contents of the "graph" to log files if the logging level is
|
||||
// sufficiently high.
|
||||
void DumpGraph(StringPiece label, const Graph* g);
|
||||
|
||||
// Convert the Graph of a function to a GraphDef.
|
||||
//
|
||||
// Handles renaming of nodes to avoid duplicate names which may
|
||||
// be present after various rewriting operations.
|
||||
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
|
||||
|
||||
// Extracts function name and attributes from `call_def`
|
||||
// `call_def` can be a native function call (where the op type is the function
|
||||
// name) or a call through PartitionedCall/StatefulPartitionedCall.
|
||||
Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
|
||||
NameAttrList* function);
|
||||
|
||||
// A few hand-crafted optimization on the instantiated function body
|
||||
// (a Graph*).
|
||||
|
||||
// Removes nodes that are
|
||||
// 1. not stateful; and
|
||||
// 2. not _Arg; and
|
||||
// 3. not reachable from _Retval.
|
||||
//
|
||||
// This function is triggered by function inlining, unlike 'PruneFunctionBody'
|
||||
// it doesn't preserve nodes that are reachable from control returns. Function
|
||||
// inlining is responsible for connecting control return nodes with the nodes
|
||||
// that have input control edges from the inlined function call node.
|
||||
//
|
||||
// Assuming that automatic control dependency tracking is correct, absence of
|
||||
// outgoing control edge from the function call node means that no one needs to
|
||||
// observe side-effect that might have been generated by the function (see
|
||||
// documentation in common_runtime/function.cc for details).
|
||||
//
|
||||
// Returns true iff any node is removed from "g".
|
||||
bool RemoveDeadNodes(Graph* g);
|
||||
|
||||
// Find a pattern:
|
||||
// src -(in)-> node -(out)-> dst, where
|
||||
// 1) node is an identity node;
|
||||
// 2) in is the only incoming data edge;
|
||||
// 3) out is the only outgoing data edge;
|
||||
//
|
||||
// Rewrites the above pattern with src->dst and relevant data
|
||||
// dependencies updated. Repeat the process until no such pattern
|
||||
// left.
|
||||
bool RemoveIdentityNodes(Graph* g);
|
||||
|
||||
// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
|
||||
bool RemoveListArrayConverter(Graph* g);
|
||||
|
||||
// Extracts function name and attributes from `call_def` and invokes
|
||||
// flr->Instantiate(name, attrs, handle).
|
||||
// `call_def` can be a native function call (where the op type is the function
|
||||
// name) or a call through PartitionedCall/StatefulPartitionedCall.
|
||||
Status InstantiateFunctionCall(const NodeDef& call_def,
|
||||
FunctionLibraryRuntime* flr,
|
||||
FunctionLibraryRuntime::Handle* handle);
|
||||
|
||||
// Returns true iff `n` represents a function call. `n` can be a native
|
||||
// function call (n.type_string() is the function name),
|
||||
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
|
||||
// has been deprecated for a while).
|
||||
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
|
@ -0,0 +1,204 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class ShapeRefiner;
|
||||
|
||||
// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on
|
||||
// error, in which case *g is left in an incomplete state.
|
||||
//
|
||||
// *g is expected to be an empty graph (with no more than a source and sink
|
||||
// nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph,
|
||||
// see ImportGraphDef.
|
||||
struct GraphConstructorOptions {
|
||||
GraphConstructorOptions() {}
|
||||
|
||||
// If true, allows internal ops in the GraphDef.
|
||||
bool allow_internal_ops = false;
|
||||
|
||||
// If true, the graph def is expected to have fully specified
|
||||
// devices for all nodes. A node in the resulting graph "g" has the
|
||||
// device name set accordingly.
|
||||
//
|
||||
// TODO(zhifengc): if possible, consider removing this option.
|
||||
bool expect_device_spec = false;
|
||||
|
||||
// If true, validates that nodes being converted have all expected attrs
|
||||
// set and no unknown attrs set by calling ValidateNodeDef().
|
||||
// Setting validate_nodes without add_default_attributes, will fail if
|
||||
// the GraphDef does not have all required attributes set.
|
||||
bool validate_nodes = false;
|
||||
|
||||
// If true, GraphConstructor will add attributes with their default
|
||||
// value to the Node when they are missing from the NodeDef.
|
||||
bool add_default_attributes = true;
|
||||
};
|
||||
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
|
||||
const GraphDef& gdef, Graph* g);
|
||||
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
|
||||
GraphDef&& gdef, Graph* g);
|
||||
|
||||
// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function
|
||||
// instantiation.
|
||||
// TODO(irving): This will turn into std::vector<NodeInfoPtr> soon.
|
||||
extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
|
||||
gtl::ArraySlice<NodeDef> nodes, Graph* g);
|
||||
|
||||
// Options for calling ImportGraphDef().
|
||||
struct ImportGraphDefOptions {
|
||||
ImportGraphDefOptions()
|
||||
: uniquify_names(false),
|
||||
uniquify_prefix(false),
|
||||
skip_mapped_nodes(false),
|
||||
validate_shape(true) {}
|
||||
|
||||
// Name prefix to use for nodes imported from the GraphDef. For example, if
|
||||
// prefix="animals" and GraphDef contains a node "bunny" then the node will be
|
||||
// named "animals/bunny" in *g. Must not be already used as a node name or
|
||||
// prefix in the graph.
|
||||
string prefix;
|
||||
|
||||
// If true, imported node names will be modified if their name already exists
|
||||
// in the graph. If false, conflicting names will be treated as an error. Note
|
||||
// that this option has no effect if `prefix` is specified, since `prefix`
|
||||
// will guarantee all node names are unique.
|
||||
bool uniquify_names;
|
||||
|
||||
// If true, `prefix` will be modified if it already exists as a node name or
|
||||
// prefix in the graph. If false, a conflicting prefix will be treated as an
|
||||
// error. This option has no effect if `prefix` isn't specified.
|
||||
bool uniquify_prefix;
|
||||
|
||||
// Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef`
|
||||
// corresponding to `input_map` keys will be remapped to the nodes in `g`
|
||||
// corresponding to the values.
|
||||
//
|
||||
// Keys should not include `prefix`, i.e., a key ID's name should be the name
|
||||
// as it originally appears in `gdef`.
|
||||
//
|
||||
// If this is non-empty, ImportGraphDef must be called with the shape refiner
|
||||
// used to create the existing nodes referenced in `input_map`.
|
||||
// TODO(skyewm): can we remove this requirement? How do we access the original
|
||||
// shape refiner?
|
||||
std::map<SafeTensorId, SafeTensorId> input_map;
|
||||
|
||||
// If true, nodes that will have all output edges removed because of
|
||||
// overrides in `input_map` will not be imported.
|
||||
bool skip_mapped_nodes;
|
||||
|
||||
// The names of existing nodes in `g` that the imported graph should have
|
||||
// control dependencies on.
|
||||
//
|
||||
// Note that to avoid creating many redundant control edges, ImportGraphDef()
|
||||
// won't add control edges to nodes that will inherit the dependencies from
|
||||
// other nodes in `gdef`.
|
||||
std::vector<string> control_dependencies;
|
||||
|
||||
// Tensors in `gdef` that will be returned via the ImportGraphDefResults
|
||||
// output parameter of `ImportGraphDef()`. If this list is non-empty, the
|
||||
// caller must pass a results object to `ImportGraphDef()`. The
|
||||
// `return_tensors` field will be populated with the imported nodes in `g`.
|
||||
//
|
||||
// Entries should not include `prefix`, i.e., each ID's name should be the
|
||||
// name as it originally appears in `gdef`.
|
||||
//
|
||||
// If this contains a tensor that's also being remapped via `input_map`, the
|
||||
// corresponding existing tensor in `g` will be returned.
|
||||
std::vector<SafeTensorId> return_tensors;
|
||||
|
||||
// The names of nodes in `gdef` that will be returned via the
|
||||
// ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
|
||||
// is non-empty, the caller must pass a results object to
|
||||
// `ImportGraphDef()`. The `return_nodes` field will be populated with the
|
||||
// imported nodes in `g`.
|
||||
//
|
||||
// Entries should not include `prefix`, i.e., each node's name should be the
|
||||
// name as it originally appears in `gdef`.
|
||||
//
|
||||
// Unlike `return_tensors`, `input_map` has no effect on the nodes
|
||||
// returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
|
||||
// TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
|
||||
std::vector<string> return_nodes;
|
||||
|
||||
// If true, checks that all colocation constraints are nodes in the GraphDef.
|
||||
bool validate_colocation_constraints = true;
|
||||
|
||||
// If false skips shape validation.
|
||||
bool validate_shape;
|
||||
|
||||
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
|
||||
// with ops that are not defined in the binary calling ImportGraphDef.
|
||||
// Similar to the producer_op_list argument to import_graph_def in the
|
||||
// python API.
|
||||
|
||||
// Try to set default execution device for this grapth.
|
||||
string default_device;
|
||||
};
|
||||
|
||||
// Optional results that may be returned by ImportGraphDef.
|
||||
struct ImportGraphDefResults {
|
||||
// The requested tensors associated with
|
||||
// ImportGraphDefOptions::return_tensors. Note that the index may be different
|
||||
// than the requested index if the returned tensor has been remapped according
|
||||
// to `input_map`.
|
||||
typedef int Index;
|
||||
std::vector<std::pair<Node*, Index>> return_tensors;
|
||||
|
||||
// The requested nodes associated with ImportGraphDefOptions::return_nodes.
|
||||
std::vector<Node*> return_nodes;
|
||||
|
||||
// Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
|
||||
// weren't used as an input to any node in `gdef`. These keys are likely due
|
||||
// to typos, and callers may wish to treat their existence as an error.
|
||||
std::vector<SafeTensorId> missing_unused_input_map_keys;
|
||||
};
|
||||
|
||||
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
|
||||
//
|
||||
// On error, returns non-OK and leaves `*g` unmodified.
|
||||
//
|
||||
// `refiner` can be null. It should be non-null if the caller
|
||||
// intends to add additional nodes to the graph after the import. This
|
||||
// allows the caller to validate shapes of those nodes (since
|
||||
// ShapeRefiner::AddNode must be called in topological order).
|
||||
//
|
||||
// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is
|
||||
// non-empty. It can also be set to fetch the unused input map keys. If it's
|
||||
// non-null, all the vector fields must be empty.
|
||||
//
|
||||
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
|
||||
// as a means of enhancing an existing Graph.
|
||||
extern Status ImportGraphDef(const ImportGraphDefOptions& opts,
|
||||
const GraphDef& gdef, Graph* g,
|
||||
ShapeRefiner* refiner,
|
||||
ImportGraphDefResults* results = nullptr);
|
||||
|
||||
// Make a copy of "src" into "*dest".
|
||||
//
|
||||
// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges
|
||||
// other than the implicit Source/Sink nodes.
|
||||
extern void CopyGraph(const Graph& src, Graph* dest);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/placer.h"
|
||||
|
@ -40,7 +41,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/collective_order.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/graph/validate.h"
|
||||
|
|
|
@ -16,9 +16,10 @@ limitations under the License.
|
|||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/optimizer_cse.h"
|
||||
|
||||
|
@ -144,4 +145,19 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
|
|||
options.inline_with_single_device_body_placer);
|
||||
}
|
||||
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
|
||||
const GraphOptimizer::Options& graph_optimizer_options) {
|
||||
OptimizerOptions opts;
|
||||
opts.set_do_common_subexpression_elimination(true);
|
||||
opts.set_do_function_inlining(true);
|
||||
opts.set_do_constant_folding(true);
|
||||
GraphOptimizer optimizer(opts);
|
||||
optimizer.Optimize(lib, lib->env(), lib->device(), g,
|
||||
graph_optimizer_options);
|
||||
}
|
||||
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
|
||||
OptimizeGraph(lib, g, GraphOptimizer::Options());
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
|
|
@ -91,6 +91,17 @@ class GraphOptimizer {
|
|||
TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizer);
|
||||
};
|
||||
|
||||
// Applies graph rewrite optimization such as inlining, dead code
|
||||
// removal, etc.
|
||||
//
|
||||
// **g is a graph constructed based on the runtime library 'lib'.
|
||||
// OptimizeGraph mutates **g extensively and replaces '*g' with a
|
||||
// complete copy. Therefore, the caller should not keep any references
|
||||
// to nodes *g.
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
|
||||
const GraphOptimizer::Options& graph_optimizer_options);
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
|
||||
|
|
|
@ -20,9 +20,10 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/single_threaded_cpu_device.h"
|
||||
|
@ -31,11 +32,12 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
|
|
@ -20,15 +20,16 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
class Env;
|
||||
class Graph;
|
||||
|
||||
// GraphRunner takes a Graph, some inputs to feed, and some outputs
|
||||
// to fetch and executes the graph required to feed and fetch the
|
||||
// inputs and outputs.
|
||||
|
|
|
@ -0,0 +1,865 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/optimizer_cse.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// A few string constant used throughout this module.
|
||||
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
||||
static constexpr const char* const kDeviceArgOp =
|
||||
FunctionLibraryDefinition::kDeviceArgOp;
|
||||
static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
|
||||
static constexpr const char* const kDeviceRetOp =
|
||||
FunctionLibraryDefinition::kDeviceRetOp;
|
||||
static constexpr const char* const kGradientOp =
|
||||
FunctionLibraryDefinition::kGradientOp;
|
||||
static constexpr const char* const kNodeLabel = "Func";
|
||||
static constexpr const char* const kFuncAttr =
|
||||
FunctionLibraryDefinition::kFuncAttr;
|
||||
|
||||
// Represents the index-th output of a node.
|
||||
struct Endpoint {
|
||||
Node* node;
|
||||
int index;
|
||||
|
||||
// Returns the string name represents this endpoint.
|
||||
string name() const {
|
||||
if (index == 0) {
|
||||
return node->name();
|
||||
} else {
|
||||
return strings::StrCat(node->name(), ":", index);
|
||||
}
|
||||
}
|
||||
|
||||
DataType dtype() const { return node->output_type(index); }
|
||||
};
|
||||
|
||||
struct EndpointHash {
|
||||
uint64 operator()(const Endpoint& x) const {
|
||||
return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
|
||||
x.index);
|
||||
}
|
||||
};
|
||||
|
||||
struct EndpointEq {
|
||||
bool operator()(const Endpoint& x, const Endpoint& y) const {
|
||||
return (x.node == y.node) && (x.index == y.index);
|
||||
}
|
||||
};
|
||||
|
||||
// The following Add* routines are used to add a few graph nodes while
|
||||
// functions are transformed.
|
||||
static Node* AddNoOp(StringPiece name, Graph* g) {
|
||||
NodeDef ndef;
|
||||
ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
|
||||
ndef.set_op("NoOp");
|
||||
Status s;
|
||||
Node* ret = g->AddNode(ndef, &s);
|
||||
TF_CHECK_OK(s);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
|
||||
DCHECK_LT(0, input.dtype());
|
||||
NodeDef ndef;
|
||||
ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
|
||||
ndef.set_op("Identity");
|
||||
ndef.add_input(input.name());
|
||||
AddNodeAttr("T", BaseType(input.dtype()), &ndef);
|
||||
Status s;
|
||||
Node* ret = g->AddNode(ndef, &s);
|
||||
TF_CHECK_OK(s);
|
||||
g->AddEdge(input.node, input.index, ret, 0);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<string> InputDevices(const Node& caller) {
|
||||
std::vector<string> input_devices(caller.in_edges().size());
|
||||
std::vector<string> input_tensors(caller.in_edges().size());
|
||||
|
||||
for (const Edge* edge : caller.in_edges()) {
|
||||
if (edge->IsControlEdge()) continue;
|
||||
const string& input_device = edge->src()->has_assigned_device_name()
|
||||
? edge->src()->assigned_device_name()
|
||||
: edge->src()->requested_device();
|
||||
input_devices[edge->dst_input()] = input_device;
|
||||
input_tensors[edge->dst_input()] =
|
||||
absl::StrCat(edge->src()->name(), ":", edge->src_output());
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(4)) {
|
||||
VLOG(4) << "Function instantiation input devices:";
|
||||
for (int i = 0; i < input_devices.size(); ++i) {
|
||||
if (input_tensors[i].empty()) continue; // skip control edges
|
||||
VLOG(4) << " [index " << i << "]"
|
||||
<< " device: " << input_devices[i]
|
||||
<< " (input: " << input_tensors[i] << ")";
|
||||
}
|
||||
}
|
||||
|
||||
return input_devices;
|
||||
}
|
||||
|
||||
// Place input nodes on the same device as the corresponding caller input
|
||||
// node. Do not specify any placement for all other nodes.
|
||||
class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
|
||||
public:
|
||||
explicit DefaultFunctionBodyPlacer(const Node& caller)
|
||||
: input_devices_(InputDevices(caller)) {}
|
||||
|
||||
absl::optional<string> InputNodeDevice(int input_index) const override {
|
||||
return input_devices_[input_index];
|
||||
}
|
||||
absl::optional<string> OutputNodeDevice(int output_index) const override {
|
||||
return absl::nullopt;
|
||||
}
|
||||
bool ColocateInputOutputIdentities() const override { return false; }
|
||||
absl::optional<string> ControlNodeDevice() const override {
|
||||
return absl::nullopt;
|
||||
}
|
||||
absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<string> input_devices_;
|
||||
};
|
||||
|
||||
// Place all nodes on the same device as caller node.
|
||||
class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
|
||||
public:
|
||||
explicit SingleDeviceFunctionBodyPlacer(const Node& caller)
|
||||
: caller_device_(caller.def().device()) {}
|
||||
|
||||
absl::optional<string> InputNodeDevice(int input_index) const override {
|
||||
return caller_device_;
|
||||
}
|
||||
absl::optional<string> OutputNodeDevice(int output_index) const override {
|
||||
return caller_device_;
|
||||
}
|
||||
bool ColocateInputOutputIdentities() const override { return false; }
|
||||
absl::optional<string> ControlNodeDevice() const override {
|
||||
return caller_device_;
|
||||
}
|
||||
absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
|
||||
return caller_device_;
|
||||
}
|
||||
|
||||
private:
|
||||
const string caller_device_;
|
||||
};
|
||||
|
||||
// Place input nodes on the same device as the corresponding caller input
|
||||
// node. Do not place output node. Place control nodes on the same device as
|
||||
// caller node. For all function body nodes overrides job, replica and task
|
||||
// parts of the device assignment to match function caller node.
|
||||
class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
|
||||
public:
|
||||
explicit MultiDeviceFunctionBodyPlacer(const Node& caller)
|
||||
: caller_device_(caller.def().device()),
|
||||
input_devices_(InputDevices(caller)) {
|
||||
has_parsed_caller_device_ =
|
||||
DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_);
|
||||
}
|
||||
|
||||
absl::optional<string> InputNodeDevice(int input_index) const override {
|
||||
return input_devices_[input_index];
|
||||
}
|
||||
absl::optional<string> OutputNodeDevice(int output_index) const override {
|
||||
return absl::nullopt;
|
||||
}
|
||||
bool ColocateInputOutputIdentities() const override { return true; }
|
||||
absl::optional<string> ControlNodeDevice() const override {
|
||||
return caller_device_;
|
||||
}
|
||||
absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
|
||||
// TODO(ezhulenev): If function would have been instantiated as a
|
||||
// multi-device function and executed via FunctionLibraryRuntime, it could
|
||||
// be potentially placed on any available device. However there are multiple
|
||||
// tests relying on this assumption. Fix them, and remove this line.
|
||||
if (ndef.device().empty()) return caller_device_;
|
||||
|
||||
if (!has_parsed_caller_device_) return ndef.device();
|
||||
|
||||
DeviceNameUtils::ParsedName ndef_parsed_device;
|
||||
if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device))
|
||||
return ndef.device();
|
||||
|
||||
if (caller_parsed_device_.has_job) {
|
||||
ndef_parsed_device.has_job = caller_parsed_device_.has_job;
|
||||
ndef_parsed_device.job = caller_parsed_device_.job;
|
||||
}
|
||||
|
||||
if (caller_parsed_device_.has_replica) {
|
||||
ndef_parsed_device.has_replica = caller_parsed_device_.has_replica;
|
||||
ndef_parsed_device.replica = caller_parsed_device_.replica;
|
||||
}
|
||||
|
||||
if (caller_parsed_device_.has_task) {
|
||||
ndef_parsed_device.has_task = caller_parsed_device_.has_task;
|
||||
ndef_parsed_device.task = caller_parsed_device_.task;
|
||||
}
|
||||
return DeviceNameUtils::ParsedNameToString(ndef_parsed_device);
|
||||
}
|
||||
|
||||
private:
|
||||
string caller_device_;
|
||||
bool has_parsed_caller_device_;
|
||||
DeviceNameUtils::ParsedName caller_parsed_device_;
|
||||
std::vector<string> input_devices_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<InlinedFunctionBodyPlacer>
|
||||
InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph,
|
||||
const Node& caller) {
|
||||
VLOG(3) << "Create default placer for inlined function body.";
|
||||
return absl::make_unique<DefaultFunctionBodyPlacer>(caller);
|
||||
}
|
||||
|
||||
std::unique_ptr<InlinedFunctionBodyPlacer>
|
||||
InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph,
|
||||
const Node& caller) {
|
||||
VLOG(3) << "Create single device placer for inlined function body.";
|
||||
return absl::make_unique<SingleDeviceFunctionBodyPlacer>(caller);
|
||||
}
|
||||
|
||||
std::unique_ptr<InlinedFunctionBodyPlacer>
|
||||
InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph,
|
||||
const Node& caller) {
|
||||
VLOG(3) << "Create multi device placer for inlined function body.";
|
||||
return absl::make_unique<MultiDeviceFunctionBodyPlacer>(caller);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Status ValidateNoInline(const FunctionBody* fbody) {
|
||||
const auto attr = AttrSlice(&fbody->fdef.attr());
|
||||
bool noinline = false;
|
||||
if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) {
|
||||
return errors::InvalidArgument(
|
||||
"Can't inline function marked with '_noinline'");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
|
||||
|
||||
// Propagate the debug info of `nodes` in function `func` to the `target` node.
|
||||
// If the debug info of any node is missing, its node name and function name
|
||||
// is used.
|
||||
void PropagateDebugInfoToNode(const string& func,
|
||||
const std::vector<const Node*>& nodes,
|
||||
NodeDef* target) {
|
||||
if (nodes.empty() || target->has_experimental_debug_info()) {
|
||||
return;
|
||||
}
|
||||
for (const Node* node : nodes) {
|
||||
const auto& node_def = node->def();
|
||||
if (node_def.has_experimental_debug_info()) {
|
||||
target->mutable_experimental_debug_info()->MergeFrom(
|
||||
node_def.experimental_debug_info());
|
||||
} else {
|
||||
target->mutable_experimental_debug_info()->add_original_node_names(
|
||||
node_def.name());
|
||||
target->mutable_experimental_debug_info()->add_original_func_names(func);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
string InlineFunctionBodyOptions::DebugString() const {
|
||||
const auto true_false = [](bool b) { return b ? "true" : "false"; };
|
||||
|
||||
const auto keep_caller_node_str = [this]() -> string {
|
||||
switch (keep_caller_node) {
|
||||
case KeepCallerNode::kDoNotKeep:
|
||||
return "DoNotKeep";
|
||||
case KeepCallerNode::kFetchable:
|
||||
return "Fetchable";
|
||||
case KeepCallerNode::kTargetable:
|
||||
return "Targetable";
|
||||
}
|
||||
};
|
||||
|
||||
return absl::StrCat(
|
||||
"disable_inlining=", true_false(disable_inlining),
|
||||
", ignore_noinline=", true_false(ignore_noinline),
|
||||
", inline_impl_selection_group_functions=",
|
||||
true_false(inline_impl_selection_group_functions),
|
||||
", keep_caller_node=", keep_caller_node_str(), ", output_control_src=",
|
||||
output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs"
|
||||
: "ControlOutputs",
|
||||
", inlined_function_body_placer=", inlined_function_body_placer.name,
|
||||
", uniquify_frame_names=", true_false(uniquify_frame_names));
|
||||
}
|
||||
|
||||
Status ValidateInlining(const Node* node, const FunctionBody* fbody,
|
||||
const InlineFunctionBodyOptions& options) {
|
||||
// TODO(ezhulenev): Currently common_runtime function inlining can't guarantee
|
||||
// that all side-effectful ops will be executed after inlining. See Grappler
|
||||
// function_optimizer for details. Unify all function inlining mechanism.
|
||||
// Do not inline if `!fbody->control_ret_nodes.empty()`.
|
||||
|
||||
const auto num_node_inputs = static_cast<size_t>(node->num_inputs());
|
||||
const auto num_node_outputs = static_cast<size_t>(node->num_outputs());
|
||||
|
||||
if (num_node_inputs != fbody->arg_types.size() ||
|
||||
num_node_inputs != fbody->arg_nodes.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Node inputs do not match function arguments: inputs=", num_node_inputs,
|
||||
" arg_types=", fbody->arg_types.size(),
|
||||
" arg_nodes=", fbody->arg_nodes.size());
|
||||
}
|
||||
|
||||
if (num_node_outputs != fbody->ret_types.size() ||
|
||||
num_node_outputs != fbody->ret_nodes.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Node outputs do not match function returns: outputs=",
|
||||
num_node_outputs, " ret_types=", fbody->ret_types.size(),
|
||||
" ret_nodes=", fbody->ret_nodes.size());
|
||||
}
|
||||
|
||||
for (int i = 0; i < node->num_inputs(); ++i) {
|
||||
if (node->input_type(i) != fbody->arg_types[i]) {
|
||||
return errors::InvalidArgument(
|
||||
"Node input type doesn't match function argument type: ",
|
||||
node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < node->num_outputs(); ++i) {
|
||||
if (node->output_type(i) != fbody->ret_types[i]) {
|
||||
return errors::InvalidArgument(
|
||||
"Node output type doesn't match function return type: ",
|
||||
node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.disable_inlining) {
|
||||
return errors::InvalidArgument(
|
||||
"Function inlining explicitly disabled by 'options.disable_inlining'");
|
||||
}
|
||||
|
||||
if (!options.inline_impl_selection_group_functions) {
|
||||
bool is_impl_selection_group_function =
|
||||
fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end();
|
||||
if (is_impl_selection_group_function) {
|
||||
return errors::InvalidArgument(
|
||||
"Inlining of implementation selection group function ",
|
||||
fbody->fdef.signature().name(),
|
||||
" is disabled by options.inline_impl_selection_group_functions");
|
||||
}
|
||||
}
|
||||
|
||||
if (!options.ignore_noinline) {
|
||||
TF_RETURN_IF_ERROR(ValidateNoInline(fbody));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function inlining must preserve function execution semantics with regards to
|
||||
// side-effects visibility. Tensorflow in Eager mode has an automatic control
|
||||
// dependencies tracking mechanism, which enforces well-defined execution order
|
||||
// of all side-effects. Any other frontend (e.g. Swift) must produce graphs
|
||||
// following the same rules, to ensure that function inlining works correctly.
|
||||
//
|
||||
// IMPORTANT: Currently we do not have a true notion of "side-effectful" node,
|
||||
// we assume that all stateful nodes might have side-effects, though it's not
|
||||
// true in practice, e.g. `ReadVariableOp` doesn't have an observable
|
||||
// side-effect.
|
||||
//
|
||||
// Automatic control dependency rules in Tensorflow 2.0 (python in eager mode):
|
||||
//
|
||||
// 1) When a function has a resource (DT_RESOURCE data type) input argument it
|
||||
// "captures" the mutable resource. This is implemented by automatically
|
||||
// adding a incoming control edge from the previous side-effectful op
|
||||
// touching that resource, and an outgoing control edge to the next
|
||||
// side-effectful op using the same resource. This serializes the mutations
|
||||
// of the resource to make graph execution deterministic.
|
||||
//
|
||||
// 2) All stateful ops inside a function body are guaranteed to execute in
|
||||
// program order, this is achieved by adding control edges between stateful
|
||||
// ops at graph construction time. Stateful ops (or ops that must execute)
|
||||
// should be in the function control return set. Having a data edge to the
|
||||
// regular function output might be not enough, because after function
|
||||
// inlining it might happen that data output is unused.
|
||||
//
|
||||
// 3) Furthermore, all ops accepting the same resource as an input are
|
||||
// guaranteed to run in program order. This is also done by adding control
|
||||
// edges at graph construction time. The last op touching the resource
|
||||
// must be in a control return set, which will guarantee that all side
|
||||
// effects to the resource will happen before function completion.
|
||||
//
|
||||
// Function inlining must preserve side-effect visibility:
|
||||
//
|
||||
// 1) All side-effects to the captured resources, that happened before function
|
||||
// call must be visible to the function body nodes using that resources.
|
||||
//
|
||||
// 2) All side-effects to the captured resources, that happened inside function
|
||||
// body, must be visible to every op/function using that resource after the
|
||||
// function call completed.
|
||||
//
|
||||
// To guarantee that these properties are preserved after inlining we:
|
||||
//
|
||||
// 1) Create "input_control_node" NoOp. Function call node incoming control
|
||||
// edges will be forwarded *to* this node. Function inputs (Identity nodes)
|
||||
// will have a control edge *from* this node. If function body has nodes
|
||||
// without inputs, they will have a control edge *from* this node.
|
||||
//
|
||||
// 2) Create "output_control_node" NoOp. All nodes that have incoming control
|
||||
// edge *from* the function call node, will be forwarded to this node.
|
||||
//
|
||||
// We have two options for choosing which nodes will have a control edge *to*
|
||||
// the "output control node":
|
||||
// a) control returns (`control_ret` field in FunctionDef)
|
||||
// b) data returns (`ret` field in FunctionDef)
|
||||
//
|
||||
// We do a) for multi-device function calls in Tensorflow v2 and b)
|
||||
// for the rest for compatibility with Tensorflow v1.
|
||||
//
|
||||
// Following the automatic control dependencies tracking rules, a node that
|
||||
// has an incoming control edge from the function call node is dependent on
|
||||
// the side-effects happening inside the function body. The output control
|
||||
// node will guarantee side-effects execution order.
|
||||
//
|
||||
// If function call node doesn't have an outgoing control edge, it means that
|
||||
// no one is interested in observing side-effects that might have happened.
|
||||
//
|
||||
// Function inlining might leave the graph in partially-placed state. Function
|
||||
// inlining caller must call Placer to guarantee that all nodes are placed.
|
||||
//
|
||||
// Function inlining with `options.override_device=true` will leave graph in
|
||||
// fully placed state, by overriding all inlined nodes devices with the caller
|
||||
// node device, but it will make functions always single-device. These functions
|
||||
// after inlining will not be able to handle resources on multiple devices. This
|
||||
// is currently acceptable for XLA use cases (XLA cluster is always executed on
|
||||
// a single device).
|
||||
//
|
||||
// TODO(ezhulenev): Documentation above is ahead of implementation below.
|
||||
Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
Node* caller, const FunctionBody* fbody,
|
||||
const InlineFunctionBodyOptions& options) {
|
||||
VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
|
||||
<< options.DebugString() << "]";
|
||||
|
||||
Status validation = ValidateInlining(caller, fbody, options);
|
||||
if (!validation.ok()) {
|
||||
return errors::Internal("Inlining mismatch: ", validation.error_message());
|
||||
}
|
||||
|
||||
// Placer is responsible for assigning devices for all nodes that we will add
|
||||
// to the graph.
|
||||
const std::unique_ptr<InlinedFunctionBodyPlacer> placer =
|
||||
options.inlined_function_body_placer.get(*g, *caller);
|
||||
|
||||
// We can't possibly introduce a duplicate control edge during function
|
||||
// inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'.
|
||||
static constexpr bool kDoNotCheckDuplicates = true;
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Helper functions to create `NoOp` and `Identity` nodes for auxiliary
|
||||
// control nodes and inlined function inputs and outputs.
|
||||
|
||||
// Add a NoOp node for function control inputs/outputs.
|
||||
const auto no_op = [&](StringPiece name) -> Node* {
|
||||
Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g);
|
||||
const absl::optional<string> device = placer->ControlNodeDevice();
|
||||
if (device.has_value()) node->set_requested_device(*device);
|
||||
return node;
|
||||
};
|
||||
|
||||
// Add an Identity node for function input.
|
||||
const auto input_identity = [&](StringPiece name, Endpoint input,
|
||||
int index) -> Node* {
|
||||
Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
|
||||
const absl::optional<string> device = placer->InputNodeDevice(index);
|
||||
if (device.has_value()) node->set_requested_device(*device);
|
||||
bool colocate_identity = placer->ColocateInputOutputIdentities();
|
||||
if (colocate_identity) {
|
||||
node->AddAttr(kColocationAttrName,
|
||||
std::vector<string>{absl::StrCat(kColocationGroupPrefix,
|
||||
input.node->name())});
|
||||
}
|
||||
return node;
|
||||
};
|
||||
|
||||
// Add an Identity node for function output.
|
||||
const auto output_identity = [&](StringPiece name, Endpoint input,
|
||||
int index) -> Node* {
|
||||
Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
|
||||
const absl::optional<string> device = placer->OutputNodeDevice(index);
|
||||
if (device.has_value()) node->set_requested_device(*device);
|
||||
bool colocate_identity = placer->ColocateInputOutputIdentities();
|
||||
if (colocate_identity) {
|
||||
node->AddAttr(kColocationAttrName,
|
||||
std::vector<string>{absl::StrCat(kColocationGroupPrefix,
|
||||
input.node->name())});
|
||||
}
|
||||
return node;
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Input edges. For data edges coming into "caller", we first compute the
|
||||
// <src>:<src_output> for the i-th input in "inputs".
|
||||
// If "caller" has any input control dependencies, we add a NoOp
|
||||
// node "input_control_node", which depends on "caller"'s control inputs.
|
||||
std::vector<Endpoint> inputs(caller->num_inputs());
|
||||
Node* input_control_node = nullptr;
|
||||
for (const Edge* e : caller->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
if (input_control_node == nullptr) {
|
||||
input_control_node = no_op("input_control_node");
|
||||
}
|
||||
g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates);
|
||||
} else {
|
||||
inputs[e->dst_input()] = {e->src(), e->src_output()};
|
||||
}
|
||||
}
|
||||
if (input_control_node != nullptr) {
|
||||
VLOG(3) << "Created input control node: " << input_control_node->name();
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Duplicate fbody->graph into 'g'. First, we copy the nodes of
|
||||
// fbody->graph into 'g' except the source and sink nodes. We copy
|
||||
// edges among nodes in 'fbody->graph'.
|
||||
//
|
||||
// If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
|
||||
// remember 'y' in node_map[x->id()].
|
||||
std::vector<Node*> node_map(fbody->graph->num_node_ids());
|
||||
for (Node* n : fbody->graph->op_nodes()) {
|
||||
NodeDef ndef = n->def();
|
||||
|
||||
// Maybe override requested node device assignment.
|
||||
const absl::optional<string> device = placer->BodyNodeDevice(ndef);
|
||||
if (device.has_value()) ndef.set_device(*device);
|
||||
|
||||
// Add inlined function name to inlined node debug information.
|
||||
PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef);
|
||||
|
||||
// Add the function node name as a prefix:
|
||||
// 1) to node name to avoid collisions
|
||||
// 2) to frame name to avoid multiple LoopCond nodes in one frame
|
||||
// 3) to colocation attribute
|
||||
const string prefix = strings::StrCat(caller->name(), "/");
|
||||
TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef,
|
||||
options.uniquify_frame_names));
|
||||
|
||||
Status added_node;
|
||||
Node* clone = g->AddNode(ndef, &added_node);
|
||||
TF_CHECK_OK(added_node);
|
||||
node_map[n->id()] = clone;
|
||||
|
||||
// If there is an input control node, and one of:
|
||||
// a) the node has no data or control inputs, or
|
||||
// b) the node is a function call (including SymbolicGradient),
|
||||
// then add a control edge from the input control node to the clone (only
|
||||
// if it does not already have a control input).
|
||||
//
|
||||
// We must not execute any nodes if the original function call would not
|
||||
// have executed. This is especially critical when the function call is
|
||||
// inside a control-flow construct like tf.cond(). Case (a) ensures that
|
||||
// such nodes do not run.
|
||||
//
|
||||
// The purpose of case (b) is to ensure that instances of case (a) created
|
||||
// by further inlining steps also receive the control dependency.
|
||||
//
|
||||
// This edge is required to transfer execution frame down to all function
|
||||
// body nodes of inlined nested function calls.
|
||||
if (input_control_node) {
|
||||
const auto is_input_edge = [](const Edge* e) -> bool {
|
||||
return !e->src()->IsSource();
|
||||
};
|
||||
const auto is_control_edge = [](const Edge* e) -> bool {
|
||||
return !e->src()->IsSource() && e->IsControlEdge();
|
||||
};
|
||||
|
||||
// Forward execution frame if:
|
||||
//
|
||||
// a) The node has no data or control inputs.
|
||||
// b) OR the node is a function call without control inputs (control edge
|
||||
// will be used in nested function inlining to forward execution frame
|
||||
// to constants inside the function body).
|
||||
//
|
||||
// c) Do not forward control frame to function argument nodes, they will
|
||||
// be connected to the corresponding function input later.
|
||||
const bool forward_execution_frame =
|
||||
(absl::c_none_of(n->in_edges(), is_input_edge) || // (a)
|
||||
(n->IsFunctionCall() && // (b)
|
||||
absl::c_none_of(n->in_edges(), is_control_edge))) && //
|
||||
!n->IsArg(); // (c)
|
||||
|
||||
if (forward_execution_frame) {
|
||||
VLOG(4) << "Add control edge from input control node to: "
|
||||
<< clone->name();
|
||||
g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const Edge* e : fbody->graph->edges()) {
|
||||
if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
|
||||
e->dst()->IsSink()) {
|
||||
continue;
|
||||
}
|
||||
Node* src_copy = node_map[e->src()->id()];
|
||||
Node* dst_copy = node_map[e->dst()->id()];
|
||||
g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Connect input edges.
|
||||
//
|
||||
// We create one Identity node for each input. Then, we connect inputs[i] to
|
||||
// the i-th identity node added. The nodes that previously connected
|
||||
// to the j-th output of i-th arg node are reconnected to the i-th
|
||||
// identity node.
|
||||
//
|
||||
// The added identity nodes depend on "input_control_node".
|
||||
VLOG(4) << "Add input Identity nodes for each function argument:";
|
||||
for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
|
||||
Node* arg = node_map[fbody->arg_nodes[i]->id()];
|
||||
Node* n = input_identity("input", inputs[i], i);
|
||||
VLOG(4) << " [index " << i << "] "
|
||||
<< fbody->fdef.signature().input_arg(i).name() << " as "
|
||||
<< n->name() << " (input: " << inputs[i].name()
|
||||
<< ", requested_device: " << n->requested_device() << ")";
|
||||
|
||||
if (input_control_node) {
|
||||
g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates);
|
||||
}
|
||||
for (const Edge* e : arg->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates);
|
||||
} else {
|
||||
g->AddEdge(n, 0, e->dst(), e->dst_input());
|
||||
}
|
||||
}
|
||||
node_map[fbody->arg_nodes[i]->id()] = n;
|
||||
g->RemoveNode(arg); // 'arg' is disconnected.
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Connect output edges.
|
||||
//
|
||||
// For i-th return node in fbody->graph, we add in "g" an identity node
|
||||
// (outputs[i-th]). We then reconnect every incoming edge into the i-th return
|
||||
// node to the added identity node.
|
||||
//
|
||||
// For every data edge coming out of "callee"s i-th output, we reconnect it to
|
||||
// the i-th identity added above.
|
||||
//
|
||||
// If "callee" is control-depended upon by any other nodes, we add a NoOp node
|
||||
// "output_control_node". "output_control_node" depends on all identity nodes
|
||||
// added above or on all control return nodes (controlled by
|
||||
// `options.output_control_src` value). And nodes previously depend on
|
||||
// "callee" is changed to depend on "output_control_node".
|
||||
//
|
||||
// If `keep_node_fetchable` is `true` we always add an output control node, to
|
||||
// guarantee that executing a fetchable node will execute all side-effects.
|
||||
VLOG(4) << "Add output Identity nodes for each function output argument:";
|
||||
std::vector<Node*> outputs(caller->num_outputs());
|
||||
for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
|
||||
Node* ret = node_map[fbody->ret_nodes[i]->id()];
|
||||
Endpoint data; // Data input for the ret node.
|
||||
for (const Edge* e : ret->in_edges()) {
|
||||
if (!e->IsControlEdge()) {
|
||||
data = {e->src(), e->src_output()};
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK(data.node != nullptr);
|
||||
Node* n = output_identity("output", data, i);
|
||||
outputs[i] = n;
|
||||
VLOG(4) << " [index " << i << "] "
|
||||
<< fbody->fdef.signature().output_arg(i).name() << " as "
|
||||
<< n->name() << " (ret: " << data.node->name() << ":" << data.index
|
||||
<< ", requested_device: " << n->requested_device() << ")";
|
||||
for (const Edge* e : ret->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates);
|
||||
}
|
||||
}
|
||||
g->RemoveNode(ret); // 'ret' is disconnected.
|
||||
}
|
||||
|
||||
Node* output_control_node = nullptr;
|
||||
const bool has_control_outputs = absl::c_any_of(
|
||||
caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); });
|
||||
|
||||
using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
|
||||
const bool keep_caller_node =
|
||||
options.keep_caller_node == KeepCallerNode::kFetchable ||
|
||||
options.keep_caller_node == KeepCallerNode::kTargetable;
|
||||
|
||||
if (has_control_outputs || keep_caller_node) {
|
||||
output_control_node = no_op("output_control_node");
|
||||
VLOG(4) << "Add output control node: " << output_control_node->name();
|
||||
if (options.output_control_src == OutputControlSrc::kDataOutputs) {
|
||||
for (Node* n : outputs) {
|
||||
VLOG(4) << " [data output] add control edge from: " << n->name();
|
||||
g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
|
||||
}
|
||||
} else {
|
||||
for (Node* fbody_node : fbody->control_ret_nodes) {
|
||||
Node* n = node_map[fbody_node->id()];
|
||||
VLOG(4) << " [control output] add control edge from: " << n->name();
|
||||
g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We can't leave output control node without incoming control edges, because
|
||||
// in this case outgoing control edge will loose execution frame information.
|
||||
// We connect input_control_node and output_control_node with a control edge
|
||||
// to forward execution frame to the controlled nodes. Above we add a control
|
||||
// edge to all function calls inside function body, to guarantee that we will
|
||||
// always have input_control_node when we need it.
|
||||
if (output_control_node && output_control_node->in_edges().empty()) {
|
||||
if (input_control_node) {
|
||||
VLOG(4)
|
||||
<< "Add add a control edge between input and output control nodes: "
|
||||
<< input_control_node->name() << " to "
|
||||
<< output_control_node->name();
|
||||
g->AddControlEdge(input_control_node, output_control_node,
|
||||
kDoNotCheckDuplicates);
|
||||
} else {
|
||||
VLOG(4) << "Function inlining potentially dropped execution frame "
|
||||
"information from outgoing control edges.";
|
||||
}
|
||||
}
|
||||
|
||||
for (const Edge* e : caller->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates);
|
||||
} else {
|
||||
g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Add an IdentityN or NoOp node in-place of caller node to keep `caller`
|
||||
// fetchable or targetable.
|
||||
|
||||
if (keep_caller_node) {
|
||||
std::vector<NodeBuilder::NodeOut> output_tensors;
|
||||
absl::c_transform(outputs, std::back_inserter(output_tensors),
|
||||
[](Node* n) { return NodeBuilder::NodeOut(n, 0); });
|
||||
|
||||
Node* caller_substitute_node;
|
||||
if (options.keep_caller_node == KeepCallerNode::kTargetable ||
|
||||
output_tensors.empty()) {
|
||||
// IdentityN node must have at least one data input. If function has no
|
||||
// data outputs, we can't keep it fetchable.
|
||||
TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp")
|
||||
.Device(caller->requested_device())
|
||||
.ControlInput(output_control_node)
|
||||
.Finalize(g, &caller_substitute_node));
|
||||
|
||||
} else if (options.keep_caller_node == KeepCallerNode::kFetchable) {
|
||||
TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN")
|
||||
.Device(caller->requested_device())
|
||||
.Input(output_tensors)
|
||||
.ControlInput(output_control_node)
|
||||
.Finalize(g, &caller_substitute_node));
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// 'caller' is replaced with inlined function body nodes and maybe IdentityN
|
||||
// to keep it fetchable.
|
||||
VLOG(3) << "Successfully inlined function call node: " << caller->name();
|
||||
g->RemoveNode(caller);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
|
||||
const ExpandInlineFunctionsOptions& options) {
|
||||
std::vector<std::pair<Node*, const FunctionBody*>> candidates;
|
||||
|
||||
const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
|
||||
|
||||
for (Node* node : graph->nodes()) {
|
||||
// Skip nodes that are not function calls or SymbolicGradient calls.
|
||||
if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
|
||||
continue;
|
||||
}
|
||||
// Skip function calls that marked noinline.
|
||||
bool noinline;
|
||||
if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
|
||||
VLOG(3) << "noinline: " << SummarizeNode(*node);
|
||||
continue;
|
||||
}
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status s = InstantiateFunctionCall(node->def(), lib, &handle);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Failed to instantiate a function: " << s.error_message();
|
||||
continue;
|
||||
}
|
||||
const FunctionBody* fbody = lib->GetFunctionBody(handle);
|
||||
CHECK_NOTNULL(fbody);
|
||||
candidates.emplace_back(node, fbody);
|
||||
}
|
||||
|
||||
bool inlined_any = false;
|
||||
for (const auto& p : candidates) {
|
||||
Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second,
|
||||
p.first->IsPartitionedCall()
|
||||
? options.multi_device_options
|
||||
: options.native_options);
|
||||
if (inlined.ok()) {
|
||||
inlined_any = true;
|
||||
} else {
|
||||
VLOG(1) << "Failed to inline function call: node=" << p.first->name()
|
||||
<< " error=" << inlined.error_message();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(ezhulenev): Release handles for inlined function calls.
|
||||
|
||||
return inlined_any;
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
|
@ -0,0 +1,236 @@
|
|||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static constexpr const char* const kNoInlineAttr = "_noinline";
|
||||
|
||||
// Optionally override device assignment for nodes added to the graph for
|
||||
// inlined functions:
|
||||
// (1) Identity nodes added in place of function input arguments.
|
||||
// (2) Identity nodes added in place of function return values.
|
||||
// (3) Special NoOp nodes that enforce side-effects execution order.
|
||||
// (4) All nodes inside function body specified in FunctionDef.
|
||||
class InlinedFunctionBodyPlacer {
|
||||
public:
|
||||
virtual ~InlinedFunctionBodyPlacer() = default;
|
||||
|
||||
virtual absl::optional<string> InputNodeDevice(int input_index) const = 0;
|
||||
virtual absl::optional<string> OutputNodeDevice(int output_index) const = 0;
|
||||
// Returns true if the added input/output identity nodes should be colocated
|
||||
// with the corresponding input/output from the function body.
|
||||
virtual bool ColocateInputOutputIdentities() const = 0;
|
||||
virtual absl::optional<string> ControlNodeDevice() const = 0;
|
||||
virtual absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const = 0;
|
||||
|
||||
// Place input nodes on the same device as the corresponding caller input
|
||||
// node. Do not specify any placement for all other nodes.
|
||||
static std::unique_ptr<InlinedFunctionBodyPlacer> DefaultPlacer(
|
||||
const Graph& graph, const Node& caller);
|
||||
|
||||
// Place all nodes on the same device as caller node.
|
||||
static std::unique_ptr<InlinedFunctionBodyPlacer> SingleDevicePlacer(
|
||||
const Graph& graph, const Node& caller);
|
||||
|
||||
// Place input nodes on the same device as the corresponding caller input
|
||||
// node. Do not place output node. Place control nodes on the same device as
|
||||
// caller node. For all function body nodes overrides job, replica and task
|
||||
// parts of the device assignment to match function caller node.
|
||||
static std::unique_ptr<InlinedFunctionBodyPlacer> MultiDevicePlacer(
|
||||
const Graph& graph, const Node& caller);
|
||||
|
||||
using Factory = std::function<std::unique_ptr<InlinedFunctionBodyPlacer>(
|
||||
const Graph&, const Node&)>;
|
||||
|
||||
struct Config {
|
||||
string name;
|
||||
Factory get;
|
||||
};
|
||||
|
||||
static Config Default() { return {"default", DefaultPlacer}; }
|
||||
static Config SingleDevice() { return {"single_device", SingleDevicePlacer}; }
|
||||
static Config MultiDevice() { return {"multi_device", MultiDevicePlacer}; }
|
||||
};
|
||||
|
||||
struct InlineFunctionBodyOptions {
|
||||
// All nodes that have incoming control edge *from* the function call node,
|
||||
// will be forwarded to the "output control node". There are two options for
|
||||
// choosing which nodes will have a control edge *to* the "output control
|
||||
// node":
|
||||
// a) control returns (`control_ret` field in FunctionDef)
|
||||
// b) data returns (`ret` field in FunctionDef)
|
||||
enum class OutputControlSource { kDataOutputs, kControlOutputs };
|
||||
|
||||
// Keep a node in a graph with the same name as the function call node:
|
||||
//
|
||||
// a) DoNotKeep: Function call node is fully inlined, and there is no node in
|
||||
// a graph with the same name.
|
||||
//
|
||||
// b) Fetchable: Add an IdentityN node to the graph in place of the inlined
|
||||
// function call node. It will have a control edge from inlined
|
||||
// 'output_control_node' and data edges from function output nodes.
|
||||
// The IdentityN node will be placed on the same device as the caller node.
|
||||
//
|
||||
// This is mostly for compatibility with Tensorflow v1 and sessions.
|
||||
// When we prepare a graph for execution in
|
||||
// GraphExecutionState::MakeForBaseGraph we don't know what nodes will be
|
||||
// fetched, so we can't safely remove any of them. When graph executed as a
|
||||
// function it has 'Retval' nodes for all fetched tensors, and we can
|
||||
// safely inline function calls.
|
||||
//
|
||||
// c) Targetable: Add a NoOp node to the graph in place of the inlined
|
||||
// function call node. It will have a control edge from inline
|
||||
// 'output_control_node' and no data edges. NoOp node will be placed on the
|
||||
// same device as the caller node. This will keep the inlined function call
|
||||
// node a valid 'session.run' target, and also will keep it a valid control
|
||||
// output node.
|
||||
enum class KeepCallerNode { kDoNotKeep, kFetchable, kTargetable };
|
||||
|
||||
// If 'true' function inlining is completely disabled. This allows to control
|
||||
// function inlining for different types of function calls (see
|
||||
// 'ExpandInlineFunctionsOptions' below).
|
||||
bool disable_inlining = false;
|
||||
// Ignore '_noinline' function attribute.
|
||||
bool ignore_noinline = false;
|
||||
// If 'true' function inlining will inline functions in implementation
|
||||
// selection group. Normally those functions should not be inlined; they will
|
||||
// be handled by Grappler.
|
||||
bool inline_impl_selection_group_functions = false;
|
||||
// Controls if we want to keep a node with the name as the function call node
|
||||
// in a graph after function inlining.
|
||||
KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep;
|
||||
// For compatibility with Tensorflow v1 by default we will use data outputs.
|
||||
// Control returns were added to Tensorflow v2 with automatic control
|
||||
// dependencies tracking in Eager mode.
|
||||
OutputControlSource output_control_src = OutputControlSource::kDataOutputs;
|
||||
// Inlined function body placer decides what requested device assignments
|
||||
// should be added to the nodes added to the graph. See documentation above
|
||||
// for available strategies.
|
||||
InlinedFunctionBodyPlacer::Config inlined_function_body_placer =
|
||||
InlinedFunctionBodyPlacer::Default();
|
||||
// If true, frame names in the function body will be
|
||||
// made unique in the resulting graph (e.g. by prepending a unique prefix).
|
||||
// NOTE(mrry): Only set this option to false when there is a single function
|
||||
// call in the graph (e.g. when making a remote function call via
|
||||
// ClusterFunctionLibraryRuntime). This option is provided because the graph
|
||||
// partitioner generates frame names that must remain unmodified across all
|
||||
// partitions of a multi-device function.
|
||||
bool uniquify_frame_names = true;
|
||||
|
||||
// A human-readable debug string for this options.
|
||||
string DebugString() const;
|
||||
};
|
||||
|
||||
// Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node'
|
||||
// based on the type signature of 'node' and 'fbody':
|
||||
//
|
||||
// (1) Caller node has the same number of inputs and outputs as the function.
|
||||
// (2) Caller node inputs and outputs have the same data types as function
|
||||
// inputs and returns.
|
||||
// (3) Validation rules defined in InlineFunctionBodyOptions.
|
||||
//
|
||||
// If function can't be safely inlined, returns error message with details why
|
||||
// inlining is not possible or safe.
|
||||
Status ValidateInlining(const Node* node, const FunctionBody* fbody,
|
||||
const InlineFunctionBodyOptions& options);
|
||||
|
||||
// Given a "caller" in graph "g", which is a function call of a function
|
||||
// to "fbody". Replaces the "caller" with fbody->graph and connects
|
||||
// edges properly. "override_device" specifies whether inlining should replace
|
||||
// explicitly specified devices inside fbody with the callee's device.
|
||||
//
|
||||
// Returns 'Status::OK()' if function was successfully inlined into the graph.
|
||||
// If function inlining is not possible returns an error with a reason, and
|
||||
// leaves the graph in unmodified state.
|
||||
Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
Node* caller, const FunctionBody* fbody,
|
||||
const InlineFunctionBodyOptions& options);
|
||||
|
||||
// There are three types of function calls that could be invoked during
|
||||
// *Tensorflow graph execution*:
|
||||
//
|
||||
// 1) Native function call (node.type_string() is the function name). These
|
||||
// functions are always executed on a single-device, which is the device of
|
||||
// the function call node.
|
||||
//
|
||||
// 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall
|
||||
// ops) can execute on multiple devices and accept DT_RESOURCE inputs that
|
||||
// belong to different devices. This type of functions was added in
|
||||
// Tensorflow 2.0 Eager mode, and it has control outputs to represent
|
||||
// side-effects that must always execute (see `control_ret` in FunctionDef).
|
||||
//
|
||||
// 3) SymbolicGradient has been deprecated for a while, but we still keep it and
|
||||
// use `native` options for inlining for compatibility.
|
||||
//
|
||||
// We need to have distinct inlining rules for compatibility with Tensorflow v1.
|
||||
//
|
||||
// There are few other places in Tensorflow that could execute functions:
|
||||
//
|
||||
// 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level"
|
||||
// functions directly via function library runtime, without going through
|
||||
// the graph.
|
||||
// 2) tf.data pipelines - also execute functions directly via function library
|
||||
// runtime with custom executors.
|
||||
struct ExpandInlineFunctionsOptions {
|
||||
ExpandInlineFunctionsOptions() : native_options(), multi_device_options() {
|
||||
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
|
||||
multi_device_options.output_control_src = OutputControlSrc::kControlOutputs;
|
||||
}
|
||||
|
||||
InlineFunctionBodyOptions native_options;
|
||||
InlineFunctionBodyOptions multi_device_options;
|
||||
};
|
||||
|
||||
// WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary
|
||||
// workaround that will be enabled only during the function inlining unification
|
||||
// (b/126811947). Contact ezhulenev@ if you think you need it.
|
||||
// TODO(ezhulenev): Delete this function.
|
||||
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
|
||||
const ExpandInlineFunctionsOptions& options);
|
||||
|
||||
// For each node in "graph", if "lib" indicates that the node is a
|
||||
// function call, inline the function body. Returns true if at least
|
||||
// one node is inlined.
|
||||
//
|
||||
// This routine goes through "graph" nodes once and applies the
|
||||
// inlining. The caller may decide to apply the inlining on "graph"
|
||||
// multiple times by calling ExpandInlineFunctions a few times.
|
||||
//
|
||||
// Function calls that can't be safely inlined into the graph (ValidateInlining
|
||||
// returns error), are ignored.
|
||||
//
|
||||
// TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the
|
||||
// FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see
|
||||
// lower_function_call.cc).
|
||||
inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
|
||||
return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_
|
|
@ -20,12 +20,12 @@ limitations under the License.
|
|||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
|
|
@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
|
@ -22,12 +20,13 @@ limitations under the License.
|
|||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/inline_function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
|
|
@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
|
@ -22,12 +20,13 @@ limitations under the License.
|
|||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
|
|
@ -21,12 +21,12 @@ limitations under the License.
|
|||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
|
|
@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
|
@ -22,12 +20,13 @@ limitations under the License.
|
|||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
|
|
@ -13,19 +13,18 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
|
|
@ -16,10 +16,10 @@ limitations under the License.
|
|||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
|
|
@ -19,10 +19,10 @@ limitations under the License.
|
|||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
|
@ -34,7 +35,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
||||
#include "tensorflow/core/common_runtime/placer.h"
|
||||
|
@ -35,7 +36,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_node_util.h"
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
|
|
|
@ -20,7 +20,8 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/eval_const_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/function_utils.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
|
@ -28,9 +29,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
|
|
@ -163,11 +163,10 @@ filegroup(
|
|||
],
|
||||
)
|
||||
|
||||
# Both of these files depend on common_runtime.
|
||||
# This file depends on common_runtime.
|
||||
filegroup(
|
||||
name = "core_cpu_base_no_ops_srcs",
|
||||
srcs = [
|
||||
"graph_constructor.cc",
|
||||
"graph_def_builder_util.cc",
|
||||
],
|
||||
)
|
||||
|
@ -250,7 +249,6 @@ filegroup(
|
|||
"gradients.h",
|
||||
"graph.cc",
|
||||
"graph.h",
|
||||
"graph_constructor.cc",
|
||||
"graph_constructor.h",
|
||||
"graph_def_builder.cc",
|
||||
"graph_def_builder.h",
|
||||
|
|
|
@ -16,189 +16,6 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
|
||||
#define TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class ShapeRefiner;
|
||||
|
||||
// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on
|
||||
// error, in which case *g is left in an incomplete state.
|
||||
//
|
||||
// *g is expected to be an empty graph (with no more than a source and sink
|
||||
// nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph,
|
||||
// see ImportGraphDef.
|
||||
struct GraphConstructorOptions {
|
||||
GraphConstructorOptions() {}
|
||||
|
||||
// If true, allows internal ops in the GraphDef.
|
||||
bool allow_internal_ops = false;
|
||||
|
||||
// If true, the graph def is expected to have fully specified
|
||||
// devices for all nodes. A node in the resulting graph "g" has the
|
||||
// device name set accordingly.
|
||||
//
|
||||
// TODO(zhifengc): if possible, consider removing this option.
|
||||
bool expect_device_spec = false;
|
||||
|
||||
// If true, validates that nodes being converted have all expected attrs
|
||||
// set and no unknown attrs set by calling ValidateNodeDef().
|
||||
// Setting validate_nodes without add_default_attributes, will fail if
|
||||
// the GraphDef does not have all required attributes set.
|
||||
bool validate_nodes = false;
|
||||
|
||||
// If true, GraphConstructor will add attributes with their default
|
||||
// value to the Node when they are missing from the NodeDef.
|
||||
bool add_default_attributes = true;
|
||||
};
|
||||
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
|
||||
const GraphDef& gdef, Graph* g);
|
||||
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
|
||||
GraphDef&& gdef, Graph* g);
|
||||
|
||||
// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function
|
||||
// instantiation.
|
||||
// TODO(irving): This will turn into std::vector<NodeInfoPtr> soon.
|
||||
extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
|
||||
gtl::ArraySlice<NodeDef> nodes, Graph* g);
|
||||
|
||||
// Options for calling ImportGraphDef().
|
||||
struct ImportGraphDefOptions {
|
||||
ImportGraphDefOptions()
|
||||
: uniquify_names(false),
|
||||
uniquify_prefix(false),
|
||||
skip_mapped_nodes(false),
|
||||
validate_shape(true) {}
|
||||
|
||||
// Name prefix to use for nodes imported from the GraphDef. For example, if
|
||||
// prefix="animals" and GraphDef contains a node "bunny" then the node will be
|
||||
// named "animals/bunny" in *g. Must not be already used as a node name or
|
||||
// prefix in the graph.
|
||||
string prefix;
|
||||
|
||||
// If true, imported node names will be modified if their name already exists
|
||||
// in the graph. If false, conflicting names will be treated as an error. Note
|
||||
// that this option has no effect if `prefix` is specified, since `prefix`
|
||||
// will guarantee all node names are unique.
|
||||
bool uniquify_names;
|
||||
|
||||
// If true, `prefix` will be modified if it already exists as a node name or
|
||||
// prefix in the graph. If false, a conflicting prefix will be treated as an
|
||||
// error. This option has no effect if `prefix` isn't specified.
|
||||
bool uniquify_prefix;
|
||||
|
||||
// Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef`
|
||||
// corresponding to `input_map` keys will be remapped to the nodes in `g`
|
||||
// corresponding to the values.
|
||||
//
|
||||
// Keys should not include `prefix`, i.e., a key ID's name should be the name
|
||||
// as it originally appears in `gdef`.
|
||||
//
|
||||
// If this is non-empty, ImportGraphDef must be called with the shape refiner
|
||||
// used to create the existing nodes referenced in `input_map`.
|
||||
// TODO(skyewm): can we remove this requirement? How do we access the original
|
||||
// shape refiner?
|
||||
std::map<SafeTensorId, SafeTensorId> input_map;
|
||||
|
||||
// If true, nodes that will have all output edges removed because of
|
||||
// overrides in `input_map` will not be imported.
|
||||
bool skip_mapped_nodes;
|
||||
|
||||
// The names of existing nodes in `g` that the imported graph should have
|
||||
// control dependencies on.
|
||||
//
|
||||
// Note that to avoid creating many redundant control edges, ImportGraphDef()
|
||||
// won't add control edges to nodes that will inherit the dependencies from
|
||||
// other nodes in `gdef`.
|
||||
std::vector<string> control_dependencies;
|
||||
|
||||
// Tensors in `gdef` that will be returned via the ImportGraphDefResults
|
||||
// output parameter of `ImportGraphDef()`. If this list is non-empty, the
|
||||
// caller must pass a results object to `ImportGraphDef()`. The
|
||||
// `return_tensors` field will be populated with the imported nodes in `g`.
|
||||
//
|
||||
// Entries should not include `prefix`, i.e., each ID's name should be the
|
||||
// name as it originally appears in `gdef`.
|
||||
//
|
||||
// If this contains a tensor that's also being remapped via `input_map`, the
|
||||
// corresponding existing tensor in `g` will be returned.
|
||||
std::vector<SafeTensorId> return_tensors;
|
||||
|
||||
// The names of nodes in `gdef` that will be returned via the
|
||||
// ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
|
||||
// is non-empty, the caller must pass a results object to
|
||||
// `ImportGraphDef()`. The `return_nodes` field will be populated with the
|
||||
// imported nodes in `g`.
|
||||
//
|
||||
// Entries should not include `prefix`, i.e., each node's name should be the
|
||||
// name as it originally appears in `gdef`.
|
||||
//
|
||||
// Unlike `return_tensors`, `input_map` has no effect on the nodes
|
||||
// returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
|
||||
// TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
|
||||
std::vector<string> return_nodes;
|
||||
|
||||
// If true, checks that all colocation constraints are nodes in the GraphDef.
|
||||
bool validate_colocation_constraints = true;
|
||||
|
||||
// If false skips shape validation.
|
||||
bool validate_shape;
|
||||
|
||||
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
|
||||
// with ops that are not defined in the binary calling ImportGraphDef.
|
||||
// Similar to the producer_op_list argument to import_graph_def in the
|
||||
// python API.
|
||||
|
||||
// Try to set default execution device for this grapth.
|
||||
string default_device;
|
||||
};
|
||||
|
||||
// Optional results that may be returned by ImportGraphDef.
|
||||
struct ImportGraphDefResults {
|
||||
// The requested tensors associated with
|
||||
// ImportGraphDefOptions::return_tensors. Note that the index may be different
|
||||
// than the requested index if the returned tensor has been remapped according
|
||||
// to `input_map`.
|
||||
typedef int Index;
|
||||
std::vector<std::pair<Node*, Index>> return_tensors;
|
||||
|
||||
// The requested nodes associated with ImportGraphDefOptions::return_nodes.
|
||||
std::vector<Node*> return_nodes;
|
||||
|
||||
// Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
|
||||
// weren't used as an input to any node in `gdef`. These keys are likely due
|
||||
// to typos, and callers may wish to treat their existence as an error.
|
||||
std::vector<SafeTensorId> missing_unused_input_map_keys;
|
||||
};
|
||||
|
||||
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
|
||||
//
|
||||
// On error, returns non-OK and leaves `*g` unmodified.
|
||||
//
|
||||
// `refiner` can be null. It should be non-null if the caller
|
||||
// intends to add additional nodes to the graph after the import. This
|
||||
// allows the caller to validate shapes of those nodes (since
|
||||
// ShapeRefiner::AddNode must be called in topological order).
|
||||
//
|
||||
// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is
|
||||
// non-empty. It can also be set to fetch the unused input map keys. If it's
|
||||
// non-null, all the vector fields must be empty.
|
||||
//
|
||||
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
|
||||
// as a means of enhancing an existing Graph.
|
||||
extern Status ImportGraphDef(const ImportGraphDefOptions& opts,
|
||||
const GraphDef& gdef, Graph* g,
|
||||
ShapeRefiner* refiner,
|
||||
ImportGraphDefResults* results = nullptr);
|
||||
|
||||
// Make a copy of "src" into "*dest".
|
||||
//
|
||||
// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges
|
||||
// other than the implicit Source/Sink nodes.
|
||||
extern void CopyGraph(const Graph& src, Graph* dest);
|
||||
|
||||
} // namespace tensorflow
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
|
||||
|
|
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
|
|
@ -5951,6 +5951,7 @@ filegroup(
|
|||
"//tensorflow/core/common_runtime:core_cpu_rump_impl", # quantize_training
|
||||
"//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session
|
||||
"//tensorflow/core/common_runtime:device_factory", # device_lib, tfe, tf_session
|
||||
"//tensorflow/core/common_runtime:graph_constructor", # tf_session
|
||||
"//tensorflow/core/common_runtime:session_options", # device_lib, tfe, tf_session
|
||||
"//tensorflow/core/common_runtime:session_state", # tf_session
|
||||
"//tensorflow/core/data/service:server_lib", # server_lib
|
||||
|
|
|
@ -220,7 +220,7 @@ tensorflow::ImportGraphDef
|
|||
[op_gen_lib] # tf_session
|
||||
tensorflow::ApiDefMap::~ApiDefMap
|
||||
|
||||
[core_cpu_base_no_ops] # tf_session
|
||||
[graph_constructor] # tf_session
|
||||
tensorflow::ShapeRefiner::~ShapeRefiner
|
||||
|
||||
[python_api] # tf_session
|
||||
|
|
Loading…
Reference in New Issue