[XLA:Python] Add support for non-CPU CustomCalls.

PiperOrigin-RevId: 261113049
This commit is contained in:
Peter Hawkins 2019-08-01 06:48:58 -07:00 committed by TensorFlower Gardener
parent 2ece719ef2
commit 48af54a586
4 changed files with 37 additions and 18 deletions

View File

@ -15,7 +15,7 @@ cdef void test_subtract_f32(void* out_ptr, void** data_ptr) nogil:
cpu_custom_call_targets = {} cpu_custom_call_targets = {}
cdef register_custom_call_target(fn_name, void* fn): cdef register_custom_call_target(fn_name, void* fn):
cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET" cdef const char* name = "xla._CUSTOM_CALL_TARGET"
cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL) cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32)) register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32))

View File

@ -110,18 +110,23 @@ StatusOr<std::string> GetComputationHloDotGraph(
} }
// Registers a 'fn_capsule' as a CPU custom call target. // Registers a 'fn_capsule' as a CPU custom call target.
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
// "xla._CPU_CUSTOM_CALL_TARGET". // with name "xla._CUSTOM_CALL_TARGET".
Status RegisterCpuCustomCallTarget(const std::string& fn_name, // 'platform' is an XLA platform name, e.g., "Host" or "CUDA".
py::capsule capsule) { Status PyRegisterCustomCallTarget(const std::string& fn_name,
static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET"; py::capsule capsule,
if (absl::string_view(capsule.name()) != kName) { const std::string& platform) {
static const char* const kName = "xla._CUSTOM_CALL_TARGET";
// TODO(phawkins): remove old name after fixing users.
static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET";
if (absl::string_view(capsule.name()) != kName &&
absl::string_view(capsule.name()) != kOldCpuName) {
return InvalidArgument( return InvalidArgument(
"Argument to RegisterCpuCustomCallTargetRegistry was not a " "Argument to RegisterCustomCallTargetRegistry was not a "
"xla._CPU_CUSTOM_CALL_TARGET capsule."); "xla._CUSTOM_CALL_TARGET capsule.");
} }
CustomCallTargetRegistry::Global()->Register( CustomCallTargetRegistry::Global()->Register(
fn_name, static_cast<void*>(capsule), "Host"); fn_name, static_cast<void*>(capsule), platform);
return Status::OK(); return Status::OK();
} }
@ -295,8 +300,8 @@ PYBIND11_MODULE(xla_extension, m) {
// Local XLA client methods. // Local XLA client methods.
// CPU custom-call targets. // Custom-call targets.
m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget); m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget);
py::class_<AllocatorConfig> alloc_config(m, "AllocatorConfig"); py::class_<AllocatorConfig> alloc_config(m, "AllocatorConfig");
alloc_config.def(py::init<>()) alloc_config.def(py::init<>())

View File

@ -116,9 +116,17 @@ class LocalBackend(Backend):
compile_options.device_assignment) compile_options.device_assignment)
xla_platform_names = {
'cpu': 'Host',
'gpu': 'CUDA',
}
def _cpu_backend_factory(): def _cpu_backend_factory():
client = _xla.LocalClient.Get( client = _xla.LocalClient.Get(
platform='cpu', xla_platform_id='Host', asynchronous=True) platform='cpu',
xla_platform_id=xla_platform_names['cpu'],
asynchronous=True)
return LocalBackend(platform='cpu', client=client) return LocalBackend(platform='cpu', client=client)
@ -143,7 +151,9 @@ def _gpu_backend_factory():
config.preallocate = preallocate not in ('0', 'false', 'False') config.preallocate = preallocate not in ('0', 'false', 'False')
client = _xla.LocalClient.Get( client = _xla.LocalClient.Get(
platform='gpu', xla_platform_id='CUDA', asynchronous=True, platform='gpu',
xla_platform_id=xla_platform_names['gpu'],
asynchronous=True,
allocator_config=config) allocator_config=config)
return LocalBackend(platform='gpu', client=client) return LocalBackend(platform='gpu', client=client)
@ -1596,14 +1606,18 @@ def _forward_methods_to_local_builder():
_forward_methods_to_local_builder() _forward_methods_to_local_builder()
def register_cpu_custom_call_target(name, fn): def register_custom_call_target(name, fn, platform='cpu'):
"""Registers a CPU custom call target. """Registers a custom call target.
Args: Args:
name: bytes containing the name of the function. name: bytes containing the name of the function.
fn: a PyCapsule object containing the function pointer. fn: a PyCapsule object containing the function pointer.
platform: the target platform.
""" """
_xla.RegisterCpuCustomCallTarget(name, fn) _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
# Deprecated. Use register_custom_call_target instead.
register_cpu_custom_call_target = register_custom_call_target
class PaddingConfigDimension(object): class PaddingConfigDimension(object):

View File

@ -311,7 +311,7 @@ class ComputationsWithConstantsTest(ComputationTest):
def testCustomCall(self): def testCustomCall(self):
c = self._NewComputation() c = self._NewComputation()
for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
xla_client.register_cpu_custom_call_target(name, fn) xla_client.register_custom_call_target(name, fn, platform="cpu")
c.CustomCall( c.CustomCall(
b"test_subtract_f32", b"test_subtract_f32",
operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),