[libtpu] Reimplement XLA_Shape in C API for performance.
Changes the representation of XLA_Shape to directly mirror the corresponding C++ classes (including adding XLA_Layout and XLA_Tile) structs), instead of using the serialized Shape proto. This improves the performance of converting between the C and C++ Shape class, which is currently in the hot execution path. PiperOrigin-RevId: 337928954 Change-Id: I941f5477007585e7f15e63f31b195b5745aa734f
This commit is contained in:
parent
af93956653
commit
2b036aebba
@ -39,6 +39,15 @@ class Shape {
|
|||||||
// Construct a shape from a ShapeProto.
|
// Construct a shape from a ShapeProto.
|
||||||
explicit Shape(const ShapeProto& shape_proto);
|
explicit Shape(const ShapeProto& shape_proto);
|
||||||
|
|
||||||
|
Shape(PrimitiveType element_type, absl::Span<const int64> dimensions,
|
||||||
|
absl::Span<const bool> dynamic_dimensions,
|
||||||
|
std::vector<Shape> tuple_shapes)
|
||||||
|
: element_type_(element_type),
|
||||||
|
dimensions_(dimensions.begin(), dimensions.end()),
|
||||||
|
dynamic_dimensions_(dynamic_dimensions.begin(),
|
||||||
|
dynamic_dimensions.end()),
|
||||||
|
tuple_shapes_(std::move(tuple_shapes)) {}
|
||||||
|
|
||||||
// Returns a ShapeProto representation of the Shape.
|
// Returns a ShapeProto representation of the Shape.
|
||||||
ShapeProto ToProto() const;
|
ShapeProto ToProto() const;
|
||||||
|
|
||||||
|
@ -149,18 +149,171 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base) {
|
|||||||
return base;
|
return base;
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::Shape FromC(const XLA_Shape* shape) {
|
// Helper functions for copying data to possibly-inlined C arrays.
|
||||||
xla::ShapeProto p;
|
|
||||||
p.ParseFromArray(shape->bytes, shape->size);
|
// 'Src' and 'Dst' are allowed to be different types to make this usable with
|
||||||
return xla::Shape(p);
|
// memory-identical types, e.g. int64 and int64_t. This should not be used with
|
||||||
|
// types that require a static_cast.
|
||||||
|
template <typename Src, typename Dst, typename DstList>
|
||||||
|
static void CopyVectorBase(const absl::Span<Src> src, DstList* dst) {
|
||||||
|
static_assert(sizeof(Src) == sizeof(Dst));
|
||||||
|
dst->size = src.size();
|
||||||
|
if (dst->size > TPU_C_API_MAX_INLINED) {
|
||||||
|
dst->heap = new Dst[dst->size];
|
||||||
|
memcpy(dst->heap, src.data(), dst->size * sizeof(Src));
|
||||||
|
} else {
|
||||||
|
memcpy(dst->inlined, src.data(), dst->size * sizeof(Src));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void CopyVector(const absl::Span<const typename tensorflow::int64> src,
|
||||||
|
Int64List* dst) {
|
||||||
|
return CopyVectorBase<const typename tensorflow::int64, int64_t, Int64List>(
|
||||||
|
src, dst);
|
||||||
|
}
|
||||||
|
static void CopyVector(const absl::Span<const bool> src, BoolList* dst) {
|
||||||
|
return CopyVectorBase<const bool, bool, BoolList>(src, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void CopyVector(const absl::Span<const xla::Tile> src, TileList* dst) {
|
||||||
|
dst->size = src.size();
|
||||||
|
XLA_Tile* c_tiles;
|
||||||
|
if (dst->size > TPU_C_API_MAX_INLINED) {
|
||||||
|
dst->heap = new XLA_Tile[dst->size];
|
||||||
|
c_tiles = dst->heap;
|
||||||
|
} else {
|
||||||
|
c_tiles = dst->inlined;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < dst->size; ++i) {
|
||||||
|
ToC(src[i], &c_tiles[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for creating a view of possibly-inlined C arrays.
|
||||||
|
|
||||||
|
// 'Src' and 'Dst' are allowed to be different types to make this usable with
|
||||||
|
// memory-identical types, e.g. int64 and int64_t. This should not be used with
|
||||||
|
// types that require a static_cast.
|
||||||
|
template <typename Dst, typename Src, typename SrcList>
|
||||||
|
static absl::Span<const Dst> MakeSpanBase(const SrcList& src_list) {
|
||||||
|
static_assert(sizeof(Src) == sizeof(Dst));
|
||||||
|
const Src* src = src_list.size > TPU_C_API_MAX_INLINED ? src_list.heap
|
||||||
|
: &src_list.inlined[0];
|
||||||
|
return absl::Span<const Dst>(reinterpret_cast<const Dst*>(src),
|
||||||
|
src_list.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static absl::Span<const typename tensorflow::int64> MakeSpan(
|
||||||
|
const Int64List& src_list) {
|
||||||
|
return MakeSpanBase<typename tensorflow::int64, int64_t, Int64List>(src_list);
|
||||||
|
}
|
||||||
|
static absl::Span<const bool> MakeSpan(const BoolList& src_list) {
|
||||||
|
return MakeSpanBase<bool, bool, BoolList>(src_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) {
|
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) {
|
||||||
xla::ShapeProto p = xla_shape.ToProto();
|
c_shape->element_type = xla_shape.element_type();
|
||||||
std::string p_str = p.SerializeAsString();
|
|
||||||
c_shape->bytes = new char[p_str.size()];
|
CopyVector(xla_shape.dimensions(), &c_shape->dimensions);
|
||||||
c_shape->size = p_str.size();
|
CopyVector(xla_shape.dynamic_dimensions(), &c_shape->dynamic_dimensions);
|
||||||
memcpy(c_shape->bytes, p_str.data(), p_str.size());
|
|
||||||
|
c_shape->ntuple_shapes = xla_shape.tuple_shapes_size();
|
||||||
|
if (c_shape->ntuple_shapes > 0) {
|
||||||
|
c_shape->tuple_shapes = new XLA_Shape[c_shape->ntuple_shapes];
|
||||||
|
for (int i = 0; i < c_shape->ntuple_shapes; ++i) {
|
||||||
|
ToC(xla_shape.tuple_shapes(i), &c_shape->tuple_shapes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (xla_shape.has_layout()) {
|
||||||
|
ToC(xla_shape.layout(), &c_shape->layout);
|
||||||
|
} else {
|
||||||
|
c_shape->layout.format = xla::INVALID_FORMAT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Shape FromC(const XLA_Shape* c_shape) {
|
||||||
|
absl::Span<const typename tensorflow::int64> dims =
|
||||||
|
MakeSpan(c_shape->dimensions);
|
||||||
|
absl::Span<const bool> dynamic_dims = MakeSpan(c_shape->dynamic_dimensions);
|
||||||
|
|
||||||
|
std::vector<xla::Shape> tuple_shapes;
|
||||||
|
tuple_shapes.reserve(c_shape->ntuple_shapes);
|
||||||
|
for (int i = 0; i < c_shape->ntuple_shapes; ++i) {
|
||||||
|
tuple_shapes.push_back(FromC(&c_shape->tuple_shapes[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Shape result(static_cast<xla::PrimitiveType>(c_shape->element_type),
|
||||||
|
dims, dynamic_dims, std::move(tuple_shapes));
|
||||||
|
if (c_shape->layout.format != xla::INVALID_FORMAT) {
|
||||||
|
*result.mutable_layout() = FromC(&c_shape->layout);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(XLA_Shape* c_shape) {
|
||||||
|
if (c_shape->dimensions.size > TPU_C_API_MAX_INLINED) {
|
||||||
|
delete[] c_shape->dimensions.heap;
|
||||||
|
}
|
||||||
|
if (c_shape->dynamic_dimensions.size > TPU_C_API_MAX_INLINED) {
|
||||||
|
delete[] c_shape->dynamic_dimensions.heap;
|
||||||
|
}
|
||||||
|
if (c_shape->ntuple_shapes > 0) {
|
||||||
|
for (int i = 0; i < c_shape->ntuple_shapes; ++i) {
|
||||||
|
Free(&c_shape->tuple_shapes[i]);
|
||||||
|
}
|
||||||
|
delete[] c_shape->tuple_shapes;
|
||||||
|
}
|
||||||
|
if (c_shape->layout.format != xla::INVALID_FORMAT) {
|
||||||
|
Free(&c_shape->layout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ToC(const xla::Layout& layout, XLA_Layout* c_layout) {
|
||||||
|
c_layout->format = layout.format();
|
||||||
|
CopyVector(layout.minor_to_major(), &c_layout->minor_to_major);
|
||||||
|
c_layout->element_size_in_bits = layout.element_size_in_bits();
|
||||||
|
c_layout->memory_space = layout.memory_space();
|
||||||
|
CopyVector(layout.tiles(), &c_layout->tiles);
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Layout FromC(const XLA_Layout* c_layout) {
|
||||||
|
absl::Span<const typename tensorflow::int64> minor_to_major =
|
||||||
|
MakeSpan(c_layout->minor_to_major);
|
||||||
|
absl::InlinedVector<xla::Tile, 1> tiles;
|
||||||
|
const XLA_Tile* c_tiles = c_layout->tiles.size > TPU_C_API_MAX_INLINED
|
||||||
|
? c_layout->tiles.heap
|
||||||
|
: c_layout->tiles.inlined;
|
||||||
|
for (int i = 0; i < c_layout->tiles.size; ++i) {
|
||||||
|
tiles.push_back(FromC(&c_tiles[i]));
|
||||||
|
}
|
||||||
|
return xla::Layout(minor_to_major, tiles, c_layout->element_size_in_bits,
|
||||||
|
c_layout->memory_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(XLA_Layout* c_layout) {
|
||||||
|
if (c_layout->minor_to_major.size > TPU_C_API_MAX_INLINED) {
|
||||||
|
delete[] c_layout->minor_to_major.heap;
|
||||||
|
}
|
||||||
|
if (c_layout->tiles.size > TPU_C_API_MAX_INLINED) {
|
||||||
|
delete[] c_layout->tiles.heap;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ToC(const xla::Tile& tile, XLA_Tile* c_tile) {
|
||||||
|
CopyVector(tile.dimensions(), &c_tile->dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Tile FromC(const XLA_Tile* c_tile) {
|
||||||
|
absl::Span<const typename tensorflow::int64> dims =
|
||||||
|
MakeSpan(c_tile->dimensions);
|
||||||
|
return xla::Tile(dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(XLA_Tile* c_tile) {
|
||||||
|
if (c_tile->dimensions.size > TPU_C_API_MAX_INLINED) {
|
||||||
|
delete[] c_tile->dimensions.heap;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape) {
|
XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape) {
|
||||||
@ -212,7 +365,6 @@ void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(XLA_Shape* shape) { delete[] shape->bytes; }
|
|
||||||
void Free(XLA_ShapeIndex* shape_index) { delete[] shape_index; }
|
void Free(XLA_ShapeIndex* shape_index) { delete[] shape_index; }
|
||||||
void Free(SE_DeviceMemoryBase*) {}
|
void Free(SE_DeviceMemoryBase*) {}
|
||||||
|
|
||||||
|
@ -43,9 +43,19 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
|
|||||||
void Free(SE_DeviceMemoryBase*);
|
void Free(SE_DeviceMemoryBase*);
|
||||||
|
|
||||||
// xla::Shape
|
// xla::Shape
|
||||||
xla::Shape FromC(const XLA_Shape* shape);
|
xla::Shape FromC(const XLA_Shape* c_shape);
|
||||||
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
|
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
|
||||||
void Free(XLA_Shape* shape);
|
void Free(XLA_Shape* c_shape);
|
||||||
|
|
||||||
|
// xla::Layout
|
||||||
|
xla::Layout FromC(const XLA_Layout* c_layout);
|
||||||
|
void ToC(const xla::Layout& xla_layout, XLA_Layout* c_layout);
|
||||||
|
void Free(XLA_Layout* c_layout);
|
||||||
|
|
||||||
|
// xla::Tile
|
||||||
|
xla::Tile FromC(const XLA_Tile* c_tile);
|
||||||
|
void ToC(const xla::Tile& xla_tile, XLA_Tile* c_tile);
|
||||||
|
void Free(XLA_Tile* c_tile);
|
||||||
|
|
||||||
// xla::ShapeIndex
|
// xla::ShapeIndex
|
||||||
XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape);
|
XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape);
|
||||||
|
@ -25,6 +25,9 @@ limitations under the License.
|
|||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
|
// Maximum number of array elements to inline into structs for performance.
|
||||||
|
#define TPU_C_API_MAX_INLINED 6
|
||||||
|
|
||||||
enum TpuCoreTypeEnum {
|
enum TpuCoreTypeEnum {
|
||||||
kTensorCore,
|
kTensorCore,
|
||||||
kEmbeddingV1,
|
kEmbeddingV1,
|
||||||
@ -168,11 +171,50 @@ typedef struct SE_MaybeOwningDeviceMemory {
|
|||||||
SE_DeviceMemoryAllocator allocator;
|
SE_DeviceMemoryAllocator allocator;
|
||||||
} SE_MaybeOwningDeviceMemory;
|
} SE_MaybeOwningDeviceMemory;
|
||||||
|
|
||||||
|
struct Int64List {
|
||||||
|
union {
|
||||||
|
int64_t* heap; // owned
|
||||||
|
int64_t inlined[TPU_C_API_MAX_INLINED];
|
||||||
|
};
|
||||||
|
int64_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BoolList {
|
||||||
|
union {
|
||||||
|
bool* heap; // owned
|
||||||
|
bool inlined[TPU_C_API_MAX_INLINED];
|
||||||
|
};
|
||||||
|
int64_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct XLA_Tile {
|
||||||
|
Int64List dimensions;
|
||||||
|
} XLA_Tile;
|
||||||
|
|
||||||
|
struct TileList {
|
||||||
|
union {
|
||||||
|
XLA_Tile* heap; // owned
|
||||||
|
XLA_Tile inlined[TPU_C_API_MAX_INLINED];
|
||||||
|
};
|
||||||
|
int64_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct XLA_Layout {
|
||||||
|
int format;
|
||||||
|
Int64List minor_to_major;
|
||||||
|
TileList tiles;
|
||||||
|
int64_t element_size_in_bits;
|
||||||
|
int64_t memory_space;
|
||||||
|
} XLA_Layout;
|
||||||
|
|
||||||
// Represents an XLA shape tree.
|
// Represents an XLA shape tree.
|
||||||
// Shapes are flattened in default traversal order.
|
|
||||||
typedef struct XLA_Shape {
|
typedef struct XLA_Shape {
|
||||||
char* bytes;
|
int element_type;
|
||||||
size_t size;
|
Int64List dimensions;
|
||||||
|
BoolList dynamic_dimensions;
|
||||||
|
XLA_Shape* tuple_shapes; // owned
|
||||||
|
int ntuple_shapes;
|
||||||
|
XLA_Layout layout;
|
||||||
} XLA_Shape;
|
} XLA_Shape;
|
||||||
|
|
||||||
// Represents a leaf node for a XLA shaped buffer.
|
// Represents a leaf node for a XLA shaped buffer.
|
||||||
|
Loading…
Reference in New Issue
Block a user