Allow default TF/XLA op registration with specific backend overrides.

PiperOrigin-RevId: 201252399
This commit is contained in:
Tony Wang 2018-06-19 15:29:38 -07:00 committed by TensorFlower Gardener
parent 10091aa9a9
commit bbba4e06e9
4 changed files with 238 additions and 100 deletions

View File

@ -489,3 +489,13 @@ cc_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
], ],
) )
tf_cc_test(
name = "xla_op_registry_test",
srcs = ["xla_op_registry_test.cc"],
deps = [
":xla_compiler",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -71,11 +71,12 @@ XlaOpRegistry::~XlaOpRegistry() = default;
<< " have incompatible allow_resource_types settings."; << " have incompatible allow_resource_types settings.";
return false; return false;
} }
if (!x.has_device_whitelist || !y.has_device_whitelist) { if (!x.has_device_whitelist && !y.has_device_whitelist) {
LOG(WARNING) << "Registrations of " << x.name LOG(WARNING) << "Duplicate registrations of " << x.name
<< " do not both have device whitelists."; << "with no device whitelists.";
return false; return false;
} }
if (x.has_device_whitelist && y.has_device_whitelist) {
for (const auto& device : x.device_whitelist) { for (const auto& device : x.device_whitelist) {
if (y.device_whitelist.count(device) != 0) { if (y.device_whitelist.count(device) != 0) {
LOG(WARNING) << "Multiple registrations of " << x.name << " on device " LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
@ -83,6 +84,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
return false; return false;
} }
} }
}
if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
LOG(WARNING) << "Registrations of " << x.name LOG(WARNING) << "Registrations of " << x.name
<< " have incompatible compile time constant inputs."; << " have incompatible compile time constant inputs.";
@ -157,15 +159,43 @@ void XlaOpRegistry::RegisterCompilationKernels() {
registry.jit_kernels_registered_ = true; registry.jit_kernels_registered_ = true;
OpRegistryInterface* op_registry = OpRegistry::Global(); OpRegistryInterface* op_registry = OpRegistry::Global();
for (const auto& op : registry.ops_) { // Order of op registration:
const string& op_name = op.first; // The goal is to allow the co-existence of backend-specific kernels and
const std::unique_ptr<OpRegistration>& op_registration = op.second; // generic kernels. To achieve this, we enforce the following order of
// registrations for one op:
// 1. Process op registration with device whitelists:
// this pass registers backend-specific kernels for this op.
// 2. Process op registration without device whitelists:
// this pass registers the kernels for all the other supported backends.
for (auto& ops : registry.ops_) {
const string& op_name = ops.first;
std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
// Partition the op registration so that the ones with device whitelists
// precede the one without device whitelist.
std::partition(op_registrations.begin(), op_registrations.end(),
[](const std::unique_ptr<OpRegistration>& op_reg) {
return op_reg->has_device_whitelist;
});
// Collect a set of backend registered by ops with device whitelists.
// The op registration without whitelists will register a generic kernel
// for all other backends not in this set.
std::unordered_set<string> whitelisted_backend;
for (auto& op_registration : op_registrations) {
if (op_registration->has_device_whitelist) {
whitelisted_backend.insert(op_registration->device_whitelist.begin(),
op_registration->device_whitelist.end());
}
}
for (auto& op_registration : op_registrations) {
const OpDef* op_def; const OpDef* op_def;
Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
if (!lookup_status.ok()) { if (!lookup_status.ok()) {
LOG(ERROR) << lookup_status.error_message(); LOG(ERROR) << lookup_status.error_message();
XLA_LOG_LINES( XLA_LOG_LINES(
ERROR, "Ops registered: \n" + ERROR,
"Ops registered: \n" +
dynamic_cast<OpRegistry*>(op_registry)->DebugString(true)); dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
} }
TF_CHECK_OK(lookup_status); TF_CHECK_OK(lookup_status);
@ -194,6 +224,14 @@ void XlaOpRegistry::RegisterCompilationKernels() {
continue; continue;
} }
// If the operator does NOT has a device whitelist, skip all devices
// that has already been registered.
if (!op_registration->has_device_whitelist &&
whitelisted_backend.find(backend.first) !=
whitelisted_backend.end()) {
continue;
}
std::unique_ptr<KernelDef> kdef(new KernelDef); std::unique_ptr<KernelDef> kdef(new KernelDef);
kdef->set_op(op_registration->name); kdef->set_op(op_registration->name);
kdef->set_device_type(backend.first); kdef->set_device_type(backend.first);
@ -213,7 +251,8 @@ void XlaOpRegistry::RegisterCompilationKernels() {
op_def_attr.has_allowed_values() op_def_attr.has_allowed_values()
? &op_def_attr.allowed_values().list().type() ? &op_def_attr.allowed_values().list().type()
: nullptr; : nullptr;
auto constraint_it = op_registration->type_constraints.find(type_attr); auto constraint_it =
op_registration->type_constraints.find(type_attr);
const std::set<DataType>* type_constraints = const std::set<DataType>* type_constraints =
constraint_it != op_registration->type_constraints.end() constraint_it != op_registration->type_constraints.end()
? &constraint_it->second ? &constraint_it->second
@ -251,6 +290,7 @@ void XlaOpRegistry::RegisterCompilationKernels() {
} }
} }
} }
}
std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels( std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
const string& compilation_device_name, const string& compilation_device_name,
@ -265,12 +305,12 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
<< "Unknown backend " << compilation_device_name; << "Unknown backend " << compilation_device_name;
for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) { for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) {
auto op_iter = registry.ops_.find(k->op()); auto op_iter = registry.ops_.find(k->op());
CHECK(op_iter != registry.ops_.end()); CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty());
// The test in IsCompatible ensures that if there are multiple matching // The test in IsCompatible ensures that if there are multiple matching
// registrations for this op name, they all have the same value of // registrations for this op name, they all have the same value of
// compilation_only, so only the first match needs to be tested. // compilation_only, so only the first match needs to be tested.
if (include_compilation_only_kernels || if (include_compilation_only_kernels ||
!op_iter->second->compilation_only) { !op_iter->second.front()->compilation_only) {
kernels.push_back(k.get()); kernels.push_back(k.get());
} }
} }
@ -282,10 +322,13 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
XlaOpRegistry& registry = Instance(); XlaOpRegistry& registry = Instance();
mutex_lock lock(registry.mutex_); mutex_lock lock(registry.mutex_);
auto it = registry.ops_.find(op); auto it = registry.ops_.find(op);
if (it == registry.ops_.end()) { if (it == registry.ops_.end() || it->second.empty()) {
return nullptr; return nullptr;
} }
return &it->second->compile_time_constant_inputs; // The test in IsCompatible ensures that if there are multiple matching
// registrations for this op name, they all have the same value of
// compile_time_constant_inputs, so only the first match is returned.
return &it->second.front()->compile_time_constant_inputs;
} }
std::vector<string> XlaOpRegistry::BackendNames() { std::vector<string> XlaOpRegistry::BackendNames() {
@ -378,16 +421,15 @@ XlaOpRegistrar::XlaOpRegistrar(
std::unique_ptr<XlaOpRegistry::OpRegistration> registration) { std::unique_ptr<XlaOpRegistry::OpRegistration> registration) {
XlaOpRegistry& registry = XlaOpRegistry::Instance(); XlaOpRegistry& registry = XlaOpRegistry::Instance();
mutex_lock lock(registry.mutex_); mutex_lock lock(registry.mutex_);
auto existing_ops = registry.ops_.equal_range(registration->name); auto& existing_ops = registry.ops_[registration->name];
for (auto existing = existing_ops.first; existing != existing_ops.second; for (auto& existing : existing_ops) {
++existing) { if (!XlaOpRegistry::IsCompatible(*existing, *registration)) {
if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) {
LOG(FATAL) LOG(FATAL)
<< "XLA op registration " << registration->name << "XLA op registration " << registration->name
<< " is incompatible with existing registration of the same name."; << " is incompatible with existing registration of the same name.";
} }
} }
registry.ops_.emplace(registration->name, std::move(registration)); existing_ops.emplace_back(std::move(registration));
} }
XlaBackendRegistrar::XlaBackendRegistrar( XlaBackendRegistrar::XlaBackendRegistrar(

View File

@ -203,7 +203,7 @@ class XlaOpRegistry {
// Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
// Registrations present under the same key must satisfy IsCompatible above, // Registrations present under the same key must satisfy IsCompatible above,
// and this is checked during registration. // and this is checked during registration.
std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_ std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_
GUARDED_BY(mutex_); GUARDED_BY(mutex_);
// Have we already registered the JIT kernels on the JIT devices? // Have we already registered the JIT kernels on the JIT devices?

View File

@ -0,0 +1,86 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
// This test is to verify the correctness of XLA op registration with specific
// backend overrides.
// A dummy backend-specific OpKernel for CPU.
class DummyCPUOp : public XlaOpKernel {
public:
explicit DummyCPUOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, ctx->Input(0));
}
};
// A dummy generic OpKernel for all backends.
class DummyGenericOp : public XlaOpKernel {
public:
explicit DummyGenericOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, ctx->Input(0));
}
};
REGISTER_OP("DummyDuplicateOp")
.Attr("T: {float, int32}")
.Input("input: int32")
.Output("output: int32")
.Doc(R"doc(
A dummy Op.
input: dummy input.
output: dummy output.
)doc");
// Register the DummyCPUOp kernel for CPU with type INT32.
REGISTER_XLA_OP(Name("DummyDuplicateOp")
.Device(DEVICE_CPU_XLA_JIT)
.TypeConstraint("T", DT_INT32),
DummyCPUOp);
// Register the DummyGeneric kernel for all registered device (except CPU since
// it is already registered), with type FLOAT.
REGISTER_XLA_OP(Name("DummyDuplicateOp").TypeConstraint("T", DT_FLOAT),
DummyGenericOp);
// Test the correctness of registered kernels. The kernel registered for CPU
// should have type INT32 while all other kernels should have type FLOAT.
TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) {
XlaOpRegistry::RegisterCompilationKernels();
auto registered_kernels = GetAllRegisteredKernels();
for (const auto& kernels : registered_kernels) {
if (kernels.op() == "DummyDuplicateOp") {
EXPECT_EQ(kernels.constraint_size(), 1);
EXPECT_EQ(kernels.constraint(0).name(), "T");
if (kernels.device_type() == "XLA_CPU_JIT") {
EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
DT_INT32);
} else {
EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
DT_FLOAT);
}
}
}
}
} // namespace
} // namespace tensorflow