[TF2XLA] Remove redundant compilability checking in xla_kernel_creator
There's no reason to duplicate the check even before we start compilation. Previously, this was used to provide (slightly) better stack traces, but now that we have true Python stack traces, we can rely on them to provide same UX while returning the error during the compilation process. PiperOrigin-RevId: 358909080 Change-Id: I2d175d9d071ea87f8c35acba43d90a332bae6b21
This commit is contained in:
parent
7734ef812e
commit
19f8c999b3
tensorflow
compiler
python
@ -164,29 +164,6 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap
|
||||
RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
|
||||
node_stack_trace) const {
|
||||
// If `node_stack_trace` is provided, that means `call_def` is inside
|
||||
// a function body, and therefore, arg nodes and retval nodes are
|
||||
// not considered uncompilable.
|
||||
std::vector<StackFrameView> stack_trace;
|
||||
if (node_stack_trace != nullptr) {
|
||||
for (const auto& frame : *node_stack_trace) {
|
||||
stack_trace.emplace_back(
|
||||
StackFrameView{frame.name, frame.function_name, frame.stack_trace});
|
||||
}
|
||||
}
|
||||
stack_trace.emplace_back(StackFrameView{call_def.name(), "", nullptr});
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
|
||||
IsCompilableCall(call_def, lib_runtime, &stack_trace,
|
||||
/*encapsulating_function=*/nullptr, &uncompilable_nodes);
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::HasXLAKernel(
|
||||
const Node& node, string* uncompilable_reason) const {
|
||||
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
|
||||
|
@ -157,20 +157,6 @@ class RecursiveCompilabilityChecker {
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
|
||||
|
||||
// Returns a map where the key is the function identifier(short debug
|
||||
// string) of the function encapsulating the uncompilable nodes, and the
|
||||
// value is a pair of NameAttrList of the function and a vector of
|
||||
// uncompilable node info. When uncompilable node is not inside any
|
||||
// function call nodes, then key is a ShortDebugString() of an empty
|
||||
// NameAttrList.
|
||||
//
|
||||
// Also, when `node` is inside a function body, users can set
|
||||
// `node_stack_trace` to provide an additional context for `node`'s
|
||||
// placement within the outer most graph.
|
||||
UncompilableNodesMap FindUncompilableNodes(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
|
||||
|
||||
// Returns true if `node` can be compiled by XLA.
|
||||
bool IsCompilableNode(const Node& node,
|
||||
FunctionLibraryRuntime* lib_runtime) const {
|
||||
@ -179,15 +165,6 @@ class RecursiveCompilabilityChecker {
|
||||
return IsCompilableNode(node, lib_runtime, &stack_trace);
|
||||
}
|
||||
|
||||
// Returns true if `call_def` can be compiled by XLA. It is assumed that
|
||||
// `call_def` is a call operation.
|
||||
bool IsCompilableCall(const NodeDef& call_def,
|
||||
FunctionLibraryRuntime* lib_runtime) {
|
||||
std::vector<StackFrameView> stack_trace;
|
||||
stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
|
||||
return IsCompilableCall(call_def, lib_runtime, &stack_trace);
|
||||
}
|
||||
|
||||
// Returns true if XLA supports this Op, but we don't want to cluster it (ie:
|
||||
// due to performance or correctness concerns).
|
||||
bool OpIsInaccurate(const Node& node) const;
|
||||
|
@ -32,44 +32,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns true iff 'ndef' is a call to a function that is compilable. A
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
|
||||
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
|
||||
static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info) {
|
||||
Device* device = flr->device();
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
||||
®istration));
|
||||
|
||||
// We can always *compile* resource operations, stateful RNGs and dummy ops,
|
||||
// even if we are sometimes unable to auto-cluster them.
|
||||
RecursiveCompilabilityChecker::OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops_in_called_functions = true;
|
||||
op_filter.allow_stack_ops = true;
|
||||
op_filter.allow_tensor_array_ops = true;
|
||||
op_filter.allow_stateful_rng_ops = true;
|
||||
op_filter.allow_control_trigger = true;
|
||||
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
||||
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||
op_filter.allow_slow_ops = true;
|
||||
op_filter.allow_inaccurate_ops = true;
|
||||
|
||||
RecursiveCompilabilityChecker checker{
|
||||
op_filter, DeviceType{registration->compilation_device_name}};
|
||||
if (!uncompilable_node_info) {
|
||||
// We do not need uncompilable node info. Just return the result.
|
||||
return checker.IsCompilableCall(ndef, flr);
|
||||
}
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
|
||||
checker.FindUncompilableNodes(ndef, flr);
|
||||
uncompilable_node_info->swap(uncompilable_node_result);
|
||||
return uncompilable_node_info->empty();
|
||||
}
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const {
|
||||
@ -98,56 +60,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
// Only check for compilability if the MLIR bridge is not enabled.
|
||||
absl::optional<ConfigProto> config_proto;
|
||||
if (flr->config_proto()) {
|
||||
config_proto = *flr->config_proto();
|
||||
}
|
||||
// There is no easy way to check if we have uninitialized resource args here
|
||||
// so we assume there are uninitialized resource args. This means that we
|
||||
// might run the compilability checker in cases where we don't need to (when
|
||||
// MLIR bridge is run later). Note that this is just temporary until
|
||||
// b/171732021 gets fixed.
|
||||
// We should also revisit if this check provides any value, otherwise we
|
||||
// should remove it.
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
|
||||
*fbody->graph, config_proto, /*uses_uninitialized_resource_args=*/true);
|
||||
if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompilable_node_info;
|
||||
for (const auto& it : uncompilable_nodes_map) {
|
||||
for (const auto& info : it.second.second) {
|
||||
uncompilable_node_info.emplace_back(info);
|
||||
}
|
||||
}
|
||||
std::string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable operations:");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
std::string node_message = absl::StrCat(
|
||||
"\n", node_info.name, ": ", node_info.uncompilable_reason, "\n",
|
||||
"The op is created at:\n");
|
||||
if (node_info.stack_trace.back().stack_trace) {
|
||||
AbstractStackTrace::TracePrintingOptions opts;
|
||||
opts.show_line_contents = true;
|
||||
opts.filter_common_prefix = true;
|
||||
opts.drop_internal_frames = true;
|
||||
absl::StrAppend(
|
||||
&node_message,
|
||||
node_info.stack_trace.back().stack_trace->ToString(opts));
|
||||
} else {
|
||||
absl::StrAppend(&node_message, "<Unavailable>\n");
|
||||
}
|
||||
absl::StrAppend(&message, node_message);
|
||||
}
|
||||
VLOG(1) << message;
|
||||
return errors::InvalidArgument(message);
|
||||
}
|
||||
}
|
||||
|
||||
MemoryTypeVector input_memory_types =
|
||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||
|
@ -1232,14 +1232,26 @@ Status ValidateGraph(const Graph* graph,
|
||||
|
||||
auto maybe_error = [&](const Node* node, const Status& s) -> Status {
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
std::string errmsg = absl::StrCat(
|
||||
"Detected unsupported operations when trying to compile graph ", name,
|
||||
" on ", device_type.type_string(), ": ", node->def().op(), " (",
|
||||
s.error_message(), ")", FormatNodeForError(*node),
|
||||
"One approach is to outside compile the unsupported ops to run on "
|
||||
"CPUs by enabling soft placement "
|
||||
"`tf.config.set_soft_device_placement(True)`."
|
||||
" This has a potential performance penalty."));
|
||||
s.error_message(), ")", FormatNodeForError(*node));
|
||||
if (absl::StrContains(device_type.type_string(), "TPU")) {
|
||||
absl::StrAppend(&errmsg,
|
||||
"\nOne approach is to outside compile the unsupported "
|
||||
"ops to run on CPUs by enabling soft placement "
|
||||
"`tf.config.set_soft_device_placement(True)`."
|
||||
" This has a potential performance penalty.\n");
|
||||
}
|
||||
if (std::shared_ptr<AbstractStackTrace> stack_trace =
|
||||
node->GetStackTrace()) {
|
||||
absl::StrAppend(&errmsg, "\nThe op is created at: \n",
|
||||
stack_trace->ToString({.show_line_contents = true,
|
||||
.filter_common_prefix = true,
|
||||
.drop_internal_frames = true}));
|
||||
}
|
||||
|
||||
return errors::InvalidArgument(errmsg);
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
@ -94,7 +94,7 @@ class JitCompileTest(test.TestCase):
|
||||
inputs = array_ops.placeholder(dtypes.float32, [5])
|
||||
x = xla_func(inputs)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"not compilable"):
|
||||
"Detected unsupported operations"):
|
||||
with session.Session(graph=g) as sess:
|
||||
sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
|
||||
|
||||
|
@ -149,7 +149,7 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError, 'legalization failed'
|
||||
if test_util.is_mlir_bridge_enabled() else 'not compilable'):
|
||||
if test_util.is_mlir_bridge_enabled() else 'unsupported operations'):
|
||||
func(inputs)
|
||||
|
||||
def testUnsupportedOps(self):
|
||||
@ -168,7 +168,7 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
self.assertAllClose([1, 2, 3], func(inputs))
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError, 'legalization failed'
|
||||
if test_util.is_mlir_bridge_enabled() else 'not compilable'):
|
||||
if test_util.is_mlir_bridge_enabled() else 'unsupported operations'):
|
||||
xla_func(inputs)
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
@ -384,7 +384,7 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
c = C()
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError, 'legalization failed'
|
||||
if test_util.is_mlir_bridge_enabled() else 'not compilable'):
|
||||
if test_util.is_mlir_bridge_enabled() else 'unsupported operations'):
|
||||
c.f1(inputs)
|
||||
|
||||
def testMustBeConstantPropagation(self):
|
||||
|
Loading…
Reference in New Issue
Block a user