Cleanup in IrArray.

Remove unused methods, and don't use optional for the shape.

PiperOrigin-RevId: 236816066
This commit is contained in:
Adrian Kuegel 2019-03-05 02:57:02 -08:00 committed by TensorFlower Gardener
parent b0315f0960
commit 9d914a58d0
2 changed files with 13 additions and 60 deletions

View File

@ -92,18 +92,6 @@ IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
<< " should have a layout.";
}
IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
const Shape& shape, llvm::IRBuilder<>* b)
: multidim_(multidim.begin(), multidim.end()),
layout_(shape.layout()),
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
CHECK_GT(multidim_.size(), 0);
index_type_ = multidim[0]->getType();
CHECK_NE(index_type_, nullptr);
CHECK_EQ(shape.dimensions_size(), multidim.size());
CHECK(LayoutUtil::HasLayout(shape));
}
IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
: base_ptr_(base_ptr), shape_(std::move(shape)) {
TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
@ -117,10 +105,10 @@ IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
++depth;
}
if (!shape_->IsArray() || ShapeUtil::IsScalar(*shape_)) {
if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
DCHECK(depth == 1 || depth == 0) << depth;
} else {
DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString();
DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
}
}
@ -342,19 +330,19 @@ llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
llvm::IRBuilder<>* b,
absl::string_view name) const {
if (ShapeUtil::IsScalar(*shape_)) {
if (ShapeUtil::IsScalar(shape_)) {
// Special handling of scalars: a scalar pretends to have the same value for
// every index, thus effectively implementing broadcasting of its value
// over higher-rank arrays.
return base_ptr_;
}
CHECK_EQ(index.size(), shape_->rank());
CHECK_EQ(index.size(), shape_.rank());
if (index.LinearValidOnShape(*shape_)) {
if (index.LinearValidOnShape(shape_)) {
llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
return b->CreateInBoundsGEP(
b->CreateBitCast(base_ptr_,
PrimitiveTypeToIrType(shape_->element_type(), module)
PrimitiveTypeToIrType(shape_.element_type(), module)
->getPointerTo()),
{index.linear()}, llvm_ir::AsStringRef(name));
}
@ -364,7 +352,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
// When dimension i is of size 1, LLVM optimization is able to replace
// index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
// produce better code in some cases.
auto dim = shape_->dimensions(i);
auto dim = shape_.dimensions(i);
actual_index.push_back(
dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
}
@ -377,8 +365,8 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
CHECK_GT(index.size(), 0);
std::vector<llvm::Value*> gep_indices(
1, llvm::ConstantInt::get(index[0]->getType(), 0));
for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) {
int64 dimension = LayoutUtil::Major(shape_->layout(), i);
for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
int64 dimension = LayoutUtil::Major(shape_.layout(), i);
gep_indices.push_back(actual_index[dimension]);
}
return b->CreateInBoundsGEP(base_ptr_, gep_indices,
@ -423,18 +411,5 @@ IrArray IrArray::CastToShape(const Shape& new_shape,
return new_irarray;
}
/* static */ IrArray::Index IrArray::BumpIndex(const Index& index,
int64 which_dimension,
int64 addend,
llvm::IRBuilder<>* b) {
Index new_index = index;
new_index[which_dimension] = b->CreateAdd(
index[which_dimension],
llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "",
/*HasNUW=*/true,
/*HasNSW=*/true);
return new_index;
}
} // namespace llvm_ir
} // namespace xla

View File

@ -55,13 +55,6 @@ class IrArray {
// multidimensional index, which LLVM DCE can delete.
class Index {
public:
// Constructs an index of rank "size". Each dimension of the index is
// initialized to "value".
explicit Index(size_t size, llvm::Value* value)
: multidim_(size, value), index_type_(value->getType()) {
CHECK_NE(index_type_, nullptr);
}
// Constructs an index of rank "size". Each dimension of the index is
// initialized to nullptr.
explicit Index(llvm::Type* index_ty, size_t size = 0)
@ -96,13 +89,6 @@ class IrArray {
// Precondition: "shape" has a layout.
Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
// Constructs an index from the given multi-dimensional index and the shape
// that it indexes into.
//
// Precondition: "shape" has a layout.
Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
llvm::IRBuilder<>* b);
// Constructs an index from both a multi-dimensional index and a linear
// index. "shape" has the same meaning as that in the constructor that takes
// only a linear index.
@ -145,13 +131,12 @@ class IrArray {
const_iterator begin() const { return multidim().begin(); }
const_iterator end() const { return multidim().end(); }
llvm::Value* back() const { return multidim().back(); }
bool LinearValidOnShape(const Shape& a) const;
// Given that "this" is the target index of a reshape from `operand_shape`
// to `shape`, returns the source index.
Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape,
Index SourceIndexOfReshape(const Shape& output_shape,
const Shape& input_shape,
llvm::IRBuilder<>* builder) const;
// Returns the index into the source operand from which a slice operation
@ -242,9 +227,7 @@ class IrArray {
llvm::Value* GetBasePointer() const { return base_ptr_; }
llvm::Type* GetElementLlvmType() const { return element_type_; }
const Shape& GetShape() const {
return *shape_;
}
const Shape& GetShape() const { return shape_; }
// Emit a sequence of instructions to compute the address of the element in
// the given array at the given index. Returns the address of the element as
@ -318,11 +301,6 @@ class IrArray {
const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
// Bumps the "which_dimension" value within the provided index by the provided
// addend.
static Index BumpIndex(const Index& index, int64 which_dimension,
int64 addend, llvm::IRBuilder<>* b);
private:
// Add the specified LLVM IR metadata to loads/stores associated with this
// IrArray.
@ -337,7 +315,7 @@ class IrArray {
llvm::Type* element_type_;
// Shape of the XLA array.
absl::optional<Shape> shape_;
Shape shape_;
// The list of key/value pairs used when attaching metadata to emitted
// loads/stores for this array. They keys are the metadata kinds and the