[XLA:Python] Add support for non-CPU CustomCalls.
PiperOrigin-RevId: 261113049
This commit is contained in:
parent
2ece719ef2
commit
48af54a586
tensorflow/compiler/xla/python
@ -15,7 +15,7 @@ cdef void test_subtract_f32(void* out_ptr, void** data_ptr) nogil:
|
|||||||
cpu_custom_call_targets = {}
|
cpu_custom_call_targets = {}
|
||||||
|
|
||||||
cdef register_custom_call_target(fn_name, void* fn):
|
cdef register_custom_call_target(fn_name, void* fn):
|
||||||
cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET"
|
cdef const char* name = "xla._CUSTOM_CALL_TARGET"
|
||||||
cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
|
cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
|
||||||
|
|
||||||
register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32))
|
register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32))
|
||||||
|
@ -110,18 +110,23 @@ StatusOr<std::string> GetComputationHloDotGraph(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Registers a 'fn_capsule' as a CPU custom call target.
|
// Registers a 'fn_capsule' as a CPU custom call target.
|
||||||
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
// 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
|
||||||
// "xla._CPU_CUSTOM_CALL_TARGET".
|
// with name "xla._CUSTOM_CALL_TARGET".
|
||||||
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
// 'platform' is an XLA platform name, e.g., "Host" or "CUDA".
|
||||||
py::capsule capsule) {
|
Status PyRegisterCustomCallTarget(const std::string& fn_name,
|
||||||
static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET";
|
py::capsule capsule,
|
||||||
if (absl::string_view(capsule.name()) != kName) {
|
const std::string& platform) {
|
||||||
|
static const char* const kName = "xla._CUSTOM_CALL_TARGET";
|
||||||
|
// TODO(phawkins): remove old name after fixing users.
|
||||||
|
static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET";
|
||||||
|
if (absl::string_view(capsule.name()) != kName &&
|
||||||
|
absl::string_view(capsule.name()) != kOldCpuName) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
|
"Argument to RegisterCustomCallTargetRegistry was not a "
|
||||||
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
|
"xla._CUSTOM_CALL_TARGET capsule.");
|
||||||
}
|
}
|
||||||
CustomCallTargetRegistry::Global()->Register(
|
CustomCallTargetRegistry::Global()->Register(
|
||||||
fn_name, static_cast<void*>(capsule), "Host");
|
fn_name, static_cast<void*>(capsule), platform);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,8 +300,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
|
|
||||||
// Local XLA client methods.
|
// Local XLA client methods.
|
||||||
|
|
||||||
// CPU custom-call targets.
|
// Custom-call targets.
|
||||||
m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget);
|
m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget);
|
||||||
|
|
||||||
py::class_<AllocatorConfig> alloc_config(m, "AllocatorConfig");
|
py::class_<AllocatorConfig> alloc_config(m, "AllocatorConfig");
|
||||||
alloc_config.def(py::init<>())
|
alloc_config.def(py::init<>())
|
||||||
|
@ -116,9 +116,17 @@ class LocalBackend(Backend):
|
|||||||
compile_options.device_assignment)
|
compile_options.device_assignment)
|
||||||
|
|
||||||
|
|
||||||
|
xla_platform_names = {
|
||||||
|
'cpu': 'Host',
|
||||||
|
'gpu': 'CUDA',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _cpu_backend_factory():
|
def _cpu_backend_factory():
|
||||||
client = _xla.LocalClient.Get(
|
client = _xla.LocalClient.Get(
|
||||||
platform='cpu', xla_platform_id='Host', asynchronous=True)
|
platform='cpu',
|
||||||
|
xla_platform_id=xla_platform_names['cpu'],
|
||||||
|
asynchronous=True)
|
||||||
return LocalBackend(platform='cpu', client=client)
|
return LocalBackend(platform='cpu', client=client)
|
||||||
|
|
||||||
|
|
||||||
@ -143,7 +151,9 @@ def _gpu_backend_factory():
|
|||||||
config.preallocate = preallocate not in ('0', 'false', 'False')
|
config.preallocate = preallocate not in ('0', 'false', 'False')
|
||||||
|
|
||||||
client = _xla.LocalClient.Get(
|
client = _xla.LocalClient.Get(
|
||||||
platform='gpu', xla_platform_id='CUDA', asynchronous=True,
|
platform='gpu',
|
||||||
|
xla_platform_id=xla_platform_names['gpu'],
|
||||||
|
asynchronous=True,
|
||||||
allocator_config=config)
|
allocator_config=config)
|
||||||
return LocalBackend(platform='gpu', client=client)
|
return LocalBackend(platform='gpu', client=client)
|
||||||
|
|
||||||
@ -1596,14 +1606,18 @@ def _forward_methods_to_local_builder():
|
|||||||
_forward_methods_to_local_builder()
|
_forward_methods_to_local_builder()
|
||||||
|
|
||||||
|
|
||||||
def register_cpu_custom_call_target(name, fn):
|
def register_custom_call_target(name, fn, platform='cpu'):
|
||||||
"""Registers a CPU custom call target.
|
"""Registers a custom call target.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: bytes containing the name of the function.
|
name: bytes containing the name of the function.
|
||||||
fn: a PyCapsule object containing the function pointer.
|
fn: a PyCapsule object containing the function pointer.
|
||||||
|
platform: the target platform.
|
||||||
"""
|
"""
|
||||||
_xla.RegisterCpuCustomCallTarget(name, fn)
|
_xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
|
||||||
|
|
||||||
|
# Deprecated. Use register_custom_call_target instead.
|
||||||
|
register_cpu_custom_call_target = register_custom_call_target
|
||||||
|
|
||||||
|
|
||||||
class PaddingConfigDimension(object):
|
class PaddingConfigDimension(object):
|
||||||
|
@ -311,7 +311,7 @@ class ComputationsWithConstantsTest(ComputationTest):
|
|||||||
def testCustomCall(self):
|
def testCustomCall(self):
|
||||||
c = self._NewComputation()
|
c = self._NewComputation()
|
||||||
for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
|
for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
|
||||||
xla_client.register_cpu_custom_call_target(name, fn)
|
xla_client.register_custom_call_target(name, fn, platform="cpu")
|
||||||
c.CustomCall(
|
c.CustomCall(
|
||||||
b"test_subtract_f32",
|
b"test_subtract_f32",
|
||||||
operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),
|
operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),
|
||||||
|
Loading…
Reference in New Issue
Block a user