Changes the representation of XLA_Shape to directly mirror the corresponding C++ classes (including adding XLA_Layout and XLA_Tile) structs), instead of using the serialized Shape proto. This improves the performance of converting between the C and C++ Shape class, which is currently in the hot execution path. PiperOrigin-RevId: 337928954 Change-Id: I941f5477007585e7f15e63f31b195b5745aa734f
344 lines
12 KiB
C++
344 lines
12 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
|
|
#define TENSORFLOW_COMPILER_XLA_SHAPE_H_
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "absl/container/inlined_vector.h"
|
|
#include "absl/types/optional.h"
|
|
#include "tensorflow/compiler/xla/layout.h"
|
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
#include "tensorflow/core/platform/types.h"
|
|
|
|
namespace xla {
|
|
|
|
// A shape describes the number of dimensions in a array, the bounds of each
|
|
// dimension, and the primitive component type. For tuples, shape describes the
|
|
// structure (number of elements and nesting).
|
|
class Shape {
|
|
public:
|
|
Shape() = default;
|
|
|
|
// Construct a shape from a ShapeProto.
|
|
explicit Shape(const ShapeProto& shape_proto);
|
|
|
|
Shape(PrimitiveType element_type, absl::Span<const int64> dimensions,
|
|
absl::Span<const bool> dynamic_dimensions,
|
|
std::vector<Shape> tuple_shapes)
|
|
: element_type_(element_type),
|
|
dimensions_(dimensions.begin(), dimensions.end()),
|
|
dynamic_dimensions_(dynamic_dimensions.begin(),
|
|
dynamic_dimensions.end()),
|
|
tuple_shapes_(std::move(tuple_shapes)) {}
|
|
|
|
// Returns a ShapeProto representation of the Shape.
|
|
ShapeProto ToProto() const;
|
|
|
|
// Returns a human-readable string that represents the given shape, with or
|
|
// without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
|
|
string ToString(bool print_layout = false) const;
|
|
|
|
// Returns the rank (number of dimensions) of the given shape. Shape must be
|
|
// an array.
|
|
int64 rank() const {
|
|
DCHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString();
|
|
return dimensions_.size();
|
|
}
|
|
|
|
// Returns whether the shape is of the specified type (array, tuple, etc).
|
|
bool IsArray() const { return primitive_util::IsArrayType(element_type()); }
|
|
bool IsTuple() const { return element_type() == TUPLE; }
|
|
bool IsToken() const { return element_type() == TOKEN; }
|
|
bool IsOpaque() const { return element_type() == OPAQUE_TYPE; }
|
|
|
|
// Returns true if no array dimension in the shape is dynamically sized. Tuple
|
|
// shapes are traversed recursively.
|
|
bool is_static() const;
|
|
|
|
bool is_dynamic() const { return !is_static(); }
|
|
|
|
// Returns true if the given dimension is dynamically-sized.
|
|
bool is_dynamic_dimension(int dimension) const {
|
|
return dynamic_dimensions_.at(dimension);
|
|
}
|
|
|
|
// Sets whether or not the given dimension is dynamically-sized.
|
|
void set_dynamic_dimension(int dimension, bool is_dynamic) {
|
|
dynamic_dimensions_[dimension] = is_dynamic;
|
|
}
|
|
|
|
absl::Span<const bool> dynamic_dimensions() const {
|
|
return dynamic_dimensions_;
|
|
}
|
|
|
|
absl::Span<bool> mutable_dynamic_dimensions() {
|
|
return absl::MakeSpan(dynamic_dimensions_);
|
|
}
|
|
|
|
// Add dimension_upper_bound().
|
|
|
|
// Removes the given dimension form the shape. Layout, if it exists, is
|
|
// adjusted to match the modified shape.
|
|
void DeleteDimension(int64 dim_to_delete);
|
|
|
|
// The following methods mirror the protobuf generated code interface for the
|
|
// message ShapeProto. This enabled easy migration of this data structure
|
|
// from a proto to a proper C++ class.
|
|
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
|
|
// interface.
|
|
|
|
// Methods for accessing the primitive type.
|
|
PrimitiveType element_type() const { return element_type_; }
|
|
void set_element_type(PrimitiveType value) { element_type_ = value; }
|
|
|
|
// Methods for accessing the dimensions array.
|
|
int dimensions_size() const { return dimensions_.size(); }
|
|
int64 dimensions(int index) const { return dimensions_.at(index); }
|
|
void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; }
|
|
void add_dimensions(int64 value) {
|
|
dimensions_.push_back(value);
|
|
dynamic_dimensions_.push_back(false);
|
|
}
|
|
void clear_dimensions() {
|
|
dimensions_.clear();
|
|
dynamic_dimensions_.clear();
|
|
}
|
|
absl::Span<const int64> dimensions() const { return dimensions_; }
|
|
absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); }
|
|
|
|
// Methods for accessing the tuple subshapes. This field only non-empty for
|
|
// tuple shapes.
|
|
int tuple_shapes_size() const { return tuple_shapes_.size(); }
|
|
const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); }
|
|
Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); }
|
|
Shape* add_tuple_shapes() {
|
|
tuple_shapes_.push_back(Shape());
|
|
return &tuple_shapes_.back();
|
|
}
|
|
void clear_tuple_shapes() { tuple_shapes_.clear(); }
|
|
const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; }
|
|
std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
|
|
|
|
// Methods for accessing the layout field.
|
|
bool has_layout() const { return layout_.format() != INVALID_FORMAT; }
|
|
const Layout& layout() const { return layout_; }
|
|
Layout* mutable_layout() { return &layout_; }
|
|
void clear_layout() { layout_.Clear(); }
|
|
|
|
// Recursively clear dynamic dimension of a shape.
|
|
void clear_dynamic_dimensions() {
|
|
if (!IsTuple()) {
|
|
for (int64 i = 0; i < dynamic_dimensions_.size(); ++i) {
|
|
dynamic_dimensions_[i] = false;
|
|
}
|
|
return;
|
|
}
|
|
for (auto& subshape : tuple_shapes_) {
|
|
subshape.clear_dynamic_dimensions();
|
|
}
|
|
}
|
|
|
|
void Swap(Shape* other) {
|
|
using std::swap;
|
|
swap(*this, *other);
|
|
}
|
|
|
|
void Clear() {
|
|
element_type_ = PRIMITIVE_TYPE_INVALID;
|
|
clear_dimensions();
|
|
tuple_shapes_.clear();
|
|
clear_layout();
|
|
}
|
|
|
|
string SerializeAsString() const { return ToProto().SerializeAsString(); }
|
|
string ShortDebugString() const { return ToProto().ShortDebugString(); }
|
|
string DebugString() const { return ToProto().DebugString(); }
|
|
|
|
// Equal is a configurable functor to check the equality of two shapes.
|
|
//
|
|
// Examples:
|
|
//
|
|
// - Comparing two shapes ignoring their layout difference:
|
|
// Equal().IgnoreLayout()(shape1, shape2);
|
|
//
|
|
// - Comparing two shapes ignoring their layout and element type difference:
|
|
// Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2);
|
|
class Equal {
|
|
public:
|
|
Equal() = default;
|
|
|
|
bool operator()(const Shape& lhs, const Shape& rhs);
|
|
|
|
Equal& IgnoreLayout() {
|
|
ignore_layout_ = true;
|
|
return *this;
|
|
}
|
|
Equal& IgnoreTilesInLayout() {
|
|
ignore_tiles_in_layout_ = true;
|
|
return *this;
|
|
}
|
|
Equal& IgnoreElementSizeInLayout() {
|
|
ignore_element_size_in_layout_ = true;
|
|
return *this;
|
|
}
|
|
Equal& IgnoreMemorySpaceInLayout() {
|
|
ignore_memory_space_in_layout_ = true;
|
|
return *this;
|
|
}
|
|
Equal& MinorToMajorOnlyInLayout() {
|
|
ignore_tiles_in_layout_ = true;
|
|
ignore_element_size_in_layout_ = true;
|
|
ignore_memory_space_in_layout_ = true;
|
|
return *this;
|
|
}
|
|
Equal& IgnoreElementType() {
|
|
ignore_element_type_ = true;
|
|
return *this;
|
|
}
|
|
Equal& IgnoreFpPrecision() {
|
|
ignore_fp_precision_ = true;
|
|
return *this;
|
|
}
|
|
Equal& IgnoreDynamicDimension() {
|
|
ignore_dynamic_dimension_ = true;
|
|
return *this;
|
|
}
|
|
|
|
private:
|
|
bool ignore_layout_ = false;
|
|
bool ignore_tiles_in_layout_ = false;
|
|
bool ignore_element_size_in_layout_ = false;
|
|
bool ignore_memory_space_in_layout_ = false;
|
|
bool ignore_element_type_ = false;
|
|
bool ignore_fp_precision_ = false;
|
|
bool ignore_dynamic_dimension_ = false;
|
|
};
|
|
|
|
// Test that all fields of the shape are the same, equivalent to Equal().
|
|
bool operator==(const Shape& other) const { return Equal()(*this, other); }
|
|
bool operator!=(const Shape& other) const { return !(*this == other); }
|
|
|
|
template <typename H>
|
|
friend H AbslHashValue(H h, const Shape& s) {
|
|
return H::combine(std::move(h), s.element_type_, s.dimensions_,
|
|
s.dynamic_dimensions_, s.tuple_shapes_, s.layout_);
|
|
}
|
|
|
|
private:
|
|
// The element type of this shape (tuple, array, etc).
|
|
PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
|
|
|
|
// The array bounds of the dimensions. This is nonempty only for array
|
|
// shapes. For a dynamically-sized dimension, the respective value in this
|
|
// vector is an inclusive upper limit of the array bound.
|
|
absl::InlinedVector<int64, 6> dimensions_;
|
|
|
|
// This vector is the same size as 'dimensions_' and indicates whether the
|
|
// respective dimension is dynamically sized.
|
|
absl::InlinedVector<bool, 6> dynamic_dimensions_;
|
|
|
|
// The tuple element subshapes. This is nonempty only for tuple shapes.
|
|
std::vector<Shape> tuple_shapes_;
|
|
|
|
// The layout of the shape. Only relevant for arrays.
|
|
Layout layout_;
|
|
};
|
|
|
|
// Shape of the parameters and output of an XLA computation. This is analogous
|
|
// to a traditional function signature.
|
|
class ProgramShape {
|
|
public:
|
|
ProgramShape() = default;
|
|
|
|
// Creates a ProgramShape from a ProgramShapeProto protobuf.
|
|
explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
|
|
|
|
// Returns a proto representation of the object.
|
|
ProgramShapeProto ToProto() const;
|
|
|
|
string ToString() const;
|
|
|
|
// The following methods mirror the protobuf generated code interface for the
|
|
// message ProgramShapeProto. This enabled easy migration of this data
|
|
// structure from a proto to a proper C++ class.
|
|
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
|
|
// interface.
|
|
|
|
// Methods for accessing and manipulating the Shape of the parameters.
|
|
int parameters_size() const { return parameters_.size(); }
|
|
const Shape& parameters(int index) const { return parameters_.at(index); }
|
|
Shape* mutable_parameters(int index) { return ¶meters_.at(index); }
|
|
Shape* add_parameters() {
|
|
parameters_.emplace_back();
|
|
return ¶meters_.back();
|
|
}
|
|
void clear_parameters() { parameters_.clear(); }
|
|
const std::vector<Shape>& parameters() const { return parameters_; }
|
|
std::vector<Shape>* mutable_parameters() { return ¶meters_; }
|
|
|
|
// Methods for accessing and manipulating the Shape of the result.
|
|
const Shape& result() const { return result_; }
|
|
Shape* mutable_result() { return &result_; }
|
|
|
|
// Methods for accessing and manipulating the names of the parameters.
|
|
int parameter_names_size() const { return parameter_names_.size(); }
|
|
const string& parameter_names(int index) const {
|
|
return parameter_names_.at(index);
|
|
}
|
|
void set_parameter_names(int index, const string& value) {
|
|
parameter_names_.at(index) = value;
|
|
}
|
|
string* mutable_parameter_names(int index) {
|
|
return ¶meter_names_.at(index);
|
|
}
|
|
void add_parameter_names(const string& value) {
|
|
parameter_names_.push_back(value);
|
|
}
|
|
string* add_parameter_names() {
|
|
parameter_names_.push_back("");
|
|
return ¶meter_names_.back();
|
|
}
|
|
void clear_parameter_names() { parameter_names_.clear(); }
|
|
const std::vector<string>& parameter_names() const {
|
|
return parameter_names_;
|
|
}
|
|
std::vector<string>* mutable_parameter_names() { return ¶meter_names_; }
|
|
|
|
string ShortDebugString() const { return ToProto().ShortDebugString(); }
|
|
string DebugString() const { return ToProto().DebugString(); }
|
|
|
|
private:
|
|
// The shapes of the parameters of the computation represented by this object.
|
|
std::vector<Shape> parameters_;
|
|
|
|
// The names of the parameters of the computation represented by this object.
|
|
std::vector<string> parameter_names_;
|
|
|
|
// The shape of the result of the computation represented by this object.
|
|
Shape result_;
|
|
};
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Shape& shape);
|
|
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
|
|
|
|
} // namespace xla
|
|
|
|
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_
|