Support function calls through PartitionedCall in tf2xla
Also, make eager runtime always emit PartitionedCall and remove special handling of xla compilation. Becase this change makes XLA look inside PartitionedCalls, this change had to update/disable some tests that include PartitionedCalls with some uncompilable ops inside. PiperOrigin-RevId: 237486703
This commit is contained in:
parent
745b6aa434
commit
ed2b195990
@ -154,11 +154,14 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) {
|
||||
TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
|
||||
|
||||
VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
|
||||
VLOG(3) << "Attemping to create XlaLaunchOp for " << node_def.DebugString();
|
||||
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
if (!IsCompilable(flr, node_def)) {
|
||||
VLOG(1) << "Not creating XlaLaunchOp because function invoked by the "
|
||||
"following node is not compilable: "
|
||||
<< node_def.DebugString();
|
||||
// node_def is calling a function that XLA can't compile.
|
||||
return errors::InvalidArgument("Not compilable: ",
|
||||
node_def.ShortDebugString());
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
@ -108,14 +109,14 @@ void MarkGuaranteedConstants(
|
||||
for (const auto& src_arg : src_arg_pairs) {
|
||||
srcs.push_back(src_arg.first);
|
||||
}
|
||||
ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
|
||||
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
|
||||
// TODO(vinuraja): Doesn't work in the presence of loops.
|
||||
if (AreAllParentsGuaranteedConst(*n,
|
||||
guaranteed_const_nodes)) {
|
||||
guaranteed_const_nodes.insert(n);
|
||||
}
|
||||
});
|
||||
ReverseDFSFrom(
|
||||
graph, srcs, /*enter=*/nullptr,
|
||||
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
|
||||
// TODO(vinuraja): Doesn't work in the presence of loops.
|
||||
if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
|
||||
guaranteed_const_nodes.insert(n);
|
||||
}
|
||||
});
|
||||
|
||||
for (auto& src_arg : src_arg_pairs) {
|
||||
if (guaranteed_const_nodes.count(src_arg.first) != 0) {
|
||||
@ -2319,12 +2320,12 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
|
||||
" in function library.");
|
||||
}
|
||||
FunctionBody* fbody = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionDefToBodyHelper(*fdef, node->attrs(), library,
|
||||
[library](const string& op, const OpDef** sig) {
|
||||
return library->LookUpOpDef(op, sig);
|
||||
},
|
||||
&fbody));
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
|
||||
*fdef, node->attrs(), library,
|
||||
[library](const string& op, const OpDef** sig) {
|
||||
return library->LookUpOpDef(op, sig);
|
||||
},
|
||||
&fbody));
|
||||
|
||||
InlineFunctionBodyOptions inline_opts;
|
||||
inline_opts.override_device = false;
|
||||
@ -2534,12 +2535,29 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
std::unique_ptr<Graph> graph_out;
|
||||
FunctionLibraryDefinition* const library = options.flib_def;
|
||||
|
||||
// Constant folding below might need to run part of the function to compute
|
||||
// constants. Create an FunctionLibraryRuntime with a single CPU device
|
||||
// that can run the part of the function.
|
||||
SessionOptions session_options;
|
||||
auto* device_count = session_options.config.mutable_device_count();
|
||||
device_count->insert({"CPU", 1});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||
std::unique_ptr<DeviceMgr> device_mgr =
|
||||
absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env,
|
||||
new ProcessFunctionLibraryRuntime(device_mgr.get(),
|
||||
options.session_options->env,
|
||||
TF_GRAPH_DEF_VERSION, library, opts));
|
||||
FunctionLibraryRuntime* flr =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
if (flr == nullptr) {
|
||||
return errors::Internal(
|
||||
"Failed to create and retrieve function library runtime to run "
|
||||
"constant folding");
|
||||
}
|
||||
|
||||
auto rewrite_subgraph =
|
||||
[flr](const std::vector<OutputTensor>& arg_source_tensors,
|
||||
|
@ -227,10 +227,9 @@ bool IsCompilableCall(const NodeDef& call_def,
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status =
|
||||
lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
|
||||
Status status = InstantiateFunctionCall(call_def, *lib_runtime, &handle);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Rejecting " << call_def.op()
|
||||
VLOG(2) << "Rejecting " << call_def.DebugString()
|
||||
<< ": could not instantiate: " << status;
|
||||
return false;
|
||||
}
|
||||
|
@ -33,7 +33,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
@ -87,6 +89,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
}
|
||||
break;
|
||||
case XlaExpression::Kind::kResource:
|
||||
// TODO(b/126601755): This is a fairly common use case in TF 2.0 that
|
||||
// we can hit when inlining is disabled or fails.
|
||||
return errors::Unimplemented(
|
||||
"Resource as function argument is not yet implemented.");
|
||||
case XlaExpression::Kind::kTensorList:
|
||||
@ -124,6 +128,8 @@ Status GraphCompiler::Compile() {
|
||||
|
||||
for (Node* n : topo_sorted_nodes) {
|
||||
OpKernel* op_kernel_raw = nullptr;
|
||||
// The kernel is not actually run for functional ops, we just need it
|
||||
// for metadata.
|
||||
Status s = flib_->CreateKernel(n->def(), &op_kernel_raw);
|
||||
// Transfer ownership of the kernel to a local smart pointer.
|
||||
std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
|
||||
@ -157,7 +163,7 @@ Status GraphCompiler::Compile() {
|
||||
|
||||
OpKernelContext op_context(¶ms, n->num_outputs());
|
||||
VLOG(3) << "Translating " << params.op_kernel->name();
|
||||
if (IsFunctional(n)) {
|
||||
if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
|
||||
TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
|
||||
} else {
|
||||
device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
|
||||
@ -182,15 +188,37 @@ Status GraphCompiler::Compile() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool GraphCompiler::IsFunctional(Node* n) {
|
||||
return n->type_string() == FunctionLibraryDefinition::kGradientOp ||
|
||||
(flib_->GetFunctionLibraryDefinition()->Find(n->def().op()) !=
|
||||
nullptr);
|
||||
namespace {
|
||||
|
||||
Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
|
||||
const Node& node, NameAttrList* func) {
|
||||
if (node.IsPartitionedCall()) {
|
||||
const AttrValue* attr_value;
|
||||
TF_RETURN_IF_ERROR(
|
||||
node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
|
||||
if (!attr_value->has_func()) {
|
||||
return errors::InvalidArgument(
|
||||
"The attribute value for attribute 'f' in node ", node.DebugString(),
|
||||
" does not have 'func' field set");
|
||||
}
|
||||
*func = attr_value->func();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
|
||||
func->set_name(node.type_string());
|
||||
} else {
|
||||
func->set_name(FunctionLibraryDefinition::kGradientOp);
|
||||
}
|
||||
*func->mutable_attr() = node.def().attr();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status GraphCompiler::CompileFunctionalNode(Node* n,
|
||||
OpKernelContext* op_context) {
|
||||
TF_RET_CHECK(IsFunctional(n));
|
||||
TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
|
||||
// For functional nodes, compile them using compiler from the context and call
|
||||
// into the functions.
|
||||
XlaOpKernelContext xla_op_context(op_context);
|
||||
@ -201,12 +229,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
||||
XlaCompiler* compiler = xla_op_context.compiler();
|
||||
|
||||
NameAttrList func;
|
||||
if (flib_->GetFunctionLibraryDefinition()->Find(n->def().op())) {
|
||||
func.set_name(n->def().op());
|
||||
} else {
|
||||
func.set_name(FunctionLibraryDefinition::kGradientOp);
|
||||
}
|
||||
*func.mutable_attr() = n->def().attr();
|
||||
TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
|
||||
|
||||
std::vector<const XlaExpression*> expressions;
|
||||
|
||||
@ -227,7 +250,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
||||
PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
|
||||
|
||||
bool add_token_input_output =
|
||||
HasNodeAttr(n->def(), kXlaTokenInputNodesAttrName);
|
||||
func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = false;
|
||||
@ -247,8 +270,9 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
||||
}
|
||||
if (add_token_input_output) {
|
||||
std::vector<string> token_input_nodes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(n->def(), kXlaTokenInputNodesAttrName, &token_input_nodes));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
|
||||
kXlaTokenInputNodesAttrName,
|
||||
&token_input_nodes));
|
||||
std::vector<xla::XlaOp> token_inputs;
|
||||
for (const string& node_name : token_input_nodes) {
|
||||
auto token_or = compiler->GetNodeToken(node_name);
|
||||
|
@ -73,10 +73,6 @@ class GraphCompiler {
|
||||
// across multiple nodes visit.
|
||||
void PartiallySetupParams(OpKernelContext::Params* params);
|
||||
|
||||
// Tests if a node is a functional node. A functional node represents a
|
||||
// defined computation and should be compiled using `compiler_`.
|
||||
bool IsFunctional(Node* n);
|
||||
|
||||
// Compiles a functional node and writes result to OpkernelContext. A
|
||||
// functional node represents a defined computation and should be compiled
|
||||
// using `compiler_`.
|
||||
|
@ -180,6 +180,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:list_kernels",
|
||||
"//tensorflow/core/kernels:partitioned_function_ops",
|
||||
"//tensorflow/core/kernels:pooling_ops",
|
||||
"//tensorflow/core/kernels:random_op",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/kernels/partitioned_function_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -107,6 +108,10 @@ class SymbolicGradientOp : public AsyncOpKernel {
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name(kGradientOp), SymbolicGradientOp);
|
||||
REGISTER_XLA_OP(Name("PartitionedCall").AllowResourceTypes(),
|
||||
PartitionedCallOp);
|
||||
REGISTER_XLA_OP(Name("StatefulPartitionedCall").AllowResourceTypes(),
|
||||
PartitionedCallOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -956,6 +956,28 @@ Status ValidateFunctionDef(const FunctionDef* fdef,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If node is PartitionedCall or StatefulPartitionedCall, returns the
|
||||
// name from the "f" attr, else returns node.def().op().
|
||||
// Returned pointer points to the internal string either in node's attributes
|
||||
// or in its NodeDef. This pointer is valid as long as the node has not been
|
||||
// modified.
|
||||
Status GetPotentialFunctionName(const Node& node, const string** name) {
|
||||
if (node.IsPartitionedCall()) {
|
||||
const AttrValue* attr_value;
|
||||
TF_RETURN_IF_ERROR(
|
||||
node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
|
||||
if (!attr_value->has_func()) {
|
||||
return errors::InvalidArgument(
|
||||
"The attribute value for attribute 'f' in node ", node.DebugString(),
|
||||
" does not have 'func' field set");
|
||||
}
|
||||
*name = &attr_value->func().name();
|
||||
return Status::OK();
|
||||
}
|
||||
*name = &node.type_string();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Check that the graph doesn't have any invalid nodes (e.g. incompatible with
|
||||
// given device_type, invalid data type, missing attributes...)
|
||||
Status ValidateGraph(const Graph* graph,
|
||||
@ -975,7 +997,9 @@ Status ValidateGraph(const Graph* graph,
|
||||
if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
|
||||
continue;
|
||||
}
|
||||
const FunctionDef* fdef = flib_def.Find(node->def().op());
|
||||
const string* function_name;
|
||||
TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
|
||||
const FunctionDef* fdef = flib_def.Find(*function_name);
|
||||
Status s;
|
||||
if (fdef) {
|
||||
s = ValidateFunctionDef(fdef, flib_def);
|
||||
|
@ -1629,8 +1629,8 @@ cc_library(
|
||||
srcs = ["common_runtime/testlib_ops.cc"],
|
||||
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
":framework",
|
||||
":lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -327,14 +327,15 @@ void FindConstantFoldableNodes(
|
||||
shape_replacement_map) {
|
||||
bool internal_node_inserted = false;
|
||||
// Walk the nodes in data flow order.
|
||||
ReverseDFS(*graph, nullptr,
|
||||
[nodes, constant_control_deps, shape_replacement_map,
|
||||
&internal_node_inserted, &opts](Node* n) {
|
||||
ConsiderConstantFoldableNode(
|
||||
n, opts, nodes, constant_control_deps, shape_replacement_map,
|
||||
&internal_node_inserted);
|
||||
},
|
||||
NodeComparatorName());
|
||||
ReverseDFS(
|
||||
*graph, nullptr,
|
||||
[nodes, constant_control_deps, shape_replacement_map,
|
||||
&internal_node_inserted, &opts](Node* n) {
|
||||
ConsiderConstantFoldableNode(n, opts, nodes, constant_control_deps,
|
||||
shape_replacement_map,
|
||||
&internal_node_inserted);
|
||||
},
|
||||
NodeComparatorName());
|
||||
// If we have inserted just leaf level nodes, then there is nothing to fold.
|
||||
if (!internal_node_inserted) {
|
||||
nodes->clear();
|
||||
|
@ -1462,6 +1462,26 @@ Status ValidateInlining(const Node* node, const FunctionBody* fbody,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InstantiateFunctionCall(const NodeDef& call_def,
|
||||
FunctionLibraryRuntime& flr,
|
||||
FunctionLibraryRuntime::Handle* handle) {
|
||||
const string* func_name;
|
||||
AttrSlice attrs;
|
||||
|
||||
NameAttrList func;
|
||||
if (call_def.op() == "PartitionedCall" ||
|
||||
call_def.op() == "StatefulPartitionedCall") {
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", &func));
|
||||
func_name = &func.name();
|
||||
attrs = AttrSlice(&func.attr());
|
||||
} else {
|
||||
func_name = &call_def.op();
|
||||
attrs = AttrSlice(call_def);
|
||||
}
|
||||
|
||||
return flr.Instantiate(*func_name, attrs, handle);
|
||||
}
|
||||
|
||||
Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
Node* caller, const FunctionBody* fbody,
|
||||
const InlineFunctionBodyOptions& options) {
|
||||
@ -1633,6 +1653,13 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def,
|
||||
const Node& node) {
|
||||
return node.IsPartitionedCall() ||
|
||||
node.type_string() == FunctionLibraryDefinition::kGradientOp ||
|
||||
lib_def.Find(node.def().op()) != nullptr;
|
||||
}
|
||||
|
||||
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
|
||||
const InlineFunctionBodyOptions& options) {
|
||||
std::vector<std::pair<Node*, const FunctionBody*>> candidates;
|
||||
@ -1641,8 +1668,7 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
|
||||
|
||||
for (Node* node : graph->nodes()) {
|
||||
// Skip nodes that are not function calls or SymbolicGradient calls.
|
||||
if (fld->Find(node->type_string()) == nullptr &&
|
||||
node->type_string() != FunctionLibraryDefinition::kGradientOp) {
|
||||
if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
|
||||
continue;
|
||||
}
|
||||
// Skip function calls that marked noinline.
|
||||
@ -1651,9 +1677,8 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
|
||||
VLOG(3) << "noinline: " << SummarizeNode(*node);
|
||||
continue;
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status s = lib->Instantiate(node->type_string(), node->attrs(), &handle);
|
||||
Status s = InstantiateFunctionCall(node->def(), *lib, &handle);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Failed to instantiate a function: " << s.error_message();
|
||||
continue;
|
||||
@ -1670,7 +1695,7 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
|
||||
if (inlined.ok()) {
|
||||
inlined_any = true;
|
||||
} else {
|
||||
VLOG(3) << "Failed to inline function call: node=" << p.first->name()
|
||||
VLOG(1) << "Failed to inline function call: node=" << p.first->name()
|
||||
<< " error=" << inlined.error_message();
|
||||
}
|
||||
}
|
||||
@ -1866,8 +1891,8 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
||||
const int num_y = static_cast<int>(gbody_->ret_nodes.size());
|
||||
|
||||
// Populate 'y_node_outputs_' with node function body outputs.
|
||||
// Populate 'y_grad_nodes' with initial gradient nodes for each return node of
|
||||
// the original function body (these will be 'arg' nodes in the function
|
||||
// Populate 'y_grad_nodes' with initial gradient nodes for each return node
|
||||
// of the original function body (these will be 'arg' nodes in the function
|
||||
// gradient body).
|
||||
std::vector<NodeOut> y_node_outputs;
|
||||
y_node_outputs.reserve(num_y);
|
||||
@ -1894,8 +1919,8 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
||||
}
|
||||
|
||||
// Call AddSymbolicGradients which will add nodes to graph 'g' that
|
||||
// compute the function gradient (adding an entry in 'x_grad_node_outputs' for
|
||||
// each node in 'x_node_outputs').
|
||||
// compute the function gradient (adding an entry in 'x_grad_node_outputs'
|
||||
// for each node in 'x_node_outputs').
|
||||
std::vector<NodeOut> x_grad_node_outputs;
|
||||
TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
|
||||
y_grad_node_outputs, &x_grad_node_outputs,
|
||||
|
@ -201,6 +201,20 @@ inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
|
||||
return ExpandInlineFunctions(lib, graph, InlineFunctionBodyOptions());
|
||||
}
|
||||
|
||||
// Extracts function name and attributes from `call_def` and invokes
|
||||
// flr->Instantiate(name, attrs, handle).
|
||||
// `call_def` can be a native function call (where the op type is the function
|
||||
// name) or a call through PartitionedCall/StatefulPartitionedCall.
|
||||
Status InstantiateFunctionCall(const NodeDef& call_def,
|
||||
FunctionLibraryRuntime& flr,
|
||||
FunctionLibraryRuntime::Handle* handle);
|
||||
|
||||
// Returns true iff `n` represents a function call. `n` can be a native
|
||||
// function call (n.type_string() is the function name),
|
||||
// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
|
||||
// has been deprecated for a while).
|
||||
bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
|
||||
|
||||
// Instantiates FunctionDef into a graph. Set *fbody to point to the
|
||||
// FunctionBody that holds the instantiated FunctionDef.
|
||||
Status FunctionDefToBodyHelper(
|
||||
|
@ -85,6 +85,8 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
|
||||
{"CollectiveBcastSend", NC_COLLECTIVE},
|
||||
{"CollectiveBcastRecv", NC_COLLECTIVE},
|
||||
{"FakeParam", NC_FAKE_PARAM},
|
||||
{"PartitionedCall", NC_PARTITIONED_CALL},
|
||||
{"StatefulPartitionedCall", NC_PARTITIONED_CALL},
|
||||
});
|
||||
|
||||
#undef REF_CLASS
|
||||
|
@ -63,8 +63,8 @@ struct OutputTensor;
|
||||
class VersionDef;
|
||||
class WhileContext;
|
||||
|
||||
class NeighborIter; // Declared below
|
||||
class NodeIter; // Declared below
|
||||
class NeighborIter; // Declared below
|
||||
class NodeIter; // Declared below
|
||||
struct NodeProperties; // Defined in .cc
|
||||
|
||||
class Node {
|
||||
@ -173,6 +173,7 @@ class Node {
|
||||
|
||||
bool IsMetadata() const { return class_ == NC_METADATA; }
|
||||
bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
|
||||
bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
|
||||
|
||||
template <typename T>
|
||||
void AddAttr(const string& name, const T& val) {
|
||||
@ -254,6 +255,7 @@ class Node {
|
||||
NC_SCOPED_ALLOCATOR,
|
||||
NC_COLLECTIVE,
|
||||
NC_FAKE_PARAM,
|
||||
NC_PARTITIONED_CALL,
|
||||
NC_OTHER // Not a special kind of node
|
||||
};
|
||||
|
||||
|
@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/partitioned_function_ops.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -31,234 +34,219 @@ limitations under the License.
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// A `PartitionedCallOp` asynchronously executes a function, potentially across
|
||||
// multiple devices but within a single process. The kernel places and
|
||||
// partitions a given function's underlying graph, and executes each of the
|
||||
// partitioned subgraphs as a function.
|
||||
//
|
||||
// TODO(akshayka): Support distributed execution.
|
||||
class PartitionedCallOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit PartitionedCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
|
||||
string deprecated_config_serialized;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &deprecated_config_serialized));
|
||||
string config_proto_serialized;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("config_proto", &config_proto_serialized));
|
||||
|
||||
PartitionedCallOp::PartitionedCallOp(OpKernelConstruction* ctx)
|
||||
: AsyncOpKernel(ctx),
|
||||
func_(new NameAttrList),
|
||||
config_proto_(new ConfigProto) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, func_.get()));
|
||||
string deprecated_config_serialized;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &deprecated_config_serialized));
|
||||
string config_proto_serialized;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("config_proto", &config_proto_serialized));
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
deprecated_config_serialized.empty() || config_proto_serialized.empty(),
|
||||
errors::InvalidArgument("Provided both 'config' and 'config_proto' but "
|
||||
"only one should be provided. Note the "
|
||||
"'config' option is deprecated."));
|
||||
if (!deprecated_config_serialized.empty()) {
|
||||
OP_REQUIRES(ctx,
|
||||
config_proto_->mutable_graph_options()
|
||||
->mutable_rewrite_options()
|
||||
->ParseFromString(deprecated_config_serialized),
|
||||
errors::InvalidArgument("Unable to parse config string as "
|
||||
"tensorflow::RewriteOptions proto."));
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
deprecated_config_serialized.empty() || config_proto_serialized.empty(),
|
||||
errors::InvalidArgument("Provided both 'config' and 'config_proto' but "
|
||||
"only one should be provided. Note the "
|
||||
"'config' option is deprecated."));
|
||||
if (!deprecated_config_serialized.empty()) {
|
||||
OP_REQUIRES(ctx,
|
||||
config_proto_.mutable_graph_options()
|
||||
->mutable_rewrite_options()
|
||||
->ParseFromString(deprecated_config_serialized),
|
||||
errors::InvalidArgument("Unable to parse config string as "
|
||||
"tensorflow::RewriteOptions proto."));
|
||||
ctx, config_proto_->ParseFromString(config_proto_serialized),
|
||||
errors::InvalidArgument("Unable to parse config_proto string as "
|
||||
"tensorflow::ConfigProto proto."));
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("executor_type", &executor_type_));
|
||||
}
|
||||
|
||||
PartitionedCallOp::~PartitionedCallOp() {
|
||||
for (const auto& it : handles_) {
|
||||
Status status = it.first->ReleaseHandle(it.second);
|
||||
if (!status.ok()) {
|
||||
LOG(INFO) << "Ignoring error while destructing PartitionedCallOp: "
|
||||
<< status.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PartitionedCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
FunctionLibraryRuntime* lib = ctx->function_library();
|
||||
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
|
||||
errors::Internal("No function library is provided."), done);
|
||||
|
||||
// The function body's graph is placed and partitioned the first time
|
||||
// `ComputeAsync` is invoked; every subsequent invocation calls each
|
||||
// of the function shards yielded by partitioning.
|
||||
//
|
||||
// The partitioning step yields a set of devices on which to run the
|
||||
// function, and exactly one function shard is created for each device
|
||||
// Inputs and outputs are pinned to the local device, for simplicity.
|
||||
//
|
||||
// TODO(akshayka): Support re-sharding the function on subsequent calls,
|
||||
// via, e.g., virtual device annotations and a list of device names
|
||||
// supplied through an attribute.
|
||||
//
|
||||
// TODO(akshayka): Add a fastpath for functions that execute on a single
|
||||
// device.
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If we are instantiating the function, we can efficiently extract the
|
||||
// inputs while instantiating. Else, we extract them separately below.
|
||||
std::vector<Tensor> inputs;
|
||||
bool inputs_extracted = false;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
auto it = handles_.find(lib);
|
||||
if (it == handles_.end()) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, ctx, &inputs, &handle), done);
|
||||
inputs_extracted = true;
|
||||
handles_[lib] = handle;
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx, config_proto_.ParseFromString(config_proto_serialized),
|
||||
errors::InvalidArgument("Unable to parse config_proto string as "
|
||||
"tensorflow::ConfigProto proto."));
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("executor_type", &executor_type_));
|
||||
}
|
||||
|
||||
~PartitionedCallOp() override {
|
||||
for (const auto& it : handles_) {
|
||||
Status status = it.first->ReleaseHandle(it.second);
|
||||
if (!status.ok()) {
|
||||
LOG(INFO) << "Ignoring error while destructing PartitionedCallOp: "
|
||||
<< status.ToString();
|
||||
}
|
||||
handle = it->second;
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
FunctionLibraryRuntime* lib = ctx->function_library();
|
||||
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
|
||||
errors::Internal("No function library is provided."),
|
||||
done);
|
||||
|
||||
// The function body's graph is placed and partitioned the first time
|
||||
// `ComputeAsync` is invoked; every subsequent invocation calls each
|
||||
// of the function shards yielded by partitioning.
|
||||
//
|
||||
// The partitioning step yields a set of devices on which to run the
|
||||
// function, and exactly one function shard is created for each device
|
||||
// Inputs and outputs are pinned to the local device, for simplicity.
|
||||
//
|
||||
// TODO(akshayka): Support re-sharding the function on subsequent calls,
|
||||
// via, e.g., virtual device annotations and a list of device names
|
||||
// supplied through an attribute.
|
||||
//
|
||||
// TODO(akshayka): Add a fastpath for functions that execute on a single
|
||||
// device.
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If we are instantiating the function, we can efficiently extract the
|
||||
// inputs while instantiating. Else, we extract them separately below.
|
||||
std::vector<Tensor> inputs;
|
||||
bool inputs_extracted = false;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
auto it = handles_.find(lib);
|
||||
if (it == handles_.end()) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, ctx, &inputs, &handle),
|
||||
done);
|
||||
inputs_extracted = true;
|
||||
handles_[lib] = handle;
|
||||
} else {
|
||||
handle = it->second;
|
||||
}
|
||||
}
|
||||
|
||||
if (!inputs_extracted) {
|
||||
OpInputList args;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &args), done);
|
||||
inputs.reserve(args.size());
|
||||
for (const Tensor& tensor : args) {
|
||||
inputs.push_back(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
RunFunction(handle, inputs, lib, ctx, done);
|
||||
}
|
||||
|
||||
private:
|
||||
Status FillOutputDevices(const FunctionLibraryRuntime& lib,
|
||||
const Device& cpu_device, AttrSlice attrs,
|
||||
FunctionLibraryRuntime::InstantiateOptions* opts) {
|
||||
const FunctionLibraryDefinition* flib = lib.GetFunctionLibraryDefinition();
|
||||
const FunctionDef* fdef = flib->Find(func_.name());
|
||||
if (fdef == nullptr) {
|
||||
return errors::NotFound("Failed for find definiton for function \"",
|
||||
func_.name(), "\"");
|
||||
}
|
||||
|
||||
bool is_type_list;
|
||||
for (const OpDef::ArgDef& ret_def : fdef->signature().output_arg()) {
|
||||
DataTypeVector dtypes;
|
||||
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
|
||||
for (DataType dtype : dtypes) {
|
||||
if (MTypeFromDType(dtype) == HOST_MEMORY) {
|
||||
opts->output_devices.push_back(cpu_device.name());
|
||||
} else {
|
||||
opts->output_devices.push_back(opts->target);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx,
|
||||
std::vector<Tensor>* inputs,
|
||||
FunctionLibraryRuntime::Handle* handle) {
|
||||
grappler::GrapplerItem::OptimizationOptions optimization_options;
|
||||
|
||||
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
||||
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
||||
// However because the function will be executed via FunctionLibraryRuntime,
|
||||
// and current function implementation does not prune stateful and dataset
|
||||
// ops, we rely on Grappler to do the correct graph pruning.
|
||||
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
||||
|
||||
// All the nested function calls will be executed and optimized via
|
||||
// PartitionedCallOp, there is no need to optimize functions now.
|
||||
optimization_options.optimize_function_library = false;
|
||||
|
||||
FunctionLibraryRuntime::InstantiateOptions opts;
|
||||
opts.target = lib->device()->name();
|
||||
opts.is_multi_device_function = true;
|
||||
opts.optimize_graph_fn =
|
||||
std::bind(grappler::OptimizeGraph, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3,
|
||||
std::placeholders::_4, std::placeholders::_5, config_proto_,
|
||||
func_.name(), optimization_options, std::placeholders::_6);
|
||||
opts.graph_collector = ctx->graph_collector();
|
||||
opts.executor_type = executor_type_;
|
||||
|
||||
if (!inputs_extracted) {
|
||||
OpInputList args;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("args", &args));
|
||||
Device* cpu_device;
|
||||
TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
|
||||
|
||||
inputs->reserve(args.size());
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &args), done);
|
||||
inputs.reserve(args.size());
|
||||
for (const Tensor& tensor : args) {
|
||||
inputs->push_back(tensor);
|
||||
DataType dtype = tensor.dtype();
|
||||
if (dtype == DT_RESOURCE) {
|
||||
const ResourceHandle& handle = tensor.flat<ResourceHandle>()(0);
|
||||
opts.input_devices.push_back(handle.device());
|
||||
} else if (MTypeFromDType(dtype) == HOST_MEMORY) {
|
||||
opts.input_devices.push_back(cpu_device->name());
|
||||
inputs.push_back(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
RunFunction(handle, inputs, lib, ctx, done);
|
||||
}
|
||||
|
||||
Status PartitionedCallOp::FillOutputDevices(
|
||||
const FunctionLibraryRuntime& lib, const Device& cpu_device,
|
||||
AttrSlice attrs, FunctionLibraryRuntime::InstantiateOptions* opts) {
|
||||
const FunctionLibraryDefinition* flib = lib.GetFunctionLibraryDefinition();
|
||||
const FunctionDef* fdef = flib->Find(func_->name());
|
||||
if (fdef == nullptr) {
|
||||
return errors::NotFound("Failed for find definition for function \"",
|
||||
func_->name(), "\"");
|
||||
}
|
||||
|
||||
bool is_type_list;
|
||||
for (const OpDef::ArgDef& ret_def : fdef->signature().output_arg()) {
|
||||
DataTypeVector dtypes;
|
||||
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
|
||||
for (DataType dtype : dtypes) {
|
||||
if (MTypeFromDType(dtype) == HOST_MEMORY) {
|
||||
opts->output_devices.push_back(cpu_device.name());
|
||||
} else {
|
||||
opts.input_devices.push_back(opts.target);
|
||||
opts->output_devices.push_back(opts->target);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
FillOutputDevices(*lib, *cpu_device, AttrSlice(&func_.attr()), &opts));
|
||||
Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib,
|
||||
OpKernelContext* ctx,
|
||||
std::vector<Tensor>* inputs,
|
||||
FunctionLibraryRuntime::Handle* handle) {
|
||||
grappler::GrapplerItem::OptimizationOptions optimization_options;
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts, handle));
|
||||
return Status::OK();
|
||||
// Tensorflow 2.0 in eager mode with automatic control dependencies will
|
||||
// prune all nodes that are not in the transitive fanin of the fetch nodes.
|
||||
// However because the function will be executed via FunctionLibraryRuntime,
|
||||
// and current function implementation does not prune stateful and dataset
|
||||
// ops, we rely on Grappler to do the correct graph pruning.
|
||||
optimization_options.allow_pruning_stateful_and_dataset_ops = true;
|
||||
|
||||
// All the nested function calls will be executed and optimized via
|
||||
// PartitionedCallOp, there is no need to optimize functions now.
|
||||
optimization_options.optimize_function_library = false;
|
||||
|
||||
FunctionLibraryRuntime::InstantiateOptions opts;
|
||||
// In some contexts like running the graph to evaluate constants,
|
||||
// the FLR won't have any device.
|
||||
opts.target = lib->device() == nullptr ? "" : lib->device()->name();
|
||||
opts.is_multi_device_function = true;
|
||||
opts.optimize_graph_fn =
|
||||
std::bind(grappler::OptimizeGraph, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3,
|
||||
std::placeholders::_4, std::placeholders::_5, *config_proto_,
|
||||
func_->name(), optimization_options, std::placeholders::_6);
|
||||
opts.graph_collector = ctx->graph_collector();
|
||||
opts.executor_type = executor_type_;
|
||||
|
||||
OpInputList args;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("args", &args));
|
||||
Device* cpu_device;
|
||||
TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
|
||||
|
||||
inputs->reserve(args.size());
|
||||
for (const Tensor& tensor : args) {
|
||||
inputs->push_back(tensor);
|
||||
DataType dtype = tensor.dtype();
|
||||
if (dtype == DT_RESOURCE) {
|
||||
const ResourceHandle& handle = tensor.flat<ResourceHandle>()(0);
|
||||
opts.input_devices.push_back(handle.device());
|
||||
} else if (MTypeFromDType(dtype) == HOST_MEMORY) {
|
||||
opts.input_devices.push_back(cpu_device->name());
|
||||
} else {
|
||||
opts.input_devices.push_back(opts.target);
|
||||
}
|
||||
}
|
||||
|
||||
void RunFunction(FunctionLibraryRuntime::Handle handle,
|
||||
const std::vector<Tensor>& inputs,
|
||||
FunctionLibraryRuntime* lib, OpKernelContext* ctx,
|
||||
DoneCallback done) {
|
||||
FunctionLibraryRuntime::Options run_opts;
|
||||
run_opts.step_id = ctx->step_id();
|
||||
run_opts.step_container = ctx->step_container();
|
||||
run_opts.cancellation_manager = ctx->cancellation_manager();
|
||||
run_opts.stats_collector = ctx->stats_collector();
|
||||
run_opts.collective_executor = ctx->collective_executor();
|
||||
// TODO(akshayka): Consider selecting a runner on a per-device basis,
|
||||
// i.e., using device-specific threadpools when available.
|
||||
run_opts.runner = ctx->runner();
|
||||
run_opts.source_device = lib->device()->name();
|
||||
run_opts.allow_dead_tensors = true;
|
||||
// TODO(akshayka): Accommodate the multiple-worker scenario by adding the
|
||||
// constructed rendezvous to a rendezvous manager.
|
||||
Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr());
|
||||
run_opts.rendezvous = rendez;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FillOutputDevices(*lib, *cpu_device, AttrSlice(&func_->attr()), &opts));
|
||||
|
||||
std::vector<Tensor>* rets = new std::vector<Tensor>;
|
||||
const string& func_name = func_.name();
|
||||
lib->Run(run_opts, handle, inputs, rets,
|
||||
[rets, rendez, done, ctx, func_name](const Status& status) {
|
||||
if (!status.ok()) {
|
||||
const string function_and_msg =
|
||||
strings::StrCat(errors::FormatFunctionForError(func_name),
|
||||
" ", status.error_message());
|
||||
ctx->SetStatus(Status(status.code(), function_and_msg));
|
||||
} else {
|
||||
for (int i = 0; i < rets->size(); ++i) {
|
||||
ctx->set_output(i, (*rets)[i]);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
lib->Instantiate(func_->name(), AttrSlice(&func_->attr()), opts, handle));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle,
|
||||
const std::vector<Tensor>& inputs,
|
||||
FunctionLibraryRuntime* lib,
|
||||
OpKernelContext* ctx, DoneCallback done) {
|
||||
FunctionLibraryRuntime::Options run_opts;
|
||||
run_opts.step_id = ctx->step_id();
|
||||
run_opts.step_container = ctx->step_container();
|
||||
run_opts.cancellation_manager = ctx->cancellation_manager();
|
||||
run_opts.stats_collector = ctx->stats_collector();
|
||||
run_opts.collective_executor = ctx->collective_executor();
|
||||
// TODO(akshayka): Consider selecting a runner on a per-device basis,
|
||||
// i.e., using device-specific threadpools when available.
|
||||
run_opts.runner = ctx->runner();
|
||||
run_opts.source_device =
|
||||
lib->device() == nullptr ? "" : lib->device()->name();
|
||||
run_opts.allow_dead_tensors = true;
|
||||
// TODO(akshayka): Accommodate the multiple-worker scenario by adding the
|
||||
// constructed rendezvous to a rendezvous manager.
|
||||
Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr());
|
||||
run_opts.rendezvous = rendez;
|
||||
|
||||
std::vector<Tensor>* rets = new std::vector<Tensor>;
|
||||
const string& func_name = func_->name();
|
||||
lib->Run(run_opts, handle, inputs, rets,
|
||||
[rets, rendez, done, ctx, func_name](const Status& status) {
|
||||
if (!status.ok()) {
|
||||
const string function_and_msg =
|
||||
strings::StrCat(errors::FormatFunctionForError(func_name),
|
||||
" ", status.error_message());
|
||||
ctx->SetStatus(Status(status.code(), function_and_msg));
|
||||
} else {
|
||||
for (int i = 0; i < rets->size(); ++i) {
|
||||
ctx->set_output(i, (*rets)[i]);
|
||||
}
|
||||
delete rets;
|
||||
rendez->Unref();
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
NameAttrList func_;
|
||||
ConfigProto config_proto_;
|
||||
string executor_type_;
|
||||
mutex mu_;
|
||||
// Cache the handle per FLR because this kernel may be instantiated for
|
||||
// a stateful op, different invocations of it may use different FLRs.
|
||||
// Different device placements of PartitionedCallOp also use
|
||||
// different FLRs.
|
||||
gtl::FlatMap<FunctionLibraryRuntime*, FunctionLibraryRuntime::Handle> handles_
|
||||
GUARDED_BY(mu_);
|
||||
};
|
||||
}
|
||||
delete rets;
|
||||
rendez->Unref();
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_CPU),
|
||||
PartitionedCallOp);
|
||||
@ -275,5 +263,4 @@ REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_SYCL),
|
||||
PartitionedCallOp);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
72
tensorflow/core/kernels/partitioned_function_ops.h
Normal file
72
tensorflow/core/kernels/partitioned_function_ops.h
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class NameAttrList;
|
||||
class ConfigProto;
|
||||
|
||||
// A `PartitionedCallOp` asynchronously executes a function, potentially across
|
||||
// multiple devices but within a single process. The kernel places and
|
||||
// partitions a given function's underlying graph, and executes each of the
|
||||
// partitioned subgraphs as a function.
|
||||
//
|
||||
// TODO(akshayka): Support distributed execution.
|
||||
class PartitionedCallOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit PartitionedCallOp(OpKernelConstruction* ctx);
|
||||
|
||||
~PartitionedCallOp() override;
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
|
||||
private:
|
||||
Status FillOutputDevices(const FunctionLibraryRuntime& lib,
|
||||
const Device& cpu_device, AttrSlice attrs,
|
||||
FunctionLibraryRuntime::InstantiateOptions* opts);
|
||||
|
||||
Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx,
|
||||
std::vector<Tensor>* inputs,
|
||||
FunctionLibraryRuntime::Handle* handle);
|
||||
|
||||
void RunFunction(FunctionLibraryRuntime::Handle handle,
|
||||
const std::vector<Tensor>& inputs,
|
||||
FunctionLibraryRuntime* lib, OpKernelContext* ctx,
|
||||
DoneCallback done);
|
||||
|
||||
// Using unique pointers to avoid including proto headers in kernel headers
|
||||
std::unique_ptr<NameAttrList> func_;
|
||||
std::unique_ptr<ConfigProto> config_proto_;
|
||||
string executor_type_;
|
||||
mutex mu_;
|
||||
// Cache the handle per FLR because this kernel may be instantiated for
|
||||
// a stateful op, different invocations of it may use different FLRs.
|
||||
// Different device placements of PartitionedCallOp also use
|
||||
// different FLRs.
|
||||
gtl::FlatMap<FunctionLibraryRuntime*, FunctionLibraryRuntime::Handle> handles_
|
||||
GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_
|
@ -38,7 +38,6 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes as dtypes_module
|
||||
from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
@ -64,7 +63,7 @@ BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
|
||||
|
||||
CacheKey = collections.namedtuple("CacheKey", [
|
||||
"input_signature", "parent_graph", "device_functions",
|
||||
"colocation_stack", "uses_xla"])
|
||||
"colocation_stack"])
|
||||
|
||||
CacheKey.replace = CacheKey._replace # pylint: disable=protected-access
|
||||
|
||||
@ -417,25 +416,6 @@ class _EagerDefinedFunction(object):
|
||||
ctx=ctx)
|
||||
# Replace empty list with None
|
||||
outputs = outputs or None
|
||||
elif self._graph._xla_compile: # pylint: disable=protected-access
|
||||
g = ops.get_default_graph()
|
||||
self.add_to_graph(g)
|
||||
signature = self.signature
|
||||
with ops.control_dependencies(self._control_captures):
|
||||
op = g.create_op(
|
||||
signature.name,
|
||||
[ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
|
||||
tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
|
||||
op_def=signature,
|
||||
name="FunctionCall",
|
||||
compute_shapes=False)
|
||||
outputs = op.outputs
|
||||
if not outputs:
|
||||
return op
|
||||
if isinstance(outputs, (ops.Tensor, type(None))):
|
||||
outputs = [outputs]
|
||||
else:
|
||||
outputs = list(outputs)
|
||||
else:
|
||||
# TODO(akshayka): Either remove this if the FunctionLibraryRuntime
|
||||
# creates `PartitionedCallOp` kernels by default, or remove the previous
|
||||
@ -1464,16 +1444,13 @@ class Function(object):
|
||||
default_graph._distribution_strategy_stack)
|
||||
if executing_eagerly:
|
||||
colocation_stack = ()
|
||||
uses_xla = ctx.device_spec.device_type == "TPU"
|
||||
if uses_distribution_strategy or uses_xla:
|
||||
if uses_distribution_strategy:
|
||||
device_functions = (pydev.merge_device(ctx.device_name),)
|
||||
else:
|
||||
device_functions = ()
|
||||
else:
|
||||
colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
|
||||
uses_xla = getattr(default_graph, "_xla_compile", False)
|
||||
if (uses_distribution_strategy
|
||||
or uses_xla
|
||||
or func_graph_module.device_stack_has_callable(
|
||||
default_graph._device_function_stack)):
|
||||
# Putting the device in the cache key ensures that call-site device
|
||||
@ -1483,7 +1460,7 @@ class Function(object):
|
||||
device_functions = ()
|
||||
# pylint: enable=protected-access
|
||||
return CacheKey(input_signature, parent_graph, device_functions,
|
||||
colocation_stack, uses_xla)
|
||||
colocation_stack)
|
||||
|
||||
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
||||
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
||||
@ -1572,11 +1549,10 @@ class Function(object):
|
||||
|
||||
call_context_key = cache_key.replace(input_signature=None)
|
||||
|
||||
# If there's a provided input signature, or XLA is being used, or
|
||||
# If there's a provided input signature, or
|
||||
# there's no cache miss for this calling context so far, go ahead and
|
||||
# build the function and bypass shape relaxation retracing.
|
||||
if (self.input_signature is not None
|
||||
or cache_key.uses_xla
|
||||
or call_context_key not in self._function_cache.missed):
|
||||
self._function_cache.missed.add(call_context_key)
|
||||
graph_function = self._create_graph_function(args, kwargs)
|
||||
|
@ -208,13 +208,9 @@ class FuncGraph(ops.Graph):
|
||||
# any None op_seed for random_op in the function, in which case we end up
|
||||
# using function seed, which could be unintended behavior for the op.
|
||||
self._seed_used = False
|
||||
device_type = context.context().device_spec.device_type
|
||||
self._xla_compile = (device_type == "TPU" or device_type == "XLA_GPU"
|
||||
or device_type == "XLA_CPU")
|
||||
else:
|
||||
self.seed = graph.seed
|
||||
self._seed_used = False
|
||||
self._xla_compile = getattr(graph, "_xla_compile", False)
|
||||
# TODO(allenl): Figure out if we can remove colocation stack
|
||||
# specialization (currently used in cond_v2), here and in the cache key.
|
||||
self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access
|
||||
@ -297,11 +293,10 @@ class FuncGraph(ops.Graph):
|
||||
# restored.
|
||||
old_device_stack = self._device_function_stack
|
||||
if context.executing_eagerly():
|
||||
if self._distribution_strategy_stack or self._xla_compile:
|
||||
if self._distribution_strategy_stack:
|
||||
self._add_device_to_stack(context.context().device_name)
|
||||
else:
|
||||
if (self._distribution_strategy_stack
|
||||
or self._xla_compile
|
||||
or device_stack_has_callable(graph._device_function_stack)):
|
||||
# Hard-code devices from device functions in the function body
|
||||
self._device_function_stack = graph._device_function_stack.copy()
|
||||
|
@ -829,6 +829,8 @@ class CondV2Test(test.TestCase):
|
||||
self.evaluate(output_t), [-5, -4, -3, -2, -1, 0, 1, 4, 9, 16])
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
@test_util.disable_xla(
|
||||
"b/127846988: No tf2xla kernel for IfOp taking DT_VARIANT")
|
||||
def testCondAndTensorArrayInDefun(self):
|
||||
|
||||
@function.defun
|
||||
|
@ -534,6 +534,7 @@ class PyFuncTest(test.TestCase):
|
||||
self.assertIsNone(ret)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.disable_xla("XLA cannot compile functions containing py_func")
|
||||
def testEagerPyFuncInDefun(self):
|
||||
with test_util.device(use_gpu=True):
|
||||
def wrapper():
|
||||
|
Loading…
Reference in New Issue
Block a user