[TF2XLA] [NFC] Make a function only used in a single library fileprivate

PiperOrigin-RevId: 332588570
Change-Id: Id65323ba198b1caed6b73b43fe047fe2ad659a6d
This commit is contained in:
George Karpenkov 2020-09-18 23:24:57 -07:00 committed by TensorFlower Gardener
parent 42afb8a459
commit 877527646a
3 changed files with 38 additions and 43 deletions

View File

@ -1708,40 +1708,6 @@ std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
} }
} // anonymous namespace } // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
RecursiveCompilabilityChecker::UncompilableNodesMap*
uncompilable_node_info) {
Device* device = flr->device();
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
&registration));
// We can always *compile* resource operations, stateful RNGs and dummy ops,
// even if we are sometimes unable to auto-cluster them.
RecursiveCompilabilityChecker::OperationFilter op_filter;
op_filter.allow_resource_ops_in_called_functions = true;
op_filter.allow_stack_ops = true;
op_filter.allow_tensor_array_ops = true;
op_filter.allow_stateful_rng_ops = true;
op_filter.allow_control_trigger = true;
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
op_filter.allow_ops_producing_or_consuming_variant = true;
op_filter.allow_slow_ops = true;
op_filter.allow_inaccurate_ops = true;
RecursiveCompilabilityChecker checker{
op_filter, DeviceType{registration->compilation_device_name}};
if (!uncompilable_node_info) {
// We do not need uncompilable node info. Just return the result.
return checker.IsCompilableCall(ndef, flr);
}
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
checker.FindUncompilableNodes(ndef, flr);
uncompilable_node_info->swap(uncompilable_node_result);
return uncompilable_node_info->empty();
}
Status MarkForCompilationPass::Run( Status MarkForCompilationPass::Run(
const GraphOptimizationPassOptions& options) { const GraphOptimizationPassOptions& options) {
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();

View File

@ -50,14 +50,6 @@ class MarkForCompilationPass : public GraphOptimizationPass {
friend class MarkForCompilationPassTestHelper; friend class MarkForCompilationPassTestHelper;
}; };
// Returns true iff 'ndef' is a call to a function that is compilable. A
// function is compilable iff every operator in the function body is
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
RecursiveCompilabilityChecker::UncompilableNodesMap*
uncompilable_node_info = nullptr);
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable(); absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
namespace testing { namespace testing {

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
@ -32,6 +31,44 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// Returns true iff 'ndef' is a call to a function that is compilable. A
// function is compilable iff every operator in the function body is
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
RecursiveCompilabilityChecker::UncompilableNodesMap*
uncompilable_node_info) {
Device* device = flr->device();
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
&registration));
// We can always *compile* resource operations, stateful RNGs and dummy ops,
// even if we are sometimes unable to auto-cluster them.
RecursiveCompilabilityChecker::OperationFilter op_filter;
op_filter.allow_resource_ops_in_called_functions = true;
op_filter.allow_stack_ops = true;
op_filter.allow_tensor_array_ops = true;
op_filter.allow_stateful_rng_ops = true;
op_filter.allow_control_trigger = true;
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
op_filter.allow_ops_producing_or_consuming_variant = true;
op_filter.allow_slow_ops = true;
op_filter.allow_inaccurate_ops = true;
RecursiveCompilabilityChecker checker{
op_filter, DeviceType{registration->compilation_device_name}};
if (!uncompilable_node_info) {
// We do not need uncompilable node info. Just return the result.
return checker.IsCompilableCall(ndef, flr);
}
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
checker.FindUncompilableNodes(ndef, flr);
uncompilable_node_info->swap(uncompilable_node_result);
return uncompilable_node_info->empty();
}
bool XlaKernelCreator::CanCreateKernel( bool XlaKernelCreator::CanCreateKernel(
const FunctionLibraryRuntime& flr, const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const { const std::shared_ptr<const NodeProperties>& props) const {