diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index ff1196b2be6..c4ab22274cf 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1708,40 +1708,6 @@ std::atomic* GetPointerToFuel(int64 initial_value) { } } // anonymous namespace -bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, - RecursiveCompilabilityChecker::UncompilableNodesMap* - uncompilable_node_info) { - Device* device = flr->device(); - const XlaOpRegistry::DeviceRegistration* registration; - CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), - ®istration)); - - // We can always *compile* resource operations, stateful RNGs and dummy ops, - // even if we are sometimes unable to auto-cluster them. - RecursiveCompilabilityChecker::OperationFilter op_filter; - op_filter.allow_resource_ops_in_called_functions = true; - op_filter.allow_stack_ops = true; - op_filter.allow_tensor_array_ops = true; - op_filter.allow_stateful_rng_ops = true; - op_filter.allow_control_trigger = true; - op_filter.allow_eliding_assert_and_checknumerics_ops = true; - op_filter.allow_ops_producing_or_consuming_variant = true; - op_filter.allow_slow_ops = true; - op_filter.allow_inaccurate_ops = true; - - RecursiveCompilabilityChecker checker{ - op_filter, DeviceType{registration->compilation_device_name}}; - if (!uncompilable_node_info) { - // We do not need uncompilable node info. Just return the result. - return checker.IsCompilableCall(ndef, flr); - } - - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result = - checker.FindUncompilableNodes(ndef, flr); - uncompilable_node_info->swap(uncompilable_node_result); - return uncompilable_node_info->empty(); -} - Status MarkForCompilationPass::Run( const GraphOptimizationPassOptions& options) { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index 0e9a64e7f28..810ebf38b5c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -50,14 +50,6 @@ class MarkForCompilationPass : public GraphOptimizationPass { friend class MarkForCompilationPassTestHelper; }; -// Returns true iff 'ndef' is a call to a function that is compilable. A -// function is compilable iff every operator in the function body is -// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not -// null, we will populate 'uncompilable_node_info' with uncompilable node info. -bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, - RecursiveCompilabilityChecker::UncompilableNodesMap* - uncompilable_node_info = nullptr); - absl::flat_hash_map>* GetAllowlistTable(); namespace testing { diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 7387978fbcd..68d726d69fb 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -32,6 +31,44 @@ limitations under the License. namespace tensorflow { +// Returns true iff 'ndef' is a call to a function that is compilable. A +// function is compilable iff every operator in the function body is +// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not +// null, we will populate 'uncompilable_node_info' with uncompilable node info. +static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, + RecursiveCompilabilityChecker::UncompilableNodesMap* + uncompilable_node_info) { + Device* device = flr->device(); + const XlaOpRegistry::DeviceRegistration* registration; + CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), + ®istration)); + + // We can always *compile* resource operations, stateful RNGs and dummy ops, + // even if we are sometimes unable to auto-cluster them. + RecursiveCompilabilityChecker::OperationFilter op_filter; + op_filter.allow_resource_ops_in_called_functions = true; + op_filter.allow_stack_ops = true; + op_filter.allow_tensor_array_ops = true; + op_filter.allow_stateful_rng_ops = true; + op_filter.allow_control_trigger = true; + op_filter.allow_eliding_assert_and_checknumerics_ops = true; + op_filter.allow_ops_producing_or_consuming_variant = true; + op_filter.allow_slow_ops = true; + op_filter.allow_inaccurate_ops = true; + + RecursiveCompilabilityChecker checker{ + op_filter, DeviceType{registration->compilation_device_name}}; + if (!uncompilable_node_info) { + // We do not need uncompilable node info. Just return the result. + return checker.IsCompilableCall(ndef, flr); + } + + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result = + checker.FindUncompilableNodes(ndef, flr); + uncompilable_node_info->swap(uncompilable_node_result); + return uncompilable_node_info->empty(); +} + bool XlaKernelCreator::CanCreateKernel( const FunctionLibraryRuntime& flr, const std::shared_ptr& props) const {