[Build cleanup] Split "core_cpu_impl" into fine-grained targets (3/n).

This change splits many (but not all) of the function-related targets into separate cc_library targets. The main changes are:

* Move "graph/graph_constructor.{h,cc}" to "common_runtime/graph_constructor.{h,cc}" and leave a forwarding header. This code depends on common_runtime and is built as part of it, so it makes sense to move it across. The "graph_constructor" library includes "shape_refiner.{h,cc}", "graph_runner.{h,cc}", and "eval_const_tensor.{h,cc}" because of a circular dependency between these modules.

* Split "function.{h,cc}" into "function_body.{h,cc}", "function_utils.{h,cc}", and "inline_function_utils.{h,cc}" (plus the original, slimmed-down module). This enables other targets in common_runtime to depend on just the function utilities they need, without the whole runtime, which breaks some cycles.

* New fine-grained targets for "constant_folding", "function_optimization_registry", and "graph_optimizer".

PiperOrigin-RevId: 308651243
Change-Id: Iac59c1db4ebdd16609f89d6caee6b7e6ba7ff0a1
This commit is contained in:
Derek Murray 2020-04-27 10:40:33 -07:00 committed by TensorFlower Gardener
parent 7c977c938e
commit d6027bd76a
41 changed files with 2131 additions and 1668 deletions

View File

