[XLA:Python] Add support for non-CPU CustomCalls.
PiperOrigin-RevId: 261113049
This commit is contained in:
parent
2ece719ef2
commit
48af54a586
@ -15,7 +15,7 @@ cdef void test_subtract_f32(void* out_ptr, void** data_ptr) nogil:
|
||||
cpu_custom_call_targets = {}
|
||||
|
||||
cdef register_custom_call_target(fn_name, void* fn):
|
||||
cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET"
|
||||
cdef const char* name = "xla._CUSTOM_CALL_TARGET"
|
||||
cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
|
||||
|
||||
register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32))
|
||||
|
@ -110,18 +110,23 @@ StatusOr<std::string> GetComputationHloDotGraph(
|
||||
}
|
||||
|
||||
// Registers a 'fn_capsule' as a CPU custom call target.
|
||||
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
||||
// "xla._CPU_CUSTOM_CALL_TARGET".
|
||||
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
||||
py::capsule capsule) {
|
||||
static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET";
|
||||
if (absl::string_view(capsule.name()) != kName) {
|
||||
// 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
|
||||
// with name "xla._CUSTOM_CALL_TARGET".
|
||||
// 'platform' is an XLA platform name, e.g., "Host" or "CUDA".
|
||||
Status PyRegisterCustomCallTarget(const std::string& fn_name,
|
||||
py::capsule capsule,
|
||||
const std::string& platform) {
|
||||
static const char* const kName = "xla._CUSTOM_CALL_TARGET";
|
||||
// TODO(phawkins): remove old name after fixing users.
|
||||
static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET";
|
||||
if (absl::string_view(capsule.name()) != kName &&
|
||||
absl::string_view(capsule.name()) != kOldCpuName) {
|
||||
return InvalidArgument(
|
||||
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
|
||||
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
|
||||
"Argument to RegisterCustomCallTargetRegistry was not a "
|
||||
"xla._CUSTOM_CALL_TARGET capsule.");
|
||||
}
|
||||
CustomCallTargetRegistry::Global()->Register(
|
||||
fn_name, static_cast<void*>(capsule), "Host");
|
||||
fn_name, static_cast<void*>(capsule), platform);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -295,8 +300,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
|
||||
// Local XLA client methods.
|
||||
|
||||
// CPU custom-call targets.
|
||||
m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget);
|
||||
// Custom-call targets.
|
||||
m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget);
|
||||
|
||||
py::class_<AllocatorConfig> alloc_config(m, "AllocatorConfig");
|
||||
alloc_config.def(py::init<>())
|
||||
|
@ -116,9 +116,17 @@ class LocalBackend(Backend):
|
||||
compile_options.device_assignment)
|
||||
|
||||
|
||||
xla_platform_names = {
|
||||
'cpu': 'Host',
|
||||
'gpu': 'CUDA',
|
||||
}
|
||||
|
||||
|
||||
def _cpu_backend_factory():
|
||||
client = _xla.LocalClient.Get(
|
||||
platform='cpu', xla_platform_id='Host', asynchronous=True)
|
||||
platform='cpu',
|
||||
xla_platform_id=xla_platform_names['cpu'],
|
||||
asynchronous=True)
|
||||
return LocalBackend(platform='cpu', client=client)
|
||||
|
||||
|
||||
@ -143,7 +151,9 @@ def _gpu_backend_factory():
|
||||
config.preallocate = preallocate not in ('0', 'false', 'False')
|
||||
|
||||
client = _xla.LocalClient.Get(
|
||||
platform='gpu', xla_platform_id='CUDA', asynchronous=True,
|
||||
platform='gpu',
|
||||
xla_platform_id=xla_platform_names['gpu'],
|
||||
asynchronous=True,
|
||||
allocator_config=config)
|
||||
return LocalBackend(platform='gpu', client=client)
|
||||
|
||||
@ -1596,14 +1606,18 @@ def _forward_methods_to_local_builder():
|
||||
_forward_methods_to_local_builder()
|
||||
|
||||
|
||||
def register_cpu_custom_call_target(name, fn):
|
||||
"""Registers a CPU custom call target.
|
||||
def register_custom_call_target(name, fn, platform='cpu'):
|
||||
"""Registers a custom call target.
|
||||
|
||||
Args:
|
||||
name: bytes containing the name of the function.
|
||||
fn: a PyCapsule object containing the function pointer.
|
||||
platform: the target platform.
|
||||
"""
|
||||
_xla.RegisterCpuCustomCallTarget(name, fn)
|
||||
_xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
|
||||
|
||||
# Deprecated. Use register_custom_call_target instead.
|
||||
register_cpu_custom_call_target = register_custom_call_target
|
||||
|
||||
|
||||
class PaddingConfigDimension(object):
|
||||
|
@ -311,7 +311,7 @@ class ComputationsWithConstantsTest(ComputationTest):
|
||||
def testCustomCall(self):
|
||||
c = self._NewComputation()
|
||||
for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
|
||||
xla_client.register_cpu_custom_call_target(name, fn)
|
||||
xla_client.register_custom_call_target(name, fn, platform="cpu")
|
||||
c.CustomCall(
|
||||
b"test_subtract_f32",
|
||||
operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),
|
||||
|
Loading…
Reference in New Issue
Block a user