diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index e155becab11..ea0c77bd684 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -164,29 +164,6 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( return uncompilable_nodes; } -RecursiveCompilabilityChecker::UncompilableNodesMap -RecursiveCompilabilityChecker::FindUncompilableNodes( - const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, - const std::vector* - node_stack_trace) const { - // If `node_stack_trace` is provided, that means `call_def` is inside - // a function body, and therefore, arg nodes and retval nodes are - // not considered uncompilable. - std::vector stack_trace; - if (node_stack_trace != nullptr) { - for (const auto& frame : *node_stack_trace) { - stack_trace.emplace_back( - StackFrameView{frame.name, frame.function_name, frame.stack_trace}); - } - } - stack_trace.emplace_back(StackFrameView{call_def.name(), "", nullptr}); - - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes; - IsCompilableCall(call_def, lib_runtime, &stack_trace, - /*encapsulating_function=*/nullptr, &uncompilable_nodes); - return uncompilable_nodes; -} - bool RecursiveCompilabilityChecker::HasXLAKernel( const Node& node, string* uncompilable_reason) const { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index c75a2090cdd..4481b6f7ad6 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -157,20 +157,6 @@ class RecursiveCompilabilityChecker { const Node& node, FunctionLibraryRuntime* lib_runtime, const std::vector* node_stack_trace = nullptr) const; - // Returns a map where the key is the function identifier(short debug - // string) of the function encapsulating the uncompilable nodes, and the - // value is a pair of NameAttrList of the function and a vector of - // uncompilable node info. When uncompilable node is not inside any - // function call nodes, then key is a ShortDebugString() of an empty - // NameAttrList. - // - // Also, when `node` is inside a function body, users can set - // `node_stack_trace` to provide an additional context for `node`'s - // placement within the outer most graph. - UncompilableNodesMap FindUncompilableNodes( - const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, - const std::vector* node_stack_trace = nullptr) const; - // Returns true if `node` can be compiled by XLA. bool IsCompilableNode(const Node& node, FunctionLibraryRuntime* lib_runtime) const { @@ -179,15 +165,6 @@ class RecursiveCompilabilityChecker { return IsCompilableNode(node, lib_runtime, &stack_trace); } - // Returns true if `call_def` can be compiled by XLA. It is assumed that - // `call_def` is a call operation. - bool IsCompilableCall(const NodeDef& call_def, - FunctionLibraryRuntime* lib_runtime) { - std::vector stack_trace; - stack_trace.emplace_back(StackFrameView{call_def.name(), ""}); - return IsCompilableCall(call_def, lib_runtime, &stack_trace); - } - // Returns true if XLA supports this Op, but we don't want to cluster it (ie: // due to performance or correctness concerns). bool OpIsInaccurate(const Node& node) const; diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 054ac9987c6..52991c5312b 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -32,44 +32,6 @@ limitations under the License. namespace tensorflow { -// Returns true iff 'ndef' is a call to a function that is compilable. A -// function is compilable iff every operator in the function body is -// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not -// null, we will populate 'uncompilable_node_info' with uncompilable node info. -static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, - RecursiveCompilabilityChecker::UncompilableNodesMap* - uncompilable_node_info) { - Device* device = flr->device(); - const XlaOpRegistry::DeviceRegistration* registration; - CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), - ®istration)); - - // We can always *compile* resource operations, stateful RNGs and dummy ops, - // even if we are sometimes unable to auto-cluster them. - RecursiveCompilabilityChecker::OperationFilter op_filter; - op_filter.allow_resource_ops_in_called_functions = true; - op_filter.allow_stack_ops = true; - op_filter.allow_tensor_array_ops = true; - op_filter.allow_stateful_rng_ops = true; - op_filter.allow_control_trigger = true; - op_filter.allow_eliding_assert_and_checknumerics_ops = true; - op_filter.allow_ops_producing_or_consuming_variant = true; - op_filter.allow_slow_ops = true; - op_filter.allow_inaccurate_ops = true; - - RecursiveCompilabilityChecker checker{ - op_filter, DeviceType{registration->compilation_device_name}}; - if (!uncompilable_node_info) { - // We do not need uncompilable node info. Just return the result. - return checker.IsCompilableCall(ndef, flr); - } - - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result = - checker.FindUncompilableNodes(ndef, flr); - uncompilable_node_info->swap(uncompilable_node_result); - return uncompilable_node_info->empty(); -} - bool XlaKernelCreator::CanCreateKernel( const FunctionLibraryRuntime& flr, const std::shared_ptr& props) const { @@ -98,56 +60,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( flr, function, &fbody, &constant_arg_indices, &resource_arg_indices)); - // Only check for compilability if the MLIR bridge is not enabled. - absl::optional config_proto; - if (flr->config_proto()) { - config_proto = *flr->config_proto(); - } - // There is no easy way to check if we have uninitialized resource args here - // so we assume there are uninitialized resource args. This means that we - // might run the compilability checker in cases where we don't need to (when - // MLIR bridge is run later). Note that this is just temporary until - // b/171732021 gets fixed. - // We should also revisit if this check provides any value, otherwise we - // should remove it. - MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( - *fbody->graph, config_proto, /*uses_uninitialized_resource_args=*/true); - if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) { - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; - if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { - std::vector - uncompilable_node_info; - for (const auto& it : uncompilable_nodes_map) { - for (const auto& info : it.second.second) { - uncompilable_node_info.emplace_back(info); - } - } - std::string message = absl::StrCat( - "Function invoked by the following node is not compilable: ", - SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); - absl::StrAppend(&message, "Uncompilable operations:"); - for (const auto& node_info : uncompilable_node_info) { - std::string node_message = absl::StrCat( - "\n", node_info.name, ": ", node_info.uncompilable_reason, "\n", - "The op is created at:\n"); - if (node_info.stack_trace.back().stack_trace) { - AbstractStackTrace::TracePrintingOptions opts; - opts.show_line_contents = true; - opts.filter_common_prefix = true; - opts.drop_internal_frames = true; - absl::StrAppend( - &node_message, - node_info.stack_trace.back().stack_trace->ToString(opts)); - } else { - absl::StrAppend(&node_message, "\n"); - } - absl::StrAppend(&message, node_message); - } - VLOG(1) << message; - return errors::InvalidArgument(message); - } - } - MemoryTypeVector input_memory_types = GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices); MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 7c41e77661b..dda8355a932 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -1232,14 +1232,26 @@ Status ValidateGraph(const Graph* graph, auto maybe_error = [&](const Node* node, const Status& s) -> Status { if (!s.ok()) { - return errors::InvalidArgument(absl::StrCat( + std::string errmsg = absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, " on ", device_type.type_string(), ": ", node->def().op(), " (", - s.error_message(), ")", FormatNodeForError(*node), - "One approach is to outside compile the unsupported ops to run on " - "CPUs by enabling soft placement " - "`tf.config.set_soft_device_placement(True)`." - " This has a potential performance penalty.")); + s.error_message(), ")", FormatNodeForError(*node)); + if (absl::StrContains(device_type.type_string(), "TPU")) { + absl::StrAppend(&errmsg, + "\nOne approach is to outside compile the unsupported " + "ops to run on CPUs by enabling soft placement " + "`tf.config.set_soft_device_placement(True)`." + " This has a potential performance penalty.\n"); + } + if (std::shared_ptr stack_trace = + node->GetStackTrace()) { + absl::StrAppend(&errmsg, "\nThe op is created at: \n", + stack_trace->ToString({.show_line_contents = true, + .filter_common_prefix = true, + .drop_internal_frames = true})); + } + + return errors::InvalidArgument(errmsg); } return Status::OK(); }; diff --git a/tensorflow/python/compiler/xla/jit_compile_test.py b/tensorflow/python/compiler/xla/jit_compile_test.py index 9ec0ffe38c7..a308aeb6cb0 100644 --- a/tensorflow/python/compiler/xla/jit_compile_test.py +++ b/tensorflow/python/compiler/xla/jit_compile_test.py @@ -94,7 +94,7 @@ class JitCompileTest(test.TestCase): inputs = array_ops.placeholder(dtypes.float32, [5]) x = xla_func(inputs) with self.assertRaisesRegex(errors.InvalidArgumentError, - "not compilable"): + "Detected unsupported operations"): with session.Session(graph=g) as sess: sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]}) diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 25f9753d4ba..04132c2a2be 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -149,7 +149,7 @@ class DefFunctionTest(xla_test.XLATestCase): inputs = constant_op.constant([1, 2, 2, 3, 3]) with self.assertRaisesRegex( errors.InvalidArgumentError, 'legalization failed' - if test_util.is_mlir_bridge_enabled() else 'not compilable'): + if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): func(inputs) def testUnsupportedOps(self): @@ -168,7 +168,7 @@ class DefFunctionTest(xla_test.XLATestCase): self.assertAllClose([1, 2, 3], func(inputs)) with self.assertRaisesRegex( errors.InvalidArgumentError, 'legalization failed' - if test_util.is_mlir_bridge_enabled() else 'not compilable'): + if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): xla_func(inputs) @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' @@ -384,7 +384,7 @@ class DefFunctionTest(xla_test.XLATestCase): c = C() with self.assertRaisesRegex( errors.InvalidArgumentError, 'legalization failed' - if test_util.is_mlir_bridge_enabled() else 'not compilable'): + if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): c.f1(inputs) def testMustBeConstantPropagation(self):