@ -604,6 +604,7 @@ tf_cc_test(
":c_api",
":c_api_internal",
":c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/hash/hash.h"

View File

@ -491,6 +491,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:core_cpu",
"//tensorflow/core/grappler/costs:graph_properties",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",

View File

@ -58,6 +58,7 @@ tf_cuda_library(
"device_factory.h",
"function.h",
"function_optimization_registry.h",
"graph_constructor.h",
"optimization_registry.h",
"shape_refiner.h",
"//tensorflow/core/graph:core_cpu_headers",
@ -153,7 +154,12 @@ filegroup(
"device_set.h",
"eval_const_tensor.h",
"function.h",
"function_body.h",
"function_utils.h",
"graph_constructor.h",
"graph_optimizer.h",
"graph_runner.h",
"inline_function_utils.h",
"metrics.h",
"process_function_library_runtime.h",
"scoped_allocator.h",
@ -167,9 +173,6 @@ filegroup(
tf_cuda_library(
name = "core_cpu_base_no_ops",
srcs = [
"eval_const_tensor.cc",
"graph_optimizer.h",
"shape_refiner.cc",
"//tensorflow/core/graph:core_cpu_base_no_ops_srcs",
"//tensorflow/core/public:session_options.h",
"//tensorflow/core/public:version.h",
@ -190,6 +193,7 @@ tf_cuda_library(
"@com_google_absl//absl/container:flat_hash_set",
"//third_party/eigen3",
] + if_static([
":graph_constructor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
]),
@ -221,7 +225,6 @@ filegroup(
"executor.h",
"executor_factory.h",
"function_optimization_registry.h",
"graph_optimizer.h",
"input_colocation_exemption_registry.h",
"isolate_placer_inspection_required_ops_pass.h",
"local_device.h",
@ -390,6 +393,27 @@ cc_library(
],
)
cc_library(
name = "constant_folding",
srcs = ["constant_folding.cc"],
hdrs = ["constant_folding.h"],
copts = tf_copts(),
deps = [
":device",
":device_factory",
":executor",
":function_utils",
":graph_constructor",
":memory_types",
":rendezvous_mgr",
":session_options",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
cc_library(
name = "costmodel_manager",
srcs = ["costmodel_manager.cc"],
@ -524,19 +548,6 @@ cc_library(
],
)
cc_library(
name = "graph_view",
srcs = ["graph_view.cc"],
hdrs = ["graph_view.h"],
copts = tf_copts(),
deps = [
":device",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
],
)
cc_library(
name = "device_set",
srcs = ["device_set.cc"],
@ -557,6 +568,112 @@ cc_library(
deps = ["//tensorflow/core:framework"],
)
cc_library(
name = "function_body",
srcs = ["function_body.cc"],
hdrs = ["function_body.h"],
copts = tf_copts(),
deps = [
":device",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
],
)
cc_library(
name = "function_optimization_registry",
srcs = ["function_optimization_registry.cc"],
hdrs = ["function_optimization_registry.h"],
deps = [
":device_set",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "function_utils",
srcs = ["function_utils.cc"],
hdrs = ["function_utils.h"],
deps = [
":function_body",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
# This library also includes "eval_const_tensor", "graph_runner", and
# "shape_refiner", because there are circular dependencies between these
# modules.
cc_library(
name = "graph_constructor",
srcs = [
"eval_const_tensor.cc",
"graph_constructor.cc",
"graph_runner.cc",
"shape_refiner.cc",
"//tensorflow/core/framework:versions.h",
],
hdrs = [
"eval_const_tensor.h",
"graph_constructor.h",
"graph_runner.h",
"shape_refiner.h",
],
copts = tf_copts(),
deps = [
":device",
":device_factory",
":executor",
":function_utils",
":memory_types",
":rendezvous_mgr",
":session_options",
":single_threaded_cpu_device",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "graph_optimizer",
srcs = ["graph_optimizer.cc"],
hdrs = ["graph_optimizer.h"],
copts = tf_copts(),
deps = [
":constant_folding",
":function_utils",
":graph_constructor",
":inline_function_utils",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
],
)
cc_library(
name = "graph_view",
srcs = ["graph_view.cc"],
hdrs = ["graph_view.h"],
copts = tf_copts(),
deps = [
":device",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
],
)
cc_library(
name = "hierarchical_tree_broadcaster",
srcs = ["hierarchical_tree_broadcaster.cc"],
@ -592,6 +709,29 @@ cc_library(
],
)
cc_library(
name = "inline_function_utils",
srcs = ["inline_function_utils.cc"],
hdrs = ["inline_function_utils.h"],
copts = tf_copts(),
deps = [
":device",
":function_body",
":function_utils",
":graph_constructor",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "input_colocation_exemption_registry",
srcs = ["input_colocation_exemption_registry.cc"],
@ -685,6 +825,7 @@ cc_library(
copts = tf_copts(),
deps = [
":device_set",
":graph_constructor",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
@ -1062,11 +1203,7 @@ tf_cuda_library(
srcs = [
"colocation_graph.cc",
"composite_device.cc",
"constant_folding.cc",
"function.cc",
"function_optimization_registry.cc",
"graph_optimizer.cc",
"graph_runner.cc",
"inspecting_placer.cc",
"isolate_placer_inspection_required_ops_pass.cc",
"lower_case_op.cc",
@ -1087,9 +1224,14 @@ tf_cuda_library(
":entry",
":executor",
":executor_factory",
":function_body",
":function_optimization_registry",
":graph_constructor",
":graph_optimizer",
":graph_view",
":local_executor_params",
":immutable_executor_state",
":inline_function_utils",
":input_colocation_exemption_registry",
":pending_counts",
":propagator_debug_utils",

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/metrics.h"
@ -49,7 +50,6 @@ limitations under the License.
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/graph/tensor_id.h"

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/strcat.h"

File diff suppressed because it is too large Load Diff

View File

@ -22,7 +22,10 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@ -30,8 +33,6 @@ limitations under the License.
namespace tensorflow {
static constexpr const char* const kNoInlineAttr = "_noinline";
// Get default customizable kernel creator if set
const CustomKernelCreator* GetDefaultCustomKernelCreator();
@ -67,88 +68,6 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const SessionMetadata* session_metadata,
ProcessFunctionLibraryRuntime* parent);
// FunctionLibraryRuntime::GetFunctionBody returns a description of an
// instantiated function that is represented as a Graph with arg/ret
// nodes annotated.
struct FunctionBody {
FunctionDef fdef;
Graph* graph = nullptr; // owned.
DataTypeVector arg_types;
DataTypeVector ret_types;
// arg_nodes[i] contains the i'th function input. In other words,
// GetNodeAttr(arg_nodes[i]->attrs(), "index") == i.
gtl::InlinedVector<Node*, 4> arg_nodes;
// ret_nodes[i] contains the i'th function output. In other words,
// GetNodeAttr(ret_nodes[i]->attrs(), "index") == i.
gtl::InlinedVector<Node*, 4> ret_nodes;
gtl::InlinedVector<Node*, 4> control_ret_nodes;
FunctionBody() {}
FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
DataTypeSlice ret_types, Graph* g);
~FunctionBody();
};
// Debugging facility. Returns a debug string for a graph
// representing an instantiated function.
string DebugString(const Graph* g);
// A few hand-crafted optimization on the instantiated function body
// (a Graph*).
// Removes nodes that are
// 1. not stateful; and
// 2. not _Arg; and
// 3. not reachable from _Retval.
//
// This function is triggered by function inlining, unlike 'PruneFunctionBody'
// it doesn't preserve nodes that are reachable from control returns. Function
// inlining is responsible for connecting control return nodes with the nodes
// that have input control edges from the inlined function call node.
//
// Assuming that automatic control dependency tracking is correct, absence of
// outgoing control edge from the function call node means that no one needs to
// observe side-effect that might have been generated by the function (see
// documentation in common_runtime/function.cc for details).
//
// Returns true iff any node is removed from "g".
bool RemoveDeadNodes(Graph* g);
// Find a pattern:
// src -(in)-> node -(out)-> dst, where
// 1) node is an identity node;
// 2) in is the only incoming data edge;
// 3) out is the only outgoing data edge;
//
// Rewrites the above pattern with src->dst and relevant data
// dependencies updated. Repeat the process until no such pattern
// left.
bool RemoveIdentityNodes(Graph* g);
// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
bool RemoveListArrayConverter(Graph* g);
// Dump the contents of the "graph" to log files if the logging level is
// sufficiently high.
void DumpGraph(StringPiece label, const Graph* g);
// Applies graph rewrite optimization such as inlining, dead code
// removal, etc.
//
// **g is a graph constructed based on the runtime library 'lib'.
// OptimizeGraph mutates **g extensively and replaces '*g' with a
// complete copy. Therefore, the caller should not keep any references
// to nodes *g.
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
const GraphOptimizer::Options& graph_optimizer_options);
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
// Convert the Graph of a function to a GraphDef.
//
// Handles renaming of nodes to avoid duplicate names which may
// be present after various rewriting operations.
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
// Given a numerical function "f", returns another numerical function
// "g", such that if "f" takes N inputs and produces M outputs, "g"
// takes N + M inputs and produces N outputs. I.e., if
@ -161,221 +80,6 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
// TODO(zhifengc): Asks math expert to say the comment again.
std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f);
// Optionally override device assignment for nodes added to the graph for
// inlined functions:
// (1) Identity nodes added in place of function input arguments.
// (2) Identity nodes added in place of function return values.
// (3) Special NoOp nodes that enforce side-effects execution order.
// (4) All nodes inside function body specified in FunctionDef.
class InlinedFunctionBodyPlacer {
public:
virtual ~InlinedFunctionBodyPlacer() = default;
virtual absl::optional<string> InputNodeDevice(int input_index) const = 0;
virtual absl::optional<string> OutputNodeDevice(int output_index) const = 0;
// Returns true if the added input/output identity nodes should be colocated
// with the corresponding input/output from the function body.
virtual bool ColocateInputOutputIdentities() const = 0;
virtual absl::optional<string> ControlNodeDevice() const = 0;
virtual absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const = 0;
// Place input nodes on the same device as the corresponding caller input
// node. Do not specify any placement for all other nodes.
static std::unique_ptr<InlinedFunctionBodyPlacer> DefaultPlacer(
const Graph& graph, const Node& caller);
// Place all nodes on the same device as caller node.
static std::unique_ptr<InlinedFunctionBodyPlacer> SingleDevicePlacer(
const Graph& graph, const Node& caller);
// Place input nodes on the same device as the corresponding caller input
// node. Do not place output node. Place control nodes on the same device as
// caller node. For all function body nodes overrides job, replica and task
// parts of the device assignment to match function caller node.
static std::unique_ptr<InlinedFunctionBodyPlacer> MultiDevicePlacer(
const Graph& graph, const Node& caller);
using Factory = std::function<std::unique_ptr<InlinedFunctionBodyPlacer>(
const Graph&, const Node&)>;
struct Config {
string name;
Factory get;
};
static Config Default() { return {"default", DefaultPlacer}; }
static Config SingleDevice() { return {"single_device", SingleDevicePlacer}; }
static Config MultiDevice() { return {"multi_device", MultiDevicePlacer}; }
};
struct InlineFunctionBodyOptions {
// All nodes that have incoming control edge *from* the function call node,
// will be forwarded to the "output control node". There are two options for
// choosing which nodes will have a control edge *to* the "output control
// node":
// a) control returns (`control_ret` field in FunctionDef)
// b) data returns (`ret` field in FunctionDef)
enum class OutputControlSource { kDataOutputs, kControlOutputs };
// Keep a node in a graph with the same name as the function call node:
//
// a) DoNotKeep: Function call node is fully inlined, and there is no node in
// a graph with the same name.
//
// b) Fetchable: Add an IdentityN node to the graph in place of the inlined
// function call node. It will have a control edge from inlined
// 'output_control_node' and data edges from function output nodes.
// The IdentityN node will be placed on the same device as the caller node.
//
// This is mostly for compatibility with Tensorflow v1 and sessions.
// When we prepare a graph for execution in
// GraphExecutionState::MakeForBaseGraph we don't know what nodes will be
// fetched, so we can't safely remove any of them. When graph executed as a
// function it has 'Retval' nodes for all fetched tensors, and we can
// safely inline function calls.
//
// c) Targetable: Add a NoOp node to the graph in place of the inlined
// function call node. It will have a control edge from inline
// 'output_control_node' and no data edges. NoOp node will be placed on the
// same device as the caller node. This will keep the inlined function call
// node a valid 'session.run' target, and also will keep it a valid control
// output node.
enum class KeepCallerNode { kDoNotKeep, kFetchable, kTargetable };
// If 'true' function inlining is completely disabled. This allows to control
// function inlining for different types of function calls (see
// 'ExpandInlineFunctionsOptions' below).
bool disable_inlining = false;
// Ignore '_noinline' function attribute.
bool ignore_noinline = false;
// If 'true' function inlining will inline functions in implementation
// selection group. Normally those functions should not be inlined; they will
// be handled by Grappler.
bool inline_impl_selection_group_functions = false;
// Controls if we want to keep a node with the name as the function call node
// in a graph after function inlining.
KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep;
// For compatibility with Tensorflow v1 by default we will use data outputs.
// Control returns were added to Tensorflow v2 with automatic control
// dependencies tracking in Eager mode.
OutputControlSource output_control_src = OutputControlSource::kDataOutputs;
// Inlined function body placer decides what requested device assignments
// should be added to the nodes added to the graph. See documentation above
// for available strategies.
InlinedFunctionBodyPlacer::Config inlined_function_body_placer =
InlinedFunctionBodyPlacer::Default();
// If true, frame names in the function body will be
// made unique in the resulting graph (e.g. by prepending a unique prefix).
// NOTE(mrry): Only set this option to false when there is a single function
// call in the graph (e.g. when making a remote function call via
// ClusterFunctionLibraryRuntime). This option is provided because the graph
// partitioner generates frame names that must remain unmodified across all
// partitions of a multi-device function.
bool uniquify_frame_names = true;
// A human-readable debug string for this options.
string DebugString() const;
};
// Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node'
// based on the type signature of 'node' and 'fbody':
//
// (1) Caller node has the same number of inputs and outputs as the function.
// (2) Caller node inputs and outputs have the same data types as function
// inputs and returns.
// (3) Validation rules defined in InlineFunctionBodyOptions.
//
// If function can't be safely inlined, returns error message with details why
// inlining is not possible or safe.
Status ValidateInlining(const Node* node, const FunctionBody* fbody,
const InlineFunctionBodyOptions& options);
// Given a "caller" in graph "g", which is a function call of a function
// to "fbody". Replaces the "caller" with fbody->graph and connects
// edges properly. "override_device" specifies whether inlining should replace
// explicitly specified devices inside fbody with the callee's device.
//
// Returns 'Status::OK()' if function was successfully inlined into the graph.
// If function inlining is not possible returns an error with a reason, and
// leaves the graph in unmodified state.
Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
Node* caller, const FunctionBody* fbody,
const InlineFunctionBodyOptions& options);
// There are three types of function calls that could be invoked during
// *Tensorflow graph execution*:
//
// 1) Native function call (node.type_string() is the function name). These
// functions are always executed on a single-device, which is the device of
// the function call node.
//
// 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall
// ops) can execute on multiple devices and accept DT_RESOURCE inputs that
// belong to different devices. This type of functions was added in
// Tensorflow 2.0 Eager mode, and it has control outputs to represent
// side-effects that must always execute (see `control_ret` in FunctionDef).
//
// 3) SymbolicGradient has been deprecated for a while, but we still keep it and
// use `native` options for inlining for compatibility.
//
// We need to have distinct inlining rules for compatibility with Tensorflow v1.
//
// There are few other places in Tensorflow that could execute functions:
//
// 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level"
// functions directly via function library runtime, without going through
// the graph.
// 2) tf.data pipelines - also execute functions directly via function library
// runtime with custom executors.
struct ExpandInlineFunctionsOptions {
ExpandInlineFunctionsOptions() : native_options(), multi_device_options() {
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
multi_device_options.output_control_src = OutputControlSrc::kControlOutputs;
}
InlineFunctionBodyOptions native_options;
InlineFunctionBodyOptions multi_device_options;
};
// WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary
// workaround that will be enabled only during the function inlining unification
// (b/126811947). Contact ezhulenev@ if you think you need it.
// TODO(ezhulenev): Delete this function.
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
const ExpandInlineFunctionsOptions& options);
// For each node in "graph", if "lib" indicates that the node is a
// function call, inline the function body. Returns true if at least
// one node is inlined.
//
// This routine goes through "graph" nodes once and applies the
// inlining. The caller may decide to apply the inlining on "graph"
// multiple times by calling ExpandInlineFunctions a few times.
//
// Function calls that can't be safely inlined into the graph (ValidateInlining
// returns error), are ignored.
//
// TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the
// FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see
// lower_function_call.cc).
inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
}
// Extracts function name and attributes from `call_def`
// `call_def` can be a native function call (where the op type is the function
// name) or a call through PartitionedCall/StatefulPartitionedCall.
Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
NameAttrList* function);
// Extracts function name and attributes from `call_def` and invokes
// flr->Instantiate(name, attrs, handle).
// `call_def` can be a native function call (where the op type is the function
// name) or a call through PartitionedCall/StatefulPartitionedCall.
Status InstantiateFunctionCall(const NodeDef& call_def,
FunctionLibraryRuntime* flr,
FunctionLibraryRuntime::Handle* handle);
// Returns true iff `n` represents a function call. `n` can be a native
// function call (n.type_string() is the function name),
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which

View File

@ -0,0 +1,64 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
DataTypeSlice ret_t, Graph* g)
: fdef(f),
graph(g),
arg_types(arg_t.begin(), arg_t.end()),
ret_types(ret_t.begin(), ret_t.end()) {
// 1. Find regular Arg/Ret nodes.
this->arg_nodes.resize(arg_types.size());
this->ret_nodes.resize(ret_types.size());
for (Node* n : this->graph->op_nodes()) {
gtl::InlinedVector<Node*, 4>* node_vec;
if (n->type_string() == FunctionLibraryDefinition::kRetOp ||
n->type_string() == FunctionLibraryDefinition::kDeviceRetOp) {
node_vec = &this->ret_nodes;
} else if (n->type_string() == FunctionLibraryDefinition::kArgOp ||
n->type_string() == FunctionLibraryDefinition::kDeviceArgOp) {
node_vec = &this->arg_nodes;
} else {
continue;
}
int index;
TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index));
CHECK_LE(0, index);
CHECK_LT(index, node_vec->size());
(*node_vec)[index] = n;
}
// 2. Find ControlRet nodes that must be always executed.
std::unordered_set<StringPiece, StringPieceHasher> control_ret_node_names;
for (const auto& control_ret : fdef.control_ret()) {
control_ret_node_names.insert(control_ret.second);
}
this->control_ret_nodes.reserve(control_ret_node_names.size());
for (Node* n : this->graph->op_nodes()) {
if (control_ret_node_names.count(n->name()) > 0) {
this->control_ret_nodes.push_back(n);
}
}
}
FunctionBody::~FunctionBody() { delete this->graph; }
} // end namespace tensorflow

View File

@ -0,0 +1,52 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace tensorflow {
class Graph;
class Node;
// FunctionLibraryRuntime::GetFunctionBody returns a description of an
// instantiated function that is represented as a Graph with arg/ret
// nodes annotated.
struct FunctionBody {
FunctionDef fdef;
Graph* graph = nullptr; // owned.
DataTypeVector arg_types;
DataTypeVector ret_types;
// arg_nodes[i] contains the i'th function input. In other words,
// GetNodeAttr(arg_nodes[i]->attrs(), "index") == i.
gtl::InlinedVector<Node*, 4> arg_nodes;
// ret_nodes[i] contains the i'th function output. In other words,
// GetNodeAttr(ret_nodes[i]->attrs(), "index") == i.
gtl::InlinedVector<Node*, 4> ret_nodes;
gtl::InlinedVector<Node*, 4> control_ret_nodes;
FunctionBody() {}
FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
DataTypeSlice ret_types, Graph* g);
~FunctionBody();
};
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
@ -40,7 +41,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include <atomic>
#include <utility>
@ -25,7 +23,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
@ -33,7 +33,6 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -0,0 +1,368 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
static constexpr const char* const kNodeLabel = "Func";
// Represents the index-th output of a node.
struct Endpoint {
Node* node;
int index;
// Returns the string name represents this endpoint.
string name() const {
if (index == 0) {
return node->name();
} else {
return strings::StrCat(node->name(), ":", index);
}
}
DataType dtype() const { return node->output_type(index); }
};
// The following Add* routines are used to add a few graph nodes while
// functions are transformed.
static Node* AddNoOp(StringPiece name, Graph* g) {
NodeDef ndef;
ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
ndef.set_op("NoOp");
Status s;
Node* ret = g->AddNode(ndef, &s);
TF_CHECK_OK(s);
return ret;
}
static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
DCHECK_LT(0, input.dtype());
NodeDef ndef;
ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
ndef.set_op("Identity");
ndef.add_input(input.name());
AddNodeAttr("T", BaseType(input.dtype()), &ndef);
Status s;
Node* ret = g->AddNode(ndef, &s);
TF_CHECK_OK(s);
g->AddEdge(input.node, input.index, ret, 0);
return ret;
}
void DumpGraph(StringPiece label, const Graph* g) {
// TODO(zhifengc): Change Graph to record #nodes.
VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
<< g->num_edges();
if (VLOG_IS_ON(5)) {
for (const auto& line : str_util::Split(DebugString(g), '\n')) {
VLOG(5) << "|| " << line;
}
}
}
bool RemoveDeadNodes(Graph* g) {
VLOG(2) << "Removing dead nodes";
std::unordered_set<const Node*> nodes;
for (auto n : g->nodes()) {
if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
n->op_def().is_stateful()) {
nodes.insert(n);
}
}
return PruneForReverseReachability(g, std::move(nodes));
}
namespace {
// If 'edges' contains only 1 non-control edge, returns it. Otherwise,
// returns a nullptr.
const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
const Edge* ret = nullptr;
for (const Edge* e : edges) {
if (e->IsControlEdge() || ret) {
// Don't touch it if there is a control edge.
return nullptr;
}
if (IsRefType(e->src()->output_type(e->src_output()))) {
// Don't touch it if the identity node is effectively de-reffing
// a ref.
return nullptr;
}
if (IsRecv(e->src()) || IsSwitch(e->src())) {
// Don't touch it if the identity is introduced for control flow.
// Recv disables all its successors if it receives a dead signal.
// When Recv has an outgoing control edge, the current executor
// would not disable the destination. The current solution (see
// graph_partition.cc) is to add an identity after Recv and change
// the control edge to be from this identity node. So the identity
// can't be removed.
return nullptr;
}
ret = e;
}
return ret;
}
} // end namespace
bool RemoveIdentityNodes(Graph* g) {
VLOG(2) << "Removing identity nodes";
bool removed_any = false;
gtl::InlinedVector<Node*, 8> matches;
for (Node* n : g->nodes()) {
if (!n->IsIdentity()) continue;
if (!GetTheOnlyDataEdge(n->in_edges())) continue;
// Some identity nodes are used as sink nodes to give names to output
// tensors. These nodes are not going to be executed unless they are in the
// fetch set. But if they are in the fetch set we don't want to remove them.
if (n->out_edges().empty()) continue;
matches.push_back(n);
}
if (!matches.empty()) {
for (Node* n : matches) {
const Edge* in = GetTheOnlyDataEdge(n->in_edges());
for (const Edge* out : n->out_edges()) {
if (out->IsControlEdge()) {
g->AddControlEdge(in->src(), out->dst());
} else {
g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
}
}
VLOG(2) << "Remove Identity: " << n->DebugString();
g->RemoveNode(n);
removed_any = true;
}
}
return removed_any;
}
bool RemoveListArrayConverter(Graph* g) {
VLOG(2) << "Removing list array converter";
gtl::InlinedVector<Node*, 8> matches;
for (Node* n : g->nodes()) {
if ((n->type_string() == "_ListToArray") ||
(n->type_string() == "_ArrayToList")) {
matches.push_back(n);
}
}
bool removed_any = false;
if (!matches.empty()) {
for (Node* n : matches) {
if (n->num_inputs() != n->num_outputs()) {
continue; // Not expected. Skip.
}
gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
const auto no_op = [&](StringPiece name) -> Node* {
return AddNoOp(absl::StrCat(n->name(), "/", name), g);
};
const auto identity = [&](StringPiece name, Endpoint input) -> Node* {
Node* node = AddIdentity(absl::StrCat(n->name(), "/", name), g, input);
node->set_requested_device(input.node->def().device());
return node;
};
// Process input edges first.
Node* input_control_node = nullptr;
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
if (input_control_node == nullptr) {
// If node "n" has any control dependencies, adds a no-op
// node (input_control_node) which the additional Identity
// nodes depends on and the input_control_node depends on
// the node "n"s control dependencies.
input_control_node = no_op("input_control_node");
}
g->AddControlEdge(e->src(), input_control_node);
} else {
const int index = e->dst_input();
Node** id_node = &identity_nodes[index];
if (*id_node != nullptr) {
LOG(ERROR)
<< "RemoveListArrayConverter unexpected duplicated input: "
<< e->dst_input();
return removed_any;
}
*id_node = identity("input", {e->src(), e->src_output()});
}
}
// If node "n" has any control dependencies, the added identity
// nodes should have control dependencies on input_control_node.
if (input_control_node != nullptr) {
for (Node* id : identity_nodes) {
g->AddControlEdge(input_control_node, id);
}
}
Node* output_control_node = nullptr;
for (const Edge* e : n->out_edges()) {
if (e->IsControlEdge()) {
if (output_control_node == nullptr) {
// If node "n" is control-depended upon by other nodes,
// adds a no-op node (output_control_node) which those
// nodes will depend on and output_control_node depends on
// all Identity nodes.
output_control_node = no_op("output_control_node");
}
g->AddControlEdge(output_control_node, e->dst());
} else {
Node* id_node = identity_nodes[e->src_output()];
if (id_node == nullptr) {
LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
<< e->src_output();
return removed_any;
}
CHECK(id_node);
g->AddEdge(id_node, 0, e->dst(), e->dst_input());
}
}
// If any nodes have control dependencies on node "n", those
// nodes should have control dependencies on
// output_control_node.
if (output_control_node != nullptr) {
for (Node* id : identity_nodes) {
g->AddControlEdge(id, output_control_node);
}
}
g->RemoveNode(n);
removed_any = true;
}
}
return removed_any;
}
Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
NameAttrList* function) {
if (call_def.op() == "PartitionedCall" ||
call_def.op() == "StatefulPartitionedCall") {
TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", function));
} else {
function->set_name(call_def.op());
*function->mutable_attr() = call_def.attr();
}
return Status::OK();
}
Status InstantiateFunctionCall(const NodeDef& call_def,
FunctionLibraryRuntime* flr,
FunctionLibraryRuntime::Handle* handle) {
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(call_def, &function));
return flr->Instantiate(function.name(), AttrSlice(&function.attr()), handle);
}
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def,
const Node& node) {
return node.IsFunctionCall();
}
string NewName(const Node* n, bool pretty) {
if (pretty) {
return strings::StrCat(n->type_string(), n->id());
} else {
return strings::StrCat("n", n->id());
}
}
// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
// and stash the original NodeDef name as an attr for documentation
// purpose.
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
// We visit nodes in forward topological sort order, which is a
// possible execution order of the graph.
gtl::InlinedVector<const Edge*, 4> inputs;
gdef->Clear();
*gdef->mutable_versions() = g->versions();
std::vector<Node*> start_nodes;
for (Node* n : g->nodes()) {
if (n->out_edges().empty()) {
start_nodes.push_back(n);
}
}
ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
if (!n->IsOp()) return;
NodeDef* ndef = gdef->add_node();
ndef->set_name(NewName(n, pretty));
ndef->set_op(n->type_string());
for (const auto& attr : n->attrs()) {
(*ndef->mutable_attr())[attr.first] = attr.second;
}
if (!n->assigned_device_name().empty()) {
ndef->set_device(n->assigned_device_name());
} else {
ndef->set_device(n->requested_device());
}
inputs.clear();
inputs.resize(n->num_inputs());
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
inputs.push_back(e);
} else {
if (inputs[e->dst_input()] == nullptr) {
inputs[e->dst_input()] = e;
} else {
LOG(WARNING) << "Malformed graph node. multiple input edges: "
<< n->DebugString();
}
}
}
// node->name() is merely NodeDef::name, which are not guaranteed
// to be unique and stable after optimization rewrites. Therefore,
// we use "n<node id>" instead.
for (const Edge* e : inputs) {
if (e == nullptr) {
ndef->add_input("unknown");
continue;
}
const string srcname = NewName(e->src(), pretty);
if (!e->src()->IsOp()) {
} else if (e->IsControlEdge()) {
ndef->add_input(strings::StrCat("^", srcname));
} else if (e->src_output() == 0) {
ndef->add_input(srcname);
} else {
ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
}
}
});
}
string DebugString(const Graph* g) {
GraphDef gdef;
ToGraphDef(g, &gdef);
return DebugString(gdef);
}
} // end namespace tensorflow

