Add a way to register custom devices with the Python TFE_Context

The API accepts TFE_RegisterCustomDevice arguments as PyCapsules, so each custom device will need some method to create those. Presumably most custom devices will end up wrapping the PyCapsule creation+registration rather than exposing it to the user.

No public API yet, but this is roughly what I have in mind at the moment.

This only works with --config=monolithic or when the custom device registration is bundled with pywrap_tensorflow.so right now since that has its own copy of the C API. Something like this could work if we switched pywrap_tensorflow.so to instead rely on libtensorflow.so for the C API, then custom device extensions could link against that.

PiperOrigin-RevId: 305762978
Change-Id: I4d2d9bd9c01ba22391e138244a3948bae8963c5c
This commit is contained in:
Allen Lavoie 2020-04-09 14:39:08 -07:00 committed by TensorFlower Gardener
parent 747c37add5
commit 17b7bc01bc
11 changed files with 473 additions and 125 deletions

View File

@ -372,6 +372,22 @@ tf_cuda_cc_test(
],
)
cc_library(
name = "custom_device_testutil",
testonly = True,
srcs = ["custom_device_testutil.cc"],
hdrs = ["custom_device_testutil.h"],
visibility = ["//tensorflow:internal"],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",
@ -382,6 +398,7 @@ tf_cc_test(
":c_api",
":c_api_experimental",
":c_api_test_util",
":custom_device_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",

View File

@ -515,6 +515,11 @@ typedef struct TFE_CustomDevice {
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
//
// TODO(allenl): Currently custom devices need to know their device name, and
// this separately needs to be fed to the registration. It would be nice if that
// duplication wasn't necessary because custom devices could be written without
// knowing their name.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status);

View File

@ -20,129 +20,11 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/custom_device_testutil.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
UnpackTensorHandle(var_value, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
@ -394,5 +274,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,172 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
} // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status) {
return reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
->tensor;
}
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info) {
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device->delete_device = &DeleteLoggingDevice;
custom_device->execute = &LoggingDeviceExecute;
*device = custom_device;
LoggingDevice* logging_device = new LoggingDevice;
logging_device->arrived_flag = arrived_flag;
logging_device->executed_flag = executed_flag;
logging_device->device_name = name;
logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0";
*device_info = reinterpret_cast<void*>(logging_device);
}

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status);
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_

View File

