Add device description fields to StreamExecutor C API.

While at it, added changed `SP_Device` argument to host_callback to be const.

PiperOrigin-RevId: 333195011
Change-Id: I5224812674c00332336a5c0fcf68ed28dd33d0aa
This commit is contained in:
Anna R 2020-09-22 17:35:44 -07:00 committed by TensorFlower Gardener
parent 42637d6dab
commit 070efb9ebf
4 changed files with 207 additions and 19 deletions

View File

@ -76,6 +76,8 @@ port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_stream_executor);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_timer_fns);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_timer_fns);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device_fns);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device_fns);
return port::Status::OK();
}
@ -104,6 +106,12 @@ port::Status ValidateSPDevice(const SP_Device& device) {
return port::Status::OK();
}
port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) {
VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return port::Status::OK();
}
port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
const SP_Platform& platform) {
VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE);
@ -311,11 +319,13 @@ void HostCallbackTrampoline(void* ctx, TF_Status* status) {
class CStreamExecutor : public internal::StreamExecutorInterface {
public:
explicit CStreamExecutor(SP_Device device, SP_StreamExecutor* stream_executor,
explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns,
SP_StreamExecutor* stream_executor,
SP_Platform* platform, SP_PlatformFns* platform_fns,
SP_TimerFns* timer_fns, const std::string& name,
int visible_device_count)
: device_(std::move(device)),
device_fns_(device_fns),
stream_executor_(stream_executor),
platform_(platform),
platform_fns_(platform_fns),
@ -678,10 +688,35 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
// Ownership is transferred to the caller.
port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
const override {
// TODO(annarev): Figure out if we need to support more description fields.
OwnedTFStatus c_status(TF_NewStatus());
internal::DeviceDescriptionBuilder builder;
builder.set_name(platform_name_);
// TODO(annarev): `Also supports_unified_memory` in DeviceDescription.
if (device_.hardware_name != nullptr) {
builder.set_name(device_.hardware_name);
}
if (device_.device_vendor != nullptr) {
builder.set_device_vendor(device_.device_vendor);
}
if (device_.pci_bus_id != nullptr) {
builder.set_pci_bus_id(device_.pci_bus_id);
}
if (device_fns_->get_numa_node != nullptr) {
int32_t numa_node = device_fns_->get_numa_node(&device_);
if (numa_node >= 0) {
builder.set_numa_node(numa_node);
}
}
if (device_fns_->get_memory_bandwidth != nullptr) {
int64_t memory_bandwidth = device_fns_->get_memory_bandwidth(&device_);
if (memory_bandwidth >= 0) {
builder.set_memory_bandwidth(memory_bandwidth);
}
}
// TODO(annarev): Add gflops field in DeviceDescription and set it here.
// TODO(annarev): Perhaps add `supports_unified_memory` in
// DeviceDescription.
return builder.Build();
}
@ -709,6 +744,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
private:
SP_Device device_;
SP_DeviceFns* device_fns_;
SP_StreamExecutor* stream_executor_;
SP_Platform* platform_;
SP_PlatformFns* platform_fns_;
@ -722,17 +758,20 @@ CPlatform::CPlatform(SP_Platform platform,
void (*destroy_platform)(SP_Platform*),
SP_PlatformFns platform_fns,
void (*destroy_platform_fns)(SP_PlatformFns*),
SP_StreamExecutor stream_executor, SP_TimerFns timer_fns)
SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
SP_TimerFns timer_fns)
: platform_(std::move(platform)),
destroy_platform_(destroy_platform),
platform_fns_(std::move(platform_fns)),
destroy_platform_fns_(destroy_platform_fns),
device_fns_(std::move(device_fns)),
stream_executor_(std::move(stream_executor)),
timer_fns_(std::move(timer_fns)),
name_(platform.name) {}
CPlatform::~CPlatform() {
executor_cache_.DestroyAllExecutors();
platform_fns_.destroy_device_fns(&platform_, &device_fns_);
platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
destroy_platform_(&platform_);
@ -781,8 +820,8 @@ port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
TF_RETURN_IF_ERROR(ValidateSPDevice(device));
auto executor = absl::make_unique<CStreamExecutor>(
std::move(device), &stream_executor_, &platform_, &platform_fns_,
&timer_fns_, name_, platform_.visible_device_count);
std::move(device), &device_fns_, &stream_executor_, &platform_,
&platform_fns_, &timer_fns_, name_, platform_.visible_device_count);
auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
config.ordinal);
return result;
@ -819,6 +858,17 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns));
// Fill SP_DeviceFns creation params
SE_CreateDeviceFnsParams device_fns_params{
SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE};
SP_DeviceFns device_fns{SP_DEVICE_FNS_STRUCT_SIZE};
device_fns_params.device_fns = &device_fns;
// Create StreamExecutor
platform_fns.create_device_fns(&platform, &device_fns_params, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSPDeviceFns(device_fns));
// Fill stream executor creation params
SE_CreateStreamExecutorParams se_params{
SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE};
@ -844,7 +894,8 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
std::unique_ptr<stream_executor::CPlatform> cplatform(
new stream_executor::CPlatform(
std::move(platform), params.destroy_platform, std::move(platform_fns),
params.destroy_platform_fns, std::move(se), std::move(timer_fns)));
params.destroy_platform_fns, std::move(device_fns), std::move(se),
std::move(timer_fns)));
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(cplatform)));

