Allow default TF/XLA op registration with specific backend overrides.

PiperOrigin-RevId: 201252399
This commit is contained in:
Tony Wang 2018-06-19 15:29:38 -07:00 committed by TensorFlower Gardener
parent 10091aa9a9
commit bbba4e06e9
4 changed files with 238 additions and 100 deletions

View File

@ -489,3 +489,13 @@ cc_library(
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "xla_op_registry_test",
srcs = ["xla_op_registry_test.cc"],
deps = [
":xla_compiler",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -71,16 +71,18 @@ XlaOpRegistry::~XlaOpRegistry() = default;
<< " have incompatible allow_resource_types settings.";
return false;
}
if (!x.has_device_whitelist || !y.has_device_whitelist) {
LOG(WARNING) << "Registrations of " << x.name
<< " do not both have device whitelists.";
if (!x.has_device_whitelist && !y.has_device_whitelist) {
LOG(WARNING) << "Duplicate registrations of " << x.name
<< "with no device whitelists.";
return false;
}
for (const auto& device : x.device_whitelist) {
if (y.device_whitelist.count(device) != 0) {
LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
<< device;
return false;
if (x.has_device_whitelist && y.has_device_whitelist) {
for (const auto& device : x.device_whitelist) {
if (y.device_whitelist.count(device) != 0) {
LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
<< device;
return false;
}
}
}
if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
@ -157,97 +159,135 @@ void XlaOpRegistry::RegisterCompilationKernels() {
registry.jit_kernels_registered_ = true;
OpRegistryInterface* op_registry = OpRegistry::Global();
for (const auto& op : registry.ops_) {
const string& op_name = op.first;
const std::unique_ptr<OpRegistration>& op_registration = op.second;
const OpDef* op_def;
Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
if (!lookup_status.ok()) {
LOG(ERROR) << lookup_status.error_message();
XLA_LOG_LINES(
ERROR, "Ops registered: \n" +
dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
}
TF_CHECK_OK(lookup_status);
// Order of op registration:
// The goal is to allow the co-existence of backend-specific kernels and
// generic kernels. To achieve this, we enforce the following order of
// registrations for one op:
// 1. Process op registration with device whitelists:
// this pass registers backend-specific kernels for this op.
// 2. Process op registration without device whitelists:
// this pass registers the kernels for all the other supported backends.
for (auto& ops : registry.ops_) {
const string& op_name = ops.first;
std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
// Partition the op registration so that the ones with device whitelists
// precede the one without device whitelist.
std::partition(op_registrations.begin(), op_registrations.end(),
[](const std::unique_ptr<OpRegistration>& op_reg) {
return op_reg->has_device_whitelist;
});
std::unordered_set<string> type_attrs;
for (const OpDef::AttrDef& attr_def : op_def->attr()) {
if (attr_def.type() == "type" || attr_def.type() == "list(type)") {
type_attrs.insert(attr_def.name());
// Collect a set of backend registered by ops with device whitelists.
// The op registration without whitelists will register a generic kernel
// for all other backends not in this set.
std::unordered_set<string> whitelisted_backend;
for (auto& op_registration : op_registrations) {
if (op_registration->has_device_whitelist) {
whitelisted_backend.insert(op_registration->device_whitelist.begin(),
op_registration->device_whitelist.end());
}
}
// Checks there are no type constraints referring to unknown attributes.
for (const auto& constraint : op_registration->type_constraints) {
if (type_attrs.find(constraint.first) == type_attrs.end()) {
LOG(FATAL) << "Unknown type attribute " << constraint.first
<< " in XLA op registration for " << op_name;
for (auto& op_registration : op_registrations) {
const OpDef* op_def;
Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
if (!lookup_status.ok()) {
LOG(ERROR) << lookup_status.error_message();
XLA_LOG_LINES(
ERROR,
"Ops registered: \n" +
dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
}
}
TF_CHECK_OK(lookup_status);
for (auto& backend : registry.backends_) {
// If the operator has a device whitelist, only register on whitelisted
// devices.
if (op_registration->has_device_whitelist &&
op_registration->device_whitelist.find(backend.first) ==
op_registration->device_whitelist.end()) {
continue;
}
std::unique_ptr<KernelDef> kdef(new KernelDef);
kdef->set_op(op_registration->name);
kdef->set_device_type(backend.first);
// Constrain each type attribute to the intersection of:
// a) the types supported by the backend, and
// b) the types allowed by the OpDef, and
// c) the type constraints.
for (const string& type_attr : type_attrs) {
KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
attr_constraint->set_name(type_attr);
auto* allowed_values =
attr_constraint->mutable_allowed_values()->mutable_list();
const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def);
const auto* op_def_allowed_types =
op_def_attr.has_allowed_values()
? &op_def_attr.allowed_values().list().type()
: nullptr;
auto constraint_it = op_registration->type_constraints.find(type_attr);
const std::set<DataType>* type_constraints =
constraint_it != op_registration->type_constraints.end()
? &constraint_it->second
: nullptr;
for (DataType dtype : backend.second.supported_types) {
// Filter out types that aren't allowed by the OpDef.
if (op_def_allowed_types != nullptr &&
std::find(op_def_allowed_types->begin(),
op_def_allowed_types->end(),
dtype) == op_def_allowed_types->end()) {
continue;
}
// Filter out types based on the type constraints.
if (type_constraints != nullptr &&
type_constraints->find(dtype) == type_constraints->end()) {
continue;
}
// Passed all the filters, this type is allowed.
allowed_values->add_type(dtype);
}
if (op_registration->allow_resource_types) {
allowed_values->add_type(DT_RESOURCE);
std::unordered_set<string> type_attrs;
for (const OpDef::AttrDef& attr_def : op_def->attr()) {
if (attr_def.type() == "type" || attr_def.type() == "list(type)") {
type_attrs.insert(attr_def.name());
}
}
if (backend.second.op_filter != nullptr &&
!backend.second.op_filter(kdef.get())) {
continue;
// Checks there are no type constraints referring to unknown attributes.
for (const auto& constraint : op_registration->type_constraints) {
if (type_attrs.find(constraint.first) == type_attrs.end()) {
LOG(FATAL) << "Unknown type attribute " << constraint.first
<< " in XLA op registration for " << op_name;
}
}
for (auto& backend : registry.backends_) {
// If the operator has a device whitelist, only register on whitelisted
// devices.
if (op_registration->has_device_whitelist &&
op_registration->device_whitelist.find(backend.first) ==
op_registration->device_whitelist.end()) {
continue;
}
// If the operator does NOT has a device whitelist, skip all devices
// that has already been registered.
if (!op_registration->has_device_whitelist &&
whitelisted_backend.find(backend.first) !=
whitelisted_backend.end()) {
continue;
}
std::unique_ptr<KernelDef> kdef(new KernelDef);
kdef->set_op(op_registration->name);
kdef->set_device_type(backend.first);
// Constrain each type attribute to the intersection of:
// a) the types supported by the backend, and
// b) the types allowed by the OpDef, and
// c) the type constraints.
for (const string& type_attr : type_attrs) {
KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
attr_constraint->set_name(type_attr);
auto* allowed_values =
attr_constraint->mutable_allowed_values()->mutable_list();
const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def);
const auto* op_def_allowed_types =
op_def_attr.has_allowed_values()
? &op_def_attr.allowed_values().list().type()
: nullptr;
auto constraint_it =
op_registration->type_constraints.find(type_attr);
const std::set<DataType>* type_constraints =
constraint_it != op_registration->type_constraints.end()
? &constraint_it->second
: nullptr;
for (DataType dtype : backend.second.supported_types) {
// Filter out types that aren't allowed by the OpDef.
if (op_def_allowed_types != nullptr &&
std::find(op_def_allowed_types->begin(),
op_def_allowed_types->end(),
dtype) == op_def_allowed_types->end()) {
continue;
}
// Filter out types based on the type constraints.
if (type_constraints != nullptr &&
type_constraints->find(dtype) == type_constraints->end()) {
continue;
}
// Passed all the filters, this type is allowed.
allowed_values->add_type(dtype);
}
if (op_registration->allow_resource_types) {
allowed_values->add_type(DT_RESOURCE);
}
}
if (backend.second.op_filter != nullptr &&
!backend.second.op_filter(kdef.get())) {
continue;
}
VLOG(2) << "XLA op registration: device: " << backend.first
<< " op: " << op_name;
registry.kernel_registrars_.emplace_back(
new kernel_factory::OpKernelRegistrar(
new KernelDef(*kdef), "XlaJitOp", op_registration->factory));
backend.second.kernel_defs.push_back(std::move(kdef));
}
VLOG(2) << "XLA op registration: device: " << backend.first
<< " op: " << op_name;
registry.kernel_registrars_.emplace_back(
new kernel_factory::OpKernelRegistrar(
new KernelDef(*kdef), "XlaJitOp", op_registration->factory));
backend.second.kernel_defs.push_back(std::move(kdef));
}
}
}
@ -265,12 +305,12 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
<< "Unknown backend " << compilation_device_name;
for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) {
auto op_iter = registry.ops_.find(k->op());
CHECK(op_iter != registry.ops_.end());
CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty());
// The test in IsCompatible ensures that if there are multiple matching
// registrations for this op name, they all have the same value of
// compilation_only, so only the first match needs to be tested.
if (include_compilation_only_kernels ||
!op_iter->second->compilation_only) {
!op_iter->second.front()->compilation_only) {
kernels.push_back(k.get());
}
}
@ -282,10 +322,13 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
XlaOpRegistry& registry = Instance();
mutex_lock lock(registry.mutex_);
auto it = registry.ops_.find(op);
if (it == registry.ops_.end()) {
if (it == registry.ops_.end() || it->second.empty()) {
return nullptr;
}
return &it->second->compile_time_constant_inputs;
// The test in IsCompatible ensures that if there are multiple matching
// registrations for this op name, they all have the same value of
// compile_time_constant_inputs, so only the first match is returned.
return &it->second.front()->compile_time_constant_inputs;
}
std::vector<string> XlaOpRegistry::BackendNames() {
@ -378,16 +421,15 @@ XlaOpRegistrar::XlaOpRegistrar(
std::unique_ptr<XlaOpRegistry::OpRegistration> registration) {
XlaOpRegistry& registry = XlaOpRegistry::Instance();
mutex_lock lock(registry.mutex_);
auto existing_ops = registry.ops_.equal_range(registration->name);
for (auto existing = existing_ops.first; existing != existing_ops.second;
++existing) {
if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) {
auto& existing_ops = registry.ops_[registration->name];
for (auto& existing : existing_ops) {
if (!XlaOpRegistry::IsCompatible(*existing, *registration)) {
LOG(FATAL)
<< "XLA op registration " << registration->name
<< " is incompatible with existing registration of the same name.";
}
}
registry.ops_.emplace(registration->name, std::move(registration));
existing_ops.emplace_back(std::move(registration));
}
XlaBackendRegistrar::XlaBackendRegistrar(

View File

@ -203,7 +203,7 @@ class XlaOpRegistry {
// Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
// Registrations present under the same key must satisfy IsCompatible above,
// and this is checked during registration.
std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_
std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_
GUARDED_BY(mutex_);
// Have we already registered the JIT kernels on the JIT devices?

View File

@ -0,0 +1,86 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
// This test is to verify the correctness of XLA op registration with specific
// backend overrides.
// A dummy backend-specific OpKernel for CPU.
class DummyCPUOp : public XlaOpKernel {
public:
explicit DummyCPUOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, ctx->Input(0));
}
};
// A dummy generic OpKernel for all backends.
class DummyGenericOp : public XlaOpKernel {
public:
explicit DummyGenericOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, ctx->Input(0));
}
};
REGISTER_OP("DummyDuplicateOp")
.Attr("T: {float, int32}")
.Input("input: int32")
.Output("output: int32")
.Doc(R"doc(
A dummy Op.
input: dummy input.
output: dummy output.
)doc");
// Register the DummyCPUOp kernel for CPU with type INT32.
REGISTER_XLA_OP(Name("DummyDuplicateOp")
.Device(DEVICE_CPU_XLA_JIT)
.TypeConstraint("T", DT_INT32),
DummyCPUOp);
// Register the DummyGeneric kernel for all registered device (except CPU since
// it is already registered), with type FLOAT.
REGISTER_XLA_OP(Name("DummyDuplicateOp").TypeConstraint("T", DT_FLOAT),
DummyGenericOp);
// Test the correctness of registered kernels. The kernel registered for CPU
// should have type INT32 while all other kernels should have type FLOAT.
TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) {
XlaOpRegistry::RegisterCompilationKernels();
auto registered_kernels = GetAllRegisteredKernels();
for (const auto& kernels : registered_kernels) {
if (kernels.op() == "DummyDuplicateOp") {
EXPECT_EQ(kernels.constraint_size(), 1);
EXPECT_EQ(kernels.constraint(0).name(), "T");
if (kernels.device_type() == "XLA_CPU_JIT") {
EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
DT_INT32);
} else {
EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
DT_FLOAT);
}
}
}
}
} // namespace
} // namespace tensorflow