@ -1,5 +1,8 @@
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
load(
"//tensorflow/tools/test:performance.bzl",
@ -163,6 +166,47 @@ py_library(
],
)
tf_python_pybind_extension(
name = "custom_device_testutil",
testonly = True,
srcs = ["custom_device_testutil.cc"],
module_name = "custom_device_testutil",
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:custom_device_testutil",
"//tensorflow/python:cpp_python_util",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"//tensorflow/python:safe_ptr",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
py_test(
name = "custom_device_test",
size = "small",
srcs = ["custom_device_test.py"],
python_version = "PY3",
# Note that this currently only works with --config=monolithic, since it
# requires the C API which runs static initializers again.
#
# TODO(allenl): Figure out a way to allow extensions to register custom
# devices which works with dynamic linking.
tags = [
"no_oss",
"no_pip",
],
deps = [
":context",
":custom_device_testutil",
":test",
],
)
cuda_py_test(
name = "context_test",
size = "small",

View File

@ -1116,6 +1116,13 @@ class Context(object):
return function_def
def register_custom_device(self, device_capsule, device_name,
device_info_capsule):
"""Calls TFE_RegisterCustomDevice. See the non-member function."""
self.ensure_initialized()
pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
device_name, device_info_capsule)
def remove_function(self, name):
"""Remove a function from the context.
@ -2287,6 +2294,32 @@ def get_function_def(name):
return context().get_function_def(name)
def register_custom_device(device_capsule, device_name, device_info_capsule):
"""Calls TFE_RegisterCustomDevice to register a custom device with Python.
Enables using C extensions specifying a custom device from Python. See the
experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
details.
Note that custom devices are not currently supported inside `tf.function`s.
Args:
device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice'
containing a pointer to a TFE_CustomDevice struct. The capsule retains
ownership of the memory.
device_name: A string indicating the name to register the custom device
under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may
subsequently be passed to `with tf.device(...):`.
device_info_capsule: A PyCapsule with the name set to
'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific
struct with the initial state of the custom device (the void* device_info
argument to TFE_RegisterCustomDevice). This method takes ownership of the
memory and clears the capsule destructor.
"""
context().register_custom_device(device_capsule, device_name,
device_info_capsule)
# Not every user creates a Context via context.context()
# (for example, enable_eager_execution in python/framework/ops.py),
# but they do all import this file. Note that IS_IN_GRAPH_MODE and

View File

@ -0,0 +1,50 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_device_testutil
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
class CustomDeviceTest(test.TestCase):
def testRegisterCustomDevice(self):
device_name = '/job:localhost/replica:0/task:0/device:CUSTOM:0'
device, device_info, arrived_flag, executed_flag = (
custom_device_testutil.GetLoggingDeviceCapsules(device_name))
context.register_custom_device(device, device_name, device_info)
self.assertFalse(custom_device_testutil.FlagValue(arrived_flag))
self.assertFalse(custom_device_testutil.FlagValue(executed_flag))
with ops.device(device_name):
x = constant_op.constant(1.)
y = x * constant_op.constant(2.)
self.assertTrue(custom_device_testutil.FlagValue(executed_flag))
# There was no copy onto the device. Actually I'm not sure how to trigger
# that from Python.
self.assertFalse(custom_device_testutil.FlagValue(arrived_flag))
with self.assertRaisesRegexp(errors.InternalError, 'Trying to copy'):
y.numpy()
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -0,0 +1,77 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/custom_device_testutil.h"
#include "Python.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include "tensorflow/python/util/util.h"
namespace py = pybind11;
void CallDelete_Flag(PyObject* capsule) {
delete reinterpret_cast<bool*>(PyCapsule_GetPointer(capsule, "flag"));
}
void CallDelete_Device(PyObject* capsule) {
delete reinterpret_cast<TFE_CustomDevice*>(
PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
}
void CallDelete_DeviceInfo(PyObject* capsule) {
PyErr_SetString(PyExc_AssertionError,
"Capsule should be consumed by TFE_Py_RegisterCustomDevice");
}
PYBIND11_MODULE(custom_device_testutil, m) {
m.def("GetLoggingDeviceCapsules", [](const char* name) {
bool* arrived_flag = new bool;
bool* executed_flag = new bool;
*arrived_flag = false;
*executed_flag = false;
tensorflow::Safe_PyObjectPtr arrived_capsule(
PyCapsule_New(arrived_flag, "flag", &CallDelete_Flag));
tensorflow::Safe_PyObjectPtr executed_capsule(
PyCapsule_New(executed_flag, "flag", &CallDelete_Flag));
TFE_CustomDevice* device;
void* device_info;
AllocateLoggingDevice(name, arrived_flag, executed_flag, &device,
&device_info);
tensorflow::Safe_PyObjectPtr device_capsule(
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
tensorflow::Safe_PyObjectPtr device_info_capsule(PyCapsule_New(
device_info, "TFE_CustomDevice_DeviceInfo", &CallDelete_DeviceInfo));
return tensorflow::pyo_or_throw(
PyTuple_Pack(4, device_capsule.get(), device_info_capsule.get(),
arrived_capsule.get(), executed_capsule.get()));
});
m.def("FlagValue", [](py::capsule flag_capsule) {
bool* flag = reinterpret_cast<bool*>(
PyCapsule_GetPointer(flag_capsule.ptr(), "flag"));
if (PyErr_Occurred()) throw py::error_already_set();
return *flag;
});
}

View File

@ -1100,6 +1100,39 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
return py::handle(EagerTensorFromHandle(thandle));
});
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
const py::capsule& device,
const char* device_name,
const py::capsule& device_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (absl::string_view(device.name()) != "TFE_CustomDevice") {
status->status = tensorflow::errors::InvalidArgument(
"Expected a capsule named 'TFE_CustomDevice' for the `device` "
"argument, got ",
absl::string_view(device.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
if (absl::string_view(device_info.name()) !=
"TFE_CustomDevice_DeviceInfo") {
status->status = tensorflow::errors::InvalidArgument(
"Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
"the `device_info` argument, got ",
absl::string_view(device_info.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
// TFE_RegisterCustomDevice takes ownership
PyCapsule_SetDestructor(device_info.ptr(), nullptr);
TFE_RegisterCustomDevice(
tensorflow::InputTFE_Context(context),
*reinterpret_cast<TFE_CustomDevice*>(
PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
device_name,
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
// C API Enum
py::enum_<TFE_ContextDevicePlacementPolicy>(

View File

@ -2814,6 +2814,7 @@ def pybind_extension(
deprecation = deprecation,
restricted_to = restricted_to,
compatible_with = compatible_with,
testonly = testonly,
)
native.py_library(
name = name,
@ -2841,7 +2842,8 @@ def tf_python_pybind_extension(
hdrs = [],
deps = [],
defines = [],
visibility = None):
visibility = None,
testonly = None):
"""A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD.
Please do not use it anywhere else as it may behave unexpectedly. b/146445820
@ -2860,6 +2862,7 @@ def tf_python_pybind_extension(
defines = defines,
visibility = visibility,
link_in_framework = True,
testonly = testonly,
)
def tf_pybind_cc_library_wrapper(name, deps, visibility = None):