View File

@ -0,0 +1,105 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_
#include <functional>
#include <memory>
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class AttrSlice;
class Graph;
class GraphDef;
class NameAttrList;
class Node;
class NodeDef;
class OpDef;
// Debugging facility. Returns a debug string for a graph
// representing an instantiated function.
string DebugString(const Graph* g);
// Dump the contents of the "graph" to log files if the logging level is
// sufficiently high.
void DumpGraph(StringPiece label, const Graph* g);
// Convert the Graph of a function to a GraphDef.
//
// Handles renaming of nodes to avoid duplicate names which may
// be present after various rewriting operations.
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
// Extracts function name and attributes from `call_def`
// `call_def` can be a native function call (where the op type is the function
// name) or a call through PartitionedCall/StatefulPartitionedCall.
Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
NameAttrList* function);
// A few hand-crafted optimization on the instantiated function body
// (a Graph*).
// Removes nodes that are
// 1. not stateful; and
// 2. not _Arg; and
// 3. not reachable from _Retval.
//
// This function is triggered by function inlining, unlike 'PruneFunctionBody'
// it doesn't preserve nodes that are reachable from control returns. Function
// inlining is responsible for connecting control return nodes with the nodes
// that have input control edges from the inlined function call node.
//
// Assuming that automatic control dependency tracking is correct, absence of
// outgoing control edge from the function call node means that no one needs to
// observe side-effect that might have been generated by the function (see
// documentation in common_runtime/function.cc for details).
//
// Returns true iff any node is removed from "g".
bool RemoveDeadNodes(Graph* g);
// Find a pattern:
// src -(in)-> node -(out)-> dst, where
// 1) node is an identity node;
// 2) in is the only incoming data edge;
// 3) out is the only outgoing data edge;
//
// Rewrites the above pattern with src->dst and relevant data
// dependencies updated. Repeat the process until no such pattern
// left.
bool RemoveIdentityNodes(Graph* g);
// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
bool RemoveListArrayConverter(Graph* g);
// Extracts function name and attributes from `call_def` and invokes
// flr->Instantiate(name, attrs, handle).
// `call_def` can be a native function call (where the op type is the function
// name) or a call through PartitionedCall/StatefulPartitionedCall.
Status InstantiateFunctionCall(const NodeDef& call_def,
FunctionLibraryRuntime* flr,
FunctionLibraryRuntime::Handle* handle);
// Returns true iff `n` represents a function call. `n` can be a native
// function call (n.type_string() is the function name),
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
// has been deprecated for a while).
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include <algorithm>
#include <set>

