[TF2XLA] [NFC] Make a function only used in a single library fileprivate
PiperOrigin-RevId: 332588570 Change-Id: Id65323ba198b1caed6b73b43fe047fe2ad659a6d
This commit is contained in:
parent
42afb8a459
commit
877527646a
@ -1708,40 +1708,6 @@ std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info) {
|
||||
Device* device = flr->device();
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
||||
®istration));
|
||||
|
||||
// We can always *compile* resource operations, stateful RNGs and dummy ops,
|
||||
// even if we are sometimes unable to auto-cluster them.
|
||||
RecursiveCompilabilityChecker::OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops_in_called_functions = true;
|
||||
op_filter.allow_stack_ops = true;
|
||||
op_filter.allow_tensor_array_ops = true;
|
||||
op_filter.allow_stateful_rng_ops = true;
|
||||
op_filter.allow_control_trigger = true;
|
||||
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
||||
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||
op_filter.allow_slow_ops = true;
|
||||
op_filter.allow_inaccurate_ops = true;
|
||||
|
||||
RecursiveCompilabilityChecker checker{
|
||||
op_filter, DeviceType{registration->compilation_device_name}};
|
||||
if (!uncompilable_node_info) {
|
||||
// We do not need uncompilable node info. Just return the result.
|
||||
return checker.IsCompilableCall(ndef, flr);
|
||||
}
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
|
||||
checker.FindUncompilableNodes(ndef, flr);
|
||||
uncompilable_node_info->swap(uncompilable_node_result);
|
||||
return uncompilable_node_info->empty();
|
||||
}
|
||||
|
||||
Status MarkForCompilationPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||
|
@ -50,14 +50,6 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
friend class MarkForCompilationPassTestHelper;
|
||||
};
|
||||
|
||||
// Returns true iff 'ndef' is a call to a function that is compilable. A
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
|
||||
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info = nullptr);
|
||||
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
|
||||
|
||||
namespace testing {
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -32,6 +31,44 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns true iff 'ndef' is a call to a function that is compilable. A
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
|
||||
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
|
||||
static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info) {
|
||||
Device* device = flr->device();
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
||||
®istration));
|
||||
|
||||
// We can always *compile* resource operations, stateful RNGs and dummy ops,
|
||||
// even if we are sometimes unable to auto-cluster them.
|
||||
RecursiveCompilabilityChecker::OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops_in_called_functions = true;
|
||||
op_filter.allow_stack_ops = true;
|
||||
op_filter.allow_tensor_array_ops = true;
|
||||
op_filter.allow_stateful_rng_ops = true;
|
||||
op_filter.allow_control_trigger = true;
|
||||
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
||||
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||
op_filter.allow_slow_ops = true;
|
||||
op_filter.allow_inaccurate_ops = true;
|
||||
|
||||
RecursiveCompilabilityChecker checker{
|
||||
op_filter, DeviceType{registration->compilation_device_name}};
|
||||
if (!uncompilable_node_info) {
|
||||
// We do not need uncompilable node info. Just return the result.
|
||||
return checker.IsCompilableCall(ndef, flr);
|
||||
}
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
|
||||
checker.FindUncompilableNodes(ndef, flr);
|
||||
uncompilable_node_info->swap(uncompilable_node_result);
|
||||
return uncompilable_node_info->empty();
|
||||
}
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const {
|
||||
|
Loading…
Reference in New Issue
Block a user