Clarify some interface comments and use libraries more consistently.

PiperOrigin-RevId: 280590844
Change-Id: I287921239fbd55d559912c8cb9f203918043cff4
This commit is contained in:
Wenhao Jia 2019-11-14 22:21:07 -08:00 committed by TensorFlower Gardener
parent e5402b6883
commit 2205360b73
3 changed files with 32 additions and 35 deletions

View File

@ -15,8 +15,7 @@
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_
#include <stdint.h> #include <cstdint>
#include <ostream> #include <ostream>
#include <string> #include <string>
#include <utility> #include <utility>

View File

@ -66,9 +66,9 @@ class GrpcEvent : public Event {
class GrpcBufferHandle : public BufferHandle { class GrpcBufferHandle : public BufferHandle {
public: public:
explicit GrpcBufferHandle(EventId id, std::shared_ptr<GrpcEvent> event, explicit GrpcBufferHandle(
int64_t bytes, EventId id, std::shared_ptr<GrpcEvent> event, int64_t bytes,
std::optional<xla::ShapeProto> shape = std::nullopt) absl::optional<xla::ShapeProto> shape = absl::nullopt)
: id_(id), : id_(id),
stream_(event->stream()), stream_(event->stream()),
event_(std::move(event)), event_(std::move(event)),
@ -81,14 +81,14 @@ class GrpcBufferHandle : public BufferHandle {
EventId id() const { return id_; } EventId id() const { return id_; }
GrpcTpuStream* stream() const { return stream_; } GrpcTpuStream* stream() const { return stream_; }
std::optional<xla::ShapeProto> shape() override { return shape_; } absl::optional<xla::ShapeProto> shape() override { return shape_; }
private: private:
const EventId id_; const EventId id_;
GrpcTpuStream* stream_; GrpcTpuStream* stream_;
std::shared_ptr<GrpcEvent> event_; std::shared_ptr<GrpcEvent> event_;
int64_t bytes_; int64_t bytes_;
std::optional<xla::ShapeProto> shape_; absl::optional<xla::ShapeProto> shape_;
}; };
class GrpcCompiledProgramHandle : public CompiledProgramHandle { class GrpcCompiledProgramHandle : public CompiledProgramHandle {

View File

@ -20,7 +20,6 @@
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <vector> #include <vector>
@ -42,10 +41,6 @@ namespace tpu_driver {
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape); uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
struct TpuDriverConfig {
std::string worker;
};
// Represents the deferred completion of a scheduled operation. // Represents the deferred completion of a scheduled operation.
// //
// Events may be blocked on, or used as `wait_for` arguments to enforce // Events may be blocked on, or used as `wait_for` arguments to enforce
@ -54,7 +49,7 @@ class Event {
public: public:
virtual ~Event() {} virtual ~Event() {}
// Block until the event completes and return the result status. // Block until the event completes and returns the result status.
virtual xla::Status Await() = 0; virtual xla::Status Await() = 0;
virtual absl::optional<xla::Status> AwaitWithTimeout( virtual absl::optional<xla::Status> AwaitWithTimeout(
absl::Duration duration) = 0; absl::Duration duration) = 0;
@ -68,13 +63,14 @@ class BufferHandle {
public: public:
virtual ~BufferHandle() {} virtual ~BufferHandle() {}
// This Event completes after the device memory is actually allocated. // This event completes after the device memory is actually allocated.
// //
// Methods that take a BufferHandle, including ExecuteProgram and Transfer*, // Methods that take a buffer handle, such as ExecuteProgram and Transfer*,
// automatically add this event as a dependency. // automatically add this event as a dependency.
virtual std::shared_ptr<Event> OnReady() = 0; virtual std::shared_ptr<Event> OnReady() = 0;
virtual int64_t size_in_bytes() = 0; virtual int64_t size_in_bytes() = 0;
virtual std::optional<xla::ShapeProto> shape() = 0; virtual absl::optional<xla::ShapeProto> shape() = 0;
}; };
// Represents a compiled program on the host. // Represents a compiled program on the host.
@ -84,16 +80,16 @@ class CompiledProgramHandle {
// This Event completes after the program is actually compiled on the host. // This Event completes after the program is actually compiled on the host.
// //
// Methods that take a CompiledProgramHandle, including LoadProgram, // Methods that take a compiled program handle, including LoadProgram,
// automatically add this event as a dependency. // automatically add this event as a dependency.
virtual std::shared_ptr<Event> OnReady() = 0; virtual std::shared_ptr<Event> OnReady() = 0;
virtual int64_t size_in_bytes() { virtual int64_t size_in_bytes() {
LOG(FATAL) << "Unimplemented."; LOG(FATAL) << "Unimplemented.";
return 0; return 0;
} }
// Gets the compiled program metadata (including shape info). Will block until // Returns the shape of the compiled program. Blocks until compile completes.
// compile completes.
virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0; virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0;
}; };
@ -104,9 +100,10 @@ class LoadedProgramHandle {
// This Event completes after the program is actually loaded on the device. // This Event completes after the program is actually loaded on the device.
// //
// Methods that take a LoadedProgramHandle, including ExecuteProgram and // Methods that take a loaded program handle, including ExecuteProgram and
// UnloadProgram, automatically add this event as a dependency. // UnloadProgram, automatically add this event as a dependency.
virtual std::shared_ptr<Event> OnReady() = 0; virtual std::shared_ptr<Event> OnReady() = 0;
virtual int64_t size_in_bytes() { virtual int64_t size_in_bytes() {
LOG(FATAL) << "Unimplemented."; LOG(FATAL) << "Unimplemented.";
return 0; return 0;
@ -114,7 +111,7 @@ class LoadedProgramHandle {
}; };
// A TpuLinearizer manages the linearization and delinearization of user buffers // A TpuLinearizer manages the linearization and delinearization of user buffers
// in the TPU Driver. This interface is not yet implemented. // in the TPU driver. This interface is not yet implemented.
class TpuLinearizer { class TpuLinearizer {
public: public:
virtual ~TpuLinearizer() {} virtual ~TpuLinearizer() {}
@ -122,26 +119,23 @@ class TpuLinearizer {
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
return ::tpu_driver::ComputeBytesFromShape(shape); return ::tpu_driver::ComputeBytesFromShape(shape);
} }
virtual uint64_t ComputeLinearizedBytesFromShape( virtual uint64_t ComputeLinearizedBytesFromShape(
const xla::ShapeProto& shape) = 0; const xla::ShapeProto& shape) = 0;
virtual xla::Status LinearizeShape(void* dst, const void* src, virtual xla::Status LinearizeShape(void* dst, const void* src,
const xla::ShapeProto& shape) = 0; const xla::ShapeProto& shape) = 0;
virtual xla::Status DelinearizeShape(void* dst, const void* src, virtual xla::Status DelinearizeShape(void* dst, const void* src,
const xla::ShapeProto& shape) = 0; const xla::ShapeProto& shape) = 0;
}; };
// A TpuDriver manages a set of operations scheduled to run on a TPU device. // A TpuDriver manages a set of operations scheduled to run on a TPU system.
// //
// By default, two independently scheduled operations may execute in any order; // By default, two independently scheduled operations may execute in any order.
// ordering is imposed in one of two ways: // Ordering can be imposed in one of two ways:
// //
// 1. Users can specify event dependency via the `wait_for` argument. // 1. Users can specify event dependencies via the `wait_for` argument.
// 2. All API calls using the same BufferHandle are serialized in calling order. // 2. Operations using buffer or program handles implicitly wait for the handles
// In particular, operations using buffer or program handles will implicitly // to become ready before executing.
// wait for the handles to be materialized before executing.
// //
// For returned handle objects, the user is responsible for calling the release // For returned handle objects, the user is responsible for calling the release
// methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr // methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr
@ -149,7 +143,7 @@ class TpuLinearizer {
// no release method; the user can let them go out of scope naturally. As soon // no release method; the user can let them go out of scope naturally. As soon
// as those methods accepting plain-pointer arguments return, the user can let // as those methods accepting plain-pointer arguments return, the user can let
// the corresponding smart-pointer objects be released or go out of scope, // the corresponding smart-pointer objects be released or go out of scope,
// regardless of whether the scheduled device operations have completed or not. // regardless of whether the scheduled device operations have started execution.
class TpuDriver { class TpuDriver {
public: public:
virtual ~TpuDriver() {} virtual ~TpuDriver() {}
@ -158,7 +152,7 @@ class TpuDriver {
// Synchronous. Reset the state of the TPU driver. All running programs // Synchronous. Reset the state of the TPU driver. All running programs
// will be terminated and all allocations reset. // will be terminated and all allocations reset.
// //
// All events and buffer handles created prior to Reset() will be invalid // All events and buffer handles created prior to Reset() will be invalid,
// and any use will result in undefined behavior. // and any use will result in undefined behavior.
virtual xla::Status Reset() = 0; virtual xla::Status Reset() = 0;
@ -190,11 +184,11 @@ class TpuDriver {
* `src` must be laid out in consecutive row-major format for ingestion, and * `src` must be laid out in consecutive row-major format for ingestion, and
* each element must take up the number of bytes specified by the type. * each element must take up the number of bytes specified by the type.
* *
* For example, if you have a [3,3,3] tensor with a Float32 type, then the * For example, for a [3,3,3] tensor with a Float32 type, the memory layout
* memory layout would be as follows: * would be as follows:
* *
* [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ... * [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ...
* [1,2,2], [2,0,0], ..., [2,2,2]. * [1,2,2], [2,0,0], ..., [2,2,2],
* *
* and the entire buffer will be 108 bytes (27 elements x 4 bytes). * and the entire buffer will be 108 bytes (27 elements x 4 bytes).
* *
@ -233,6 +227,10 @@ class TpuDriver {
virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; } virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; }
}; };
struct TpuDriverConfig {
std::string worker;
};
class TpuDriverRegistry { class TpuDriverRegistry {
public: public:
static xla::StatusOr<std::unique_ptr<TpuDriver>> Open( static xla::StatusOr<std::unique_ptr<TpuDriver>> Open(