View File

@ -0,0 +1,204 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class ShapeRefiner;
// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on
// error, in which case *g is left in an incomplete state.
//
// *g is expected to be an empty graph (with no more than a source and sink
// nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph,
// see ImportGraphDef.
struct GraphConstructorOptions {
GraphConstructorOptions() {}
// If true, allows internal ops in the GraphDef.
bool allow_internal_ops = false;
// If true, the graph def is expected to have fully specified
// devices for all nodes. A node in the resulting graph "g" has the
// device name set accordingly.
//
// TODO(zhifengc): if possible, consider removing this option.
bool expect_device_spec = false;
// If true, validates that nodes being converted have all expected attrs
// set and no unknown attrs set by calling ValidateNodeDef().
// Setting validate_nodes without add_default_attributes, will fail if
// the GraphDef does not have all required attributes set.
bool validate_nodes = false;
// If true, GraphConstructor will add attributes with their default
// value to the Node when they are missing from the NodeDef.
bool add_default_attributes = true;
};
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g);
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
GraphDef&& gdef, Graph* g);
// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function
// instantiation.
// TODO(irving): This will turn into std::vector<NodeInfoPtr> soon.
extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
gtl::ArraySlice<NodeDef> nodes, Graph* g);
// Options for calling ImportGraphDef().
struct ImportGraphDefOptions {
ImportGraphDefOptions()
: uniquify_names(false),
uniquify_prefix(false),
skip_mapped_nodes(false),
validate_shape(true) {}
// Name prefix to use for nodes imported from the GraphDef. For example, if
// prefix="animals" and GraphDef contains a node "bunny" then the node will be
// named "animals/bunny" in *g. Must not be already used as a node name or
// prefix in the graph.
string prefix;
// If true, imported node names will be modified if their name already exists
// in the graph. If false, conflicting names will be treated as an error. Note
// that this option has no effect if `prefix` is specified, since `prefix`
// will guarantee all node names are unique.
bool uniquify_names;
// If true, `prefix` will be modified if it already exists as a node name or
// prefix in the graph. If false, a conflicting prefix will be treated as an
// error. This option has no effect if `prefix` isn't specified.
bool uniquify_prefix;
// Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef`
// corresponding to `input_map` keys will be remapped to the nodes in `g`
// corresponding to the values.
//
// Keys should not include `prefix`, i.e., a key ID's name should be the name
// as it originally appears in `gdef`.
//
// If this is non-empty, ImportGraphDef must be called with the shape refiner
// used to create the existing nodes referenced in `input_map`.
// TODO(skyewm): can we remove this requirement? How do we access the original
// shape refiner?
std::map<SafeTensorId, SafeTensorId> input_map;
// If true, nodes that will have all output edges removed because of
// overrides in `input_map` will not be imported.
bool skip_mapped_nodes;
// The names of existing nodes in `g` that the imported graph should have
// control dependencies on.
//
// Note that to avoid creating many redundant control edges, ImportGraphDef()
// won't add control edges to nodes that will inherit the dependencies from
// other nodes in `gdef`.
std::vector<string> control_dependencies;
// Tensors in `gdef` that will be returned via the ImportGraphDefResults
// output parameter of `ImportGraphDef()`. If this list is non-empty, the
// caller must pass a results object to `ImportGraphDef()`. The
// `return_tensors` field will be populated with the imported nodes in `g`.
//
// Entries should not include `prefix`, i.e., each ID's name should be the
// name as it originally appears in `gdef`.
//
// If this contains a tensor that's also being remapped via `input_map`, the
// corresponding existing tensor in `g` will be returned.
std::vector<SafeTensorId> return_tensors;
// The names of nodes in `gdef` that will be returned via the
// ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
// is non-empty, the caller must pass a results object to
// `ImportGraphDef()`. The `return_nodes` field will be populated with the
// imported nodes in `g`.
//
// Entries should not include `prefix`, i.e., each node's name should be the
// name as it originally appears in `gdef`.
//
// Unlike `return_tensors`, `input_map` has no effect on the nodes
// returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
// TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
std::vector<string> return_nodes;
// If true, checks that all colocation constraints are nodes in the GraphDef.
bool validate_colocation_constraints = true;
// If false skips shape validation.
bool validate_shape;
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
// with ops that are not defined in the binary calling ImportGraphDef.
// Similar to the producer_op_list argument to import_graph_def in the
// python API.
// Try to set default execution device for this grapth.
string default_device;
};
// Optional results that may be returned by ImportGraphDef.
struct ImportGraphDefResults {
// The requested tensors associated with
// ImportGraphDefOptions::return_tensors. Note that the index may be different
// than the requested index if the returned tensor has been remapped according
// to `input_map`.
typedef int Index;
std::vector<std::pair<Node*, Index>> return_tensors;
// The requested nodes associated with ImportGraphDefOptions::return_nodes.
std::vector<Node*> return_nodes;
// Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
// weren't used as an input to any node in `gdef`. These keys are likely due
// to typos, and callers may wish to treat their existence as an error.
std::vector<SafeTensorId> missing_unused_input_map_keys;
};
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
//
// On error, returns non-OK and leaves `*g` unmodified.
//
// `refiner` can be null. It should be non-null if the caller
// intends to add additional nodes to the graph after the import. This
// allows the caller to validate shapes of those nodes (since
// ShapeRefiner::AddNode must be called in topological order).
//
// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is
// non-empty. It can also be set to fetch the unused input map keys. If it's
// non-null, all the vector fields must be empty.
//
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
// as a means of enhancing an existing Graph.
extern Status ImportGraphDef(const ImportGraphDefOptions& opts,
const GraphDef& gdef, Graph* g,
ShapeRefiner* refiner,
ImportGraphDefResults* results = nullptr);
// Make a copy of "src" into "*dest".
//
// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges
// other than the implicit Source/Sink nodes.
extern void CopyGraph(const Graph& src, Graph* dest);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/placer.h"
@ -40,7 +41,6 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/collective_order.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/graph/validate.h"

View File

@ -16,9 +16,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/optimizer_cse.h"
@ -144,4 +145,19 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
options.inline_with_single_device_body_placer);
}
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
const GraphOptimizer::Options& graph_optimizer_options) {
OptimizerOptions opts;
opts.set_do_common_subexpression_elimination(true);
opts.set_do_function_inlining(true);
opts.set_do_constant_folding(true);
GraphOptimizer optimizer(opts);
optimizer.Optimize(lib, lib->env(), lib->device(), g,
graph_optimizer_options);
}
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
OptimizeGraph(lib, g, GraphOptimizer::Options());
}
} // end namespace tensorflow

View File

@ -91,6 +91,17 @@ class GraphOptimizer {
TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizer);
};
// Applies graph rewrite optimization such as inlining, dead code
// removal, etc.
//
// **g is a graph constructed based on the runtime library 'lib'.
// OptimizeGraph mutates **g extensively and replaces '*g' with a
// complete copy. Therefore, the caller should not keep any references
// to nodes *g.
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
const GraphOptimizer::Options& graph_optimizer_options);
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_

View File

