diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 6a19a1fac09..0c9a2f3ab54 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -39,6 +39,15 @@ class Shape { // Construct a shape from a ShapeProto. explicit Shape(const ShapeProto& shape_proto); + Shape(PrimitiveType element_type, absl::Span dimensions, + absl::Span dynamic_dimensions, + std::vector tuple_shapes) + : element_type_(element_type), + dimensions_(dimensions.begin(), dimensions.end()), + dynamic_dimensions_(dynamic_dimensions.begin(), + dynamic_dimensions.end()), + tuple_shapes_(std::move(tuple_shapes)) {} + // Returns a ShapeProto representation of the Shape. ShapeProto ToProto() const; diff --git a/tensorflow/stream_executor/tpu/c_api_conversions.cc b/tensorflow/stream_executor/tpu/c_api_conversions.cc index 0a7801f45fc..d8e7dac2c2d 100644 --- a/tensorflow/stream_executor/tpu/c_api_conversions.cc +++ b/tensorflow/stream_executor/tpu/c_api_conversions.cc @@ -149,18 +149,171 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base) { return base; } -xla::Shape FromC(const XLA_Shape* shape) { - xla::ShapeProto p; - p.ParseFromArray(shape->bytes, shape->size); - return xla::Shape(p); +// Helper functions for copying data to possibly-inlined C arrays. + +// 'Src' and 'Dst' are allowed to be different types to make this usable with +// memory-identical types, e.g. int64 and int64_t. This should not be used with +// types that require a static_cast. +template +static void CopyVectorBase(const absl::Span src, DstList* dst) { + static_assert(sizeof(Src) == sizeof(Dst)); + dst->size = src.size(); + if (dst->size > TPU_C_API_MAX_INLINED) { + dst->heap = new Dst[dst->size]; + memcpy(dst->heap, src.data(), dst->size * sizeof(Src)); + } else { + memcpy(dst->inlined, src.data(), dst->size * sizeof(Src)); + } +} + +static void CopyVector(const absl::Span src, + Int64List* dst) { + return CopyVectorBase( + src, dst); +} +static void CopyVector(const absl::Span src, BoolList* dst) { + return CopyVectorBase(src, dst); +} + +static void CopyVector(const absl::Span src, TileList* dst) { + dst->size = src.size(); + XLA_Tile* c_tiles; + if (dst->size > TPU_C_API_MAX_INLINED) { + dst->heap = new XLA_Tile[dst->size]; + c_tiles = dst->heap; + } else { + c_tiles = dst->inlined; + } + for (int i = 0; i < dst->size; ++i) { + ToC(src[i], &c_tiles[i]); + } +} + +// Helper functions for creating a view of possibly-inlined C arrays. + +// 'Src' and 'Dst' are allowed to be different types to make this usable with +// memory-identical types, e.g. int64 and int64_t. This should not be used with +// types that require a static_cast. +template +static absl::Span MakeSpanBase(const SrcList& src_list) { + static_assert(sizeof(Src) == sizeof(Dst)); + const Src* src = src_list.size > TPU_C_API_MAX_INLINED ? src_list.heap + : &src_list.inlined[0]; + return absl::Span(reinterpret_cast(src), + src_list.size); +} + +static absl::Span MakeSpan( + const Int64List& src_list) { + return MakeSpanBase(src_list); +} +static absl::Span MakeSpan(const BoolList& src_list) { + return MakeSpanBase(src_list); } void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) { - xla::ShapeProto p = xla_shape.ToProto(); - std::string p_str = p.SerializeAsString(); - c_shape->bytes = new char[p_str.size()]; - c_shape->size = p_str.size(); - memcpy(c_shape->bytes, p_str.data(), p_str.size()); + c_shape->element_type = xla_shape.element_type(); + + CopyVector(xla_shape.dimensions(), &c_shape->dimensions); + CopyVector(xla_shape.dynamic_dimensions(), &c_shape->dynamic_dimensions); + + c_shape->ntuple_shapes = xla_shape.tuple_shapes_size(); + if (c_shape->ntuple_shapes > 0) { + c_shape->tuple_shapes = new XLA_Shape[c_shape->ntuple_shapes]; + for (int i = 0; i < c_shape->ntuple_shapes; ++i) { + ToC(xla_shape.tuple_shapes(i), &c_shape->tuple_shapes[i]); + } + } + + if (xla_shape.has_layout()) { + ToC(xla_shape.layout(), &c_shape->layout); + } else { + c_shape->layout.format = xla::INVALID_FORMAT; + } +} + +xla::Shape FromC(const XLA_Shape* c_shape) { + absl::Span dims = + MakeSpan(c_shape->dimensions); + absl::Span dynamic_dims = MakeSpan(c_shape->dynamic_dimensions); + + std::vector tuple_shapes; + tuple_shapes.reserve(c_shape->ntuple_shapes); + for (int i = 0; i < c_shape->ntuple_shapes; ++i) { + tuple_shapes.push_back(FromC(&c_shape->tuple_shapes[i])); + } + + xla::Shape result(static_cast(c_shape->element_type), + dims, dynamic_dims, std::move(tuple_shapes)); + if (c_shape->layout.format != xla::INVALID_FORMAT) { + *result.mutable_layout() = FromC(&c_shape->layout); + } + return result; +} + +void Free(XLA_Shape* c_shape) { + if (c_shape->dimensions.size > TPU_C_API_MAX_INLINED) { + delete[] c_shape->dimensions.heap; + } + if (c_shape->dynamic_dimensions.size > TPU_C_API_MAX_INLINED) { + delete[] c_shape->dynamic_dimensions.heap; + } + if (c_shape->ntuple_shapes > 0) { + for (int i = 0; i < c_shape->ntuple_shapes; ++i) { + Free(&c_shape->tuple_shapes[i]); + } + delete[] c_shape->tuple_shapes; + } + if (c_shape->layout.format != xla::INVALID_FORMAT) { + Free(&c_shape->layout); + } +} + +void ToC(const xla::Layout& layout, XLA_Layout* c_layout) { + c_layout->format = layout.format(); + CopyVector(layout.minor_to_major(), &c_layout->minor_to_major); + c_layout->element_size_in_bits = layout.element_size_in_bits(); + c_layout->memory_space = layout.memory_space(); + CopyVector(layout.tiles(), &c_layout->tiles); +} + +xla::Layout FromC(const XLA_Layout* c_layout) { + absl::Span minor_to_major = + MakeSpan(c_layout->minor_to_major); + absl::InlinedVector tiles; + const XLA_Tile* c_tiles = c_layout->tiles.size > TPU_C_API_MAX_INLINED + ? c_layout->tiles.heap + : c_layout->tiles.inlined; + for (int i = 0; i < c_layout->tiles.size; ++i) { + tiles.push_back(FromC(&c_tiles[i])); + } + return xla::Layout(minor_to_major, tiles, c_layout->element_size_in_bits, + c_layout->memory_space); +} + +void Free(XLA_Layout* c_layout) { + if (c_layout->minor_to_major.size > TPU_C_API_MAX_INLINED) { + delete[] c_layout->minor_to_major.heap; + } + if (c_layout->tiles.size > TPU_C_API_MAX_INLINED) { + delete[] c_layout->tiles.heap; + } +} + +void ToC(const xla::Tile& tile, XLA_Tile* c_tile) { + CopyVector(tile.dimensions(), &c_tile->dimensions); +} + +xla::Tile FromC(const XLA_Tile* c_tile) { + absl::Span dims = + MakeSpan(c_tile->dimensions); + return xla::Tile(dims); +} + +void Free(XLA_Tile* c_tile) { + if (c_tile->dimensions.size > TPU_C_API_MAX_INLINED) { + delete[] c_tile->dimensions.heap; + } } XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape) { @@ -212,7 +365,6 @@ void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) { } } -void Free(XLA_Shape* shape) { delete[] shape->bytes; } void Free(XLA_ShapeIndex* shape_index) { delete[] shape_index; } void Free(SE_DeviceMemoryBase*) {} diff --git a/tensorflow/stream_executor/tpu/c_api_conversions.h b/tensorflow/stream_executor/tpu/c_api_conversions.h index c4b5648e097..da856a8720b 100644 --- a/tensorflow/stream_executor/tpu/c_api_conversions.h +++ b/tensorflow/stream_executor/tpu/c_api_conversions.h @@ -43,9 +43,19 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base); void Free(SE_DeviceMemoryBase*); // xla::Shape -xla::Shape FromC(const XLA_Shape* shape); +xla::Shape FromC(const XLA_Shape* c_shape); void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape); -void Free(XLA_Shape* shape); +void Free(XLA_Shape* c_shape); + +// xla::Layout +xla::Layout FromC(const XLA_Layout* c_layout); +void ToC(const xla::Layout& xla_layout, XLA_Layout* c_layout); +void Free(XLA_Layout* c_layout); + +// xla::Tile +xla::Tile FromC(const XLA_Tile* c_tile); +void ToC(const xla::Tile& xla_tile, XLA_Tile* c_tile); +void Free(XLA_Tile* c_tile); // xla::ShapeIndex XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape); diff --git a/tensorflow/stream_executor/tpu/c_api_decl.h b/tensorflow/stream_executor/tpu/c_api_decl.h index dcb53823e0c..1b92913263e 100644 --- a/tensorflow/stream_executor/tpu/c_api_decl.h +++ b/tensorflow/stream_executor/tpu/c_api_decl.h @@ -25,6 +25,9 @@ limitations under the License. extern "C" { +// Maximum number of array elements to inline into structs for performance. +#define TPU_C_API_MAX_INLINED 6 + enum TpuCoreTypeEnum { kTensorCore, kEmbeddingV1, @@ -168,11 +171,50 @@ typedef struct SE_MaybeOwningDeviceMemory { SE_DeviceMemoryAllocator allocator; } SE_MaybeOwningDeviceMemory; +struct Int64List { + union { + int64_t* heap; // owned + int64_t inlined[TPU_C_API_MAX_INLINED]; + }; + int64_t size; +}; + +struct BoolList { + union { + bool* heap; // owned + bool inlined[TPU_C_API_MAX_INLINED]; + }; + int64_t size; +}; + +typedef struct XLA_Tile { + Int64List dimensions; +} XLA_Tile; + +struct TileList { + union { + XLA_Tile* heap; // owned + XLA_Tile inlined[TPU_C_API_MAX_INLINED]; + }; + int64_t size; +}; + +typedef struct XLA_Layout { + int format; + Int64List minor_to_major; + TileList tiles; + int64_t element_size_in_bits; + int64_t memory_space; +} XLA_Layout; + // Represents an XLA shape tree. -// Shapes are flattened in default traversal order. typedef struct XLA_Shape { - char* bytes; - size_t size; + int element_type; + Int64List dimensions; + BoolList dynamic_dimensions; + XLA_Shape* tuple_shapes; // owned + int ntuple_shapes; + XLA_Layout layout; } XLA_Shape; // Represents a leaf node for a XLA shaped buffer.