diff --git a/tensorflow/compiler/xla/python/tpu_driver/event_id.h b/tensorflow/compiler/xla/python/tpu_driver/event_id.h index 169631410a1..ed5f9c87cf0 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/event_id.h +++ b/tensorflow/compiler/xla/python/tpu_driver/event_id.h @@ -15,8 +15,7 @@ #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ -#include - +#include #include #include #include diff --git a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc index 591792974aa..842b83299ae 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc @@ -66,9 +66,9 @@ class GrpcEvent : public Event { class GrpcBufferHandle : public BufferHandle { public: - explicit GrpcBufferHandle(EventId id, std::shared_ptr event, - int64_t bytes, - std::optional shape = std::nullopt) + explicit GrpcBufferHandle( + EventId id, std::shared_ptr event, int64_t bytes, + absl::optional shape = absl::nullopt) : id_(id), stream_(event->stream()), event_(std::move(event)), @@ -81,14 +81,14 @@ class GrpcBufferHandle : public BufferHandle { EventId id() const { return id_; } GrpcTpuStream* stream() const { return stream_; } - std::optional shape() override { return shape_; } + absl::optional shape() override { return shape_; } private: const EventId id_; GrpcTpuStream* stream_; std::shared_ptr event_; int64_t bytes_; - std::optional shape_; + absl::optional shape_; }; class GrpcCompiledProgramHandle : public CompiledProgramHandle { diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h index 2a93de8b6e5..3b010b38a17 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h @@ -20,7 +20,6 @@ #include #include #include -#include #include #include @@ -42,10 +41,6 @@ namespace tpu_driver { uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape); -struct TpuDriverConfig { - std::string worker; -}; - // Represents the deferred completion of a scheduled operation. // // Events may be blocked on, or used as `wait_for` arguments to enforce @@ -54,7 +49,7 @@ class Event { public: virtual ~Event() {} - // Block until the event completes and return the result status. + // Block until the event completes and returns the result status. virtual xla::Status Await() = 0; virtual absl::optional AwaitWithTimeout( absl::Duration duration) = 0; @@ -68,13 +63,14 @@ class BufferHandle { public: virtual ~BufferHandle() {} - // This Event completes after the device memory is actually allocated. + // This event completes after the device memory is actually allocated. // - // Methods that take a BufferHandle, including ExecuteProgram and Transfer*, + // Methods that take a buffer handle, such as ExecuteProgram and Transfer*, // automatically add this event as a dependency. virtual std::shared_ptr OnReady() = 0; + virtual int64_t size_in_bytes() = 0; - virtual std::optional shape() = 0; + virtual absl::optional shape() = 0; }; // Represents a compiled program on the host. @@ -84,16 +80,16 @@ class CompiledProgramHandle { // This Event completes after the program is actually compiled on the host. // - // Methods that take a CompiledProgramHandle, including LoadProgram, + // Methods that take a compiled program handle, including LoadProgram, // automatically add this event as a dependency. virtual std::shared_ptr OnReady() = 0; + virtual int64_t size_in_bytes() { LOG(FATAL) << "Unimplemented."; return 0; } - // Gets the compiled program metadata (including shape info). Will block until - // compile completes. + // Returns the shape of the compiled program. Blocks until compile completes. virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0; }; @@ -104,9 +100,10 @@ class LoadedProgramHandle { // This Event completes after the program is actually loaded on the device. // - // Methods that take a LoadedProgramHandle, including ExecuteProgram and + // Methods that take a loaded program handle, including ExecuteProgram and // UnloadProgram, automatically add this event as a dependency. virtual std::shared_ptr OnReady() = 0; + virtual int64_t size_in_bytes() { LOG(FATAL) << "Unimplemented."; return 0; @@ -114,7 +111,7 @@ class LoadedProgramHandle { }; // A TpuLinearizer manages the linearization and delinearization of user buffers -// in the TPU Driver. This interface is not yet implemented. +// in the TPU driver. This interface is not yet implemented. class TpuLinearizer { public: virtual ~TpuLinearizer() {} @@ -122,26 +119,23 @@ class TpuLinearizer { uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { return ::tpu_driver::ComputeBytesFromShape(shape); } - virtual uint64_t ComputeLinearizedBytesFromShape( const xla::ShapeProto& shape) = 0; virtual xla::Status LinearizeShape(void* dst, const void* src, const xla::ShapeProto& shape) = 0; - virtual xla::Status DelinearizeShape(void* dst, const void* src, const xla::ShapeProto& shape) = 0; }; -// A TpuDriver manages a set of operations scheduled to run on a TPU device. +// A TpuDriver manages a set of operations scheduled to run on a TPU system. // -// By default, two independently scheduled operations may execute in any order; -// ordering is imposed in one of two ways: +// By default, two independently scheduled operations may execute in any order. +// Ordering can be imposed in one of two ways: // -// 1. Users can specify event dependency via the `wait_for` argument. -// 2. All API calls using the same BufferHandle are serialized in calling order. -// In particular, operations using buffer or program handles will implicitly -// wait for the handles to be materialized before executing. +// 1. Users can specify event dependencies via the `wait_for` argument. +// 2. Operations using buffer or program handles implicitly wait for the handles +// to become ready before executing. // // For returned handle objects, the user is responsible for calling the release // methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr @@ -149,7 +143,7 @@ class TpuLinearizer { // no release method; the user can let them go out of scope naturally. As soon // as those methods accepting plain-pointer arguments return, the user can let // the corresponding smart-pointer objects be released or go out of scope, -// regardless of whether the scheduled device operations have completed or not. +// regardless of whether the scheduled device operations have started execution. class TpuDriver { public: virtual ~TpuDriver() {} @@ -158,7 +152,7 @@ class TpuDriver { // Synchronous. Reset the state of the TPU driver. All running programs // will be terminated and all allocations reset. // - // All events and buffer handles created prior to Reset() will be invalid + // All events and buffer handles created prior to Reset() will be invalid, // and any use will result in undefined behavior. virtual xla::Status Reset() = 0; @@ -190,11 +184,11 @@ class TpuDriver { * `src` must be laid out in consecutive row-major format for ingestion, and * each element must take up the number of bytes specified by the type. * - * For example, if you have a [3,3,3] tensor with a Float32 type, then the - * memory layout would be as follows: + * For example, for a [3,3,3] tensor with a Float32 type, the memory layout + * would be as follows: * * [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ... - * [1,2,2], [2,0,0], ..., [2,2,2]. + * [1,2,2], [2,0,0], ..., [2,2,2], * * and the entire buffer will be 108 bytes (27 elements x 4 bytes). * @@ -233,6 +227,10 @@ class TpuDriver { virtual std::unique_ptr GetLinearizer() { return nullptr; } }; +struct TpuDriverConfig { + std::string worker; +}; + class TpuDriverRegistry { public: static xla::StatusOr> Open(