Add a way to register custom devices with the Python TFE_Context

Reapplying cl/305762978 / 17b7bc01bc

PiperOrigin-RevId: 305796963
Change-Id: Ic2e8e962a04600bf0847ae081887343edc32c918
This commit is contained in:
Allen Lavoie 2020-04-09 17:47:27 -07:00 committed by TensorFlower Gardener
parent d8c89c1cd7
commit 109fd384bd
10 changed files with 468 additions and 125 deletions

View File

@ -372,6 +372,22 @@ tf_cuda_cc_test(
],
)
cc_library(
name = "custom_device_testutil",
testonly = True,
srcs = ["custom_device_testutil.cc"],
hdrs = ["custom_device_testutil.h"],
visibility = ["//tensorflow:internal"],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",
@ -382,6 +398,7 @@ tf_cc_test(
":c_api",
":c_api_experimental",
":c_api_test_util",
":custom_device_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",

View File

@ -20,129 +20,11 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/custom_device_testutil.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
UnpackTensorHandle(var_value, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
@ -394,5 +274,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,172 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
} // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status) {
return reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
->tensor;
}
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info) {
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device->delete_device = &DeleteLoggingDevice;
custom_device->execute = &LoggingDeviceExecute;
*device = custom_device;
LoggingDevice* logging_device = new LoggingDevice;
logging_device->arrived_flag = arrived_flag;
logging_device->executed_flag = executed_flag;
logging_device->device_name = name;
logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0";
*device_info = reinterpret_cast<void*>(logging_device);
}

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status);
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_

View File

