Clarify some interface comments and use libraries more consistently.

PiperOrigin-RevId: 280590844
Change-Id: I287921239fbd55d559912c8cb9f203918043cff4
This commit is contained in:
Wenhao Jia 2019-11-14 22:21:07 -08:00 committed by TensorFlower Gardener
parent e5402b6883
commit 2205360b73
3 changed files with 32 additions and 35 deletions

View File

@ -15,8 +15,7 @@
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_
#include <stdint.h>
#include <cstdint>
#include <ostream>
#include <string>
#include <utility>

View File

@ -66,9 +66,9 @@ class GrpcEvent : public Event {
class GrpcBufferHandle : public BufferHandle {
public:
explicit GrpcBufferHandle(EventId id, std::shared_ptr<GrpcEvent> event,
int64_t bytes,
std::optional<xla::ShapeProto> shape = std::nullopt)
explicit GrpcBufferHandle(
EventId id, std::shared_ptr<GrpcEvent> event, int64_t bytes,
absl::optional<xla::ShapeProto> shape = absl::nullopt)
: id_(id),
stream_(event->stream()),
event_(std::move(event)),
@ -81,14 +81,14 @@ class GrpcBufferHandle : public BufferHandle {
EventId id() const { return id_; }
GrpcTpuStream* stream() const { return stream_; }
std::optional<xla::ShapeProto> shape() override { return shape_; }
absl::optional<xla::ShapeProto> shape() override { return shape_; }
private:
const EventId id_;
GrpcTpuStream* stream_;
std::shared_ptr<GrpcEvent> event_;
int64_t bytes_;
std::optional<xla::ShapeProto> shape_;
absl::optional<xla::ShapeProto> shape_;
};
class GrpcCompiledProgramHandle : public CompiledProgramHandle {

View File

@ -20,7 +20,6 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <vector>
@ -42,10 +41,6 @@ namespace tpu_driver {
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
struct TpuDriverConfig {
std::string worker;
};
// Represents the deferred completion of a scheduled operation.
//
// Events may be blocked on, or used as `wait_for` arguments to enforce
@ -54,7 +49,7 @@ class Event {
public:
virtual ~Event() {}
// Block until the event completes and return the result status.
// Block until the event completes and returns the result status.
virtual xla::Status Await() = 0;
virtual absl::optional<xla::Status> AwaitWithTimeout(
absl::Duration duration) = 0;
@ -68,13 +63,14 @@ class BufferHandle {
public:
virtual ~BufferHandle() {}
// This Event completes after the device memory is actually allocated.
// This event completes after the device memory is actually allocated.
//
// Methods that take a BufferHandle, including ExecuteProgram and Transfer*,
// Methods that take a buffer handle, such as ExecuteProgram and Transfer*,
// automatically add this event as a dependency.
virtual std::shared_ptr<Event> OnReady() = 0;
virtual int64_t size_in_bytes() = 0;
virtual std::optional<xla::ShapeProto> shape() = 0;
virtual absl::optional<xla::ShapeProto> shape() = 0;
};
// Represents a compiled program on the host.
@ -84,16 +80,16 @@ class CompiledProgramHandle {
// This Event completes after the program is actually compiled on the host.
//
// Methods that take a CompiledProgramHandle, including LoadProgram,
// Methods that take a compiled program handle, including LoadProgram,
// automatically add this event as a dependency.
virtual std::shared_ptr<Event> OnReady() = 0;
virtual int64_t size_in_bytes() {
LOG(FATAL) << "Unimplemented.";
return 0;
}
// Gets the compiled program metadata (including shape info). Will block until
// compile completes.
// Returns the shape of the compiled program. Blocks until compile completes.
virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0;
};
@ -104,9 +100,10 @@ class LoadedProgramHandle {
// This Event completes after the program is actually loaded on the device.
//
// Methods that take a LoadedProgramHandle, including ExecuteProgram and
// Methods that take a loaded program handle, including ExecuteProgram and
// UnloadProgram, automatically add this event as a dependency.
virtual std::shared_ptr<Event> OnReady() = 0;
virtual int64_t size_in_bytes() {
LOG(FATAL) << "Unimplemented.";
return 0;
@ -114,7 +111,7 @@ class LoadedProgramHandle {
};
// A TpuLinearizer manages the linearization and delinearization of user buffers
// in the TPU Driver. This interface is not yet implemented.
// in the TPU driver. This interface is not yet implemented.
class TpuLinearizer {
public:
virtual ~TpuLinearizer() {}
@ -122,26 +119,23 @@ class TpuLinearizer {
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
return ::tpu_driver::ComputeBytesFromShape(shape);
}
virtual uint64_t ComputeLinearizedBytesFromShape(
const xla::ShapeProto& shape) = 0;
virtual xla::Status LinearizeShape(void* dst, const void* src,
const xla::ShapeProto& shape) = 0;
virtual xla::Status DelinearizeShape(void* dst, const void* src,
const xla::ShapeProto& shape) = 0;
};
// A TpuDriver manages a set of operations scheduled to run on a TPU device.
// A TpuDriver manages a set of operations scheduled to run on a TPU system.
//
// By default, two independently scheduled operations may execute in any order;
// ordering is imposed in one of two ways:
// By default, two independently scheduled operations may execute in any order.
// Ordering can be imposed in one of two ways:
//
// 1. Users can specify event dependency via the `wait_for` argument.
// 2. All API calls using the same BufferHandle are serialized in calling order.
// In particular, operations using buffer or program handles will implicitly
// wait for the handles to be materialized before executing.
// 1. Users can specify event dependencies via the `wait_for` argument.
// 2. Operations using buffer or program handles implicitly wait for the handles
// to become ready before executing.
//
// For returned handle objects, the user is responsible for calling the release
// methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr
@ -149,7 +143,7 @@ class TpuLinearizer {
// no release method; the user can let them go out of scope naturally. As soon
// as those methods accepting plain-pointer arguments return, the user can let
// the corresponding smart-pointer objects be released or go out of scope,
// regardless of whether the scheduled device operations have completed or not.
// regardless of whether the scheduled device operations have started execution.
class TpuDriver {
public:
virtual ~TpuDriver() {}
@ -158,7 +152,7 @@ class TpuDriver {
// Synchronous. Reset the state of the TPU driver. All running programs
// will be terminated and all allocations reset.
//
// All events and buffer handles created prior to Reset() will be invalid
// All events and buffer handles created prior to Reset() will be invalid,
// and any use will result in undefined behavior.
virtual xla::Status Reset() = 0;
@ -190,11 +184,11 @@ class TpuDriver {
* `src` must be laid out in consecutive row-major format for ingestion, and
* each element must take up the number of bytes specified by the type.
*
* For example, if you have a [3,3,3] tensor with a Float32 type, then the
* memory layout would be as follows:
* For example, for a [3,3,3] tensor with a Float32 type, the memory layout
* would be as follows:
*
* [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ...
* [1,2,2], [2,0,0], ..., [2,2,2].
* [1,2,2], [2,0,0], ..., [2,2,2],
*
* and the entire buffer will be 108 bytes (27 elements x 4 bytes).
*
@ -233,6 +227,10 @@ class TpuDriver {
virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; }
};
struct TpuDriverConfig {
std::string worker;
};
class TpuDriverRegistry {
public:
static xla::StatusOr<std::unique_ptr<TpuDriver>> Open(