Rename ExternalTpuDriver to DirectTpuDriver and c_api_client to libtpu_client
PiperOrigin-RevId: 289919461 Change-Id: I8f7f2ff726af641b6fc7297c24bf84cf8aa55748
This commit is contained in:
parent
37f0ac13bd
commit
5c313b8dbb
@ -74,8 +74,8 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "external_tpu_driver",
|
||||
srcs = ["external_tpu_driver.cc"],
|
||||
name = "direct_tpu_driver",
|
||||
srcs = ["direct_tpu_driver.cc"],
|
||||
deps = [
|
||||
":tpu_driver",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
@ -22,7 +22,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/python:local_client",
|
||||
"//tensorflow/compiler/xla/python:semaphore",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:external_tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:recording_tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc",
|
||||
|
@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Before you start, make sure c_api.so, c_api.h and and c_api_client.c are in
|
||||
// the same working directory.
|
||||
// Before you start, make sure libtpu.so, libtpu.h and and libtpu_client.c are
|
||||
// in the same working directory.
|
||||
//
|
||||
// To compile: gcc -o c_api_client c_api_client.c -ldl
|
||||
// To run: sudo ./c_api_client
|
||||
// To compile: gcc -o libtpu_client libtpu_client.c -ldl
|
||||
// To run: sudo ./libtpu_client
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <stdio.h>
|
||||
@ -28,7 +28,7 @@ limitations under the License.
|
||||
void* LoadAndInitializeDriver(const char* shared_lib,
|
||||
struct TpuDriverFn* driver_fn) {
|
||||
void* handle;
|
||||
handle = dlopen("libtpu.so", RTLD_NOW);
|
||||
handle = dlopen(shared_lib, RTLD_NOW);
|
||||
if (!handle) {
|
||||
fprintf(stderr, "Error: %s\n", dlerror());
|
||||
exit(EXIT_FAILURE);
|
||||
@ -42,8 +42,13 @@ void* LoadAndInitializeDriver(const char* shared_lib,
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
char* api_path = "./libtpu.so";
|
||||
if (argc == 2) {
|
||||
api_path = argv[1];
|
||||
}
|
||||
|
||||
struct TpuDriverFn driver_fn;
|
||||
void* handle = LoadAndInitializeDriver("./c_api.so", &driver_fn);
|
||||
void* handle = LoadAndInitializeDriver(api_path, &driver_fn);
|
||||
|
||||
fprintf(stdout, "------ Going to Query Version ------\n");
|
||||
fprintf(stdout, "TPU Driver Version: %s\n", driver_fn.TpuDriver_Version());
|
@ -27,7 +27,7 @@
|
||||
namespace tpu_driver {
|
||||
namespace {
|
||||
|
||||
constexpr char kExternalProtocol[] = "external://";
|
||||
constexpr char kDirectProtocol[] = "direct://";
|
||||
|
||||
::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
|
||||
::TpuAllocationShape shape_;
|
||||
@ -42,14 +42,14 @@ constexpr char kExternalProtocol[] = "external://";
|
||||
return shape_;
|
||||
}
|
||||
|
||||
class ExternalTpuDriver;
|
||||
class DirectTpuDriver;
|
||||
|
||||
class ExternalEvent : public Event {
|
||||
class DirectEvent : public Event {
|
||||
public:
|
||||
explicit ExternalEvent(::TpuDriverFn* driver_fn, ::TpuEvent* event)
|
||||
explicit DirectEvent(::TpuDriverFn* driver_fn, ::TpuEvent* event)
|
||||
: driver_fn_(driver_fn), event_(event) {}
|
||||
|
||||
~ExternalEvent() override { driver_fn_->TpuDriver_FreeEvent(event_); }
|
||||
~DirectEvent() override { driver_fn_->TpuDriver_FreeEvent(event_); }
|
||||
|
||||
xla::Status Await() override {
|
||||
auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1);
|
||||
@ -97,14 +97,14 @@ class ExternalEvent : public Event {
|
||||
::TpuDriverFn* driver_fn_;
|
||||
::TpuEvent* event_;
|
||||
|
||||
friend ExternalTpuDriver;
|
||||
friend DirectTpuDriver;
|
||||
};
|
||||
|
||||
class ExternalBufferHandle : public BufferHandle {
|
||||
class DirectBufferHandle : public BufferHandle {
|
||||
public:
|
||||
explicit ExternalBufferHandle(::TpuDriverFn* driver_fn,
|
||||
::TpuBufferHandle* handle)
|
||||
: handle_(handle), event_(new ExternalEvent(driver_fn, handle->event)) {}
|
||||
explicit DirectBufferHandle(::TpuDriverFn* driver_fn,
|
||||
::TpuBufferHandle* handle)
|
||||
: handle_(handle), event_(new DirectEvent(driver_fn, handle->event)) {}
|
||||
|
||||
std::shared_ptr<Event> OnReady() override { return event_; }
|
||||
|
||||
@ -117,18 +117,18 @@ class ExternalBufferHandle : public BufferHandle {
|
||||
|
||||
private:
|
||||
::TpuBufferHandle* handle_;
|
||||
std::shared_ptr<ExternalEvent> event_;
|
||||
std::shared_ptr<DirectEvent> event_;
|
||||
|
||||
friend ExternalTpuDriver;
|
||||
friend DirectTpuDriver;
|
||||
};
|
||||
|
||||
class ExternalCompiledProgramHandle : public CompiledProgramHandle {
|
||||
class DirectCompiledProgramHandle : public CompiledProgramHandle {
|
||||
public:
|
||||
explicit ExternalCompiledProgramHandle(::TpuDriverFn* driver_fn,
|
||||
::TpuCompiledProgramHandle* handle)
|
||||
explicit DirectCompiledProgramHandle(::TpuDriverFn* driver_fn,
|
||||
::TpuCompiledProgramHandle* handle)
|
||||
: handle_(handle),
|
||||
driver_fn_(driver_fn),
|
||||
event_(new ExternalEvent(driver_fn, handle->event)) {}
|
||||
event_(new DirectEvent(driver_fn, handle->event)) {}
|
||||
|
||||
std::shared_ptr<Event> OnReady() override { return event_; }
|
||||
|
||||
@ -152,16 +152,16 @@ class ExternalCompiledProgramHandle : public CompiledProgramHandle {
|
||||
private:
|
||||
::TpuCompiledProgramHandle* handle_;
|
||||
::TpuDriverFn* driver_fn_;
|
||||
std::shared_ptr<ExternalEvent> event_;
|
||||
std::shared_ptr<DirectEvent> event_;
|
||||
|
||||
friend ExternalTpuDriver;
|
||||
friend DirectTpuDriver;
|
||||
};
|
||||
|
||||
class ExternalLoadedProgramHandle : public LoadedProgramHandle {
|
||||
class DirectLoadedProgramHandle : public LoadedProgramHandle {
|
||||
public:
|
||||
explicit ExternalLoadedProgramHandle(::TpuDriverFn* driver_fn,
|
||||
::TpuLoadedProgramHandle* handle)
|
||||
: handle_(handle), event_(new ExternalEvent(driver_fn, handle->event)) {}
|
||||
explicit DirectLoadedProgramHandle(::TpuDriverFn* driver_fn,
|
||||
::TpuLoadedProgramHandle* handle)
|
||||
: handle_(handle), event_(new DirectEvent(driver_fn, handle->event)) {}
|
||||
std::shared_ptr<Event> OnReady() override { return event_; }
|
||||
|
||||
int64_t size_in_bytes() override {
|
||||
@ -171,14 +171,14 @@ class ExternalLoadedProgramHandle : public LoadedProgramHandle {
|
||||
|
||||
private:
|
||||
::TpuLoadedProgramHandle* handle_;
|
||||
std::shared_ptr<ExternalEvent> event_;
|
||||
std::shared_ptr<DirectEvent> event_;
|
||||
|
||||
friend ExternalTpuDriver;
|
||||
friend DirectTpuDriver;
|
||||
};
|
||||
|
||||
class ExternalTpuLinearizer : public TpuLinearizer {
|
||||
class DirectTpuLinearizer : public TpuLinearizer {
|
||||
public:
|
||||
explicit ExternalTpuLinearizer(::TpuDriver* driver, ::TpuDriverFn* driver_fn)
|
||||
explicit DirectTpuLinearizer(::TpuDriver* driver, ::TpuDriverFn* driver_fn)
|
||||
: driver_(driver), driver_fn_(driver_fn) {}
|
||||
|
||||
int64_t ComputeLinearizedBytesFromShape(
|
||||
@ -221,9 +221,9 @@ class ExternalTpuLinearizer : public TpuLinearizer {
|
||||
::TpuDriverFn* driver_fn_;
|
||||
};
|
||||
|
||||
class ExternalTpuDriver : public TpuDriver {
|
||||
class DirectTpuDriver : public TpuDriver {
|
||||
public:
|
||||
explicit ExternalTpuDriver(const std::string& so_path) {
|
||||
explicit DirectTpuDriver(const std::string& so_path) {
|
||||
void* handle;
|
||||
handle = dlopen(so_path.c_str(), RTLD_NOW);
|
||||
if (!handle) {
|
||||
@ -238,7 +238,7 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
driver_ = driver_fn_.TpuDriver_Open("local://");
|
||||
}
|
||||
|
||||
~ExternalTpuDriver() override {}
|
||||
~DirectTpuDriver() override {}
|
||||
|
||||
void QuerySystemInfo(SystemInfo* system_info) override {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
@ -250,7 +250,7 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
int32_t core_id, MemoryRegion region, int64_t num_bytes,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto tpu_events = MakeEventArray(wait_for);
|
||||
auto bh = absl::make_unique<ExternalBufferHandle>(
|
||||
auto bh = absl::make_unique<DirectBufferHandle>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_Allocate(driver_, core_id, region, num_bytes,
|
||||
wait_for.size(), tpu_events));
|
||||
@ -264,7 +264,7 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
auto tpu_events = MakeEventArray(wait_for);
|
||||
|
||||
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
|
||||
auto bh = absl::make_unique<ExternalBufferHandle>(
|
||||
auto bh = absl::make_unique<DirectBufferHandle>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_AllocateShape(driver_, core_id, region, shape_,
|
||||
wait_for.size(), tpu_events));
|
||||
@ -283,10 +283,10 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
::TpuBufferHandle** childbuf = new ::TpuBufferHandle*[children.size()];
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
childbuf[i] =
|
||||
static_cast<ExternalBufferHandle* const>(children[i])->handle_;
|
||||
static_cast<DirectBufferHandle* const>(children[i])->handle_;
|
||||
}
|
||||
|
||||
auto bh = absl::make_unique<ExternalBufferHandle>(
|
||||
auto bh = absl::make_unique<DirectBufferHandle>(
|
||||
&driver_fn_, driver_fn_.TpuDriver_AllocateTuple(
|
||||
driver_, core_id, region, children.size(), childbuf,
|
||||
wait_for.size(), tpu_events));
|
||||
@ -300,10 +300,10 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
std::unique_ptr<BufferHandle> handle,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto tpu_events = MakeEventArray(wait_for);
|
||||
auto event = std::make_shared<ExternalEvent>(
|
||||
auto event = std::make_shared<DirectEvent>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_Deallocate(
|
||||
driver_, static_cast<ExternalBufferHandle*>(handle.get())->handle_,
|
||||
driver_, static_cast<DirectBufferHandle*>(handle.get())->handle_,
|
||||
wait_for.size(), tpu_events));
|
||||
delete[] tpu_events;
|
||||
return event;
|
||||
@ -313,10 +313,10 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
const void* src, BufferHandle* dst,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto tpu_events = MakeEventArray(wait_for);
|
||||
auto event = std::make_shared<ExternalEvent>(
|
||||
auto event = std::make_shared<DirectEvent>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_TransferToDevice(
|
||||
driver_, src, static_cast<ExternalBufferHandle*>(dst)->handle_,
|
||||
driver_, src, static_cast<DirectBufferHandle*>(dst)->handle_,
|
||||
wait_for.size(), tpu_events));
|
||||
delete[] tpu_events;
|
||||
return event;
|
||||
@ -326,11 +326,11 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
const BufferHandle* src, void* dst,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto tpu_events = MakeEventArray(wait_for);
|
||||
auto event = std::make_shared<ExternalEvent>(
|
||||
auto event = std::make_shared<DirectEvent>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_TransferFromDevice(
|
||||
driver_, static_cast<const ExternalBufferHandle*>(src)->handle_,
|
||||
dst, wait_for.size(), tpu_events));
|
||||
driver_, static_cast<const DirectBufferHandle*>(src)->handle_, dst,
|
||||
wait_for.size(), tpu_events));
|
||||
delete[] tpu_events;
|
||||
return event;
|
||||
}
|
||||
@ -339,11 +339,11 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
const BufferHandle* src, BufferHandle* dst,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto tpu_events = MakeEventArray(wait_for);
|
||||
auto event = std::make_shared<ExternalEvent>(
|
||||
auto event = std::make_shared<DirectEvent>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_TransferFromDeviceToDevice(
|
||||
driver_, static_cast<const ExternalBufferHandle*>(src)->handle_,
|
||||
static_cast<ExternalBufferHandle*>(dst)->handle_, wait_for.size(),
|
||||
driver_, static_cast<const DirectBufferHandle*>(src)->handle_,
|
||||
static_cast<DirectBufferHandle*>(dst)->handle_, wait_for.size(),
|
||||
tpu_events));
|
||||
delete[] tpu_events;
|
||||
return event;
|
||||
@ -362,7 +362,7 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto handle = absl::make_unique<ExternalCompiledProgramHandle>(
|
||||
auto handle = absl::make_unique<DirectCompiledProgramHandle>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_CompileProgram(driver_, hlo, num_replicas,
|
||||
wait_for.size(), tpu_events));
|
||||
@ -376,11 +376,11 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto tpu_events = MakeEventArray(wait_for);
|
||||
|
||||
auto loaded_handle = absl::make_unique<ExternalLoadedProgramHandle>(
|
||||
auto loaded_handle = absl::make_unique<DirectLoadedProgramHandle>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_LoadProgram(
|
||||
driver_, core_id,
|
||||
static_cast<const ExternalCompiledProgramHandle*>(handle)->handle_,
|
||||
static_cast<const DirectCompiledProgramHandle*>(handle)->handle_,
|
||||
wait_for.size(), tpu_events));
|
||||
|
||||
delete[] tpu_events;
|
||||
@ -391,11 +391,11 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
std::unique_ptr<LoadedProgramHandle> handle,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto tpu_events = MakeEventArray(wait_for);
|
||||
auto event = std::make_shared<ExternalEvent>(
|
||||
auto event = std::make_shared<DirectEvent>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_UnloadProgram(
|
||||
driver_,
|
||||
static_cast<ExternalLoadedProgramHandle*>(handle.get())->handle_,
|
||||
static_cast<DirectLoadedProgramHandle*>(handle.get())->handle_,
|
||||
wait_for.size(), tpu_events));
|
||||
delete[] tpu_events;
|
||||
return event;
|
||||
@ -412,22 +412,21 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
inputv.reserve(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
inputv.push_back(
|
||||
static_cast<ExternalBufferHandle* const>(inputs[i])->handle_);
|
||||
static_cast<DirectBufferHandle* const>(inputs[i])->handle_);
|
||||
}
|
||||
std::vector<::TpuBufferHandle*> outputv;
|
||||
outputv.reserve(outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputv.push_back(
|
||||
static_cast<ExternalBufferHandle* const>(outputs[i])->handle_);
|
||||
static_cast<DirectBufferHandle* const>(outputs[i])->handle_);
|
||||
}
|
||||
|
||||
struct DeviceAssignment da = {device_assignment.replica_count(),
|
||||
device_assignment.computation_count()};
|
||||
auto event = std::make_shared<ExternalEvent>(
|
||||
auto event = std::make_shared<DirectEvent>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_ExecuteProgram(
|
||||
driver_,
|
||||
static_cast<ExternalLoadedProgramHandle*>(program)->handle_,
|
||||
driver_, static_cast<DirectLoadedProgramHandle*>(program)->handle_,
|
||||
inputs.size(), inputv.data(), outputs.size(), outputv.data(), da,
|
||||
wait_for.size(), tpu_events));
|
||||
|
||||
@ -436,7 +435,7 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
}
|
||||
|
||||
std::unique_ptr<TpuLinearizer> GetLinearizer() override {
|
||||
return std::make_unique<ExternalTpuLinearizer>(driver_, &driver_fn_);
|
||||
return std::make_unique<DirectTpuLinearizer>(driver_, &driver_fn_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -447,20 +446,20 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
if (wait_for.empty()) return nullptr;
|
||||
::TpuEvent** ret = new ::TpuEvent*[wait_for.size()];
|
||||
for (int i = 0; i < wait_for.size(); i++) {
|
||||
ret[i] = static_cast<ExternalEvent* const>(wait_for[i])->event_;
|
||||
ret[i] = static_cast<DirectEvent* const>(wait_for[i])->event_;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterExternalTpuDriver(
|
||||
xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterDirectTpuDriver(
|
||||
const TpuDriverConfig& config) {
|
||||
std::string shared_lib = config.worker().substr(strlen(kExternalProtocol));
|
||||
std::string shared_lib = config.worker().substr(strlen(kDirectProtocol));
|
||||
return xla::StatusOr<std::unique_ptr<TpuDriver>>(
|
||||
absl::make_unique<ExternalTpuDriver>(shared_lib));
|
||||
absl::make_unique<DirectTpuDriver>(shared_lib));
|
||||
}
|
||||
|
||||
REGISTER_TPU_DRIVER(kExternalProtocol, RegisterExternalTpuDriver);
|
||||
REGISTER_TPU_DRIVER(kDirectProtocol, RegisterDirectTpuDriver);
|
||||
|
||||
} // namespace
|
||||
} // namespace tpu_driver
|
Loading…
Reference in New Issue
Block a user