[TF2XLA] [NFC] Make a function only used in a single library fileprivate
PiperOrigin-RevId: 332588570 Change-Id: Id65323ba198b1caed6b73b43fe047fe2ad659a6d
This commit is contained in:
parent
42afb8a459
commit
877527646a
@ -1708,40 +1708,6 @@ std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
|
||||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
|
||||||
uncompilable_node_info) {
|
|
||||||
Device* device = flr->device();
|
|
||||||
const XlaOpRegistry::DeviceRegistration* registration;
|
|
||||||
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
|
||||||
®istration));
|
|
||||||
|
|
||||||
// We can always *compile* resource operations, stateful RNGs and dummy ops,
|
|
||||||
// even if we are sometimes unable to auto-cluster them.
|
|
||||||
RecursiveCompilabilityChecker::OperationFilter op_filter;
|
|
||||||
op_filter.allow_resource_ops_in_called_functions = true;
|
|
||||||
op_filter.allow_stack_ops = true;
|
|
||||||
op_filter.allow_tensor_array_ops = true;
|
|
||||||
op_filter.allow_stateful_rng_ops = true;
|
|
||||||
op_filter.allow_control_trigger = true;
|
|
||||||
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
|
||||||
op_filter.allow_ops_producing_or_consuming_variant = true;
|
|
||||||
op_filter.allow_slow_ops = true;
|
|
||||||
op_filter.allow_inaccurate_ops = true;
|
|
||||||
|
|
||||||
RecursiveCompilabilityChecker checker{
|
|
||||||
op_filter, DeviceType{registration->compilation_device_name}};
|
|
||||||
if (!uncompilable_node_info) {
|
|
||||||
// We do not need uncompilable node info. Just return the result.
|
|
||||||
return checker.IsCompilableCall(ndef, flr);
|
|
||||||
}
|
|
||||||
|
|
||||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
|
|
||||||
checker.FindUncompilableNodes(ndef, flr);
|
|
||||||
uncompilable_node_info->swap(uncompilable_node_result);
|
|
||||||
return uncompilable_node_info->empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MarkForCompilationPass::Run(
|
Status MarkForCompilationPass::Run(
|
||||||
const GraphOptimizationPassOptions& options) {
|
const GraphOptimizationPassOptions& options) {
|
||||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||||
|
@ -50,14 +50,6 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
|||||||
friend class MarkForCompilationPassTestHelper;
|
friend class MarkForCompilationPassTestHelper;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns true iff 'ndef' is a call to a function that is compilable. A
|
|
||||||
// function is compilable iff every operator in the function body is
|
|
||||||
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
|
|
||||||
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
|
|
||||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
|
||||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
|
||||||
uncompilable_node_info = nullptr);
|
|
||||||
|
|
||||||
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
|
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
|
||||||
|
|
||||||
namespace testing {
|
namespace testing {
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
@ -32,6 +31,44 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Returns true iff 'ndef' is a call to a function that is compilable. A
|
||||||
|
// function is compilable iff every operator in the function body is
|
||||||
|
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
|
||||||
|
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
|
||||||
|
static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||||
|
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||||
|
uncompilable_node_info) {
|
||||||
|
Device* device = flr->device();
|
||||||
|
const XlaOpRegistry::DeviceRegistration* registration;
|
||||||
|
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
||||||
|
®istration));
|
||||||
|
|
||||||
|
// We can always *compile* resource operations, stateful RNGs and dummy ops,
|
||||||
|
// even if we are sometimes unable to auto-cluster them.
|
||||||
|
RecursiveCompilabilityChecker::OperationFilter op_filter;
|
||||||
|
op_filter.allow_resource_ops_in_called_functions = true;
|
||||||
|
op_filter.allow_stack_ops = true;
|
||||||
|
op_filter.allow_tensor_array_ops = true;
|
||||||
|
op_filter.allow_stateful_rng_ops = true;
|
||||||
|
op_filter.allow_control_trigger = true;
|
||||||
|
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
||||||
|
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||||
|
op_filter.allow_slow_ops = true;
|
||||||
|
op_filter.allow_inaccurate_ops = true;
|
||||||
|
|
||||||
|
RecursiveCompilabilityChecker checker{
|
||||||
|
op_filter, DeviceType{registration->compilation_device_name}};
|
||||||
|
if (!uncompilable_node_info) {
|
||||||
|
// We do not need uncompilable node info. Just return the result.
|
||||||
|
return checker.IsCompilableCall(ndef, flr);
|
||||||
|
}
|
||||||
|
|
||||||
|
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
|
||||||
|
checker.FindUncompilableNodes(ndef, flr);
|
||||||
|
uncompilable_node_info->swap(uncompilable_node_result);
|
||||||
|
return uncompilable_node_info->empty();
|
||||||
|
}
|
||||||
|
|
||||||
bool XlaKernelCreator::CanCreateKernel(
|
bool XlaKernelCreator::CanCreateKernel(
|
||||||
const FunctionLibraryRuntime& flr,
|
const FunctionLibraryRuntime& flr,
|
||||||
const std::shared_ptr<const NodeProperties>& props) const {
|
const std::shared_ptr<const NodeProperties>& props) const {
|
||||||
|
Loading…
Reference in New Issue
Block a user