@ -1,5 +1,8 @@
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
load(
"//tensorflow/tools/test:performance.bzl",
@ -163,6 +166,47 @@ py_library(
],
)
tf_python_pybind_extension(
name = "custom_device_testutil",
testonly = True,
srcs = ["custom_device_testutil.cc"],
module_name = "custom_device_testutil",
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:custom_device_testutil",
"//tensorflow/python:cpp_python_util",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"//tensorflow/python:safe_ptr",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
py_test(
name = "custom_device_test",
size = "small",
srcs = ["custom_device_test.py"],
python_version = "PY3",
# Note that this currently only works with --config=monolithic, since it
# requires the C API which runs static initializers again.
#
# TODO(allenl): Figure out a way to allow extensions to register custom
# devices which works with dynamic linking.
tags = [
"no_oss",
"no_pip",
],
deps = [
":context",
":custom_device_testutil",
":test",
],
)
cuda_py_test(
name = "context_test",
size = "small",

View File

@ -1116,6 +1116,13 @@ class Context(object):
return function_def
def register_custom_device(self, device_capsule, device_name,
device_info_capsule):
"""Calls TFE_RegisterCustomDevice. See the non-member function."""
self.ensure_initialized()
pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
device_name, device_info_capsule)
def remove_function(self, name):
"""Remove a function from the context.
@ -2287,6 +2294,32 @@ def get_function_def(name):
return context().get_function_def(name)
def register_custom_device(device_capsule, device_name, device_info_capsule):
"""Calls TFE_RegisterCustomDevice to register a custom device with Python.
Enables using C extensions specifying a custom device from Python. See the
experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
details.
Note that custom devices are not currently supported inside `tf.function`s.
Args:
device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice'
containing a pointer to a TFE_CustomDevice struct. The capsule retains
ownership of the memory.
device_name: A string indicating the name to register the custom device
under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may
subsequently be passed to `with tf.device(...):`.
device_info_capsule: A PyCapsule with the name set to
'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific
struct with the initial state of the custom device (the void* device_info
argument to TFE_RegisterCustomDevice). This method takes ownership of the
memory and clears the capsule destructor.
"""
context().register_custom_device(device_capsule, device_name,
device_info_capsule)
# Not every user creates a Context via context.context()
# (for example, enable_eager_execution in python/framework/ops.py),
# but they do all import this file. Note that IS_IN_GRAPH_MODE and

View File

@ -0,0 +1,50 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_device_testutil
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
class CustomDeviceTest(test.TestCase):
def testRegisterCustomDevice(self):
device_name = '/job:localhost/replica:0/task:0/device:CUSTOM:0'
device, device_info, arrived_flag, executed_flag = (
custom_device_testutil.GetLoggingDeviceCapsules(device_name))
context.register_custom_device(device, device_name, device_info)
self.assertFalse(custom_device_testutil.FlagValue(arrived_flag))
self.assertFalse(custom_device_testutil.FlagValue(executed_flag))
with ops.device(device_name):
x = constant_op.constant(1.)
y = x * constant_op.constant(2.)
self.assertTrue(custom_device_testutil.FlagValue(executed_flag))
# There was no copy onto the device. Actually I'm not sure how to trigger
# that from Python.
self.assertFalse(custom_device_testutil.FlagValue(arrived_flag))
with self.assertRaisesRegexp(errors.InternalError, 'Trying to copy'):
y.numpy()
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -0,0 +1,77 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/custom_device_testutil.h"
#include "Python.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include "tensorflow/python/util/util.h"
namespace py = pybind11;
void CallDelete_Flag(PyObject* capsule) {
delete reinterpret_cast<bool*>(PyCapsule_GetPointer(capsule, "flag"));
}
void CallDelete_Device(PyObject* capsule) {
delete reinterpret_cast<TFE_CustomDevice*>(
PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
}
void CallDelete_DeviceInfo(PyObject* capsule) {
PyErr_SetString(PyExc_AssertionError,
"Capsule should be consumed by TFE_Py_RegisterCustomDevice");
}
PYBIND11_MODULE(custom_device_testutil, m) {
m.def("GetLoggingDeviceCapsules", [](const char* name) {
bool* arrived_flag = new bool;
bool* executed_flag = new bool;
*arrived_flag = false;
*executed_flag = false;
tensorflow::Safe_PyObjectPtr arrived_capsule(
PyCapsule_New(arrived_flag, "flag", &CallDelete_Flag));
tensorflow::Safe_PyObjectPtr executed_capsule(
PyCapsule_New(executed_flag, "flag", &CallDelete_Flag));
TFE_CustomDevice* device;
void* device_info;
AllocateLoggingDevice(name, arrived_flag, executed_flag, &device,
&device_info);
tensorflow::Safe_PyObjectPtr device_capsule(
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
tensorflow::Safe_PyObjectPtr device_info_capsule(PyCapsule_New(
device_info, "TFE_CustomDevice_DeviceInfo", &CallDelete_DeviceInfo));
return tensorflow::pyo_or_throw(
PyTuple_Pack(4, device_capsule.get(), device_info_capsule.get(),
arrived_capsule.get(), executed_capsule.get()));
});
m.def("FlagValue", [](py::capsule flag_capsule) {
bool* flag = reinterpret_cast<bool*>(
PyCapsule_GetPointer(flag_capsule.ptr(), "flag"));
if (PyErr_Occurred()) throw py::error_already_set();
return *flag;
});
}

View File

@ -1100,6 +1100,39 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
return py::handle(EagerTensorFromHandle(thandle));
});
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
const py::capsule& device,
const char* device_name,
const py::capsule& device_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (absl::string_view(device.name()) != "TFE_CustomDevice") {
status->status = tensorflow::errors::InvalidArgument(
"Expected a capsule named 'TFE_CustomDevice' for the `device` "
"argument, got ",
absl::string_view(device.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
if (absl::string_view(device_info.name()) !=
"TFE_CustomDevice_DeviceInfo") {
status->status = tensorflow::errors::InvalidArgument(
"Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
"the `device_info` argument, got ",
absl::string_view(device_info.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
// TFE_RegisterCustomDevice takes ownership
PyCapsule_SetDestructor(device_info.ptr(), nullptr);
TFE_RegisterCustomDevice(
tensorflow::InputTFE_Context(context),
*reinterpret_cast<TFE_CustomDevice*>(
PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
device_name,
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
// C API Enum
py::enum_<TFE_ContextDevicePlacementPolicy>(

View File

@ -2814,6 +2814,7 @@ def pybind_extension(
deprecation = deprecation,
restricted_to = restricted_to,
compatible_with = compatible_with,
testonly = testonly,
)
native.py_library(
name = name,
@ -2841,7 +2842,8 @@ def tf_python_pybind_extension(
hdrs = [],
deps = [],
defines = [],
visibility = None):
visibility = None,
testonly = None):
"""A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD.
Please do not use it anywhere else as it may behave unexpectedly. b/146445820
@ -2860,6 +2862,7 @@ def tf_python_pybind_extension(
defines = defines,
visibility = visibility,
link_in_framework = True,
testonly = testonly,
)
def tf_pybind_cc_library_wrapper(name, deps, visibility = None):