[NFC] Use compile_all_resource_ops instead of allow_resource_ops for clarity
Firstly, rename compile_resource_ops to compile_all_resource_ops to emphasize that the device registration wants us to compile all kinds of resource operations, not just resource variable ops. Secondly using op_filter.allow_resource_ops was semantically incorrect; its purpose is to disallow clustering functional while nodes with resource variable operations, while the condition that is being changed only cares about whether we're compiling for XLA_* devices. It just so happens that op_filter.allow_resource_ops and registration->compile_all_resource_ops have the same value. PiperOrigin-RevId: 238348078
This commit is contained in:
parent
d82080a042
commit
7b50beb0b5
@ -507,7 +507,7 @@ Status FindCompilationCandidates(
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
|
||||
OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops = registration->compile_resource_ops;
|
||||
op_filter.allow_resource_ops = registration->compile_all_resource_ops;
|
||||
op_filter.allow_stateful_rng_ops = always_auto_cluster;
|
||||
op_filter.allow_control_trigger = always_auto_cluster;
|
||||
op_filter.allow_dummy_ops = always_auto_cluster;
|
||||
@ -542,7 +542,7 @@ Status FindCompilationCandidates(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!op_filter.allow_resource_ops &&
|
||||
if (!registration->compile_all_resource_ops &&
|
||||
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
|
||||
// We don't have a way of returning values of type DT_RESOURCE from XLA
|
||||
// computations so we avoid auto-clustering nodes producing DT_RESOURCE.
|
||||
@ -608,8 +608,8 @@ Status FindCompilationCandidates(
|
||||
}
|
||||
// We don't auto-cluster functional control flow nodes containing resource
|
||||
// operations because safety checks are trickier in this case.
|
||||
// registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not
|
||||
// for CPU/GPU.
|
||||
// registration->compile_all_resource_ops is true for XLA_CPU/XLA_GPU but
|
||||
// not for CPU/GPU.
|
||||
if (node->type_string() == "While" &&
|
||||
!IsCompilableWhile(*node, jit_device_type, op_filter, 0, lib_runtime)) {
|
||||
continue;
|
||||
@ -936,7 +936,7 @@ static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) {
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
*ignore = true;
|
||||
} else {
|
||||
*ignore = registration->compile_resource_ops;
|
||||
*ignore = registration->compile_all_resource_ops;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
||||
compile_on_demand
|
||||
? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested
|
||||
: XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_resource_ops = true;
|
||||
registration.compile_all_resource_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -66,7 +66,7 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_resource_ops = true;
|
||||
registration.compile_all_resource_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -47,7 +47,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_resource_ops = true;
|
||||
registration.compile_all_resource_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
|
||||
registration);
|
||||
|
||||
|
@ -148,7 +148,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
cpu_global_jit
|
||||
? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally
|
||||
: XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
|
||||
registration.compile_resource_ops = false;
|
||||
registration.compile_all_resource_ops = false;
|
||||
}
|
||||
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
|
||||
DeviceRegistration& registration =
|
||||
@ -156,7 +156,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
|
||||
registration.compile_resource_ops = false;
|
||||
registration.compile_all_resource_ops = false;
|
||||
}
|
||||
return nullptr;
|
||||
}();
|
||||
|
@ -89,7 +89,7 @@ class XlaOpRegistry {
|
||||
AutoclusteringPolicy autoclustering_policy;
|
||||
|
||||
// Enable compilation of operators that use DT_RESOURCE types?
|
||||
bool compile_resource_ops = false;
|
||||
bool compile_all_resource_ops = false;
|
||||
};
|
||||
|
||||
// Registers an XLA backend. `compilation_device_name` is the name of the
|
||||
|
Loading…
Reference in New Issue
Block a user