Support function calls through PartitionedCall in tf2xla

Also, make eager runtime always emit PartitionedCall and remove
special handling of xla compilation.

Becase this change makes XLA look inside PartitionedCalls, this change
had to update/disable some tests that include PartitionedCalls with
some uncompilable ops inside.

PiperOrigin-RevId: 237486703
This commit is contained in:
Igor Ganichev 2019-03-08 11:26:39 -08:00 committed by TensorFlower Gardener
parent 745b6aa434
commit ed2b195990
20 changed files with 455 additions and 308 deletions

View File

@ -154,11 +154,14 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
VLOG(3) << "Attemping to create XlaLaunchOp for " << node_def.DebugString();
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
if (!IsCompilable(flr, node_def)) {
VLOG(1) << "Not creating XlaLaunchOp because function invoked by the "
"following node is not compilable: "
<< node_def.DebugString();
// node_def is calling a function that XLA can't compile.
return errors::InvalidArgument("Not compilable: ",
node_def.ShortDebugString());

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
@ -108,14 +109,14 @@ void MarkGuaranteedConstants(
for (const auto& src_arg : src_arg_pairs) {
srcs.push_back(src_arg.first);
}
ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
// TODO(vinuraja): Doesn't work in the presence of loops.
if (AreAllParentsGuaranteedConst(*n,
guaranteed_const_nodes)) {
guaranteed_const_nodes.insert(n);
}
});
ReverseDFSFrom(
graph, srcs, /*enter=*/nullptr,
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
// TODO(vinuraja): Doesn't work in the presence of loops.
if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
guaranteed_const_nodes.insert(n);
}
});
for (auto& src_arg : src_arg_pairs) {
if (guaranteed_const_nodes.count(src_arg.first) != 0) {
@ -2319,12 +2320,12 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
" in function library.");
}
FunctionBody* fbody = nullptr;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, node->attrs(), library,
[library](const string& op, const OpDef** sig) {
return library->LookUpOpDef(op, sig);
},
&fbody));
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fdef, node->attrs(), library,
[library](const string& op, const OpDef** sig) {
return library->LookUpOpDef(op, sig);
},
&fbody));
InlineFunctionBodyOptions inline_opts;
inline_opts.override_device = false;
@ -2534,12 +2535,29 @@ Status EncapsulateSubgraphsPass::Run(
std::unique_ptr<Graph> graph_out;
FunctionLibraryDefinition* const library = options.flib_def;
// Constant folding below might need to run part of the function to compute
// constants. Create an FunctionLibraryRuntime with a single CPU device
// that can run the part of the function.
SessionOptions session_options;
auto* device_count = session_options.config.mutable_device_count();
device_count->insert({"CPU", 1});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
std::unique_ptr<DeviceMgr> device_mgr =
absl::make_unique<DeviceMgr>(std::move(devices));
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env,
new ProcessFunctionLibraryRuntime(device_mgr.get(),
options.session_options->env,
TF_GRAPH_DEF_VERSION, library, opts));
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
if (flr == nullptr) {
return errors::Internal(
"Failed to create and retrieve function library runtime to run "
"constant folding");
}
auto rewrite_subgraph =
[flr](const std::vector<OutputTensor>& arg_source_tensors,

View File

@ -227,10 +227,9 @@ bool IsCompilableCall(const NodeDef& call_def,
}
FunctionLibraryRuntime::Handle handle;
Status status =
lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
Status status = InstantiateFunctionCall(call_def, *lib_runtime, &handle);
if (!status.ok()) {
VLOG(2) << "Rejecting " << call_def.op()
VLOG(2) << "Rejecting " << call_def.DebugString()
<< ": could not instantiate: " << status;
return false;
}

View File

@ -33,7 +33,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/algorithm.h"
@ -87,6 +89,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
}
break;
case XlaExpression::Kind::kResource:
// TODO(b/126601755): This is a fairly common use case in TF 2.0 that
// we can hit when inlining is disabled or fails.
return errors::Unimplemented(
"Resource as function argument is not yet implemented.");
case XlaExpression::Kind::kTensorList:
@ -124,6 +128,8 @@ Status GraphCompiler::Compile() {
for (Node* n : topo_sorted_nodes) {
OpKernel* op_kernel_raw = nullptr;
// The kernel is not actually run for functional ops, we just need it
// for metadata.
Status s = flib_->CreateKernel(n->def(), &op_kernel_raw);
// Transfer ownership of the kernel to a local smart pointer.
std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
@ -157,7 +163,7 @@ Status GraphCompiler::Compile() {
OpKernelContext op_context(&params, n->num_outputs());
VLOG(3) << "Translating " << params.op_kernel->name();
if (IsFunctional(n)) {
if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
} else {
device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
@ -182,15 +188,37 @@ Status GraphCompiler::Compile() {
return Status::OK();
}
bool GraphCompiler::IsFunctional(Node* n) {
return n->type_string() == FunctionLibraryDefinition::kGradientOp ||
(flib_->GetFunctionLibraryDefinition()->Find(n->def().op()) !=
nullptr);
namespace {
Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
const Node& node, NameAttrList* func) {
if (node.IsPartitionedCall()) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(
node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
if (!attr_value->has_func()) {
return errors::InvalidArgument(
"The attribute value for attribute 'f' in node ", node.DebugString(),
" does not have 'func' field set");
}
*func = attr_value->func();
return Status::OK();
}
if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
func->set_name(node.type_string());
} else {
func->set_name(FunctionLibraryDefinition::kGradientOp);
}
*func->mutable_attr() = node.def().attr();
return Status::OK();
}
} // namespace
Status GraphCompiler::CompileFunctionalNode(Node* n,
OpKernelContext* op_context) {
TF_RET_CHECK(IsFunctional(n));
TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
// For functional nodes, compile them using compiler from the context and call
// into the functions.
XlaOpKernelContext xla_op_context(op_context);
@ -201,12 +229,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
XlaCompiler* compiler = xla_op_context.compiler();
NameAttrList func;
if (flib_->GetFunctionLibraryDefinition()->Find(n->def().op())) {
func.set_name(n->def().op());
} else {
func.set_name(FunctionLibraryDefinition::kGradientOp);
}
*func.mutable_attr() = n->def().attr();
TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
std::vector<const XlaExpression*> expressions;
@ -227,7 +250,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
bool add_token_input_output =
HasNodeAttr(n->def(), kXlaTokenInputNodesAttrName);
func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = false;
@ -247,8 +270,9 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
}
if (add_token_input_output) {
std::vector<string> token_input_nodes;
TF_RETURN_IF_ERROR(
GetNodeAttr(n->def(), kXlaTokenInputNodesAttrName, &token_input_nodes));
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
kXlaTokenInputNodesAttrName,
&token_input_nodes));
std::vector<xla::XlaOp> token_inputs;
for (const string& node_name : token_input_nodes) {
auto token_or = compiler->GetNodeToken(node_name);

View File

@ -73,10 +73,6 @@ class GraphCompiler {
// across multiple nodes visit.
void PartiallySetupParams(OpKernelContext::Params* params);
// Tests if a node is a functional node. A functional node represents a
// defined computation and should be compiled using `compiler_`.
bool IsFunctional(Node* n);
// Compiles a functional node and writes result to OpkernelContext. A
// functional node represents a defined computation and should be compiled
// using `compiler_`.

View File

@ -180,6 +180,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:list_kernels",
"//tensorflow/core/kernels:partitioned_function_ops",
"//tensorflow/core/kernels:pooling_ops",
"//tensorflow/core/kernels:random_op",
"//tensorflow/core/kernels:resource_variable_ops",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/kernels/partitioned_function_ops.h"
namespace tensorflow {
namespace {
@ -107,6 +108,10 @@ class SymbolicGradientOp : public AsyncOpKernel {
};
REGISTER_XLA_OP(Name(kGradientOp), SymbolicGradientOp);
REGISTER_XLA_OP(Name("PartitionedCall").AllowResourceTypes(),
PartitionedCallOp);
REGISTER_XLA_OP(Name("StatefulPartitionedCall").AllowResourceTypes(),
PartitionedCallOp);
} // namespace
} // namespace tensorflow

View File

@ -956,6 +956,28 @@ Status ValidateFunctionDef(const FunctionDef* fdef,
return Status::OK();
}
// If node is PartitionedCall or StatefulPartitionedCall, returns the
// name from the "f" attr, else returns node.def().op().
// Returned pointer points to the internal string either in node's attributes
// or in its NodeDef. This pointer is valid as long as the node has not been
// modified.
Status GetPotentialFunctionName(const Node& node, const string** name) {
if (node.IsPartitionedCall()) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(
node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
if (!attr_value->has_func()) {
return errors::InvalidArgument(
"The attribute value for attribute 'f' in node ", node.DebugString(),
" does not have 'func' field set");
}
*name = &attr_value->func().name();
return Status::OK();
}
*name = &node.type_string();
return Status::OK();
}
// Check that the graph doesn't have any invalid nodes (e.g. incompatible with
// given device_type, invalid data type, missing attributes...)
Status ValidateGraph(const Graph* graph,
@ -975,7 +997,9 @@ Status ValidateGraph(const Graph* graph,
if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
continue;
}
const FunctionDef* fdef = flib_def.Find(node->def().op());
const string* function_name;
TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
const FunctionDef* fdef = flib_def.Find(*function_name);
Status s;
if (fdef) {
s = ValidateFunctionDef(fdef, flib_def);

View File

@ -1629,8 +1629,8 @@ cc_library(
srcs = ["common_runtime/testlib_ops.cc"],
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
":framework",
":lib",
],
alwayslink = 1,
)

View File

@ -327,14 +327,15 @@ void FindConstantFoldableNodes(
shape_replacement_map) {
bool internal_node_inserted = false;
// Walk the nodes in data flow order.
ReverseDFS(*graph, nullptr,
[nodes, constant_control_deps, shape_replacement_map,
&internal_node_inserted, &opts](Node* n) {
ConsiderConstantFoldableNode(
n, opts, nodes, constant_control_deps, shape_replacement_map,
&internal_node_inserted);
},
NodeComparatorName());
ReverseDFS(
*graph, nullptr,
[nodes, constant_control_deps, shape_replacement_map,
&internal_node_inserted, &opts](Node* n) {
ConsiderConstantFoldableNode(n, opts, nodes, constant_control_deps,
shape_replacement_map,
&internal_node_inserted);
},
NodeComparatorName());
// If we have inserted just leaf level nodes, then there is nothing to fold.
if (!internal_node_inserted) {
nodes->clear();

View File

@ -1462,6 +1462,26 @@ Status ValidateInlining(const Node* node, const FunctionBody* fbody,
return Status::OK();
}
Status InstantiateFunctionCall(const NodeDef& call_def,
FunctionLibraryRuntime& flr,
FunctionLibraryRuntime::Handle* handle) {
const string* func_name;
AttrSlice attrs;
NameAttrList func;
if (call_def.op() == "PartitionedCall" ||
call_def.op() == "StatefulPartitionedCall") {
TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", &func));
func_name = &func.name();
attrs = AttrSlice(&func.attr());
} else {
func_name = &call_def.op();
attrs = AttrSlice(call_def);
}
return flr.Instantiate(*func_name, attrs, handle);
}
Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
Node* caller, const FunctionBody* fbody,
const InlineFunctionBodyOptions& options) {
@ -1633,6 +1653,13 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
return Status::OK();
}
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def,
const Node& node) {
return node.IsPartitionedCall() ||
node.type_string() == FunctionLibraryDefinition::kGradientOp ||
lib_def.Find(node.def().op()) != nullptr;
}
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
const InlineFunctionBodyOptions& options) {
std::vector<std::pair<Node*, const FunctionBody*>> candidates;
@ -1641,8 +1668,7 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
for (Node* node : graph->nodes()) {
// Skip nodes that are not function calls or SymbolicGradient calls.
if (fld->Find(node->type_string()) == nullptr &&
node->type_string() != FunctionLibraryDefinition::kGradientOp) {
if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
continue;
}
// Skip function calls that marked noinline.
@ -1651,9 +1677,8 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
VLOG(3) << "noinline: " << SummarizeNode(*node);
continue;
}
FunctionLibraryRuntime::Handle handle;
Status s = lib->Instantiate(node->type_string(), node->attrs(), &handle);
Status s = InstantiateFunctionCall(node->def(), *lib, &handle);
if (!s.ok()) {
LOG(ERROR) << "Failed to instantiate a function: " << s.error_message();
continue;
@ -1670,7 +1695,7 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
if (inlined.ok()) {
inlined_any = true;
} else {
VLOG(3) << "Failed to inline function call: node=" << p.first->name()
VLOG(1) << "Failed to inline function call: node=" << p.first->name()
<< " error=" << inlined.error_message();
}
}
@ -1866,8 +1891,8 @@ FunctionBody* SymbolicGradientHelper::Compute() {
const int num_y = static_cast<int>(gbody_->ret_nodes.size());
// Populate 'y_node_outputs_' with node function body outputs.
// Populate 'y_grad_nodes' with initial gradient nodes for each return node of
// the original function body (these will be 'arg' nodes in the function
// Populate 'y_grad_nodes' with initial gradient nodes for each return node
// of the original function body (these will be 'arg' nodes in the function
// gradient body).
std::vector<NodeOut> y_node_outputs;
y_node_outputs.reserve(num_y);
@ -1894,8 +1919,8 @@ FunctionBody* SymbolicGradientHelper::Compute() {
}
// Call AddSymbolicGradients which will add nodes to graph 'g' that
// compute the function gradient (adding an entry in 'x_grad_node_outputs' for
// each node in 'x_node_outputs').
// compute the function gradient (adding an entry in 'x_grad_node_outputs'
// for each node in 'x_node_outputs').
std::vector<NodeOut> x_grad_node_outputs;
TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
y_grad_node_outputs, &x_grad_node_outputs,

View File

@ -201,6 +201,20 @@ inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
return ExpandInlineFunctions(lib, graph, InlineFunctionBodyOptions());
}
// Extracts function name and attributes from `call_def` and invokes
// flr->Instantiate(name, attrs, handle).
// `call_def` can be a native function call (where the op type is the function
// name) or a call through PartitionedCall/StatefulPartitionedCall.
Status InstantiateFunctionCall(const NodeDef& call_def,
FunctionLibraryRuntime& flr,
FunctionLibraryRuntime::Handle* handle);
// Returns true iff `n` represents a function call. `n` can be a native
// function call (n.type_string() is the function name),
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
// has been deprecated for a while).
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
// Instantiates FunctionDef into a graph. Set *fbody to point to the
// FunctionBody that holds the instantiated FunctionDef.
Status FunctionDefToBodyHelper(

View File

@ -85,6 +85,8 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
{"CollectiveBcastSend", NC_COLLECTIVE},
{"CollectiveBcastRecv", NC_COLLECTIVE},
{"FakeParam", NC_FAKE_PARAM},
{"PartitionedCall", NC_PARTITIONED_CALL},
{"StatefulPartitionedCall", NC_PARTITIONED_CALL},
});
#undef REF_CLASS

View File

@ -63,8 +63,8 @@ struct OutputTensor;
class VersionDef;
class WhileContext;
class NeighborIter; // Declared below
class NodeIter; // Declared below
class NeighborIter; // Declared below
class NodeIter; // Declared below
struct NodeProperties; // Defined in .cc
class Node {
@ -173,6 +173,7 @@ class Node {
bool IsMetadata() const { return class_ == NC_METADATA; }
bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
template <typename T>
void AddAttr(const string& name, const T& val) {
@ -254,6 +255,7 @@ class Node {
NC_SCOPED_ALLOCATOR,
NC_COLLECTIVE,
NC_FAKE_PARAM,
NC_PARTITIONED_CALL,
NC_OTHER // Not a special kind of node
};

View File

@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/partitioned_function_ops.h"
#include "absl/strings/match.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@ -31,234 +34,219 @@ limitations under the License.
#endif // GOOGLE_CUDA
namespace tensorflow {
namespace {
// A `PartitionedCallOp` asynchronously executes a function, potentially across
// multiple devices but within a single process. The kernel places and
// partitions a given function's underlying graph, and executes each of the
// partitioned subgraphs as a function.
//
// TODO(akshayka): Support distributed execution.
class PartitionedCallOp : public AsyncOpKernel {
public:
explicit PartitionedCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
string deprecated_config_serialized;
OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &deprecated_config_serialized));
string config_proto_serialized;
OP_REQUIRES_OK(ctx, ctx->GetAttr("config_proto", &config_proto_serialized));
PartitionedCallOp::PartitionedCallOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
func_(new NameAttrList),
config_proto_(new ConfigProto) {
OP_REQUIRES_OK(
ctx, ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, func_.get()));
string deprecated_config_serialized;
OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &deprecated_config_serialized));
string config_proto_serialized;
OP_REQUIRES_OK(ctx, ctx->GetAttr("config_proto", &config_proto_serialized));
OP_REQUIRES(
ctx,
deprecated_config_serialized.empty() || config_proto_serialized.empty(),
errors::InvalidArgument("Provided both 'config' and 'config_proto' but "
"only one should be provided. Note the "
"'config' option is deprecated."));
if (!deprecated_config_serialized.empty()) {
OP_REQUIRES(ctx,
config_proto_->mutable_graph_options()
->mutable_rewrite_options()
->ParseFromString(deprecated_config_serialized),
errors::InvalidArgument("Unable to parse config string as "
"tensorflow::RewriteOptions proto."));
} else {
OP_REQUIRES(
ctx,
deprecated_config_serialized.empty() || config_proto_serialized.empty(),
errors::InvalidArgument("Provided both 'config' and 'config_proto' but "
"only one should be provided. Note the "
"'config' option is deprecated."));
if (!deprecated_config_serialized.empty()) {
OP_REQUIRES(ctx,
config_proto_.mutable_graph_options()
->mutable_rewrite_options()
->ParseFromString(deprecated_config_serialized),
errors::InvalidArgument("Unable to parse config string as "
"tensorflow::RewriteOptions proto."));
ctx, config_proto_->ParseFromString(config_proto_serialized),
errors::InvalidArgument("Unable to parse config_proto string as "
"tensorflow::ConfigProto proto."));
}
OP_REQUIRES_OK(ctx, ctx->GetAttr("executor_type", &executor_type_));
}
PartitionedCallOp::~PartitionedCallOp() {
for (const auto& it : handles_) {
Status status = it.first->ReleaseHandle(it.second);
if (!status.ok()) {
LOG(INFO) << "Ignoring error while destructing PartitionedCallOp: "
<< status.ToString();
}
}
}
void PartitionedCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library is provided."), done);
// The function body's graph is placed and partitioned the first time
// `ComputeAsync` is invoked; every subsequent invocation calls each
// of the function shards yielded by partitioning.
//
// The partitioning step yields a set of devices on which to run the
// function, and exactly one function shard is created for each device
// Inputs and outputs are pinned to the local device, for simplicity.
//
// TODO(akshayka): Support re-sharding the function on subsequent calls,
// via, e.g., virtual device annotations and a list of device names
// supplied through an attribute.
//
// TODO(akshayka): Add a fastpath for functions that execute on a single
// device.
FunctionLibraryRuntime::Handle handle;
// If we are instantiating the function, we can efficiently extract the
// inputs while instantiating. Else, we extract them separately below.
std::vector<Tensor> inputs;
bool inputs_extracted = false;
{
mutex_lock l(mu_);
auto it = handles_.find(lib);
if (it == handles_.end()) {
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, ctx, &inputs, &handle), done);
inputs_extracted = true;
handles_[lib] = handle;
} else {
OP_REQUIRES(
ctx, config_proto_.ParseFromString(config_proto_serialized),
errors::InvalidArgument("Unable to parse config_proto string as "
"tensorflow::ConfigProto proto."));
}
OP_REQUIRES_OK(ctx, ctx->GetAttr("executor_type", &executor_type_));
}
~PartitionedCallOp() override {
for (const auto& it : handles_) {
Status status = it.first->ReleaseHandle(it.second);
if (!status.ok()) {
LOG(INFO) << "Ignoring error while destructing PartitionedCallOp: "
<< status.ToString();
}
handle = it->second;
}
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library is provided."),
done);
// The function body's graph is placed and partitioned the first time
// `ComputeAsync` is invoked; every subsequent invocation calls each
// of the function shards yielded by partitioning.
//
// The partitioning step yields a set of devices on which to run the
// function, and exactly one function shard is created for each device
// Inputs and outputs are pinned to the local device, for simplicity.
//
// TODO(akshayka): Support re-sharding the function on subsequent calls,
// via, e.g., virtual device annotations and a list of device names
// supplied through an attribute.
//
// TODO(akshayka): Add a fastpath for functions that execute on a single
// device.
FunctionLibraryRuntime::Handle handle;
// If we are instantiating the function, we can efficiently extract the
// inputs while instantiating. Else, we extract them separately below.
std::vector<Tensor> inputs;
bool inputs_extracted = false;
{
mutex_lock l(mu_);
auto it = handles_.find(lib);
if (it == handles_.end()) {
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, ctx, &inputs, &handle),
done);
inputs_extracted = true;
handles_[lib] = handle;
} else {
handle = it->second;
}
}
if (!inputs_extracted) {
OpInputList args;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &args), done);
inputs.reserve(args.size());
for (const Tensor& tensor : args) {
inputs.push_back(tensor);
}
}
RunFunction(handle, inputs, lib, ctx, done);
}
private:
Status FillOutputDevices(const FunctionLibraryRuntime& lib,
const Device& cpu_device, AttrSlice attrs,
FunctionLibraryRuntime::InstantiateOptions* opts) {
const FunctionLibraryDefinition* flib = lib.GetFunctionLibraryDefinition();
const FunctionDef* fdef = flib->Find(func_.name());
if (fdef == nullptr) {
return errors::NotFound("Failed for find definiton for function \"",
func_.name(), "\"");
}
bool is_type_list;
for (const OpDef::ArgDef& ret_def : fdef->signature().output_arg()) {
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
for (DataType dtype : dtypes) {
if (MTypeFromDType(dtype) == HOST_MEMORY) {
opts->output_devices.push_back(cpu_device.name());
} else {
opts->output_devices.push_back(opts->target);
}
}
}
return Status::OK();
}
Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx,
std::vector<Tensor>* inputs,
FunctionLibraryRuntime::Handle* handle) {
grappler::GrapplerItem::OptimizationOptions optimization_options;
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
FunctionLibraryRuntime::InstantiateOptions opts;
opts.target = lib->device()->name();
opts.is_multi_device_function = true;
opts.optimize_graph_fn =
std::bind(grappler::OptimizeGraph, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3,
std::placeholders::_4, std::placeholders::_5, config_proto_,
func_.name(), optimization_options, std::placeholders::_6);
opts.graph_collector = ctx->graph_collector();
opts.executor_type = executor_type_;
if (!inputs_extracted) {
OpInputList args;
TF_RETURN_IF_ERROR(ctx->input_list("args", &args));
Device* cpu_device;
TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
inputs->reserve(args.size());
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &args), done);
inputs.reserve(args.size());
for (const Tensor& tensor : args) {
inputs->push_back(tensor);
DataType dtype = tensor.dtype();
if (dtype == DT_RESOURCE) {
const ResourceHandle& handle = tensor.flat<ResourceHandle>()(0);
opts.input_devices.push_back(handle.device());
} else if (MTypeFromDType(dtype) == HOST_MEMORY) {
opts.input_devices.push_back(cpu_device->name());
inputs.push_back(tensor);
}
}
RunFunction(handle, inputs, lib, ctx, done);
}
Status PartitionedCallOp::FillOutputDevices(
const FunctionLibraryRuntime& lib, const Device& cpu_device,
AttrSlice attrs, FunctionLibraryRuntime::InstantiateOptions* opts) {
const FunctionLibraryDefinition* flib = lib.GetFunctionLibraryDefinition();
const FunctionDef* fdef = flib->Find(func_->name());
if (fdef == nullptr) {
return errors::NotFound("Failed for find definition for function \"",
func_->name(), "\"");
}
bool is_type_list;
for (const OpDef::ArgDef& ret_def : fdef->signature().output_arg()) {
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
for (DataType dtype : dtypes) {
if (MTypeFromDType(dtype) == HOST_MEMORY) {
opts->output_devices.push_back(cpu_device.name());
} else {
opts.input_devices.push_back(opts.target);
opts->output_devices.push_back(opts->target);
}
}
}
return Status::OK();
}
TF_RETURN_IF_ERROR(
FillOutputDevices(*lib, *cpu_device, AttrSlice(&func_.attr()), &opts));
Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib,
OpKernelContext* ctx,
std::vector<Tensor>* inputs,
FunctionLibraryRuntime::Handle* handle) {
grappler::GrapplerItem::OptimizationOptions optimization_options;
TF_RETURN_IF_ERROR(
lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts, handle));
return Status::OK();
// Tensorflow 2.0 in eager mode with automatic control dependencies will
// prune all nodes that are not in the transitive fanin of the fetch nodes.
// However because the function will be executed via FunctionLibraryRuntime,
// and current function implementation does not prune stateful and dataset
// ops, we rely on Grappler to do the correct graph pruning.
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
// All the nested function calls will be executed and optimized via
// PartitionedCallOp, there is no need to optimize functions now.
optimization_options.optimize_function_library = false;
FunctionLibraryRuntime::InstantiateOptions opts;
// In some contexts like running the graph to evaluate constants,
// the FLR won't have any device.
opts.target = lib->device() == nullptr ? "" : lib->device()->name();
opts.is_multi_device_function = true;
opts.optimize_graph_fn =
std::bind(grappler::OptimizeGraph, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3,
std::placeholders::_4, std::placeholders::_5, *config_proto_,
func_->name(), optimization_options, std::placeholders::_6);
opts.graph_collector = ctx->graph_collector();
opts.executor_type = executor_type_;
OpInputList args;
TF_RETURN_IF_ERROR(ctx->input_list("args", &args));
Device* cpu_device;
TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
inputs->reserve(args.size());
for (const Tensor& tensor : args) {
inputs->push_back(tensor);
DataType dtype = tensor.dtype();
if (dtype == DT_RESOURCE) {
const ResourceHandle& handle = tensor.flat<ResourceHandle>()(0);
opts.input_devices.push_back(handle.device());
} else if (MTypeFromDType(dtype) == HOST_MEMORY) {
opts.input_devices.push_back(cpu_device->name());
} else {
opts.input_devices.push_back(opts.target);
}
}
void RunFunction(FunctionLibraryRuntime::Handle handle,
const std::vector<Tensor>& inputs,
FunctionLibraryRuntime* lib, OpKernelContext* ctx,
DoneCallback done) {
FunctionLibraryRuntime::Options run_opts;
run_opts.step_id = ctx->step_id();
run_opts.step_container = ctx->step_container();
run_opts.cancellation_manager = ctx->cancellation_manager();
run_opts.stats_collector = ctx->stats_collector();
run_opts.collective_executor = ctx->collective_executor();
// TODO(akshayka): Consider selecting a runner on a per-device basis,
// i.e., using device-specific threadpools when available.
run_opts.runner = ctx->runner();
run_opts.source_device = lib->device()->name();
run_opts.allow_dead_tensors = true;
// TODO(akshayka): Accommodate the multiple-worker scenario by adding the
// constructed rendezvous to a rendezvous manager.
Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr());
run_opts.rendezvous = rendez;
TF_RETURN_IF_ERROR(
FillOutputDevices(*lib, *cpu_device, AttrSlice(&func_->attr()), &opts));
std::vector<Tensor>* rets = new std::vector<Tensor>;
const string& func_name = func_.name();
lib->Run(run_opts, handle, inputs, rets,
[rets, rendez, done, ctx, func_name](const Status& status) {
if (!status.ok()) {
const string function_and_msg =
strings::StrCat(errors::FormatFunctionForError(func_name),
" ", status.error_message());
ctx->SetStatus(Status(status.code(), function_and_msg));
} else {
for (int i = 0; i < rets->size(); ++i) {
ctx->set_output(i, (*rets)[i]);
}
TF_RETURN_IF_ERROR(
lib->Instantiate(func_->name(), AttrSlice(&func_->attr()), opts, handle));
return Status::OK();
}
void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle,
const std::vector<Tensor>& inputs,
FunctionLibraryRuntime* lib,
OpKernelContext* ctx, DoneCallback done) {
FunctionLibraryRuntime::Options run_opts;
run_opts.step_id = ctx->step_id();
run_opts.step_container = ctx->step_container();
run_opts.cancellation_manager = ctx->cancellation_manager();
run_opts.stats_collector = ctx->stats_collector();
run_opts.collective_executor = ctx->collective_executor();
// TODO(akshayka): Consider selecting a runner on a per-device basis,
// i.e., using device-specific threadpools when available.
run_opts.runner = ctx->runner();
run_opts.source_device =
lib->device() == nullptr ? "" : lib->device()->name();
run_opts.allow_dead_tensors = true;
// TODO(akshayka): Accommodate the multiple-worker scenario by adding the
// constructed rendezvous to a rendezvous manager.
Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr());
run_opts.rendezvous = rendez;
std::vector<Tensor>* rets = new std::vector<Tensor>;
const string& func_name = func_->name();
lib->Run(run_opts, handle, inputs, rets,
[rets, rendez, done, ctx, func_name](const Status& status) {
if (!status.ok()) {
const string function_and_msg =
strings::StrCat(errors::FormatFunctionForError(func_name),
" ", status.error_message());
ctx->SetStatus(Status(status.code(), function_and_msg));
} else {
for (int i = 0; i < rets->size(); ++i) {
ctx->set_output(i, (*rets)[i]);
}
delete rets;
rendez->Unref();
done();
});
}
NameAttrList func_;
ConfigProto config_proto_;
string executor_type_;
mutex mu_;
// Cache the handle per FLR because this kernel may be instantiated for
// a stateful op, different invocations of it may use different FLRs.
// Different device placements of PartitionedCallOp also use
// different FLRs.
gtl::FlatMap<FunctionLibraryRuntime*, FunctionLibraryRuntime::Handle> handles_
GUARDED_BY(mu_);
};
}
delete rets;
rendez->Unref();
done();
});
}
REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_CPU),
PartitionedCallOp);
@ -275,5 +263,4 @@ REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_SYCL),
PartitionedCallOp);
#endif // TENSORFLOW_USE_SYCL
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,72 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_
#define TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
class NameAttrList;
class ConfigProto;
// A `PartitionedCallOp` asynchronously executes a function, potentially across
// multiple devices but within a single process. The kernel places and
// partitions a given function's underlying graph, and executes each of the
// partitioned subgraphs as a function.
//
// TODO(akshayka): Support distributed execution.
class PartitionedCallOp : public AsyncOpKernel {
public:
explicit PartitionedCallOp(OpKernelConstruction* ctx);
~PartitionedCallOp() override;
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
private:
Status FillOutputDevices(const FunctionLibraryRuntime& lib,
const Device& cpu_device, AttrSlice attrs,
FunctionLibraryRuntime::InstantiateOptions* opts);
Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx,
std::vector<Tensor>* inputs,
FunctionLibraryRuntime::Handle* handle);
void RunFunction(FunctionLibraryRuntime::Handle handle,
const std::vector<Tensor>& inputs,
FunctionLibraryRuntime* lib, OpKernelContext* ctx,
DoneCallback done);
// Using unique pointers to avoid including proto headers in kernel headers
std::unique_ptr<NameAttrList> func_;
std::unique_ptr<ConfigProto> config_proto_;
string executor_type_;
mutex mu_;
// Cache the handle per FLR because this kernel may be instantiated for
// a stateful op, different invocations of it may use different FLRs.
// Different device placements of PartitionedCallOp also use
// different FLRs.
gtl::FlatMap<FunctionLibraryRuntime*, FunctionLibraryRuntime::Handle> handles_
GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_