@ -20,9 +20,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/single_threaded_cpu_device.h"
@ -31,11 +32,12 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {

View File

@ -20,15 +20,16 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
class Device;
class Env;
class Graph;
// GraphRunner takes a Graph, some inputs to feed, and some outputs
// to fetch and executes the graph required to feed and fetch the
// inputs and outputs.

View File

@ -0,0 +1,865 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include <deque>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/optimizer_cse.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
namespace {
// A few string constant used throughout this module.
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
static constexpr const char* const kDeviceArgOp =
FunctionLibraryDefinition::kDeviceArgOp;
static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
static constexpr const char* const kDeviceRetOp =
FunctionLibraryDefinition::kDeviceRetOp;
static constexpr const char* const kGradientOp =
FunctionLibraryDefinition::kGradientOp;
static constexpr const char* const kNodeLabel = "Func";
static constexpr const char* const kFuncAttr =
FunctionLibraryDefinition::kFuncAttr;
// Represents the index-th output of a node.
struct Endpoint {
Node* node;
int index;
// Returns the string name represents this endpoint.
string name() const {
if (index == 0) {
return node->name();
} else {
return strings::StrCat(node->name(), ":", index);
}
}
DataType dtype() const { return node->output_type(index); }
};
struct EndpointHash {
uint64 operator()(const Endpoint& x) const {
return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
x.index);
}
};
struct EndpointEq {
bool operator()(const Endpoint& x, const Endpoint& y) const {
return (x.node == y.node) && (x.index == y.index);
}
};
// The following Add* routines are used to add a few graph nodes while
// functions are transformed.
static Node* AddNoOp(StringPiece name, Graph* g) {
NodeDef ndef;
ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
ndef.set_op("NoOp");
Status s;
Node* ret = g->AddNode(ndef, &s);
TF_CHECK_OK(s);
return ret;
}
static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
DCHECK_LT(0, input.dtype());
NodeDef ndef;
ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
ndef.set_op("Identity");
ndef.add_input(input.name());
AddNodeAttr("T", BaseType(input.dtype()), &ndef);
Status s;
Node* ret = g->AddNode(ndef, &s);
TF_CHECK_OK(s);
g->AddEdge(input.node, input.index, ret, 0);
return ret;
}
std::vector<string> InputDevices(const Node& caller) {
std::vector<string> input_devices(caller.in_edges().size());
std::vector<string> input_tensors(caller.in_edges().size());
for (const Edge* edge : caller.in_edges()) {
if (edge->IsControlEdge()) continue;
const string& input_device = edge->src()->has_assigned_device_name()
? edge->src()->assigned_device_name()
: edge->src()->requested_device();
input_devices[edge->dst_input()] = input_device;
input_tensors[edge->dst_input()] =
absl::StrCat(edge->src()->name(), ":", edge->src_output());
}
if (VLOG_IS_ON(4)) {
VLOG(4) << "Function instantiation input devices:";
for (int i = 0; i < input_devices.size(); ++i) {
if (input_tensors[i].empty()) continue; // skip control edges
VLOG(4) << " [index " << i << "]"
<< " device: " << input_devices[i]
<< " (input: " << input_tensors[i] << ")";
}
}
return input_devices;
}
// Place input nodes on the same device as the corresponding caller input
// node. Do not specify any placement for all other nodes.
class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
public:
explicit DefaultFunctionBodyPlacer(const Node& caller)
: input_devices_(InputDevices(caller)) {}
absl::optional<string> InputNodeDevice(int input_index) const override {
return input_devices_[input_index];
}
absl::optional<string> OutputNodeDevice(int output_index) const override {
return absl::nullopt;
}
bool ColocateInputOutputIdentities() const override { return false; }
absl::optional<string> ControlNodeDevice() const override {
return absl::nullopt;
}
absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
return absl::nullopt;
}
private:
const std::vector<string> input_devices_;
};
// Place all nodes on the same device as caller node.
class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
public:
explicit SingleDeviceFunctionBodyPlacer(const Node& caller)
: caller_device_(caller.def().device()) {}
absl::optional<string> InputNodeDevice(int input_index) const override {
return caller_device_;
}
absl::optional<string> OutputNodeDevice(int output_index) const override {
return caller_device_;
}
bool ColocateInputOutputIdentities() const override { return false; }
absl::optional<string> ControlNodeDevice() const override {
return caller_device_;
}
absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
return caller_device_;
}
private:
const string caller_device_;
};
// Place input nodes on the same device as the corresponding caller input
// node. Do not place output node. Place control nodes on the same device as
// caller node. For all function body nodes overrides job, replica and task
// parts of the device assignment to match function caller node.
class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
public:
explicit MultiDeviceFunctionBodyPlacer(const Node& caller)
: caller_device_(caller.def().device()),
input_devices_(InputDevices(caller)) {
has_parsed_caller_device_ =
DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_);
}
absl::optional<string> InputNodeDevice(int input_index) const override {
return input_devices_[input_index];
}
absl::optional<string> OutputNodeDevice(int output_index) const override {
return absl::nullopt;
}
bool ColocateInputOutputIdentities() const override { return true; }
absl::optional<string> ControlNodeDevice() const override {
return caller_device_;
}
absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
// TODO(ezhulenev): If function would have been instantiated as a
// multi-device function and executed via FunctionLibraryRuntime, it could
// be potentially placed on any available device. However there are multiple
// tests relying on this assumption. Fix them, and remove this line.
if (ndef.device().empty()) return caller_device_;
if (!has_parsed_caller_device_) return ndef.device();
DeviceNameUtils::ParsedName ndef_parsed_device;
if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device))
return ndef.device();
if (caller_parsed_device_.has_job) {
ndef_parsed_device.has_job = caller_parsed_device_.has_job;
ndef_parsed_device.job = caller_parsed_device_.job;
}
if (caller_parsed_device_.has_replica) {
ndef_parsed_device.has_replica = caller_parsed_device_.has_replica;
ndef_parsed_device.replica = caller_parsed_device_.replica;
}
if (caller_parsed_device_.has_task) {
ndef_parsed_device.has_task = caller_parsed_device_.has_task;
ndef_parsed_device.task = caller_parsed_device_.task;
}
return DeviceNameUtils::ParsedNameToString(ndef_parsed_device);
}
private:
string caller_device_;
bool has_parsed_caller_device_;
DeviceNameUtils::ParsedName caller_parsed_device_;
std::vector<string> input_devices_;
};
} // namespace
std::unique_ptr<InlinedFunctionBodyPlacer>
InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph,
const Node& caller) {
VLOG(3) << "Create default placer for inlined function body.";
return absl::make_unique<DefaultFunctionBodyPlacer>(caller);
}
std::unique_ptr<InlinedFunctionBodyPlacer>
InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph,
const Node& caller) {
VLOG(3) << "Create single device placer for inlined function body.";
return absl::make_unique<SingleDeviceFunctionBodyPlacer>(caller);
}
std::unique_ptr<InlinedFunctionBodyPlacer>
InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph,
const Node& caller) {
VLOG(3) << "Create multi device placer for inlined function body.";
return absl::make_unique<MultiDeviceFunctionBodyPlacer>(caller);
}
namespace {
Status ValidateNoInline(const FunctionBody* fbody) {
const auto attr = AttrSlice(&fbody->fdef.attr());
bool noinline = false;
if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) {
return errors::InvalidArgument(
"Can't inline function marked with '_noinline'");
}
return Status::OK();
}
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
// Propagate the debug info of `nodes` in function `func` to the `target` node.
// If the debug info of any node is missing, its node name and function name
// is used.
void PropagateDebugInfoToNode(const string& func,
const std::vector<const Node*>& nodes,
NodeDef* target) {
if (nodes.empty() || target->has_experimental_debug_info()) {
return;
}
for (const Node* node : nodes) {
const auto& node_def = node->def();
if (node_def.has_experimental_debug_info()) {
target->mutable_experimental_debug_info()->MergeFrom(
node_def.experimental_debug_info());
} else {
target->mutable_experimental_debug_info()->add_original_node_names(
node_def.name());
target->mutable_experimental_debug_info()->add_original_func_names(func);
}
}
}
} // namespace
string InlineFunctionBodyOptions::DebugString() const {
const auto true_false = [](bool b) { return b ? "true" : "false"; };
const auto keep_caller_node_str = [this]() -> string {
switch (keep_caller_node) {
case KeepCallerNode::kDoNotKeep:
return "DoNotKeep";
case KeepCallerNode::kFetchable:
return "Fetchable";
case KeepCallerNode::kTargetable:
return "Targetable";
}
};
return absl::StrCat(
"disable_inlining=", true_false(disable_inlining),
", ignore_noinline=", true_false(ignore_noinline),
", inline_impl_selection_group_functions=",
true_false(inline_impl_selection_group_functions),
", keep_caller_node=", keep_caller_node_str(), ", output_control_src=",
output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs"
: "ControlOutputs",
", inlined_function_body_placer=", inlined_function_body_placer.name,
", uniquify_frame_names=", true_false(uniquify_frame_names));
}
Status ValidateInlining(const Node* node, const FunctionBody* fbody,
const InlineFunctionBodyOptions& options) {
// TODO(ezhulenev): Currently common_runtime function inlining can't guarantee
// that all side-effectful ops will be executed after inlining. See Grappler
// function_optimizer for details. Unify all function inlining mechanism.
// Do not inline if `!fbody->control_ret_nodes.empty()`.
const auto num_node_inputs = static_cast<size_t>(node->num_inputs());
const auto num_node_outputs = static_cast<size_t>(node->num_outputs());
if (num_node_inputs != fbody->arg_types.size() ||
num_node_inputs != fbody->arg_nodes.size()) {
return errors::InvalidArgument(
"Node inputs do not match function arguments: inputs=", num_node_inputs,
" arg_types=", fbody->arg_types.size(),
" arg_nodes=", fbody->arg_nodes.size());
}
if (num_node_outputs != fbody->ret_types.size() ||
num_node_outputs != fbody->ret_nodes.size()) {
return errors::InvalidArgument(
"Node outputs do not match function returns: outputs=",
num_node_outputs, " ret_types=", fbody->ret_types.size(),
" ret_nodes=", fbody->ret_nodes.size());
}
for (int i = 0; i < node->num_inputs(); ++i) {
if (node->input_type(i) != fbody->arg_types[i]) {
return errors::InvalidArgument(
"Node input type doesn't match function argument type: ",
node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i);
}
}
for (int i = 0; i < node->num_outputs(); ++i) {
if (node->output_type(i) != fbody->ret_types[i]) {
return errors::InvalidArgument(
"Node output type doesn't match function return type: ",
node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i);
}
}
if (options.disable_inlining) {
return errors::InvalidArgument(
"Function inlining explicitly disabled by 'options.disable_inlining'");
}
if (!options.inline_impl_selection_group_functions) {
bool is_impl_selection_group_function =
fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end();
if (is_impl_selection_group_function) {
return errors::InvalidArgument(
"Inlining of implementation selection group function ",
fbody->fdef.signature().name(),
" is disabled by options.inline_impl_selection_group_functions");
}
}
if (!options.ignore_noinline) {
TF_RETURN_IF_ERROR(ValidateNoInline(fbody));
}
return Status::OK();
}
// Function inlining must preserve function execution semantics with regards to
// side-effects visibility. Tensorflow in Eager mode has an automatic control
// dependencies tracking mechanism, which enforces well-defined execution order
// of all side-effects. Any other frontend (e.g. Swift) must produce graphs
// following the same rules, to ensure that function inlining works correctly.
//
// IMPORTANT: Currently we do not have a true notion of "side-effectful" node,
// we assume that all stateful nodes might have side-effects, though it's not
// true in practice, e.g. `ReadVariableOp` doesn't have an observable
// side-effect.
//
// Automatic control dependency rules in Tensorflow 2.0 (python in eager mode):
//
// 1) When a function has a resource (DT_RESOURCE data type) input argument it
// "captures" the mutable resource. This is implemented by automatically
// adding a incoming control edge from the previous side-effectful op
// touching that resource, and an outgoing control edge to the next
// side-effectful op using the same resource. This serializes the mutations
// of the resource to make graph execution deterministic.
//
// 2) All stateful ops inside a function body are guaranteed to execute in
// program order, this is achieved by adding control edges between stateful
// ops at graph construction time. Stateful ops (or ops that must execute)
// should be in the function control return set. Having a data edge to the
// regular function output might be not enough, because after function
// inlining it might happen that data output is unused.
//
// 3) Furthermore, all ops accepting the same resource as an input are
// guaranteed to run in program order. This is also done by adding control
// edges at graph construction time. The last op touching the resource
// must be in a control return set, which will guarantee that all side
// effects to the resource will happen before function completion.
//
// Function inlining must preserve side-effect visibility:
//
// 1) All side-effects to the captured resources, that happened before function
// call must be visible to the function body nodes using that resources.
//
// 2) All side-effects to the captured resources, that happened inside function
// body, must be visible to every op/function using that resource after the
// function call completed.
//
// To guarantee that these properties are preserved after inlining we:
//
// 1) Create "input_control_node" NoOp. Function call node incoming control
// edges will be forwarded *to* this node. Function inputs (Identity nodes)
// will have a control edge *from* this node. If function body has nodes
// without inputs, they will have a control edge *from* this node.
//
// 2) Create "output_control_node" NoOp. All nodes that have incoming control
// edge *from* the function call node, will be forwarded to this node.
//
// We have two options for choosing which nodes will have a control edge *to*
// the "output control node":
// a) control returns (`control_ret` field in FunctionDef)
// b) data returns (`ret` field in FunctionDef)
//
// We do a) for multi-device function calls in Tensorflow v2 and b)
// for the rest for compatibility with Tensorflow v1.
//
// Following the automatic control dependencies tracking rules, a node that
// has an incoming control edge from the function call node is dependent on
// the side-effects happening inside the function body. The output control
// node will guarantee side-effects execution order.
//
// If function call node doesn't have an outgoing control edge, it means that
// no one is interested in observing side-effects that might have happened.
//
// Function inlining might leave the graph in partially-placed state. Function
// inlining caller must call Placer to guarantee that all nodes are placed.
//
// Function inlining with `options.override_device=true` will leave graph in
// fully placed state, by overriding all inlined nodes devices with the caller
// node device, but it will make functions always single-device. These functions
// after inlining will not be able to handle resources on multiple devices. This
// is currently acceptable for XLA use cases (XLA cluster is always executed on
// a single device).
//
// TODO(ezhulenev): Documentation above is ahead of implementation below.
Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
Node* caller, const FunctionBody* fbody,
const InlineFunctionBodyOptions& options) {
VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
<< options.DebugString() << "]";
Status validation = ValidateInlining(caller, fbody, options);
if (!validation.ok()) {
return errors::Internal("Inlining mismatch: ", validation.error_message());
}
// Placer is responsible for assigning devices for all nodes that we will add
// to the graph.
const std::unique_ptr<InlinedFunctionBodyPlacer> placer =
options.inlined_function_body_placer.get(*g, *caller);
// We can't possibly introduce a duplicate control edge during function
// inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'.
static constexpr bool kDoNotCheckDuplicates = true;
// ------------------------------------------------------------------------ //
// Helper functions to create `NoOp` and `Identity` nodes for auxiliary
// control nodes and inlined function inputs and outputs.
// Add a NoOp node for function control inputs/outputs.
const auto no_op = [&](StringPiece name) -> Node* {
Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g);
const absl::optional<string> device = placer->ControlNodeDevice();
if (device.has_value()) node->set_requested_device(*device);
return node;
};
// Add an Identity node for function input.
const auto input_identity = [&](StringPiece name, Endpoint input,
int index) -> Node* {
Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
const absl::optional<string> device = placer->InputNodeDevice(index);
if (device.has_value()) node->set_requested_device(*device);
bool colocate_identity = placer->ColocateInputOutputIdentities();
if (colocate_identity) {
node->AddAttr(kColocationAttrName,
std::vector<string>{absl::StrCat(kColocationGroupPrefix,
input.node->name())});
}
return node;
};
// Add an Identity node for function output.
const auto output_identity = [&](StringPiece name, Endpoint input,
int index) -> Node* {
Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
const absl::optional<string> device = placer->OutputNodeDevice(index);
if (device.has_value()) node->set_requested_device(*device);
bool colocate_identity = placer->ColocateInputOutputIdentities();
if (colocate_identity) {
node->AddAttr(kColocationAttrName,
std::vector<string>{absl::StrCat(kColocationGroupPrefix,
input.node->name())});
}
return node;
};
// ------------------------------------------------------------------------ //
// Input edges. For data edges coming into "caller", we first compute the
// <src>:<src_output> for the i-th input in "inputs".
// If "caller" has any input control dependencies, we add a NoOp
// node "input_control_node", which depends on "caller"'s control inputs.
std::vector<Endpoint> inputs(caller->num_inputs());
Node* input_control_node = nullptr;
for (const Edge* e : caller->in_edges()) {
if (e->IsControlEdge()) {
if (input_control_node == nullptr) {
input_control_node = no_op("input_control_node");
}
g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates);
} else {
inputs[e->dst_input()] = {e->src(), e->src_output()};
}
}
if (input_control_node != nullptr) {
VLOG(3) << "Created input control node: " << input_control_node->name();
}
// ------------------------------------------------------------------------ //
// Duplicate fbody->graph into 'g'. First, we copy the nodes of
// fbody->graph into 'g' except the source and sink nodes. We copy
// edges among nodes in 'fbody->graph'.
//
// If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
// remember 'y' in node_map[x->id()].
std::vector<Node*> node_map(fbody->graph->num_node_ids());
for (Node* n : fbody->graph->op_nodes()) {
NodeDef ndef = n->def();
// Maybe override requested node device assignment.
const absl::optional<string> device = placer->BodyNodeDevice(ndef);
if (device.has_value()) ndef.set_device(*device);
// Add inlined function name to inlined node debug information.
PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef);
// Add the function node name as a prefix:
// 1) to node name to avoid collisions
// 2) to frame name to avoid multiple LoopCond nodes in one frame
// 3) to colocation attribute
const string prefix = strings::StrCat(caller->name(), "/");
TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef,
options.uniquify_frame_names));
Status added_node;
Node* clone = g->AddNode(ndef, &added_node);
TF_CHECK_OK(added_node);
node_map[n->id()] = clone;
// If there is an input control node, and one of:
// a) the node has no data or control inputs, or
// b) the node is a function call (including SymbolicGradient),
// then add a control edge from the input control node to the clone (only
// if it does not already have a control input).
//
// We must not execute any nodes if the original function call would not
// have executed. This is especially critical when the function call is
// inside a control-flow construct like tf.cond(). Case (a) ensures that
// such nodes do not run.
//
// The purpose of case (b) is to ensure that instances of case (a) created
// by further inlining steps also receive the control dependency.
//
// This edge is required to transfer execution frame down to all function
// body nodes of inlined nested function calls.
if (input_control_node) {
const auto is_input_edge = [](const Edge* e) -> bool {
return !e->src()->IsSource();
};
const auto is_control_edge = [](const Edge* e) -> bool {
return !e->src()->IsSource() && e->IsControlEdge();
};
// Forward execution frame if:
//
// a) The node has no data or control inputs.
// b) OR the node is a function call without control inputs (control edge
// will be used in nested function inlining to forward execution frame
// to constants inside the function body).
//
// c) Do not forward control frame to function argument nodes, they will
// be connected to the corresponding function input later.
const bool forward_execution_frame =
(absl::c_none_of(n->in_edges(), is_input_edge) || // (a)
(n->IsFunctionCall() && // (b)
absl::c_none_of(n->in_edges(), is_control_edge))) && //
!n->IsArg(); // (c)
if (forward_execution_frame) {
VLOG(4) << "Add control edge from input control node to: "
<< clone->name();
g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates);
}
}
}
for (const Edge* e : fbody->graph->edges()) {
if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
e->dst()->IsSink()) {
continue;
}
Node* src_copy = node_map[e->src()->id()];
Node* dst_copy = node_map[e->dst()->id()];
g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
}
// ------------------------------------------------------------------------ //
// Connect input edges.
//
// We create one Identity node for each input. Then, we connect inputs[i] to
// the i-th identity node added. The nodes that previously connected
// to the j-th output of i-th arg node are reconnected to the i-th
// identity node.
//
// The added identity nodes depend on "input_control_node".
VLOG(4) << "Add input Identity nodes for each function argument:";
for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
Node* arg = node_map[fbody->arg_nodes[i]->id()];
Node* n = input_identity("input", inputs[i], i);
VLOG(4) << " [index " << i << "] "
<< fbody->fdef.signature().input_arg(i).name() << " as "
<< n->name() << " (input: " << inputs[i].name()
<< ", requested_device: " << n->requested_device() << ")";
if (input_control_node) {
g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates);
}
for (const Edge* e : arg->out_edges()) {
if (e->IsControlEdge()) {
g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates);
} else {
g->AddEdge(n, 0, e->dst(), e->dst_input());
}
}
node_map[fbody->arg_nodes[i]->id()] = n;
g->RemoveNode(arg); // 'arg' is disconnected.
}
// ------------------------------------------------------------------------ //
// Connect output edges.
//
// For i-th return node in fbody->graph, we add in "g" an identity node
// (outputs[i-th]). We then reconnect every incoming edge into the i-th return
// node to the added identity node.
//
// For every data edge coming out of "callee"s i-th output, we reconnect it to
// the i-th identity added above.
//
// If "callee" is control-depended upon by any other nodes, we add a NoOp node
// "output_control_node". "output_control_node" depends on all identity nodes
// added above or on all control return nodes (controlled by
// `options.output_control_src` value). And nodes previously depend on
// "callee" is changed to depend on "output_control_node".
//
// If `keep_node_fetchable` is `true` we always add an output control node, to
// guarantee that executing a fetchable node will execute all side-effects.
VLOG(4) << "Add output Identity nodes for each function output argument:";
std::vector<Node*> outputs(caller->num_outputs());
for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
Node* ret = node_map[fbody->ret_nodes[i]->id()];
Endpoint data; // Data input for the ret node.
for (const Edge* e : ret->in_edges()) {
if (!e->IsControlEdge()) {
data = {e->src(), e->src_output()};
break;
}
}
CHECK(data.node != nullptr);
Node* n = output_identity("output", data, i);
outputs[i] = n;
VLOG(4) << " [index " << i << "] "
<< fbody->fdef.signature().output_arg(i).name() << " as "
<< n->name() << " (ret: " << data.node->name() << ":" << data.index
<< ", requested_device: " << n->requested_device() << ")";
for (const Edge* e : ret->in_edges()) {
if (e->IsControlEdge()) {
g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates);
}
}
g->RemoveNode(ret); // 'ret' is disconnected.
}
Node* output_control_node = nullptr;
const bool has_control_outputs = absl::c_any_of(
caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); });
using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
const bool keep_caller_node =
options.keep_caller_node == KeepCallerNode::kFetchable ||
options.keep_caller_node == KeepCallerNode::kTargetable;
if (has_control_outputs || keep_caller_node) {
output_control_node = no_op("output_control_node");
VLOG(4) << "Add output control node: " << output_control_node->name();
if (options.output_control_src == OutputControlSrc::kDataOutputs) {
for (Node* n : outputs) {
VLOG(4) << " [data output] add control edge from: " << n->name();
g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
}
} else {
for (Node* fbody_node : fbody->control_ret_nodes) {
Node* n = node_map[fbody_node->id()];
VLOG(4) << " [control output] add control edge from: " << n->name();
g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
}
}
}
// We can't leave output control node without incoming control edges, because
// in this case outgoing control edge will loose execution frame information.
// We connect input_control_node and output_control_node with a control edge
// to forward execution frame to the controlled nodes. Above we add a control
// edge to all function calls inside function body, to guarantee that we will
// always have input_control_node when we need it.
if (output_control_node && output_control_node->in_edges().empty()) {
if (input_control_node) {
VLOG(4)
<< "Add add a control edge between input and output control nodes: "
<< input_control_node->name() << " to "
<< output_control_node->name();
g->AddControlEdge(input_control_node, output_control_node,
kDoNotCheckDuplicates);
} else {
VLOG(4) << "Function inlining potentially dropped execution frame "
"information from outgoing control edges.";
}
}
for (const Edge* e : caller->out_edges()) {
if (e->IsControlEdge()) {
g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates);
} else {
g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
}
}
// ------------------------------------------------------------------------ //
// Add an IdentityN or NoOp node in-place of caller node to keep `caller`
// fetchable or targetable.
if (keep_caller_node) {
std::vector<NodeBuilder::NodeOut> output_tensors;
absl::c_transform(outputs, std::back_inserter(output_tensors),
[](Node* n) { return NodeBuilder::NodeOut(n, 0); });
Node* caller_substitute_node;
if (options.keep_caller_node == KeepCallerNode::kTargetable ||
output_tensors.empty()) {
// IdentityN node must have at least one data input. If function has no
// data outputs, we can't keep it fetchable.
TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp")
.Device(caller->requested_device())
.ControlInput(output_control_node)
.Finalize(g, &caller_substitute_node));
} else if (options.keep_caller_node == KeepCallerNode::kFetchable) {
TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN")
.Device(caller->requested_device())
.Input(output_tensors)
.ControlInput(output_control_node)
.Finalize(g, &caller_substitute_node));
}
}
// ------------------------------------------------------------------------ //
// 'caller' is replaced with inlined function body nodes and maybe IdentityN
// to keep it fetchable.
VLOG(3) << "Successfully inlined function call node: " << caller->name();
g->RemoveNode(caller);
return Status::OK();
}
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
const ExpandInlineFunctionsOptions& options) {
std::vector<std::pair<Node*, const FunctionBody*>> candidates;
const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
for (Node* node : graph->nodes()) {
// Skip nodes that are not function calls or SymbolicGradient calls.
if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
continue;
}
// Skip function calls that marked noinline.
bool noinline;
if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
VLOG(3) << "noinline: " << SummarizeNode(*node);
continue;
}
FunctionLibraryRuntime::Handle handle;
Status s = InstantiateFunctionCall(node->def(), lib, &handle);
if (!s.ok()) {
LOG(ERROR) << "Failed to instantiate a function: " << s.error_message();
continue;
}
const FunctionBody* fbody = lib->GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
candidates.emplace_back(node, fbody);
}
bool inlined_any = false;
for (const auto& p : candidates) {
Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second,
p.first->IsPartitionedCall()
? options.multi_device_options
: options.native_options);
if (inlined.ok()) {
inlined_any = true;
} else {
VLOG(1) << "Failed to inline function call: node=" << p.first->name()
<< " error=" << inlined.error_message();
}
}
// TODO(ezhulenev): Release handles for inlined function calls.
return inlined_any;
}
} // end namespace tensorflow

