Allow default TF/XLA op registration with specific backend overrides.
PiperOrigin-RevId: 201252399
This commit is contained in:
parent
10091aa9a9
commit
bbba4e06e9
@ -489,3 +489,13 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_op_registry_test",
|
||||
srcs = ["xla_op_registry_test.cc"],
|
||||
deps = [
|
||||
":xla_compiler",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
@ -71,11 +71,12 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
<< " have incompatible allow_resource_types settings.";
|
||||
return false;
|
||||
}
|
||||
if (!x.has_device_whitelist || !y.has_device_whitelist) {
|
||||
LOG(WARNING) << "Registrations of " << x.name
|
||||
<< " do not both have device whitelists.";
|
||||
if (!x.has_device_whitelist && !y.has_device_whitelist) {
|
||||
LOG(WARNING) << "Duplicate registrations of " << x.name
|
||||
<< "with no device whitelists.";
|
||||
return false;
|
||||
}
|
||||
if (x.has_device_whitelist && y.has_device_whitelist) {
|
||||
for (const auto& device : x.device_whitelist) {
|
||||
if (y.device_whitelist.count(device) != 0) {
|
||||
LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
|
||||
@ -83,6 +84,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
|
||||
LOG(WARNING) << "Registrations of " << x.name
|
||||
<< " have incompatible compile time constant inputs.";
|
||||
@ -157,15 +159,43 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
||||
registry.jit_kernels_registered_ = true;
|
||||
|
||||
OpRegistryInterface* op_registry = OpRegistry::Global();
|
||||
for (const auto& op : registry.ops_) {
|
||||
const string& op_name = op.first;
|
||||
const std::unique_ptr<OpRegistration>& op_registration = op.second;
|
||||
// Order of op registration:
|
||||
// The goal is to allow the co-existence of backend-specific kernels and
|
||||
// generic kernels. To achieve this, we enforce the following order of
|
||||
// registrations for one op:
|
||||
// 1. Process op registration with device whitelists:
|
||||
// this pass registers backend-specific kernels for this op.
|
||||
// 2. Process op registration without device whitelists:
|
||||
// this pass registers the kernels for all the other supported backends.
|
||||
for (auto& ops : registry.ops_) {
|
||||
const string& op_name = ops.first;
|
||||
std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
|
||||
// Partition the op registration so that the ones with device whitelists
|
||||
// precede the one without device whitelist.
|
||||
std::partition(op_registrations.begin(), op_registrations.end(),
|
||||
[](const std::unique_ptr<OpRegistration>& op_reg) {
|
||||
return op_reg->has_device_whitelist;
|
||||
});
|
||||
|
||||
// Collect a set of backend registered by ops with device whitelists.
|
||||
// The op registration without whitelists will register a generic kernel
|
||||
// for all other backends not in this set.
|
||||
std::unordered_set<string> whitelisted_backend;
|
||||
for (auto& op_registration : op_registrations) {
|
||||
if (op_registration->has_device_whitelist) {
|
||||
whitelisted_backend.insert(op_registration->device_whitelist.begin(),
|
||||
op_registration->device_whitelist.end());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& op_registration : op_registrations) {
|
||||
const OpDef* op_def;
|
||||
Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
|
||||
if (!lookup_status.ok()) {
|
||||
LOG(ERROR) << lookup_status.error_message();
|
||||
XLA_LOG_LINES(
|
||||
ERROR, "Ops registered: \n" +
|
||||
ERROR,
|
||||
"Ops registered: \n" +
|
||||
dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
|
||||
}
|
||||
TF_CHECK_OK(lookup_status);
|
||||
@ -194,6 +224,14 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the operator does NOT has a device whitelist, skip all devices
|
||||
// that has already been registered.
|
||||
if (!op_registration->has_device_whitelist &&
|
||||
whitelisted_backend.find(backend.first) !=
|
||||
whitelisted_backend.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<KernelDef> kdef(new KernelDef);
|
||||
kdef->set_op(op_registration->name);
|
||||
kdef->set_device_type(backend.first);
|
||||
@ -213,7 +251,8 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
||||
op_def_attr.has_allowed_values()
|
||||
? &op_def_attr.allowed_values().list().type()
|
||||
: nullptr;
|
||||
auto constraint_it = op_registration->type_constraints.find(type_attr);
|
||||
auto constraint_it =
|
||||
op_registration->type_constraints.find(type_attr);
|
||||
const std::set<DataType>* type_constraints =
|
||||
constraint_it != op_registration->type_constraints.end()
|
||||
? &constraint_it->second
|
||||
@ -251,6 +290,7 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
|
||||
const string& compilation_device_name,
|
||||
@ -265,12 +305,12 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
|
||||
<< "Unknown backend " << compilation_device_name;
|
||||
for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) {
|
||||
auto op_iter = registry.ops_.find(k->op());
|
||||
CHECK(op_iter != registry.ops_.end());
|
||||
CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty());
|
||||
// The test in IsCompatible ensures that if there are multiple matching
|
||||
// registrations for this op name, they all have the same value of
|
||||
// compilation_only, so only the first match needs to be tested.
|
||||
if (include_compilation_only_kernels ||
|
||||
!op_iter->second->compilation_only) {
|
||||
!op_iter->second.front()->compilation_only) {
|
||||
kernels.push_back(k.get());
|
||||
}
|
||||
}
|
||||
@ -282,10 +322,13 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
|
||||
XlaOpRegistry& registry = Instance();
|
||||
mutex_lock lock(registry.mutex_);
|
||||
auto it = registry.ops_.find(op);
|
||||
if (it == registry.ops_.end()) {
|
||||
if (it == registry.ops_.end() || it->second.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &it->second->compile_time_constant_inputs;
|
||||
// The test in IsCompatible ensures that if there are multiple matching
|
||||
// registrations for this op name, they all have the same value of
|
||||
// compile_time_constant_inputs, so only the first match is returned.
|
||||
return &it->second.front()->compile_time_constant_inputs;
|
||||
}
|
||||
|
||||
std::vector<string> XlaOpRegistry::BackendNames() {
|
||||
@ -378,16 +421,15 @@ XlaOpRegistrar::XlaOpRegistrar(
|
||||
std::unique_ptr<XlaOpRegistry::OpRegistration> registration) {
|
||||
XlaOpRegistry& registry = XlaOpRegistry::Instance();
|
||||
mutex_lock lock(registry.mutex_);
|
||||
auto existing_ops = registry.ops_.equal_range(registration->name);
|
||||
for (auto existing = existing_ops.first; existing != existing_ops.second;
|
||||
++existing) {
|
||||
if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) {
|
||||
auto& existing_ops = registry.ops_[registration->name];
|
||||
for (auto& existing : existing_ops) {
|
||||
if (!XlaOpRegistry::IsCompatible(*existing, *registration)) {
|
||||
LOG(FATAL)
|
||||
<< "XLA op registration " << registration->name
|
||||
<< " is incompatible with existing registration of the same name.";
|
||||
}
|
||||
}
|
||||
registry.ops_.emplace(registration->name, std::move(registration));
|
||||
existing_ops.emplace_back(std::move(registration));
|
||||
}
|
||||
|
||||
XlaBackendRegistrar::XlaBackendRegistrar(
|
||||
|
@ -203,7 +203,7 @@ class XlaOpRegistry {
|
||||
// Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
|
||||
// Registrations present under the same key must satisfy IsCompatible above,
|
||||
// and this is checked during registration.
|
||||
std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_
|
||||
std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_
|
||||
GUARDED_BY(mutex_);
|
||||
|
||||
// Have we already registered the JIT kernels on the JIT devices?
|
||||
|
86
tensorflow/compiler/tf2xla/xla_op_registry_test.cc
Normal file
86
tensorflow/compiler/tf2xla/xla_op_registry_test.cc
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// This test is to verify the correctness of XLA op registration with specific
|
||||
// backend overrides.
|
||||
|
||||
// A dummy backend-specific OpKernel for CPU.
|
||||
class DummyCPUOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit DummyCPUOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
ctx->SetOutput(0, ctx->Input(0));
|
||||
}
|
||||
};
|
||||
|
||||
// A dummy generic OpKernel for all backends.
|
||||
class DummyGenericOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit DummyGenericOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
ctx->SetOutput(0, ctx->Input(0));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_OP("DummyDuplicateOp")
|
||||
.Attr("T: {float, int32}")
|
||||
.Input("input: int32")
|
||||
.Output("output: int32")
|
||||
.Doc(R"doc(
|
||||
A dummy Op.
|
||||
|
||||
input: dummy input.
|
||||
output: dummy output.
|
||||
)doc");
|
||||
|
||||
// Register the DummyCPUOp kernel for CPU with type INT32.
|
||||
REGISTER_XLA_OP(Name("DummyDuplicateOp")
|
||||
.Device(DEVICE_CPU_XLA_JIT)
|
||||
.TypeConstraint("T", DT_INT32),
|
||||
DummyCPUOp);
|
||||
// Register the DummyGeneric kernel for all registered device (except CPU since
|
||||
// it is already registered), with type FLOAT.
|
||||
REGISTER_XLA_OP(Name("DummyDuplicateOp").TypeConstraint("T", DT_FLOAT),
|
||||
DummyGenericOp);
|
||||
|
||||
// Test the correctness of registered kernels. The kernel registered for CPU
|
||||
// should have type INT32 while all other kernels should have type FLOAT.
|
||||
TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) {
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
auto registered_kernels = GetAllRegisteredKernels();
|
||||
for (const auto& kernels : registered_kernels) {
|
||||
if (kernels.op() == "DummyDuplicateOp") {
|
||||
EXPECT_EQ(kernels.constraint_size(), 1);
|
||||
EXPECT_EQ(kernels.constraint(0).name(), "T");
|
||||
if (kernels.device_type() == "XLA_CPU_JIT") {
|
||||
EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
|
||||
DT_INT32);
|
||||
} else {
|
||||
EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
|
||||
DT_FLOAT);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user