View File

@ -38,7 +38,6 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph as func_graph_module
@ -64,7 +63,7 @@ BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
CacheKey = collections.namedtuple("CacheKey", [
"input_signature", "parent_graph", "device_functions",
"colocation_stack", "uses_xla"])
"colocation_stack"])
CacheKey.replace = CacheKey._replace # pylint: disable=protected-access
@ -417,25 +416,6 @@ class _EagerDefinedFunction(object):
ctx=ctx)
# Replace empty list with None
outputs = outputs or None
elif self._graph._xla_compile: # pylint: disable=protected-access
g = ops.get_default_graph()
self.add_to_graph(g)
signature = self.signature
with ops.control_dependencies(self._control_captures):
op = g.create_op(
signature.name,
[ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
op_def=signature,
name="FunctionCall",
compute_shapes=False)
outputs = op.outputs
if not outputs:
return op
if isinstance(outputs, (ops.Tensor, type(None))):
outputs = [outputs]
else:
outputs = list(outputs)
else:
# TODO(akshayka): Either remove this if the FunctionLibraryRuntime
# creates `PartitionedCallOp` kernels by default, or remove the previous
@ -1464,16 +1444,13 @@ class Function(object):
default_graph._distribution_strategy_stack)
if executing_eagerly:
colocation_stack = ()
uses_xla = ctx.device_spec.device_type == "TPU"
if uses_distribution_strategy or uses_xla:
if uses_distribution_strategy:
device_functions = (pydev.merge_device(ctx.device_name),)
else:
device_functions = ()
else:
colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
uses_xla = getattr(default_graph, "_xla_compile", False)
if (uses_distribution_strategy
or uses_xla
or func_graph_module.device_stack_has_callable(
default_graph._device_function_stack)):
# Putting the device in the cache key ensures that call-site device
@ -1483,7 +1460,7 @@ class Function(object):
device_functions = ()
# pylint: enable=protected-access
return CacheKey(input_signature, parent_graph, device_functions,
colocation_stack, uses_xla)
colocation_stack)
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
@ -1572,11 +1549,10 @@ class Function(object):
call_context_key = cache_key.replace(input_signature=None)
# If there's a provided input signature, or XLA is being used, or
# If there's a provided input signature, or
# there's no cache miss for this calling context so far, go ahead and
# build the function and bypass shape relaxation retracing.
if (self.input_signature is not None
or cache_key.uses_xla
or call_context_key not in self._function_cache.missed):
self._function_cache.missed.add(call_context_key)
graph_function = self._create_graph_function(args, kwargs)