View File

@ -0,0 +1,236 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_
#include <functional>
#include <memory>
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
static constexpr const char* const kNoInlineAttr = "_noinline";
// Optionally override device assignment for nodes added to the graph for
// inlined functions:
// (1) Identity nodes added in place of function input arguments.
// (2) Identity nodes added in place of function return values.
// (3) Special NoOp nodes that enforce side-effects execution order.
// (4) All nodes inside function body specified in FunctionDef.
class InlinedFunctionBodyPlacer {
public:
virtual ~InlinedFunctionBodyPlacer() = default;
virtual absl::optional<string> InputNodeDevice(int input_index) const = 0;
virtual absl::optional<string> OutputNodeDevice(int output_index) const = 0;
// Returns true if the added input/output identity nodes should be colocated
// with the corresponding input/output from the function body.
virtual bool ColocateInputOutputIdentities() const = 0;
virtual absl::optional<string> ControlNodeDevice() const = 0;
virtual absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const = 0;
// Place input nodes on the same device as the corresponding caller input
// node. Do not specify any placement for all other nodes.
static std::unique_ptr<InlinedFunctionBodyPlacer> DefaultPlacer(
const Graph& graph, const Node& caller);
// Place all nodes on the same device as caller node.
static std::unique_ptr<InlinedFunctionBodyPlacer> SingleDevicePlacer(
const Graph& graph, const Node& caller);
// Place input nodes on the same device as the corresponding caller input
// node. Do not place output node. Place control nodes on the same device as
// caller node. For all function body nodes overrides job, replica and task
// parts of the device assignment to match function caller node.
static std::unique_ptr<InlinedFunctionBodyPlacer> MultiDevicePlacer(
const Graph& graph, const Node& caller);
using Factory = std::function<std::unique_ptr<InlinedFunctionBodyPlacer>(
const Graph&, const Node&)>;
struct Config {
string name;
Factory get;
};
static Config Default() { return {"default", DefaultPlacer}; }
static Config SingleDevice() { return {"single_device", SingleDevicePlacer}; }
static Config MultiDevice() { return {"multi_device", MultiDevicePlacer}; }
};
struct InlineFunctionBodyOptions {
// All nodes that have incoming control edge *from* the function call node,
// will be forwarded to the "output control node". There are two options for
// choosing which nodes will have a control edge *to* the "output control
// node":
// a) control returns (`control_ret` field in FunctionDef)
// b) data returns (`ret` field in FunctionDef)
enum class OutputControlSource { kDataOutputs, kControlOutputs };
// Keep a node in a graph with the same name as the function call node:
//
// a) DoNotKeep: Function call node is fully inlined, and there is no node in
// a graph with the same name.
//
// b) Fetchable: Add an IdentityN node to the graph in place of the inlined
// function call node. It will have a control edge from inlined
// 'output_control_node' and data edges from function output nodes.
// The IdentityN node will be placed on the same device as the caller node.
//
// This is mostly for compatibility with Tensorflow v1 and sessions.
// When we prepare a graph for execution in
// GraphExecutionState::MakeForBaseGraph we don't know what nodes will be
// fetched, so we can't safely remove any of them. When graph executed as a
// function it has 'Retval' nodes for all fetched tensors, and we can
// safely inline function calls.
//
// c) Targetable: Add a NoOp node to the graph in place of the inlined
// function call node. It will have a control edge from inline
// 'output_control_node' and no data edges. NoOp node will be placed on the
// same device as the caller node. This will keep the inlined function call
// node a valid 'session.run' target, and also will keep it a valid control
// output node.
enum class KeepCallerNode { kDoNotKeep, kFetchable, kTargetable };
// If 'true' function inlining is completely disabled. This allows to control
// function inlining for different types of function calls (see
// 'ExpandInlineFunctionsOptions' below).
bool disable_inlining = false;
// Ignore '_noinline' function attribute.
bool ignore_noinline = false;
// If 'true' function inlining will inline functions in implementation
// selection group. Normally those functions should not be inlined; they will
// be handled by Grappler.
bool inline_impl_selection_group_functions = false;
// Controls if we want to keep a node with the name as the function call node
// in a graph after function inlining.
KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep;
// For compatibility with Tensorflow v1 by default we will use data outputs.
// Control returns were added to Tensorflow v2 with automatic control
// dependencies tracking in Eager mode.
OutputControlSource output_control_src = OutputControlSource::kDataOutputs;
// Inlined function body placer decides what requested device assignments
// should be added to the nodes added to the graph. See documentation above
// for available strategies.
InlinedFunctionBodyPlacer::Config inlined_function_body_placer =
InlinedFunctionBodyPlacer::Default();
// If true, frame names in the function body will be
// made unique in the resulting graph (e.g. by prepending a unique prefix).
// NOTE(mrry): Only set this option to false when there is a single function
// call in the graph (e.g. when making a remote function call via
// ClusterFunctionLibraryRuntime). This option is provided because the graph
// partitioner generates frame names that must remain unmodified across all
// partitions of a multi-device function.
bool uniquify_frame_names = true;
// A human-readable debug string for this options.
string DebugString() const;
};
// Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node'
// based on the type signature of 'node' and 'fbody':
//
// (1) Caller node has the same number of inputs and outputs as the function.
// (2) Caller node inputs and outputs have the same data types as function
// inputs and returns.
// (3) Validation rules defined in InlineFunctionBodyOptions.
//
// If function can't be safely inlined, returns error message with details why
// inlining is not possible or safe.
Status ValidateInlining(const Node* node, const FunctionBody* fbody,
const InlineFunctionBodyOptions& options);
// Given a "caller" in graph "g", which is a function call of a function
// to "fbody". Replaces the "caller" with fbody->graph and connects
// edges properly. "override_device" specifies whether inlining should replace
// explicitly specified devices inside fbody with the callee's device.
//
// Returns 'Status::OK()' if function was successfully inlined into the graph.
// If function inlining is not possible returns an error with a reason, and
// leaves the graph in unmodified state.
Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
Node* caller, const FunctionBody* fbody,
const InlineFunctionBodyOptions& options);
// There are three types of function calls that could be invoked during
// *Tensorflow graph execution*:
//
// 1) Native function call (node.type_string() is the function name). These
// functions are always executed on a single-device, which is the device of
// the function call node.
//
// 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall
// ops) can execute on multiple devices and accept DT_RESOURCE inputs that
// belong to different devices. This type of functions was added in
// Tensorflow 2.0 Eager mode, and it has control outputs to represent
// side-effects that must always execute (see `control_ret` in FunctionDef).
//
// 3) SymbolicGradient has been deprecated for a while, but we still keep it and
// use `native` options for inlining for compatibility.
//
// We need to have distinct inlining rules for compatibility with Tensorflow v1.
//
// There are few other places in Tensorflow that could execute functions:
//
// 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level"
// functions directly via function library runtime, without going through
// the graph.
// 2) tf.data pipelines - also execute functions directly via function library
// runtime with custom executors.
struct ExpandInlineFunctionsOptions {
ExpandInlineFunctionsOptions() : native_options(), multi_device_options() {
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
multi_device_options.output_control_src = OutputControlSrc::kControlOutputs;
}
InlineFunctionBodyOptions native_options;
InlineFunctionBodyOptions multi_device_options;
};
// WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary
// workaround that will be enabled only during the function inlining unification
// (b/126811947). Contact ezhulenev@ if you think you need it.
// TODO(ezhulenev): Delete this function.
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
const ExpandInlineFunctionsOptions& options);
// For each node in "graph", if "lib" indicates that the node is a
// function call, inline the function body. Returns true if at least
// one node is inlined.
//
// This routine goes through "graph" nodes once and applies the
// inlining. The caller may decide to apply the inlining on "graph"
// multiple times by calling ExpandInlineFunctions a few times.
//
// Function calls that can't be safely inlined into the graph (ValidateInlining
// returns error), are ignored.
//
// TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the
// FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see
// lower_function_call.cc).
inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
}
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_

