Cleanup in IrArray.
Remove unused methods, and don't use optional for the shape. PiperOrigin-RevId: 236816066
This commit is contained in:
parent
b0315f0960
commit
9d914a58d0
@ -92,18 +92,6 @@ IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
|
||||
<< " should have a layout.";
|
||||
}
|
||||
|
||||
IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
|
||||
const Shape& shape, llvm::IRBuilder<>* b)
|
||||
: multidim_(multidim.begin(), multidim.end()),
|
||||
layout_(shape.layout()),
|
||||
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
|
||||
CHECK_GT(multidim_.size(), 0);
|
||||
index_type_ = multidim[0]->getType();
|
||||
CHECK_NE(index_type_, nullptr);
|
||||
CHECK_EQ(shape.dimensions_size(), multidim.size());
|
||||
CHECK(LayoutUtil::HasLayout(shape));
|
||||
}
|
||||
|
||||
IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
|
||||
: base_ptr_(base_ptr), shape_(std::move(shape)) {
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
|
||||
@ -117,10 +105,10 @@ IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
|
||||
++depth;
|
||||
}
|
||||
|
||||
if (!shape_->IsArray() || ShapeUtil::IsScalar(*shape_)) {
|
||||
if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
|
||||
DCHECK(depth == 1 || depth == 0) << depth;
|
||||
} else {
|
||||
DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString();
|
||||
DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
|
||||
}
|
||||
}
|
||||
|
||||
@ -342,19 +330,19 @@ llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
|
||||
llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
|
||||
llvm::IRBuilder<>* b,
|
||||
absl::string_view name) const {
|
||||
if (ShapeUtil::IsScalar(*shape_)) {
|
||||
if (ShapeUtil::IsScalar(shape_)) {
|
||||
// Special handling of scalars: a scalar pretends to have the same value for
|
||||
// every index, thus effectively implementing broadcasting of its value
|
||||
// over higher-rank arrays.
|
||||
return base_ptr_;
|
||||
}
|
||||
CHECK_EQ(index.size(), shape_->rank());
|
||||
CHECK_EQ(index.size(), shape_.rank());
|
||||
|
||||
if (index.LinearValidOnShape(*shape_)) {
|
||||
if (index.LinearValidOnShape(shape_)) {
|
||||
llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
|
||||
return b->CreateInBoundsGEP(
|
||||
b->CreateBitCast(base_ptr_,
|
||||
PrimitiveTypeToIrType(shape_->element_type(), module)
|
||||
PrimitiveTypeToIrType(shape_.element_type(), module)
|
||||
->getPointerTo()),
|
||||
{index.linear()}, llvm_ir::AsStringRef(name));
|
||||
}
|
||||
@ -364,7 +352,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
|
||||
// When dimension i is of size 1, LLVM optimization is able to replace
|
||||
// index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
|
||||
// produce better code in some cases.
|
||||
auto dim = shape_->dimensions(i);
|
||||
auto dim = shape_.dimensions(i);
|
||||
actual_index.push_back(
|
||||
dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
|
||||
}
|
||||
@ -377,8 +365,8 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
|
||||
CHECK_GT(index.size(), 0);
|
||||
std::vector<llvm::Value*> gep_indices(
|
||||
1, llvm::ConstantInt::get(index[0]->getType(), 0));
|
||||
for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) {
|
||||
int64 dimension = LayoutUtil::Major(shape_->layout(), i);
|
||||
for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
|
||||
int64 dimension = LayoutUtil::Major(shape_.layout(), i);
|
||||
gep_indices.push_back(actual_index[dimension]);
|
||||
}
|
||||
return b->CreateInBoundsGEP(base_ptr_, gep_indices,
|
||||
@ -423,18 +411,5 @@ IrArray IrArray::CastToShape(const Shape& new_shape,
|
||||
return new_irarray;
|
||||
}
|
||||
|
||||
/* static */ IrArray::Index IrArray::BumpIndex(const Index& index,
|
||||
int64 which_dimension,
|
||||
int64 addend,
|
||||
llvm::IRBuilder<>* b) {
|
||||
Index new_index = index;
|
||||
new_index[which_dimension] = b->CreateAdd(
|
||||
index[which_dimension],
|
||||
llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "",
|
||||
/*HasNUW=*/true,
|
||||
/*HasNSW=*/true);
|
||||
return new_index;
|
||||
}
|
||||
|
||||
} // namespace llvm_ir
|
||||
} // namespace xla
|
||||
|
@ -55,13 +55,6 @@ class IrArray {
|
||||
// multidimensional index, which LLVM DCE can delete.
|
||||
class Index {
|
||||
public:
|
||||
// Constructs an index of rank "size". Each dimension of the index is
|
||||
// initialized to "value".
|
||||
explicit Index(size_t size, llvm::Value* value)
|
||||
: multidim_(size, value), index_type_(value->getType()) {
|
||||
CHECK_NE(index_type_, nullptr);
|
||||
}
|
||||
|
||||
// Constructs an index of rank "size". Each dimension of the index is
|
||||
// initialized to nullptr.
|
||||
explicit Index(llvm::Type* index_ty, size_t size = 0)
|
||||
@ -96,13 +89,6 @@ class IrArray {
|
||||
// Precondition: "shape" has a layout.
|
||||
Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
|
||||
|
||||
// Constructs an index from the given multi-dimensional index and the shape
|
||||
// that it indexes into.
|
||||
//
|
||||
// Precondition: "shape" has a layout.
|
||||
Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
|
||||
llvm::IRBuilder<>* b);
|
||||
|
||||
// Constructs an index from both a multi-dimensional index and a linear
|
||||
// index. "shape" has the same meaning as that in the constructor that takes
|
||||
// only a linear index.
|
||||
@ -145,13 +131,12 @@ class IrArray {
|
||||
const_iterator begin() const { return multidim().begin(); }
|
||||
const_iterator end() const { return multidim().end(); }
|
||||
|
||||
llvm::Value* back() const { return multidim().back(); }
|
||||
|
||||
bool LinearValidOnShape(const Shape& a) const;
|
||||
|
||||
// Given that "this" is the target index of a reshape from `operand_shape`
|
||||
// to `shape`, returns the source index.
|
||||
Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape,
|
||||
Index SourceIndexOfReshape(const Shape& output_shape,
|
||||
const Shape& input_shape,
|
||||
llvm::IRBuilder<>* builder) const;
|
||||
|
||||
// Returns the index into the source operand from which a slice operation
|
||||
@ -242,9 +227,7 @@ class IrArray {
|
||||
llvm::Value* GetBasePointer() const { return base_ptr_; }
|
||||
llvm::Type* GetElementLlvmType() const { return element_type_; }
|
||||
|
||||
const Shape& GetShape() const {
|
||||
return *shape_;
|
||||
}
|
||||
const Shape& GetShape() const { return shape_; }
|
||||
|
||||
// Emit a sequence of instructions to compute the address of the element in
|
||||
// the given array at the given index. Returns the address of the element as
|
||||
@ -318,11 +301,6 @@ class IrArray {
|
||||
|
||||
const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
|
||||
|
||||
// Bumps the "which_dimension" value within the provided index by the provided
|
||||
// addend.
|
||||
static Index BumpIndex(const Index& index, int64 which_dimension,
|
||||
int64 addend, llvm::IRBuilder<>* b);
|
||||
|
||||
private:
|
||||
// Add the specified LLVM IR metadata to loads/stores associated with this
|
||||
// IrArray.
|
||||
@ -337,7 +315,7 @@ class IrArray {
|
||||
llvm::Type* element_type_;
|
||||
|
||||
// Shape of the XLA array.
|
||||
absl::optional<Shape> shape_;
|
||||
Shape shape_;
|
||||
|
||||
// The list of key/value pairs used when attaching metadata to emitted
|
||||
// loads/stores for this array. They keys are the metadata kinds and the
|
||||
|
Loading…
Reference in New Issue
Block a user