View File

@ -208,13 +208,9 @@ class FuncGraph(ops.Graph):
# any None op_seed for random_op in the function, in which case we end up
# using function seed, which could be unintended behavior for the op.
self._seed_used = False
device_type = context.context().device_spec.device_type
self._xla_compile = (device_type == "TPU" or device_type == "XLA_GPU"
or device_type == "XLA_CPU")
else:
self.seed = graph.seed
self._seed_used = False
self._xla_compile = getattr(graph, "_xla_compile", False)
# TODO(allenl): Figure out if we can remove colocation stack
# specialization (currently used in cond_v2), here and in the cache key.
self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access
@ -297,11 +293,10 @@ class FuncGraph(ops.Graph):
# restored.
old_device_stack = self._device_function_stack
if context.executing_eagerly():
if self._distribution_strategy_stack or self._xla_compile:
if self._distribution_strategy_stack:
self._add_device_to_stack(context.context().device_name)
else:
if (self._distribution_strategy_stack
or self._xla_compile
or device_stack_has_callable(graph._device_function_stack)):
# Hard-code devices from device functions in the function body
self._device_function_stack = graph._device_function_stack.copy()

View File

@ -829,6 +829,8 @@ class CondV2Test(test.TestCase):
self.evaluate(output_t), [-5, -4, -3, -2, -1, 0, 1, 4, 9, 16])
@test_util.enable_control_flow_v2
@test_util.disable_xla(
"b/127846988: No tf2xla kernel for IfOp taking DT_VARIANT")
def testCondAndTensorArrayInDefun(self):
@function.defun

View File

@ -534,6 +534,7 @@ class PyFuncTest(test.TestCase):
self.assertIsNone(ret)
@test_util.run_in_graph_and_eager_modes
@test_util.disable_xla("XLA cannot compile functions containing py_func")
def testEagerPyFuncInDefun(self):
with test_util.device(use_gpu=True):
def wrapper():