Add a way to register custom devices with the Python TFE_Context
The API accepts TFE_RegisterCustomDevice arguments as PyCapsules, so each custom device will need some method to create those. Presumably most custom devices will end up wrapping the PyCapsule creation+registration rather than exposing it to the user. No public API yet, but this is roughly what I have in mind at the moment. This only works with --config=monolithic or when the custom device registration is bundled with pywrap_tensorflow.so right now since that has its own copy of the C API. Something like this could work if we switched pywrap_tensorflow.so to instead rely on libtensorflow.so for the C API, then custom device extensions could link against that. PiperOrigin-RevId: 305762978 Change-Id: I4d2d9bd9c01ba22391e138244a3948bae8963c5c
This commit is contained in:
parent
747c37add5
commit
17b7bc01bc
@ -372,6 +372,22 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_device_testutil",
|
||||
testonly = True,
|
||||
srcs = ["custom_device_testutil.cc"],
|
||||
hdrs = ["custom_device_testutil.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
@ -382,6 +398,7 @@ tf_cc_test(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":custom_device_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
|
@ -515,6 +515,11 @@ typedef struct TFE_CustomDevice {
|
||||
// This API is highly experimental, and in particular is expected to change when
|
||||
// it starts supporting operations with attributes and when tf.function support
|
||||
// is added.
|
||||
//
|
||||
// TODO(allenl): Currently custom devices need to know their device name, and
|
||||
// this separately needs to be fed to the registration. It would be nice if that
|
||||
// duplication wasn't necessary because custom devices could be written without
|
||||
// knowing their name.
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status);
|
||||
|
@ -20,129 +20,11 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/custom_device_testutil.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
tensorflow::string(
|
||||
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||
TFE_TensorHandle* var_value_unpacked =
|
||||
reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||
->tensor;
|
||||
UnpackTensorHandle(var_value, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||
@ -394,5 +274,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
172
tensorflow/c/eager/custom_device_testutil.cc
Normal file
172
tensorflow/c/eager/custom_device_testutil.cc
Normal file
@ -0,0 +1,172 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
|
||||
TF_Status* status) {
|
||||
return reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
|
||||
->tensor;
|
||||
}
|
||||
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info) {
|
||||
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
|
||||
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device->delete_device = &DeleteLoggingDevice;
|
||||
custom_device->execute = &LoggingDeviceExecute;
|
||||
*device = custom_device;
|
||||
LoggingDevice* logging_device = new LoggingDevice;
|
||||
logging_device->arrived_flag = arrived_flag;
|
||||
logging_device->executed_flag = executed_flag;
|
||||
logging_device->device_name = name;
|
||||
logging_device->underlying_device =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
*device_info = reinterpret_cast<void*>(logging_device);
|
||||
}
|
36
tensorflow/c/eager/custom_device_testutil.h
Normal file
36
tensorflow/c/eager/custom_device_testutil.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status);
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info);
|
||||
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
|
||||
TF_Status* status);
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
@ -1,5 +1,8 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
|
||||
load(
|
||||
"//tensorflow/tools/test:performance.bzl",
|
||||
@ -163,6 +166,47 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "custom_device_testutil",
|
||||
testonly = True,
|
||||
srcs = ["custom_device_testutil.cc"],
|
||||
module_name = "custom_device_testutil",
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/eager:custom_device_testutil",
|
||||
"//tensorflow/python:cpp_python_util",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"//tensorflow/python:safe_ptr",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
srcs = ["custom_device_test.py"],
|
||||
python_version = "PY3",
|
||||
# Note that this currently only works with --config=monolithic, since it
|
||||
# requires the C API which runs static initializers again.
|
||||
#
|
||||
# TODO(allenl): Figure out a way to allow extensions to register custom
|
||||
# devices which works with dynamic linking.
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
],
|
||||
deps = [
|
||||
":context",
|
||||
":custom_device_testutil",
|
||||
":test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "context_test",
|
||||
size = "small",
|
||||
|
@ -1116,6 +1116,13 @@ class Context(object):
|
||||
|
||||
return function_def
|
||||
|
||||
def register_custom_device(self, device_capsule, device_name,
|
||||
device_info_capsule):
|
||||
"""Calls TFE_RegisterCustomDevice. See the non-member function."""
|
||||
self.ensure_initialized()
|
||||
pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
|
||||
device_name, device_info_capsule)
|
||||
|
||||
def remove_function(self, name):
|
||||
"""Remove a function from the context.
|
||||
|
||||
@ -2287,6 +2294,32 @@ def get_function_def(name):
|
||||
return context().get_function_def(name)
|
||||
|
||||
|
||||
def register_custom_device(device_capsule, device_name, device_info_capsule):
|
||||
"""Calls TFE_RegisterCustomDevice to register a custom device with Python.
|
||||
|
||||
Enables using C extensions specifying a custom device from Python. See the
|
||||
experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
|
||||
details.
|
||||
|
||||
Note that custom devices are not currently supported inside `tf.function`s.
|
||||
|
||||
Args:
|
||||
device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice'
|
||||
containing a pointer to a TFE_CustomDevice struct. The capsule retains
|
||||
ownership of the memory.
|
||||
device_name: A string indicating the name to register the custom device
|
||||
under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may
|
||||
subsequently be passed to `with tf.device(...):`.
|
||||
device_info_capsule: A PyCapsule with the name set to
|
||||
'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific
|
||||
struct with the initial state of the custom device (the void* device_info
|
||||
argument to TFE_RegisterCustomDevice). This method takes ownership of the
|
||||
memory and clears the capsule destructor.
|
||||
"""
|
||||
context().register_custom_device(device_capsule, device_name,
|
||||
device_info_capsule)
|
||||
|
||||
|
||||
# Not every user creates a Context via context.context()
|
||||
# (for example, enable_eager_execution in python/framework/ops.py),
|
||||
# but they do all import this file. Note that IS_IN_GRAPH_MODE and
|
||||
|
50
tensorflow/python/eager/custom_device_test.py
Normal file
50
tensorflow/python/eager/custom_device_test.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import custom_device_testutil
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CustomDeviceTest(test.TestCase):
|
||||
|
||||
def testRegisterCustomDevice(self):
|
||||
device_name = '/job:localhost/replica:0/task:0/device:CUSTOM:0'
|
||||
device, device_info, arrived_flag, executed_flag = (
|
||||
custom_device_testutil.GetLoggingDeviceCapsules(device_name))
|
||||
context.register_custom_device(device, device_name, device_info)
|
||||
self.assertFalse(custom_device_testutil.FlagValue(arrived_flag))
|
||||
self.assertFalse(custom_device_testutil.FlagValue(executed_flag))
|
||||
with ops.device(device_name):
|
||||
x = constant_op.constant(1.)
|
||||
y = x * constant_op.constant(2.)
|
||||
self.assertTrue(custom_device_testutil.FlagValue(executed_flag))
|
||||
# There was no copy onto the device. Actually I'm not sure how to trigger
|
||||
# that from Python.
|
||||
self.assertFalse(custom_device_testutil.FlagValue(arrived_flag))
|
||||
with self.assertRaisesRegexp(errors.InternalError, 'Trying to copy'):
|
||||
y.numpy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
77
tensorflow/python/eager/custom_device_testutil.cc
Normal file
77
tensorflow/python/eager/custom_device_testutil.cc
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/custom_device_testutil.h"
|
||||
|
||||
#include "Python.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void CallDelete_Flag(PyObject* capsule) {
|
||||
delete reinterpret_cast<bool*>(PyCapsule_GetPointer(capsule, "flag"));
|
||||
}
|
||||
|
||||
void CallDelete_Device(PyObject* capsule) {
|
||||
delete reinterpret_cast<TFE_CustomDevice*>(
|
||||
PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
|
||||
}
|
||||
|
||||
void CallDelete_DeviceInfo(PyObject* capsule) {
|
||||
PyErr_SetString(PyExc_AssertionError,
|
||||
"Capsule should be consumed by TFE_Py_RegisterCustomDevice");
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(custom_device_testutil, m) {
|
||||
m.def("GetLoggingDeviceCapsules", [](const char* name) {
|
||||
bool* arrived_flag = new bool;
|
||||
bool* executed_flag = new bool;
|
||||
*arrived_flag = false;
|
||||
*executed_flag = false;
|
||||
tensorflow::Safe_PyObjectPtr arrived_capsule(
|
||||
PyCapsule_New(arrived_flag, "flag", &CallDelete_Flag));
|
||||
tensorflow::Safe_PyObjectPtr executed_capsule(
|
||||
PyCapsule_New(executed_flag, "flag", &CallDelete_Flag));
|
||||
TFE_CustomDevice* device;
|
||||
void* device_info;
|
||||
AllocateLoggingDevice(name, arrived_flag, executed_flag, &device,
|
||||
&device_info);
|
||||
tensorflow::Safe_PyObjectPtr device_capsule(
|
||||
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
|
||||
tensorflow::Safe_PyObjectPtr device_info_capsule(PyCapsule_New(
|
||||
device_info, "TFE_CustomDevice_DeviceInfo", &CallDelete_DeviceInfo));
|
||||
return tensorflow::pyo_or_throw(
|
||||
PyTuple_Pack(4, device_capsule.get(), device_info_capsule.get(),
|
||||
arrived_capsule.get(), executed_capsule.get()));
|
||||
});
|
||||
m.def("FlagValue", [](py::capsule flag_capsule) {
|
||||
bool* flag = reinterpret_cast<bool*>(
|
||||
PyCapsule_GetPointer(flag_capsule.ptr(), "flag"));
|
||||
if (PyErr_Occurred()) throw py::error_already_set();
|
||||
return *flag;
|
||||
});
|
||||
}
|
@ -1100,6 +1100,39 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
return py::handle(EagerTensorFromHandle(thandle));
|
||||
});
|
||||
|
||||
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
|
||||
const py::capsule& device,
|
||||
const char* device_name,
|
||||
const py::capsule& device_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
if (absl::string_view(device.name()) != "TFE_CustomDevice") {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Expected a capsule named 'TFE_CustomDevice' for the `device` "
|
||||
"argument, got ",
|
||||
absl::string_view(device.name()));
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
}
|
||||
if (absl::string_view(device_info.name()) !=
|
||||
"TFE_CustomDevice_DeviceInfo") {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
|
||||
"the `device_info` argument, got ",
|
||||
absl::string_view(device_info.name()));
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
}
|
||||
// TFE_RegisterCustomDevice takes ownership
|
||||
PyCapsule_SetDestructor(device_info.ptr(), nullptr);
|
||||
TFE_RegisterCustomDevice(
|
||||
tensorflow::InputTFE_Context(context),
|
||||
*reinterpret_cast<TFE_CustomDevice*>(
|
||||
PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
|
||||
device_name,
|
||||
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
|
||||
status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
});
|
||||
|
||||
// C API Enum
|
||||
|
||||
py::enum_<TFE_ContextDevicePlacementPolicy>(
|
||||
|
@ -2814,6 +2814,7 @@ def pybind_extension(
|
||||
deprecation = deprecation,
|
||||
restricted_to = restricted_to,
|
||||
compatible_with = compatible_with,
|
||||
testonly = testonly,
|
||||
)
|
||||
native.py_library(
|
||||
name = name,
|
||||
@ -2841,7 +2842,8 @@ def tf_python_pybind_extension(
|
||||
hdrs = [],
|
||||
deps = [],
|
||||
defines = [],
|
||||
visibility = None):
|
||||
visibility = None,
|
||||
testonly = None):
|
||||
"""A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD.
|
||||
|
||||
Please do not use it anywhere else as it may behave unexpectedly. b/146445820
|
||||
@ -2860,6 +2862,7 @@ def tf_python_pybind_extension(
|
||||
defines = defines,
|
||||
visibility = visibility,
|
||||
link_in_framework = True,
|
||||
testonly = testonly,
|
||||
)
|
||||
|
||||
def tf_pybind_cc_library_wrapper(name, deps, visibility = None):
|
||||
|
Loading…
Reference in New Issue
Block a user