[XLA:Python] Add support for non-CPU CustomCalls.

PiperOrigin-RevId: 261113049
This commit is contained in:
Peter Hawkins 2019-08-01 06:48:58 -07:00 committed by TensorFlower Gardener
parent 2ece719ef2
commit 48af54a586
4 changed files with 37 additions and 18 deletions

View File

@ -15,7 +15,7 @@ cdef void test_subtract_f32(void* out_ptr, void** data_ptr) nogil:
cpu_custom_call_targets = {}
cdef register_custom_call_target(fn_name, void* fn):
cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET"
cdef const char* name = "xla._CUSTOM_CALL_TARGET"
cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32))

View File

@ -110,18 +110,23 @@ StatusOr<std::string> GetComputationHloDotGraph(
}
// Registers a 'fn_capsule' as a CPU custom call target.
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
// "xla._CPU_CUSTOM_CALL_TARGET".
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
py::capsule capsule) {
static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET";
if (absl::string_view(capsule.name()) != kName) {
// 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
// with name "xla._CUSTOM_CALL_TARGET".
// 'platform' is an XLA platform name, e.g., "Host" or "CUDA".
Status PyRegisterCustomCallTarget(const std::string& fn_name,
py::capsule capsule,
const std::string& platform) {
static const char* const kName = "xla._CUSTOM_CALL_TARGET";
// TODO(phawkins): remove old name after fixing users.
static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET";
if (absl::string_view(capsule.name()) != kName &&
absl::string_view(capsule.name()) != kOldCpuName) {
return InvalidArgument(
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
"Argument to RegisterCustomCallTargetRegistry was not a "
"xla._CUSTOM_CALL_TARGET capsule.");
}
CustomCallTargetRegistry::Global()->Register(
fn_name, static_cast<void*>(capsule), "Host");
fn_name, static_cast<void*>(capsule), platform);
return Status::OK();
}
@ -295,8 +300,8 @@ PYBIND11_MODULE(xla_extension, m) {
// Local XLA client methods.
// CPU custom-call targets.
m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget);
// Custom-call targets.
m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget);
py::class_<AllocatorConfig> alloc_config(m, "AllocatorConfig");
alloc_config.def(py::init<>())

View File

@ -116,9 +116,17 @@ class LocalBackend(Backend):
compile_options.device_assignment)
xla_platform_names = {
'cpu': 'Host',
'gpu': 'CUDA',
}
def _cpu_backend_factory():
client = _xla.LocalClient.Get(
platform='cpu', xla_platform_id='Host', asynchronous=True)
platform='cpu',
xla_platform_id=xla_platform_names['cpu'],
asynchronous=True)
return LocalBackend(platform='cpu', client=client)
@ -143,7 +151,9 @@ def _gpu_backend_factory():
config.preallocate = preallocate not in ('0', 'false', 'False')
client = _xla.LocalClient.Get(
platform='gpu', xla_platform_id='CUDA', asynchronous=True,
platform='gpu',
xla_platform_id=xla_platform_names['gpu'],
asynchronous=True,
allocator_config=config)
return LocalBackend(platform='gpu', client=client)
@ -1596,14 +1606,18 @@ def _forward_methods_to_local_builder():
_forward_methods_to_local_builder()
def register_cpu_custom_call_target(name, fn):
"""Registers a CPU custom call target.
def register_custom_call_target(name, fn, platform='cpu'):
"""Registers a custom call target.
Args:
name: bytes containing the name of the function.
fn: a PyCapsule object containing the function pointer.
platform: the target platform.
"""
_xla.RegisterCpuCustomCallTarget(name, fn)
_xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
# Deprecated. Use register_custom_call_target instead.
register_cpu_custom_call_target = register_custom_call_target
class PaddingConfigDimension(object):

View File

@ -311,7 +311,7 @@ class ComputationsWithConstantsTest(ComputationTest):
def testCustomCall(self):
c = self._NewComputation()
for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
xla_client.register_cpu_custom_call_target(name, fn)
xla_client.register_custom_call_target(name, fn, platform="cpu")
c.CustomCall(
b"test_subtract_f32",
operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),