[TF2XLA] Check 'case' op compilation from autoclustering codepath, but not from tf.function
This is a bit complicated: we do allow uncompilable code under tf.function(compile=True) in 'case' statements (and, in the future, under 'if/while' statements), provided the input is constant-foldable, and the uncompiled branch is guaranteed to be never executed. We can not allow the same during autoclustering though, as we don't know the value in advance, and we have to be more conservative. PiperOrigin-RevId: 335928032 Change-Id: I84a4683079d1871672f429029e279522c258d66c
This commit is contained in:
parent
6ecbdff0c2
commit
67db2b7620
@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
|
||||
const Node& node, absl::string_view attr_name,
|
||||
absl::string_view call_name) {
|
||||
std::vector<NameAttrList> attr_lists;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
|
||||
|
||||
std::vector<NodeDef> out;
|
||||
for (int i = 0; i < attr_lists.size(); i++) {
|
||||
out.emplace_back();
|
||||
NodeDef& inserted = out.back();
|
||||
inserted.set_name(absl::StrCat(call_name, "_", i));
|
||||
inserted.set_op(attr_lists[i].name());
|
||||
*inserted.mutable_attr() = attr_lists[i].attr();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Utility which searches for values in a sorted list by scanning over it once.
|
||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||
@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
|
||||
return is_compilable;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::IsCompilableCase(
|
||||
const Node& case_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
xla::StatusOr<std::vector<NodeDef>> calls =
|
||||
MakeCallNodesFromAttribute(case_node, "branches", "branch");
|
||||
if (!calls.ok()) {
|
||||
VLOG(2) << "Rejecting node " << case_node.name() << ": "
|
||||
<< "missing attribute 'branches'";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_compilable = true;
|
||||
|
||||
for (const NodeDef& call : *calls) {
|
||||
is_compilable &=
|
||||
IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes);
|
||||
}
|
||||
return is_compilable;
|
||||
}
|
||||
|
||||
// Tests whether 'while_node' is a completely compilable loop.
|
||||
// Every operator in the condition and body functions must be compilable for a
|
||||
// while loop to be compilable.
|
||||
@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (op_filter_.require_always_compilable && node.IsCaseNode() &&
|
||||
!IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported case");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_stateful_rng_ops &&
|
||||
IsStatefulRandomOp(node.type_string())) {
|
||||
absl::string_view uncompilable_reason = "stateful random op";
|
||||
|
@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker {
|
||||
// Whether ops known to have numerical accuracy issues should be considered
|
||||
// compilable..
|
||||
bool allow_inaccurate_ops = false;
|
||||
|
||||
// Require the function to be always compilable, regardless whether some
|
||||
// control flow branches might be dead for a given input.
|
||||
bool require_always_compilable = false;
|
||||
};
|
||||
|
||||
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
||||
@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker {
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
|
||||
// Tests whether 'case_node' is compilable. Every operator in all branches
|
||||
// must be compilable.
|
||||
bool IsCompilableCase(const Node& case_node,
|
||||
FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
|
||||
// Returns compilability of node def retrieved from `node`'s attribute with
|
||||
// name `attr_name`.
|
||||
bool ExtractNodeDefAndCheckCompilability(
|
||||
|
@ -34,7 +34,16 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
|
||||
AttrValue attr;
|
||||
for (const char* name : names) {
|
||||
attr.mutable_list()->add_func()->set_name(name);
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
|
||||
constexpr char kFunctionalIfNodeName[] = "If";
|
||||
constexpr char kFunctionalCaseNodeName[] = "Case";
|
||||
constexpr char kFunctionalWhileNodeName[] = "While";
|
||||
constexpr char kCompilableFunctionName[] = "CompilableFn";
|
||||
constexpr char kCompilableFunctionNodeName[] = "n_c";
|
||||
@ -76,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
|
||||
op_filter_.allow_inaccurate_ops = false;
|
||||
op_filter_.allow_slow_ops = false;
|
||||
|
||||
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
||||
device_type_);
|
||||
checker_ = CreateCompilabilityChecker();
|
||||
}
|
||||
|
||||
std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
|
||||
return absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
||||
device_type_);
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
|
||||
@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
"unsupported op"));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
|
||||
FunctionDefLibrary flib;
|
||||
*flib.add_function() = FunctionDefHelper::Define(
|
||||
/*Function*/ kUncompilableFunctionName,
|
||||
/*Inputs*/ {"n_a:float"},
|
||||
/*Outputs*/ {"n_c_uncompilable:float"},
|
||||
/*Attributes*/ {},
|
||||
// Node info
|
||||
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
|
||||
*flib.add_function() = FunctionDefHelper::Define(
|
||||
/*Function*/ kUncompilableFunctionTwoName,
|
||||
/*Inputs*/ {"n_a:float"},
|
||||
/*Outputs*/ {"n_d_uncompilable:float"},
|
||||
/*Attribute*/ {},
|
||||
// Node info
|
||||
{{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
|
||||
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
|
||||
auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
|
||||
auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||
std::vector<NodeBuilder::NodeOut> inputes(
|
||||
{NodeBuilder::NodeOut(placeholder.node())});
|
||||
Node* case_node;
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
|
||||
.Input(branch_index.node())
|
||||
.Input(inputes)
|
||||
.Attr("branches", FuncListAttr({kUncompilableFunctionName,
|
||||
kUncompilableFunctionTwoName}))
|
||||
.Attr("Tout", {DT_INT32})
|
||||
.Finalize(root.graph(), &case_node));
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
|
||||
|
||||
auto case_node_it = std::find_if(
|
||||
graph->nodes().begin(), graph->nodes().end(),
|
||||
[&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
|
||||
EXPECT_NE(case_node_it, graph->nodes().end());
|
||||
auto* flib_runtime = GetFunctionLibraryRuntime();
|
||||
|
||||
op_filter_.require_always_compilable = false;
|
||||
checker_ = CreateCompilabilityChecker();
|
||||
EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||
op_filter_.require_always_compilable = true;
|
||||
checker_ = CreateCompilabilityChecker();
|
||||
EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!RecursiveCompilabilityChecker{
|
||||
CreateOperationFilter(*registration),
|
||||
DeviceType{registration->compilation_device_name}}
|
||||
.IsCompilableNode(*node, lib_runtime)) {
|
||||
RecursiveCompilabilityChecker::OperationFilter filter =
|
||||
CreateOperationFilter(*registration);
|
||||
filter.require_always_compilable = true;
|
||||
|
||||
RecursiveCompilabilityChecker checker(
|
||||
filter, DeviceType{registration->compilation_device_name});
|
||||
|
||||
if (!checker.IsCompilableNode(*node, lib_runtime)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -824,7 +824,6 @@ cuda_py_test(
|
||||
tags = [
|
||||
"no_oss",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
@ -843,7 +842,6 @@ cuda_py_test(
|
||||
"no_cuda11",
|
||||
"no_oss",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
|
Loading…
Reference in New Issue
Block a user