View File

@ -158,9 +158,30 @@ typedef struct SP_Device {
// Device vendor can store handle to their device representation
// here.
void* device_handle;
// [Optional]
// Device hardware name. Used for printing.
// Must be null-terminated.
const char* hardware_name;
// [Optional]
// Device vendor name. Used for printing.
// Must be null-terminated.
const char* device_vendor;
// [Optional]
// Returns the PCI bus identifier for this device, of the form
// [domain]:[bus]:[device].[function]
// where domain number is usually 0000.
// Example: 0000:00:02.1
// For more information see:
// https://en.wikipedia.org/wiki/PCI_configuration_space
// https://www.oreilly.com/library/view/linux-device-drivers/0596005903/ch12.html
// Used for printing. Must be null-terminated.
const char* pci_bus_id;
} SP_Device;
#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, device_handle)
#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, pci_bus_id)
typedef struct SE_CreateDeviceParams {
size_t struct_size;
@ -174,6 +195,42 @@ typedef struct SE_CreateDeviceParams {
#define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \
TF_OFFSET_OF_END(SE_CreateDeviceParams, device)
typedef struct SP_DeviceFns {
size_t struct_size;
void* ext; // reserved for future use
// [Optional]
// Returns the NUMA node associated with this device, for use in
// determining socket locality. If the NUMA node could not be determined, -1
// is returned.
// Negative values are treated as "unset".
int32_t (*get_numa_node)(const SP_Device* device);
// [Optional]
// Device's memory bandwidth in bytes/sec. (This is for reads/writes to/from
// the device's own memory, not for transfers between the host and device.)
// Negative values are treated as "unset".
int64_t (*get_memory_bandwidth)(const SP_Device* device);
// [Optional]
// Estimate of average number of floating point operations per second for
// this device * 10e-9.
// Negative values are treated as "unset".
double (*get_gflops)(const SP_Device* device);
} SP_DeviceFns;
#define SP_DEVICE_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_DeviceFns, get_gflops)
typedef struct SE_CreateDeviceFnsParams {
size_t struct_size;
void* ext; // reserved for future use
SP_DeviceFns* device_fns; // output, to be filled by plugin
} SE_CreateDeviceFnsParams;
#define SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE \
TF_OFFSET_OF_END(SE_CreateDeviceFnsParams, device_fns)
typedef struct SP_StreamExecutor {
size_t struct_size;
void* ext; // reserved for future use
@ -337,7 +394,7 @@ typedef struct SP_StreamExecutor {
// Enqueues on a stream a user-specified function to be run on the host.
// `callback_arg` should be passed as the first argument to `callback_fn`.
TF_Bool (*host_callback)(SP_Device* device, SP_Stream stream,
TF_Bool (*host_callback)(const SP_Device* device, SP_Stream stream,
SE_StatusCallbackFn callback_fn, void* callback_arg);
} SP_StreamExecutor;
@ -389,6 +446,16 @@ typedef struct SP_PlatformFns {
// by the plugin. `device` itself should not be deleted here.
void (*destroy_device)(const SP_Platform* platform, SP_Device* device);
// Callbacks for creating/destroying SP_DeviceFns.
void (*create_device_fns)(const SP_Platform* platform,
SE_CreateDeviceFnsParams* params,
TF_Status* status);
// Clean up fields inside SP_DeviceFns that were allocated
// by the plugin. `device_fns` itself should not be deleted here.
void (*destroy_device_fns)(const SP_Platform* platform,
SP_DeviceFns* device_fns);
// Callbacks for creating/destroying SP_StreamExecutor.
void (*create_stream_executor)(const SP_Platform* platform,
SE_CreateStreamExecutorParams* params,

View File

@ -43,7 +43,8 @@ class CPlatform : public Platform {
void (*destroy_platform)(SP_Platform*),
SP_PlatformFns platform_fns,
void (*destroy_platform_fns)(SP_PlatformFns*),
SP_StreamExecutor stream_executor, SP_TimerFns timer_fns);
SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
SP_TimerFns timer_fns);
~CPlatform() override;
Id id() const override { return const_cast<int*>(&plugin_id_value_); }
@ -74,6 +75,7 @@ class CPlatform : public Platform {
void (*destroy_platform_)(SP_Platform*);
SP_PlatformFns platform_fns_;
void (*destroy_platform_fns_)(SP_PlatformFns*);
SP_DeviceFns device_fns_;
SP_StreamExecutor stream_executor_;
SP_TimerFns timer_fns_;
const std::string name_;

View File

@ -108,7 +108,7 @@ void block_host_for_event(const SP_Device* const device, SP_Event event,
TF_Status* const status) {}
void synchronize_all_activity(const SP_Device* const device,
TF_Status* const status) {}
TF_Bool host_callback(SP_Device* const device, SP_Stream stream,
TF_Bool host_callback(const SP_Device* const device, SP_Stream stream,
SE_StatusCallbackFn const callback_fn,
void* const callback_arg) {
return true;
@ -144,6 +144,10 @@ void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
se->host_callback = host_callback;
}
void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
*device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
}
/*** Create SP_TimerFns ***/
uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; }
@ -171,10 +175,18 @@ void destroy_stream_executor(const SP_Platform* platform,
void create_device(const SP_Platform* platform, SE_CreateDeviceParams* params,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
params->device->struct_size = SP_DEVICE_STRUCT_SIZE;
params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
}
void destroy_device(const SP_Platform* platform, SP_Device* device) {}
void create_device_fns(const SP_Platform* platform,
SE_CreateDeviceFnsParams* params, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
}
void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) {
}
void PopulateDefaultPlatform(SP_Platform* platform,
SP_PlatformFns* platform_fns) {
*platform = {SP_PLATFORM_STRUCT_SIZE};
@ -183,6 +195,8 @@ void PopulateDefaultPlatform(SP_Platform* platform,
platform->visible_device_count = DEVICE_COUNT;
platform_fns->create_device = create_device;
platform_fns->destroy_device = destroy_device;
platform_fns->create_device_fns = create_device_fns;
platform_fns->destroy_device_fns = destroy_device_fns;
platform_fns->create_stream_executor = create_stream_executor;
platform_fns->destroy_stream_executor = destroy_stream_executor;
platform_fns->create_timer_fns = create_timer_fns;
@ -213,8 +227,6 @@ TEST(StreamExecutor, SuccessfulRegistration) {
port::StatusOr<StreamExecutor*> maybe_executor =
platform->ExecutorForDevice(0);
TF_ASSERT_OK(maybe_executor.status());
StreamExecutor* executor = maybe_executor.ConsumeValueOrDie();
ASSERT_EQ(executor->GetDeviceDescription().name(), "MyDevice");
}
TEST(StreamExecutor, NameNotSet) {
@ -271,6 +283,7 @@ class StreamExecutorTest : public ::testing::Test {
StreamExecutorTest() {}
void SetUp() override {
PopulateDefaultPlatform(&platform_, &platform_fns_);
PopulateDefaultDeviceFns(&device_fns_);
PopulateDefaultStreamExecutor(&se_);
PopulateDefaultTimerFns(&timer_fns_);
}
@ -279,8 +292,8 @@ class StreamExecutorTest : public ::testing::Test {
StreamExecutor* GetExecutor(int ordinal) {
if (!cplatform_) {
cplatform_ = absl::make_unique<CPlatform>(
platform_, destroy_platform, platform_fns_, destroy_platform_fns, se_,
timer_fns_);
platform_, destroy_platform, platform_fns_, destroy_platform_fns,
device_fns_, se_, timer_fns_);
}
port::StatusOr<StreamExecutor*> maybe_executor =
cplatform_->ExecutorForDevice(ordinal);
@ -289,6 +302,7 @@ class StreamExecutorTest : public ::testing::Test {
}
SP_Platform platform_;
SP_PlatformFns platform_fns_;
SP_DeviceFns device_fns_;
SP_StreamExecutor se_;
SP_TimerFns timer_fns_;
std::unique_ptr<CPlatform> cplatform_;
@ -841,7 +855,7 @@ TEST_F(StreamExecutorTest, SynchronizeAllActivity) {
}
TEST_F(StreamExecutorTest, HostCallbackOk) {
se_.host_callback = [](SP_Device* const device, SP_Stream stream,
se_.host_callback = [](const SP_Device* const device, SP_Stream stream,
SE_StatusCallbackFn const callback_fn,
void* const callback_arg) -> TF_Bool {
TF_Status* status = TF_NewStatus();
@ -861,7 +875,7 @@ TEST_F(StreamExecutorTest, HostCallbackOk) {
}
TEST_F(StreamExecutorTest, HostCallbackError) {
se_.host_callback = [](SP_Device* const device, SP_Stream stream,
se_.host_callback = [](const SP_Device* const device, SP_Stream stream,
SE_StatusCallbackFn const callback_fn,
void* const callback_arg) -> TF_Bool {
TF_Status* status = TF_NewStatus();
@ -879,5 +893,59 @@ TEST_F(StreamExecutorTest, HostCallbackError) {
stream.ThenDoHostCallbackWithStatus(callback);
ASSERT_FALSE(stream.ok());
}
TEST_F(StreamExecutorTest, DeviceDescription) {
static const char* hardware_name = "TestName";
static const char* vendor = "TestVendor";
static const char* pci_bus_id = "TestPCIBusId";
platform_fns_.create_device = [](const SP_Platform* platform,
SE_CreateDeviceParams* params,
TF_Status* status) {
params->device->hardware_name = hardware_name;
params->device->device_vendor = vendor;
params->device->pci_bus_id = pci_bus_id;
};
device_fns_.get_numa_node = [](const SP_Device* device) { return 123; };
device_fns_.get_memory_bandwidth = [](const SP_Device* device) -> int64_t {
return 54;
};
device_fns_.get_gflops = [](const SP_Device* device) -> double { return 32; };
StreamExecutor* executor = GetExecutor(0);
const DeviceDescription& description = executor->GetDeviceDescription();
ASSERT_EQ(description.name(), "TestName");
ASSERT_EQ(description.device_vendor(), "TestVendor");
ASSERT_EQ(description.pci_bus_id(), "TestPCIBusId");
ASSERT_EQ(description.numa_node(), 123);
ASSERT_EQ(description.memory_bandwidth(), 54);
}
TEST_F(StreamExecutorTest, DeviceDescriptionNumaNodeNotSet) {
static const char* hardware_name = "TestName";
static const char* vendor = "TestVendor";
static const char* pci_bus_id = "TestPCIBusId";
platform_fns_.create_device = [](const SP_Platform* platform,
SE_CreateDeviceParams* params,
TF_Status* status) {
params->device->hardware_name = hardware_name;
params->device->device_vendor = vendor;
params->device->pci_bus_id = pci_bus_id;
};
device_fns_.get_memory_bandwidth = [](const SP_Device* device) -> int64_t {
return 54;
};
device_fns_.get_gflops = [](const SP_Device* device) -> double { return 32; };
StreamExecutor* executor = GetExecutor(0);
const DeviceDescription& description = executor->GetDeviceDescription();
ASSERT_EQ(description.name(), "TestName");
ASSERT_EQ(description.device_vendor(), "TestVendor");
ASSERT_EQ(description.pci_bus_id(), "TestPCIBusId");
ASSERT_EQ(description.numa_node(), -1);
ASSERT_EQ(description.memory_bandwidth(), 54);
}
} // namespace
} // namespace stream_executor