View File

@ -20,12 +20,12 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@ -22,12 +20,13 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/inline_function_utils.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph.h"

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@ -22,12 +20,13 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -21,12 +21,12 @@ limitations under the License.
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@ -22,12 +20,13 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -13,19 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -16,10 +16,10 @@ limitations under the License.
#include <algorithm>
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_partition.h"
namespace tensorflow {

View File

@ -19,10 +19,10 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/partitioning_utils.h"
#include "tensorflow/core/common_runtime/placer.h"
@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/lib/core/blocking_counter.h"

View File

@ -20,7 +20,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/eval_const_tensor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_utils.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -28,9 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {

View File

@ -163,11 +163,10 @@ filegroup(
],
)
# Both of these files depend on common_runtime.
# This file depends on common_runtime.
filegroup(
name = "core_cpu_base_no_ops_srcs",
srcs = [
"graph_constructor.cc",
"graph_def_builder_util.cc",
],
)
@ -250,7 +249,6 @@ filegroup(
"gradients.h",
"graph.cc",
"graph.h",
"graph_constructor.cc",
"graph_constructor.h",
"graph_def_builder.cc",
"graph_def_builder.h",

View File

@ -16,189 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
#define TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class ShapeRefiner;
// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on
// error, in which case *g is left in an incomplete state.
//
// *g is expected to be an empty graph (with no more than a source and sink
// nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph,
// see ImportGraphDef.
struct GraphConstructorOptions {
GraphConstructorOptions() {}
// If true, allows internal ops in the GraphDef.
bool allow_internal_ops = false;
// If true, the graph def is expected to have fully specified
// devices for all nodes. A node in the resulting graph "g" has the
// device name set accordingly.
//
// TODO(zhifengc): if possible, consider removing this option.
bool expect_device_spec = false;
// If true, validates that nodes being converted have all expected attrs
// set and no unknown attrs set by calling ValidateNodeDef().
// Setting validate_nodes without add_default_attributes, will fail if
// the GraphDef does not have all required attributes set.
bool validate_nodes = false;
// If true, GraphConstructor will add attributes with their default
// value to the Node when they are missing from the NodeDef.
bool add_default_attributes = true;
};
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g);
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
GraphDef&& gdef, Graph* g);
// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function
// instantiation.
// TODO(irving): This will turn into std::vector<NodeInfoPtr> soon.
extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
gtl::ArraySlice<NodeDef> nodes, Graph* g);
// Options for calling ImportGraphDef().
struct ImportGraphDefOptions {
ImportGraphDefOptions()
: uniquify_names(false),
uniquify_prefix(false),
skip_mapped_nodes(false),
validate_shape(true) {}
// Name prefix to use for nodes imported from the GraphDef. For example, if
// prefix="animals" and GraphDef contains a node "bunny" then the node will be
// named "animals/bunny" in *g. Must not be already used as a node name or
// prefix in the graph.
string prefix;
// If true, imported node names will be modified if their name already exists
// in the graph. If false, conflicting names will be treated as an error. Note
// that this option has no effect if `prefix` is specified, since `prefix`
// will guarantee all node names are unique.
bool uniquify_names;
// If true, `prefix` will be modified if it already exists as a node name or
// prefix in the graph. If false, a conflicting prefix will be treated as an
// error. This option has no effect if `prefix` isn't specified.
bool uniquify_prefix;
// Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef`
// corresponding to `input_map` keys will be remapped to the nodes in `g`
// corresponding to the values.
//
// Keys should not include `prefix`, i.e., a key ID's name should be the name
// as it originally appears in `gdef`.
//
// If this is non-empty, ImportGraphDef must be called with the shape refiner
// used to create the existing nodes referenced in `input_map`.
// TODO(skyewm): can we remove this requirement? How do we access the original
// shape refiner?
std::map<SafeTensorId, SafeTensorId> input_map;
// If true, nodes that will have all output edges removed because of
// overrides in `input_map` will not be imported.
bool skip_mapped_nodes;
// The names of existing nodes in `g` that the imported graph should have
// control dependencies on.
//
// Note that to avoid creating many redundant control edges, ImportGraphDef()
// won't add control edges to nodes that will inherit the dependencies from
// other nodes in `gdef`.
std::vector<string> control_dependencies;
// Tensors in `gdef` that will be returned via the ImportGraphDefResults
// output parameter of `ImportGraphDef()`. If this list is non-empty, the
// caller must pass a results object to `ImportGraphDef()`. The
// `return_tensors` field will be populated with the imported nodes in `g`.
//
// Entries should not include `prefix`, i.e., each ID's name should be the
// name as it originally appears in `gdef`.
//
// If this contains a tensor that's also being remapped via `input_map`, the
// corresponding existing tensor in `g` will be returned.
std::vector<SafeTensorId> return_tensors;
// The names of nodes in `gdef` that will be returned via the
// ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
// is non-empty, the caller must pass a results object to
// `ImportGraphDef()`. The `return_nodes` field will be populated with the
// imported nodes in `g`.
//
// Entries should not include `prefix`, i.e., each node's name should be the
// name as it originally appears in `gdef`.
//
// Unlike `return_tensors`, `input_map` has no effect on the nodes
// returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
// TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
std::vector<string> return_nodes;
// If true, checks that all colocation constraints are nodes in the GraphDef.
bool validate_colocation_constraints = true;
// If false skips shape validation.
bool validate_shape;
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
// with ops that are not defined in the binary calling ImportGraphDef.
// Similar to the producer_op_list argument to import_graph_def in the
// python API.
// Try to set default execution device for this grapth.
string default_device;
};
// Optional results that may be returned by ImportGraphDef.
struct ImportGraphDefResults {
// The requested tensors associated with
// ImportGraphDefOptions::return_tensors. Note that the index may be different
// than the requested index if the returned tensor has been remapped according
// to `input_map`.
typedef int Index;
std::vector<std::pair<Node*, Index>> return_tensors;
// The requested nodes associated with ImportGraphDefOptions::return_nodes.
std::vector<Node*> return_nodes;
// Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
// weren't used as an input to any node in `gdef`. These keys are likely due
// to typos, and callers may wish to treat their existence as an error.
std::vector<SafeTensorId> missing_unused_input_map_keys;
};
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
//
// On error, returns non-OK and leaves `*g` unmodified.
//
// `refiner` can be null. It should be non-null if the caller
// intends to add additional nodes to the graph after the import. This
// allows the caller to validate shapes of those nodes (since
// ShapeRefiner::AddNode must be called in topological order).
//
// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is
// non-empty. It can also be set to fetch the unused input map keys. If it's
// non-null, all the vector fields must be empty.
//
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
// as a means of enhancing an existing Graph.
extern Status ImportGraphDef(const ImportGraphDefOptions& opts,
const GraphDef& gdef, Graph* g,
ShapeRefiner* refiner,
ImportGraphDefResults* results = nullptr);
// Make a copy of "src" into "*dest".
//
// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges
// other than the implicit Source/Sink nodes.
extern void CopyGraph(const Graph& src, Graph* dest);
} // namespace tensorflow
#include "tensorflow/core/common_runtime/graph_constructor.h"
#endif // TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"

View File

@ -5951,6 +5951,7 @@ filegroup(
"//tensorflow/core/common_runtime:core_cpu_rump_impl", # quantize_training
"//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session
"//tensorflow/core/common_runtime:device_factory", # device_lib, tfe, tf_session
"//tensorflow/core/common_runtime:graph_constructor", # tf_session
"//tensorflow/core/common_runtime:session_options", # device_lib, tfe, tf_session
"//tensorflow/core/common_runtime:session_state", # tf_session
"//tensorflow/core/data/service:server_lib", # server_lib

View File

@ -220,7 +220,7 @@ tensorflow::ImportGraphDef
[op_gen_lib] # tf_session
tensorflow::ApiDefMap::~ApiDefMap
[core_cpu_base_no_ops] # tf_session
[graph_constructor] # tf_session
tensorflow::ShapeRefiner::~ShapeRefiner
[python_api] # tf_session