Allow default TF/XLA op registration with specific backend overrides.
PiperOrigin-RevId: 201252399
This commit is contained in:
parent
10091aa9a9
commit
bbba4e06e9
@ -489,3 +489,13 @@ cc_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "xla_op_registry_test",
|
||||||
|
srcs = ["xla_op_registry_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":xla_compiler",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -71,11 +71,12 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
|||||||
<< " have incompatible allow_resource_types settings.";
|
<< " have incompatible allow_resource_types settings.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!x.has_device_whitelist || !y.has_device_whitelist) {
|
if (!x.has_device_whitelist && !y.has_device_whitelist) {
|
||||||
LOG(WARNING) << "Registrations of " << x.name
|
LOG(WARNING) << "Duplicate registrations of " << x.name
|
||||||
<< " do not both have device whitelists.";
|
<< "with no device whitelists.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (x.has_device_whitelist && y.has_device_whitelist) {
|
||||||
for (const auto& device : x.device_whitelist) {
|
for (const auto& device : x.device_whitelist) {
|
||||||
if (y.device_whitelist.count(device) != 0) {
|
if (y.device_whitelist.count(device) != 0) {
|
||||||
LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
|
LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
|
||||||
@ -83,6 +84,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
|
if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
|
||||||
LOG(WARNING) << "Registrations of " << x.name
|
LOG(WARNING) << "Registrations of " << x.name
|
||||||
<< " have incompatible compile time constant inputs.";
|
<< " have incompatible compile time constant inputs.";
|
||||||
@ -157,15 +159,43 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
|||||||
registry.jit_kernels_registered_ = true;
|
registry.jit_kernels_registered_ = true;
|
||||||
|
|
||||||
OpRegistryInterface* op_registry = OpRegistry::Global();
|
OpRegistryInterface* op_registry = OpRegistry::Global();
|
||||||
for (const auto& op : registry.ops_) {
|
// Order of op registration:
|
||||||
const string& op_name = op.first;
|
// The goal is to allow the co-existence of backend-specific kernels and
|
||||||
const std::unique_ptr<OpRegistration>& op_registration = op.second;
|
// generic kernels. To achieve this, we enforce the following order of
|
||||||
|
// registrations for one op:
|
||||||
|
// 1. Process op registration with device whitelists:
|
||||||
|
// this pass registers backend-specific kernels for this op.
|
||||||
|
// 2. Process op registration without device whitelists:
|
||||||
|
// this pass registers the kernels for all the other supported backends.
|
||||||
|
for (auto& ops : registry.ops_) {
|
||||||
|
const string& op_name = ops.first;
|
||||||
|
std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
|
||||||
|
// Partition the op registration so that the ones with device whitelists
|
||||||
|
// precede the one without device whitelist.
|
||||||
|
std::partition(op_registrations.begin(), op_registrations.end(),
|
||||||
|
[](const std::unique_ptr<OpRegistration>& op_reg) {
|
||||||
|
return op_reg->has_device_whitelist;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Collect a set of backend registered by ops with device whitelists.
|
||||||
|
// The op registration without whitelists will register a generic kernel
|
||||||
|
// for all other backends not in this set.
|
||||||
|
std::unordered_set<string> whitelisted_backend;
|
||||||
|
for (auto& op_registration : op_registrations) {
|
||||||
|
if (op_registration->has_device_whitelist) {
|
||||||
|
whitelisted_backend.insert(op_registration->device_whitelist.begin(),
|
||||||
|
op_registration->device_whitelist.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& op_registration : op_registrations) {
|
||||||
const OpDef* op_def;
|
const OpDef* op_def;
|
||||||
Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
|
Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
|
||||||
if (!lookup_status.ok()) {
|
if (!lookup_status.ok()) {
|
||||||
LOG(ERROR) << lookup_status.error_message();
|
LOG(ERROR) << lookup_status.error_message();
|
||||||
XLA_LOG_LINES(
|
XLA_LOG_LINES(
|
||||||
ERROR, "Ops registered: \n" +
|
ERROR,
|
||||||
|
"Ops registered: \n" +
|
||||||
dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
|
dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
|
||||||
}
|
}
|
||||||
TF_CHECK_OK(lookup_status);
|
TF_CHECK_OK(lookup_status);
|
||||||
@ -194,6 +224,14 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the operator does NOT has a device whitelist, skip all devices
|
||||||
|
// that has already been registered.
|
||||||
|
if (!op_registration->has_device_whitelist &&
|
||||||
|
whitelisted_backend.find(backend.first) !=
|
||||||
|
whitelisted_backend.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<KernelDef> kdef(new KernelDef);
|
std::unique_ptr<KernelDef> kdef(new KernelDef);
|
||||||
kdef->set_op(op_registration->name);
|
kdef->set_op(op_registration->name);
|
||||||
kdef->set_device_type(backend.first);
|
kdef->set_device_type(backend.first);
|
||||||
@ -213,7 +251,8 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
|||||||
op_def_attr.has_allowed_values()
|
op_def_attr.has_allowed_values()
|
||||||
? &op_def_attr.allowed_values().list().type()
|
? &op_def_attr.allowed_values().list().type()
|
||||||
: nullptr;
|
: nullptr;
|
||||||
auto constraint_it = op_registration->type_constraints.find(type_attr);
|
auto constraint_it =
|
||||||
|
op_registration->type_constraints.find(type_attr);
|
||||||
const std::set<DataType>* type_constraints =
|
const std::set<DataType>* type_constraints =
|
||||||
constraint_it != op_registration->type_constraints.end()
|
constraint_it != op_registration->type_constraints.end()
|
||||||
? &constraint_it->second
|
? &constraint_it->second
|
||||||
@ -251,6 +290,7 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
|
std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
|
||||||
const string& compilation_device_name,
|
const string& compilation_device_name,
|
||||||
@ -265,12 +305,12 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
|
|||||||
<< "Unknown backend " << compilation_device_name;
|
<< "Unknown backend " << compilation_device_name;
|
||||||
for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) {
|
for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) {
|
||||||
auto op_iter = registry.ops_.find(k->op());
|
auto op_iter = registry.ops_.find(k->op());
|
||||||
CHECK(op_iter != registry.ops_.end());
|
CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty());
|
||||||
// The test in IsCompatible ensures that if there are multiple matching
|
// The test in IsCompatible ensures that if there are multiple matching
|
||||||
// registrations for this op name, they all have the same value of
|
// registrations for this op name, they all have the same value of
|
||||||
// compilation_only, so only the first match needs to be tested.
|
// compilation_only, so only the first match needs to be tested.
|
||||||
if (include_compilation_only_kernels ||
|
if (include_compilation_only_kernels ||
|
||||||
!op_iter->second->compilation_only) {
|
!op_iter->second.front()->compilation_only) {
|
||||||
kernels.push_back(k.get());
|
kernels.push_back(k.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -282,10 +322,13 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
|
|||||||
XlaOpRegistry& registry = Instance();
|
XlaOpRegistry& registry = Instance();
|
||||||
mutex_lock lock(registry.mutex_);
|
mutex_lock lock(registry.mutex_);
|
||||||
auto it = registry.ops_.find(op);
|
auto it = registry.ops_.find(op);
|
||||||
if (it == registry.ops_.end()) {
|
if (it == registry.ops_.end() || it->second.empty()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return &it->second->compile_time_constant_inputs;
|
// The test in IsCompatible ensures that if there are multiple matching
|
||||||
|
// registrations for this op name, they all have the same value of
|
||||||
|
// compile_time_constant_inputs, so only the first match is returned.
|
||||||
|
return &it->second.front()->compile_time_constant_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> XlaOpRegistry::BackendNames() {
|
std::vector<string> XlaOpRegistry::BackendNames() {
|
||||||
@ -378,16 +421,15 @@ XlaOpRegistrar::XlaOpRegistrar(
|
|||||||
std::unique_ptr<XlaOpRegistry::OpRegistration> registration) {
|
std::unique_ptr<XlaOpRegistry::OpRegistration> registration) {
|
||||||
XlaOpRegistry& registry = XlaOpRegistry::Instance();
|
XlaOpRegistry& registry = XlaOpRegistry::Instance();
|
||||||
mutex_lock lock(registry.mutex_);
|
mutex_lock lock(registry.mutex_);
|
||||||
auto existing_ops = registry.ops_.equal_range(registration->name);
|
auto& existing_ops = registry.ops_[registration->name];
|
||||||
for (auto existing = existing_ops.first; existing != existing_ops.second;
|
for (auto& existing : existing_ops) {
|
||||||
++existing) {
|
if (!XlaOpRegistry::IsCompatible(*existing, *registration)) {
|
||||||
if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) {
|
|
||||||
LOG(FATAL)
|
LOG(FATAL)
|
||||||
<< "XLA op registration " << registration->name
|
<< "XLA op registration " << registration->name
|
||||||
<< " is incompatible with existing registration of the same name.";
|
<< " is incompatible with existing registration of the same name.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
registry.ops_.emplace(registration->name, std::move(registration));
|
existing_ops.emplace_back(std::move(registration));
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaBackendRegistrar::XlaBackendRegistrar(
|
XlaBackendRegistrar::XlaBackendRegistrar(
|
||||||
|
@ -203,7 +203,7 @@ class XlaOpRegistry {
|
|||||||
// Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
|
// Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
|
||||||
// Registrations present under the same key must satisfy IsCompatible above,
|
// Registrations present under the same key must satisfy IsCompatible above,
|
||||||
// and this is checked during registration.
|
// and this is checked during registration.
|
||||||
std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_
|
std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_
|
||||||
GUARDED_BY(mutex_);
|
GUARDED_BY(mutex_);
|
||||||
|
|
||||||
// Have we already registered the JIT kernels on the JIT devices?
|
// Have we already registered the JIT kernels on the JIT devices?
|
||||||
|
86
tensorflow/compiler/tf2xla/xla_op_registry_test.cc
Normal file
86
tensorflow/compiler/tf2xla/xla_op_registry_test.cc
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// This test is to verify the correctness of XLA op registration with specific
|
||||||
|
// backend overrides.
|
||||||
|
|
||||||
|
// A dummy backend-specific OpKernel for CPU.
|
||||||
|
class DummyCPUOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit DummyCPUOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
ctx->SetOutput(0, ctx->Input(0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// A dummy generic OpKernel for all backends.
|
||||||
|
class DummyGenericOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit DummyGenericOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
ctx->SetOutput(0, ctx->Input(0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_OP("DummyDuplicateOp")
|
||||||
|
.Attr("T: {float, int32}")
|
||||||
|
.Input("input: int32")
|
||||||
|
.Output("output: int32")
|
||||||
|
.Doc(R"doc(
|
||||||
|
A dummy Op.
|
||||||
|
|
||||||
|
input: dummy input.
|
||||||
|
output: dummy output.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
// Register the DummyCPUOp kernel for CPU with type INT32.
|
||||||
|
REGISTER_XLA_OP(Name("DummyDuplicateOp")
|
||||||
|
.Device(DEVICE_CPU_XLA_JIT)
|
||||||
|
.TypeConstraint("T", DT_INT32),
|
||||||
|
DummyCPUOp);
|
||||||
|
// Register the DummyGeneric kernel for all registered device (except CPU since
|
||||||
|
// it is already registered), with type FLOAT.
|
||||||
|
REGISTER_XLA_OP(Name("DummyDuplicateOp").TypeConstraint("T", DT_FLOAT),
|
||||||
|
DummyGenericOp);
|
||||||
|
|
||||||
|
// Test the correctness of registered kernels. The kernel registered for CPU
|
||||||
|
// should have type INT32 while all other kernels should have type FLOAT.
|
||||||
|
TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) {
|
||||||
|
XlaOpRegistry::RegisterCompilationKernels();
|
||||||
|
auto registered_kernels = GetAllRegisteredKernels();
|
||||||
|
for (const auto& kernels : registered_kernels) {
|
||||||
|
if (kernels.op() == "DummyDuplicateOp") {
|
||||||
|
EXPECT_EQ(kernels.constraint_size(), 1);
|
||||||
|
EXPECT_EQ(kernels.constraint(0).name(), "T");
|
||||||
|
if (kernels.device_type() == "XLA_CPU_JIT") {
|
||||||
|
EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
|
||||||
|
DT_INT32);
|
||||||
|
} else {
|
||||||
|
EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
|
||||||
|
DT_FLOAT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user