Clarify some interface comments and use libraries more consistently.
PiperOrigin-RevId: 280590844 Change-Id: I287921239fbd55d559912c8cb9f203918043cff4
This commit is contained in:
parent
e5402b6883
commit
2205360b73
@ -15,8 +15,7 @@
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_
|
#define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <cstdint>
|
||||||
|
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -66,9 +66,9 @@ class GrpcEvent : public Event {
|
|||||||
|
|
||||||
class GrpcBufferHandle : public BufferHandle {
|
class GrpcBufferHandle : public BufferHandle {
|
||||||
public:
|
public:
|
||||||
explicit GrpcBufferHandle(EventId id, std::shared_ptr<GrpcEvent> event,
|
explicit GrpcBufferHandle(
|
||||||
int64_t bytes,
|
EventId id, std::shared_ptr<GrpcEvent> event, int64_t bytes,
|
||||||
std::optional<xla::ShapeProto> shape = std::nullopt)
|
absl::optional<xla::ShapeProto> shape = absl::nullopt)
|
||||||
: id_(id),
|
: id_(id),
|
||||||
stream_(event->stream()),
|
stream_(event->stream()),
|
||||||
event_(std::move(event)),
|
event_(std::move(event)),
|
||||||
@ -81,14 +81,14 @@ class GrpcBufferHandle : public BufferHandle {
|
|||||||
EventId id() const { return id_; }
|
EventId id() const { return id_; }
|
||||||
GrpcTpuStream* stream() const { return stream_; }
|
GrpcTpuStream* stream() const { return stream_; }
|
||||||
|
|
||||||
std::optional<xla::ShapeProto> shape() override { return shape_; }
|
absl::optional<xla::ShapeProto> shape() override { return shape_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const EventId id_;
|
const EventId id_;
|
||||||
GrpcTpuStream* stream_;
|
GrpcTpuStream* stream_;
|
||||||
std::shared_ptr<GrpcEvent> event_;
|
std::shared_ptr<GrpcEvent> event_;
|
||||||
int64_t bytes_;
|
int64_t bytes_;
|
||||||
std::optional<xla::ShapeProto> shape_;
|
absl::optional<xla::ShapeProto> shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GrpcCompiledProgramHandle : public CompiledProgramHandle {
|
class GrpcCompiledProgramHandle : public CompiledProgramHandle {
|
||||||
|
@ -20,7 +20,6 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -42,10 +41,6 @@ namespace tpu_driver {
|
|||||||
|
|
||||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
|
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
|
||||||
|
|
||||||
struct TpuDriverConfig {
|
|
||||||
std::string worker;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Represents the deferred completion of a scheduled operation.
|
// Represents the deferred completion of a scheduled operation.
|
||||||
//
|
//
|
||||||
// Events may be blocked on, or used as `wait_for` arguments to enforce
|
// Events may be blocked on, or used as `wait_for` arguments to enforce
|
||||||
@ -54,7 +49,7 @@ class Event {
|
|||||||
public:
|
public:
|
||||||
virtual ~Event() {}
|
virtual ~Event() {}
|
||||||
|
|
||||||
// Block until the event completes and return the result status.
|
// Block until the event completes and returns the result status.
|
||||||
virtual xla::Status Await() = 0;
|
virtual xla::Status Await() = 0;
|
||||||
virtual absl::optional<xla::Status> AwaitWithTimeout(
|
virtual absl::optional<xla::Status> AwaitWithTimeout(
|
||||||
absl::Duration duration) = 0;
|
absl::Duration duration) = 0;
|
||||||
@ -68,13 +63,14 @@ class BufferHandle {
|
|||||||
public:
|
public:
|
||||||
virtual ~BufferHandle() {}
|
virtual ~BufferHandle() {}
|
||||||
|
|
||||||
// This Event completes after the device memory is actually allocated.
|
// This event completes after the device memory is actually allocated.
|
||||||
//
|
//
|
||||||
// Methods that take a BufferHandle, including ExecuteProgram and Transfer*,
|
// Methods that take a buffer handle, such as ExecuteProgram and Transfer*,
|
||||||
// automatically add this event as a dependency.
|
// automatically add this event as a dependency.
|
||||||
virtual std::shared_ptr<Event> OnReady() = 0;
|
virtual std::shared_ptr<Event> OnReady() = 0;
|
||||||
|
|
||||||
virtual int64_t size_in_bytes() = 0;
|
virtual int64_t size_in_bytes() = 0;
|
||||||
virtual std::optional<xla::ShapeProto> shape() = 0;
|
virtual absl::optional<xla::ShapeProto> shape() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Represents a compiled program on the host.
|
// Represents a compiled program on the host.
|
||||||
@ -84,16 +80,16 @@ class CompiledProgramHandle {
|
|||||||
|
|
||||||
// This Event completes after the program is actually compiled on the host.
|
// This Event completes after the program is actually compiled on the host.
|
||||||
//
|
//
|
||||||
// Methods that take a CompiledProgramHandle, including LoadProgram,
|
// Methods that take a compiled program handle, including LoadProgram,
|
||||||
// automatically add this event as a dependency.
|
// automatically add this event as a dependency.
|
||||||
virtual std::shared_ptr<Event> OnReady() = 0;
|
virtual std::shared_ptr<Event> OnReady() = 0;
|
||||||
|
|
||||||
virtual int64_t size_in_bytes() {
|
virtual int64_t size_in_bytes() {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
LOG(FATAL) << "Unimplemented.";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gets the compiled program metadata (including shape info). Will block until
|
// Returns the shape of the compiled program. Blocks until compile completes.
|
||||||
// compile completes.
|
|
||||||
virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0;
|
virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -104,9 +100,10 @@ class LoadedProgramHandle {
|
|||||||
|
|
||||||
// This Event completes after the program is actually loaded on the device.
|
// This Event completes after the program is actually loaded on the device.
|
||||||
//
|
//
|
||||||
// Methods that take a LoadedProgramHandle, including ExecuteProgram and
|
// Methods that take a loaded program handle, including ExecuteProgram and
|
||||||
// UnloadProgram, automatically add this event as a dependency.
|
// UnloadProgram, automatically add this event as a dependency.
|
||||||
virtual std::shared_ptr<Event> OnReady() = 0;
|
virtual std::shared_ptr<Event> OnReady() = 0;
|
||||||
|
|
||||||
virtual int64_t size_in_bytes() {
|
virtual int64_t size_in_bytes() {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
LOG(FATAL) << "Unimplemented.";
|
||||||
return 0;
|
return 0;
|
||||||
@ -114,7 +111,7 @@ class LoadedProgramHandle {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// A TpuLinearizer manages the linearization and delinearization of user buffers
|
// A TpuLinearizer manages the linearization and delinearization of user buffers
|
||||||
// in the TPU Driver. This interface is not yet implemented.
|
// in the TPU driver. This interface is not yet implemented.
|
||||||
class TpuLinearizer {
|
class TpuLinearizer {
|
||||||
public:
|
public:
|
||||||
virtual ~TpuLinearizer() {}
|
virtual ~TpuLinearizer() {}
|
||||||
@ -122,26 +119,23 @@ class TpuLinearizer {
|
|||||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
||||||
return ::tpu_driver::ComputeBytesFromShape(shape);
|
return ::tpu_driver::ComputeBytesFromShape(shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual uint64_t ComputeLinearizedBytesFromShape(
|
virtual uint64_t ComputeLinearizedBytesFromShape(
|
||||||
const xla::ShapeProto& shape) = 0;
|
const xla::ShapeProto& shape) = 0;
|
||||||
|
|
||||||
virtual xla::Status LinearizeShape(void* dst, const void* src,
|
virtual xla::Status LinearizeShape(void* dst, const void* src,
|
||||||
const xla::ShapeProto& shape) = 0;
|
const xla::ShapeProto& shape) = 0;
|
||||||
|
|
||||||
virtual xla::Status DelinearizeShape(void* dst, const void* src,
|
virtual xla::Status DelinearizeShape(void* dst, const void* src,
|
||||||
const xla::ShapeProto& shape) = 0;
|
const xla::ShapeProto& shape) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// A TpuDriver manages a set of operations scheduled to run on a TPU device.
|
// A TpuDriver manages a set of operations scheduled to run on a TPU system.
|
||||||
//
|
//
|
||||||
// By default, two independently scheduled operations may execute in any order;
|
// By default, two independently scheduled operations may execute in any order.
|
||||||
// ordering is imposed in one of two ways:
|
// Ordering can be imposed in one of two ways:
|
||||||
//
|
//
|
||||||
// 1. Users can specify event dependency via the `wait_for` argument.
|
// 1. Users can specify event dependencies via the `wait_for` argument.
|
||||||
// 2. All API calls using the same BufferHandle are serialized in calling order.
|
// 2. Operations using buffer or program handles implicitly wait for the handles
|
||||||
// In particular, operations using buffer or program handles will implicitly
|
// to become ready before executing.
|
||||||
// wait for the handles to be materialized before executing.
|
|
||||||
//
|
//
|
||||||
// For returned handle objects, the user is responsible for calling the release
|
// For returned handle objects, the user is responsible for calling the release
|
||||||
// methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr
|
// methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr
|
||||||
@ -149,7 +143,7 @@ class TpuLinearizer {
|
|||||||
// no release method; the user can let them go out of scope naturally. As soon
|
// no release method; the user can let them go out of scope naturally. As soon
|
||||||
// as those methods accepting plain-pointer arguments return, the user can let
|
// as those methods accepting plain-pointer arguments return, the user can let
|
||||||
// the corresponding smart-pointer objects be released or go out of scope,
|
// the corresponding smart-pointer objects be released or go out of scope,
|
||||||
// regardless of whether the scheduled device operations have completed or not.
|
// regardless of whether the scheduled device operations have started execution.
|
||||||
class TpuDriver {
|
class TpuDriver {
|
||||||
public:
|
public:
|
||||||
virtual ~TpuDriver() {}
|
virtual ~TpuDriver() {}
|
||||||
@ -158,7 +152,7 @@ class TpuDriver {
|
|||||||
// Synchronous. Reset the state of the TPU driver. All running programs
|
// Synchronous. Reset the state of the TPU driver. All running programs
|
||||||
// will be terminated and all allocations reset.
|
// will be terminated and all allocations reset.
|
||||||
//
|
//
|
||||||
// All events and buffer handles created prior to Reset() will be invalid
|
// All events and buffer handles created prior to Reset() will be invalid,
|
||||||
// and any use will result in undefined behavior.
|
// and any use will result in undefined behavior.
|
||||||
virtual xla::Status Reset() = 0;
|
virtual xla::Status Reset() = 0;
|
||||||
|
|
||||||
@ -190,11 +184,11 @@ class TpuDriver {
|
|||||||
* `src` must be laid out in consecutive row-major format for ingestion, and
|
* `src` must be laid out in consecutive row-major format for ingestion, and
|
||||||
* each element must take up the number of bytes specified by the type.
|
* each element must take up the number of bytes specified by the type.
|
||||||
*
|
*
|
||||||
* For example, if you have a [3,3,3] tensor with a Float32 type, then the
|
* For example, for a [3,3,3] tensor with a Float32 type, the memory layout
|
||||||
* memory layout would be as follows:
|
* would be as follows:
|
||||||
*
|
*
|
||||||
* [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ...
|
* [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ...
|
||||||
* [1,2,2], [2,0,0], ..., [2,2,2].
|
* [1,2,2], [2,0,0], ..., [2,2,2],
|
||||||
*
|
*
|
||||||
* and the entire buffer will be 108 bytes (27 elements x 4 bytes).
|
* and the entire buffer will be 108 bytes (27 elements x 4 bytes).
|
||||||
*
|
*
|
||||||
@ -233,6 +227,10 @@ class TpuDriver {
|
|||||||
virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; }
|
virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TpuDriverConfig {
|
||||||
|
std::string worker;
|
||||||
|
};
|
||||||
|
|
||||||
class TpuDriverRegistry {
|
class TpuDriverRegistry {
|
||||||
public:
|
public:
|
||||||
static xla::StatusOr<std::unique_ptr<TpuDriver>> Open(
|
static xla::StatusOr<std::unique_ptr<TpuDriver>> Open(
|
||||||
|
Loading…
Reference in New Issue
Block a user