Add iterator support to ElementsAttr and SparseElementsAttr.

This will allow iterating the values of a non-opaque ElementsAttr, with all of the types currently supported by DenseElementsAttr. This should help reduce the amount of specialization on DenseElementsAttr.

PiperOrigin-RevId: 264637293
This commit is contained in:
River Riddle 2019-08-21 10:23:14 -07:00 committed by TensorFlower Gardener
parent 4820c218a0
commit 8a960ef4b3
4 changed files with 293 additions and 35 deletions

View File

@ -229,7 +229,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
for (auto it : llvm::enumerate(permutation().cast<DenseIntElementsAttr>())) {
for (auto it : llvm::enumerate(permutation().getValues<APInt>())) {
if (it.index() != it.value()) {
return {};
}

View File

@ -20,6 +20,7 @@
#include "mlir/IR/AttributeSupport.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
namespace mlir {
class AffineMap;
@ -447,11 +448,18 @@ public:
// Elements Attributes
//===----------------------------------------------------------------------===//
namespace detail {
template <typename T> class ElementsAttrIterator;
template <typename T> class ElementsAttrRange;
} // namespace detail
/// A base attribute that represents a reference to a static shaped tensor or
/// vector constant.
class ElementsAttr : public Attribute {
public:
using Attribute::Attribute;
template <typename T> using iterator = detail::ElementsAttrIterator<T>;
template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
/// with static shape.
@ -467,6 +475,11 @@ public:
return getValue(index).template cast<T>();
}
/// Return the elements of this attribute as a value of type 'T'. Note:
/// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
/// iteration.
template <typename T> iterator_range<T> getValues() const;
/// Return if the given 'index' refers to a valid element in this attribute.
bool isValidIndex(ArrayRef<uint64_t> index) const;
@ -492,6 +505,11 @@ public:
return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
}
protected:
/// Returns the 1 dimenional flattened row-major index from the given
/// multi-dimensional index.
uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
};
namespace detail {
@ -853,10 +871,6 @@ protected:
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const;
/// Returns the 1 dimenional flattened index from the given multi-dimensional
/// index.
uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
};
/// An attribute that represents a reference to a dense float vector or tensor
@ -964,6 +978,11 @@ class SparseElementsAttr
public:
using Base::Base;
template <typename T>
using iterator =
llvm::mapped_iterator<llvm::detail::value_sequence_iterator<ptrdiff_t>,
std::function<T(ptrdiff_t)>>;
/// 'type' must be a vector or tensor with static shape.
static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
DenseElementsAttr values);
@ -972,6 +991,25 @@ public:
DenseElementsAttr getValues() const;
/// Return the values of this attribute in the form of the given type 'T'. 'T'
/// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc.
template <typename T> llvm::iterator_range<iterator<T>> getValues() const {
auto zeroValue = getZeroValue<T>();
auto valueIt = getValues().getValues<T>().begin();
std::vector<ptrdiff_t> flatSparseIndices = getFlattenedSparseIndices();
// TODO(riverriddle): Move-capture flatSparseIndices when c++14 is
// available.
std::function<T(ptrdiff_t)> mapFn = [=](ptrdiff_t index) {
// Try to map the current index to one of the sparse indices.
for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i)
if (flatSparseIndices[i] == index)
return *std::next(valueIt, i);
// Otherwise, return the zero value.
return zeroValue;
};
return llvm::map_range(llvm::seq<ptrdiff_t>(0, getNumElements()), mapFn);
}
/// Return the value of the element at the given index. The 'index' is
/// expected to refer to a valid element.
Attribute getValue(ArrayRef<uint64_t> index) const;
@ -980,6 +1018,49 @@ public:
static bool kindof(unsigned kind) {
return kind == StandardAttributes::SparseElements;
}
private:
/// Get a zero APFloat for the given sparse attribute.
APFloat getZeroAPFloat() const;
/// Get a zero APInt for the given sparse attribute.
APInt getZeroAPInt() const;
/// Get a zero attribute for the given sparse attribute.
Attribute getZeroAttr() const;
/// Utility methods to generate a zero value of some type 'T'. This is used by
/// the 'iterator' class.
/// Get a zero for a given attribute type.
template <typename T>
typename std::enable_if<std::is_base_of<Attribute, T>::value, T>::type
getZeroValue() const {
return getZeroAttr().template cast<T>();
}
/// Get a zero for an APInt.
template <typename T>
typename std::enable_if<std::is_same<APInt, T>::value, T>::type
getZeroValue() const {
return getZeroAPInt();
}
/// Get a zero for an APFloat.
template <typename T>
typename std::enable_if<std::is_same<APFloat, T>::value, T>::type
getZeroValue() const {
return getZeroAPFloat();
}
/// Get a zero for an C++ integer or float type.
template <typename T>
typename std::enable_if<std::numeric_limits<T>::is_integer ||
llvm::is_one_of<T, float, double>::value,
T>::type
getZeroValue() const {
return T(0);
}
/// Flatten, and return, all of the sparse indices in this attribute in
/// row-major order.
std::vector<ptrdiff_t> getFlattenedSparseIndices() const;
};
/// An attribute that represents a reference to a splat vector or tensor
@ -995,6 +1076,136 @@ public:
}
};
namespace detail {
/// This class represents a general iterator over the values of an ElementsAttr.
/// It supports all subclasses aside from OpaqueElementsAttr.
template <typename T>
class ElementsAttrIterator
: public llvm::iterator_facade_base<ElementsAttrIterator<T>,
std::random_access_iterator_tag, T,
std::ptrdiff_t, T, T> {
using DenseIteratorT =
decltype(std::declval<DenseElementsAttr>().getValues<T>().begin());
using SparseIteratorT = SparseElementsAttr::iterator<T>;
/// A union containing the specific iterators for each derived attribute kind.
union Iterator {
explicit Iterator(DenseIteratorT it) : denseIt(it) {}
explicit Iterator(SparseIteratorT it) : sparseIt(it) {}
~Iterator() {}
operator const DenseIteratorT &() const { return denseIt; }
operator const SparseIteratorT &() const { return sparseIt; }
operator DenseIteratorT &() { return denseIt; }
operator SparseIteratorT &() { return sparseIt; }
/// An instance of a dense elements iterator.
DenseIteratorT denseIt;
/// An instance of a sparse elements iterator.
SparseIteratorT sparseIt;
};
/// Utility method to process a functor on each of the internal iterator
/// types.
template <typename RetT, template <typename> class ProcessFn,
typename... Args>
RetT process(Args &... args) const {
switch (attrKind) {
case StandardAttributes::DenseElements:
return ProcessFn<DenseIteratorT>()(args...);
case StandardAttributes::SparseElements:
return ProcessFn<SparseIteratorT>()(args...);
}
llvm_unreachable("unexpected attribute kind");
}
/// Utility functors used to generically implement the iterators methods.
template <typename ItT> struct PlusAssign {
void operator()(ItT &it, ptrdiff_t offset) { it += offset; }
};
template <typename ItT> struct Minus {
ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
};
template <typename ItT> struct MinusAssign {
void operator()(ItT &it, ptrdiff_t offset) { it -= offset; }
};
template <typename ItT> struct Dereference {
T operator()(ItT &it) { return *it; }
};
template <typename ItT> struct ConstructIter {
Iterator operator()(const ItT &it) { return Iterator(it); }
};
public:
ElementsAttrIterator(const ElementsAttrIterator<T> &rhs)
: attrKind(rhs.attrKind),
it(rhs.process<Iterator, ConstructIter>(rhs.it)) {}
/// Methods necessary to support random access iteration.
ptrdiff_t operator-(const ElementsAttrIterator<T> &rhs) const {
assert(attrKind == rhs.attrKind && "incompatible iterators");
return process<ptrdiff_t, Minus>(it, rhs.it);
}
bool operator==(const ElementsAttrIterator<T> &rhs) const {
return rhs.attrKind == attrKind && process<bool, std::equal_to>(it, rhs.it);
}
bool operator<(const ElementsAttrIterator<T> &rhs) const {
assert(attrKind == rhs.attrKind && "incompatible iterators");
return process<bool, std::less>(it, rhs.it);
}
ElementsAttrIterator<T> &operator+=(ptrdiff_t offset) {
process<void, PlusAssign>(it, offset);
return *this;
}
ElementsAttrIterator<T> &operator-=(ptrdiff_t offset) {
process<void, MinusAssign>(it, offset);
return *this;
}
/// Dereference the iterator at the current index.
T operator*() { return process<T, Dereference>(it); }
private:
template <typename IteratorT>
ElementsAttrIterator(unsigned attrKind, IteratorT it)
: attrKind(attrKind), it(it) {}
/// Allow accessing the constructor.
friend ElementsAttr;
/// The kind of derived elements attribute.
unsigned attrKind;
/// A union containing the specific iterators for each derived kind.
Iterator it;
};
template <typename T>
class ElementsAttrRange : public llvm::iterator_range<ElementsAttrIterator<T>> {
using llvm::iterator_range<ElementsAttrIterator<T>>::iterator_range;
};
} // namespace detail
/// Return the elements of this attribute as a value of type 'T'.
template <typename T>
auto ElementsAttr::getValues() const -> iterator_range<T> {
if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
auto values = denseAttr.getValues<T>();
return {iterator<T>(getKind(), values.begin()),
iterator<T>(getKind(), values.end())};
}
if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
auto values = sparseAttr.getValues<T>();
return {iterator<T>(getKind(), values.begin()),
iterator<T>(getKind(), values.end())};
}
llvm_unreachable("unexpected attribute kind");
}
//===----------------------------------------------------------------------===//
// Attributes Utils
//===----------------------------------------------------------------------===//
template <typename U> bool Attribute::isa() const {
assert(impl && "isa<> used on a null attribute.");
return U::classof(*this);
@ -1015,6 +1226,10 @@ inline ::llvm::hash_code hash_value(Attribute arg) {
return ::llvm::hash_value(arg.impl);
}
//===----------------------------------------------------------------------===//
// NamedAttributeList
//===----------------------------------------------------------------------===//
/// A NamedAttributeList is used to manage a list of named attributes. This
/// provides simple interfaces for adding/removing/finding attributes from
/// within a DictionaryAttr.

View File

@ -415,6 +415,25 @@ ElementsAttr ElementsAttr::mapValues(
}
}
/// Returns the 1 dimenional flattened row-major index from the given
/// multi-dimensional index.
uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
assert(isValidIndex(index) && "expected valid multi-dimensional index");
auto type = getType();
// Reduce the provided multidimensional index into a flattended 1D row-major
// index.
auto rank = type.getRank();
auto shape = type.getShape();
uint64_t valueIndex = 0;
uint64_t dimMultiplier = 1;
for (int i = rank - 1; i >= 0; --i) {
valueIndex += index[i] * dimMultiplier;
dimMultiplier *= shape[i];
}
return valueIndex;
}
//===----------------------------------------------------------------------===//
// DenseElementAttr Utilities
//===----------------------------------------------------------------------===//
@ -779,25 +798,6 @@ DenseElementsAttr DenseElementsAttr::mapValues(
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
}
/// Returns the 1 dimenional flattened index from the given multi-dimensional
/// index.
uint64_t DenseElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
assert(isValidIndex(index) && "expected valid multi-dimensional index");
auto type = getType();
// Reduce the provided multidimensional index into a flattended 1D row-major
// index.
auto rank = type.getRank();
auto shape = type.getShape();
uint64_t valueIndex = 0;
uint64_t dimMultiplier = 1;
for (int i = rank - 1; i >= 0; --i) {
valueIndex += index[i] * dimMultiplier;
dimMultiplier *= shape[i];
}
return valueIndex;
}
//===----------------------------------------------------------------------===//
// DenseFPElementsAttr
//===----------------------------------------------------------------------===//
@ -938,15 +938,6 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
assert(isValidIndex(index) && "expected valid multi-dimensional index");
auto type = getType();
/// Return an attribute corresponding to '0' for the element type.
auto getZeroAttr = [=]() -> Attribute {
auto eltType = type.getElementType();
if (eltType.isa<FloatType>())
return FloatAttr::get(eltType, 0);
assert(eltType.isa<IntegerType>() && "unexpected element type");
return IntegerAttr::get(eltType, 0);
};
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
// as a 1-D index array.
auto sparseIndices = getIndices();
@ -983,6 +974,58 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
return getValues().getValue(it->second);
}
/// Get a zero APFloat for the given sparse attribute.
APFloat SparseElementsAttr::getZeroAPFloat() const {
auto eltType = getType().getElementType().cast<FloatType>();
return APFloat(eltType.getFloatSemantics());
}
/// Get a zero APInt for the given sparse attribute.
APInt SparseElementsAttr::getZeroAPInt() const {
auto eltType = getType().getElementType().cast<IntegerType>();
return APInt::getNullValue(eltType.getWidth());
}
/// Get a zero attribute for the given attribute type.
Attribute SparseElementsAttr::getZeroAttr() const {
auto eltType = getType().getElementType();
// Handle floating point elements.
if (eltType.isa<FloatType>())
return FloatAttr::get(eltType, 0);
// Otherwise, this is an integer.
auto intEltTy = eltType.cast<IntegerType>();
if (intEltTy.getWidth() == 1)
return BoolAttr::get(false, eltType.getContext());
return IntegerAttr::get(eltType, 0);
}
/// Flatten, and return, all of the sparse indices in this attribute in
/// row-major order.
std::vector<intptr_t> SparseElementsAttr::getFlattenedSparseIndices() const {
std::vector<intptr_t> flatSparseIndices;
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
// as a 1-D index array.
auto sparseIndices = getIndices();
auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
if (sparseIndices.isSplat()) {
SmallVector<uint64_t, 8> indices(getType().getRank(),
*sparseIndexValues.begin());
flatSparseIndices.push_back(getFlattenedIndex(indices));
return flatSparseIndices;
}
// Otherwise, reinterpret each index as an ArrayRef when flattening.
auto numSparseIndices = sparseIndices.getType().getDimSize(0);
size_t rank = getType().getRank();
for (size_t i = 0, e = numSparseIndices; i != e; ++i)
flatSparseIndices.push_back(getFlattenedIndex(
{&*std::next(sparseIndexValues.begin(), i * rank), rank}));
return flatSparseIndices;
}
//===----------------------------------------------------------------------===//
// NamedAttributeList
//===----------------------------------------------------------------------===//

View File

@ -90,12 +90,12 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
splatAttr.getSplatValue(), loc);
return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
}
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
SmallVector<llvm::Constant *, 8> constants;
uint64_t numElements = vectorType->getNumElements();
constants.reserve(numElements);
for (auto n : denseAttr.getAttributeValues()) {
for (auto n : elementsAttr.getValues<Attribute>()) {
constants.push_back(
getLLVMConstant(vectorType->getElementType(), n, loc));
if (!constants.back())