Add iterator support to ElementsAttr and SparseElementsAttr.
This will allow iterating the values of a non-opaque ElementsAttr, with all of the types currently supported by DenseElementsAttr. This should help reduce the amount of specialization on DenseElementsAttr. PiperOrigin-RevId: 264637293
This commit is contained in:
parent
4820c218a0
commit
8a960ef4b3
@ -229,7 +229,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
for (auto it : llvm::enumerate(permutation().cast<DenseIntElementsAttr>())) {
|
||||
for (auto it : llvm::enumerate(permutation().getValues<APInt>())) {
|
||||
if (it.index() != it.value()) {
|
||||
return {};
|
||||
}
|
||||
|
223
third_party/mlir/include/mlir/IR/Attributes.h
vendored
223
third_party/mlir/include/mlir/IR/Attributes.h
vendored
@ -20,6 +20,7 @@
|
||||
|
||||
#include "mlir/IR/AttributeSupport.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
@ -447,11 +448,18 @@ public:
|
||||
// Elements Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace detail {
|
||||
template <typename T> class ElementsAttrIterator;
|
||||
template <typename T> class ElementsAttrRange;
|
||||
} // namespace detail
|
||||
|
||||
/// A base attribute that represents a reference to a static shaped tensor or
|
||||
/// vector constant.
|
||||
class ElementsAttr : public Attribute {
|
||||
public:
|
||||
using Attribute::Attribute;
|
||||
template <typename T> using iterator = detail::ElementsAttrIterator<T>;
|
||||
template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
|
||||
|
||||
/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
|
||||
/// with static shape.
|
||||
@ -467,6 +475,11 @@ public:
|
||||
return getValue(index).template cast<T>();
|
||||
}
|
||||
|
||||
/// Return the elements of this attribute as a value of type 'T'. Note:
|
||||
/// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
|
||||
/// iteration.
|
||||
template <typename T> iterator_range<T> getValues() const;
|
||||
|
||||
/// Return if the given 'index' refers to a valid element in this attribute.
|
||||
bool isValidIndex(ArrayRef<uint64_t> index) const;
|
||||
|
||||
@ -492,6 +505,11 @@ public:
|
||||
return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
|
||||
attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Returns the 1 dimenional flattened row-major index from the given
|
||||
/// multi-dimensional index.
|
||||
uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
@ -853,10 +871,6 @@ protected:
|
||||
/// the current attribute. This method is used to verify specific type
|
||||
/// invariants that the templatized 'getValues' method cannot.
|
||||
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const;
|
||||
|
||||
/// Returns the 1 dimenional flattened index from the given multi-dimensional
|
||||
/// index.
|
||||
uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
|
||||
};
|
||||
|
||||
/// An attribute that represents a reference to a dense float vector or tensor
|
||||
@ -964,6 +978,11 @@ class SparseElementsAttr
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
template <typename T>
|
||||
using iterator =
|
||||
llvm::mapped_iterator<llvm::detail::value_sequence_iterator<ptrdiff_t>,
|
||||
std::function<T(ptrdiff_t)>>;
|
||||
|
||||
/// 'type' must be a vector or tensor with static shape.
|
||||
static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
|
||||
DenseElementsAttr values);
|
||||
@ -972,6 +991,25 @@ public:
|
||||
|
||||
DenseElementsAttr getValues() const;
|
||||
|
||||
/// Return the values of this attribute in the form of the given type 'T'. 'T'
|
||||
/// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc.
|
||||
template <typename T> llvm::iterator_range<iterator<T>> getValues() const {
|
||||
auto zeroValue = getZeroValue<T>();
|
||||
auto valueIt = getValues().getValues<T>().begin();
|
||||
std::vector<ptrdiff_t> flatSparseIndices = getFlattenedSparseIndices();
|
||||
// TODO(riverriddle): Move-capture flatSparseIndices when c++14 is
|
||||
// available.
|
||||
std::function<T(ptrdiff_t)> mapFn = [=](ptrdiff_t index) {
|
||||
// Try to map the current index to one of the sparse indices.
|
||||
for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i)
|
||||
if (flatSparseIndices[i] == index)
|
||||
return *std::next(valueIt, i);
|
||||
// Otherwise, return the zero value.
|
||||
return zeroValue;
|
||||
};
|
||||
return llvm::map_range(llvm::seq<ptrdiff_t>(0, getNumElements()), mapFn);
|
||||
}
|
||||
|
||||
/// Return the value of the element at the given index. The 'index' is
|
||||
/// expected to refer to a valid element.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||
@ -980,6 +1018,49 @@ public:
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardAttributes::SparseElements;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Get a zero APFloat for the given sparse attribute.
|
||||
APFloat getZeroAPFloat() const;
|
||||
|
||||
/// Get a zero APInt for the given sparse attribute.
|
||||
APInt getZeroAPInt() const;
|
||||
|
||||
/// Get a zero attribute for the given sparse attribute.
|
||||
Attribute getZeroAttr() const;
|
||||
|
||||
/// Utility methods to generate a zero value of some type 'T'. This is used by
|
||||
/// the 'iterator' class.
|
||||
/// Get a zero for a given attribute type.
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_base_of<Attribute, T>::value, T>::type
|
||||
getZeroValue() const {
|
||||
return getZeroAttr().template cast<T>();
|
||||
}
|
||||
/// Get a zero for an APInt.
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_same<APInt, T>::value, T>::type
|
||||
getZeroValue() const {
|
||||
return getZeroAPInt();
|
||||
}
|
||||
/// Get a zero for an APFloat.
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_same<APFloat, T>::value, T>::type
|
||||
getZeroValue() const {
|
||||
return getZeroAPFloat();
|
||||
}
|
||||
/// Get a zero for an C++ integer or float type.
|
||||
template <typename T>
|
||||
typename std::enable_if<std::numeric_limits<T>::is_integer ||
|
||||
llvm::is_one_of<T, float, double>::value,
|
||||
T>::type
|
||||
getZeroValue() const {
|
||||
return T(0);
|
||||
}
|
||||
|
||||
/// Flatten, and return, all of the sparse indices in this attribute in
|
||||
/// row-major order.
|
||||
std::vector<ptrdiff_t> getFlattenedSparseIndices() const;
|
||||
};
|
||||
|
||||
/// An attribute that represents a reference to a splat vector or tensor
|
||||
@ -995,6 +1076,136 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
/// This class represents a general iterator over the values of an ElementsAttr.
|
||||
/// It supports all subclasses aside from OpaqueElementsAttr.
|
||||
template <typename T>
|
||||
class ElementsAttrIterator
|
||||
: public llvm::iterator_facade_base<ElementsAttrIterator<T>,
|
||||
std::random_access_iterator_tag, T,
|
||||
std::ptrdiff_t, T, T> {
|
||||
using DenseIteratorT =
|
||||
decltype(std::declval<DenseElementsAttr>().getValues<T>().begin());
|
||||
using SparseIteratorT = SparseElementsAttr::iterator<T>;
|
||||
|
||||
/// A union containing the specific iterators for each derived attribute kind.
|
||||
union Iterator {
|
||||
explicit Iterator(DenseIteratorT it) : denseIt(it) {}
|
||||
explicit Iterator(SparseIteratorT it) : sparseIt(it) {}
|
||||
~Iterator() {}
|
||||
|
||||
operator const DenseIteratorT &() const { return denseIt; }
|
||||
operator const SparseIteratorT &() const { return sparseIt; }
|
||||
operator DenseIteratorT &() { return denseIt; }
|
||||
operator SparseIteratorT &() { return sparseIt; }
|
||||
|
||||
/// An instance of a dense elements iterator.
|
||||
DenseIteratorT denseIt;
|
||||
/// An instance of a sparse elements iterator.
|
||||
SparseIteratorT sparseIt;
|
||||
};
|
||||
|
||||
/// Utility method to process a functor on each of the internal iterator
|
||||
/// types.
|
||||
template <typename RetT, template <typename> class ProcessFn,
|
||||
typename... Args>
|
||||
RetT process(Args &... args) const {
|
||||
switch (attrKind) {
|
||||
case StandardAttributes::DenseElements:
|
||||
return ProcessFn<DenseIteratorT>()(args...);
|
||||
case StandardAttributes::SparseElements:
|
||||
return ProcessFn<SparseIteratorT>()(args...);
|
||||
}
|
||||
llvm_unreachable("unexpected attribute kind");
|
||||
}
|
||||
|
||||
/// Utility functors used to generically implement the iterators methods.
|
||||
template <typename ItT> struct PlusAssign {
|
||||
void operator()(ItT &it, ptrdiff_t offset) { it += offset; }
|
||||
};
|
||||
template <typename ItT> struct Minus {
|
||||
ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
|
||||
};
|
||||
template <typename ItT> struct MinusAssign {
|
||||
void operator()(ItT &it, ptrdiff_t offset) { it -= offset; }
|
||||
};
|
||||
template <typename ItT> struct Dereference {
|
||||
T operator()(ItT &it) { return *it; }
|
||||
};
|
||||
template <typename ItT> struct ConstructIter {
|
||||
Iterator operator()(const ItT &it) { return Iterator(it); }
|
||||
};
|
||||
|
||||
public:
|
||||
ElementsAttrIterator(const ElementsAttrIterator<T> &rhs)
|
||||
: attrKind(rhs.attrKind),
|
||||
it(rhs.process<Iterator, ConstructIter>(rhs.it)) {}
|
||||
|
||||
/// Methods necessary to support random access iteration.
|
||||
ptrdiff_t operator-(const ElementsAttrIterator<T> &rhs) const {
|
||||
assert(attrKind == rhs.attrKind && "incompatible iterators");
|
||||
return process<ptrdiff_t, Minus>(it, rhs.it);
|
||||
}
|
||||
bool operator==(const ElementsAttrIterator<T> &rhs) const {
|
||||
return rhs.attrKind == attrKind && process<bool, std::equal_to>(it, rhs.it);
|
||||
}
|
||||
bool operator<(const ElementsAttrIterator<T> &rhs) const {
|
||||
assert(attrKind == rhs.attrKind && "incompatible iterators");
|
||||
return process<bool, std::less>(it, rhs.it);
|
||||
}
|
||||
ElementsAttrIterator<T> &operator+=(ptrdiff_t offset) {
|
||||
process<void, PlusAssign>(it, offset);
|
||||
return *this;
|
||||
}
|
||||
ElementsAttrIterator<T> &operator-=(ptrdiff_t offset) {
|
||||
process<void, MinusAssign>(it, offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Dereference the iterator at the current index.
|
||||
T operator*() { return process<T, Dereference>(it); }
|
||||
|
||||
private:
|
||||
template <typename IteratorT>
|
||||
ElementsAttrIterator(unsigned attrKind, IteratorT it)
|
||||
: attrKind(attrKind), it(it) {}
|
||||
|
||||
/// Allow accessing the constructor.
|
||||
friend ElementsAttr;
|
||||
|
||||
/// The kind of derived elements attribute.
|
||||
unsigned attrKind;
|
||||
|
||||
/// A union containing the specific iterators for each derived kind.
|
||||
Iterator it;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ElementsAttrRange : public llvm::iterator_range<ElementsAttrIterator<T>> {
|
||||
using llvm::iterator_range<ElementsAttrIterator<T>>::iterator_range;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Return the elements of this attribute as a value of type 'T'.
|
||||
template <typename T>
|
||||
auto ElementsAttr::getValues() const -> iterator_range<T> {
|
||||
if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
|
||||
auto values = denseAttr.getValues<T>();
|
||||
return {iterator<T>(getKind(), values.begin()),
|
||||
iterator<T>(getKind(), values.end())};
|
||||
}
|
||||
if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
|
||||
auto values = sparseAttr.getValues<T>();
|
||||
return {iterator<T>(getKind(), values.begin()),
|
||||
iterator<T>(getKind(), values.end())};
|
||||
}
|
||||
llvm_unreachable("unexpected attribute kind");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attributes Utils
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename U> bool Attribute::isa() const {
|
||||
assert(impl && "isa<> used on a null attribute.");
|
||||
return U::classof(*this);
|
||||
@ -1015,6 +1226,10 @@ inline ::llvm::hash_code hash_value(Attribute arg) {
|
||||
return ::llvm::hash_value(arg.impl);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NamedAttributeList
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// A NamedAttributeList is used to manage a list of named attributes. This
|
||||
/// provides simple interfaces for adding/removing/finding attributes from
|
||||
/// within a DictionaryAttr.
|
||||
|
99
third_party/mlir/lib/IR/Attributes.cpp
vendored
99
third_party/mlir/lib/IR/Attributes.cpp
vendored
@ -415,6 +415,25 @@ ElementsAttr ElementsAttr::mapValues(
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the 1 dimenional flattened row-major index from the given
|
||||
/// multi-dimensional index.
|
||||
uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
|
||||
assert(isValidIndex(index) && "expected valid multi-dimensional index");
|
||||
auto type = getType();
|
||||
|
||||
// Reduce the provided multidimensional index into a flattended 1D row-major
|
||||
// index.
|
||||
auto rank = type.getRank();
|
||||
auto shape = type.getShape();
|
||||
uint64_t valueIndex = 0;
|
||||
uint64_t dimMultiplier = 1;
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
valueIndex += index[i] * dimMultiplier;
|
||||
dimMultiplier *= shape[i];
|
||||
}
|
||||
return valueIndex;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseElementAttr Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -779,25 +798,6 @@ DenseElementsAttr DenseElementsAttr::mapValues(
|
||||
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
|
||||
}
|
||||
|
||||
/// Returns the 1 dimenional flattened index from the given multi-dimensional
|
||||
/// index.
|
||||
uint64_t DenseElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
|
||||
assert(isValidIndex(index) && "expected valid multi-dimensional index");
|
||||
auto type = getType();
|
||||
|
||||
// Reduce the provided multidimensional index into a flattended 1D row-major
|
||||
// index.
|
||||
auto rank = type.getRank();
|
||||
auto shape = type.getShape();
|
||||
uint64_t valueIndex = 0;
|
||||
uint64_t dimMultiplier = 1;
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
valueIndex += index[i] * dimMultiplier;
|
||||
dimMultiplier *= shape[i];
|
||||
}
|
||||
return valueIndex;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseFPElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -938,15 +938,6 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||
assert(isValidIndex(index) && "expected valid multi-dimensional index");
|
||||
auto type = getType();
|
||||
|
||||
/// Return an attribute corresponding to '0' for the element type.
|
||||
auto getZeroAttr = [=]() -> Attribute {
|
||||
auto eltType = type.getElementType();
|
||||
if (eltType.isa<FloatType>())
|
||||
return FloatAttr::get(eltType, 0);
|
||||
assert(eltType.isa<IntegerType>() && "unexpected element type");
|
||||
return IntegerAttr::get(eltType, 0);
|
||||
};
|
||||
|
||||
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
|
||||
// as a 1-D index array.
|
||||
auto sparseIndices = getIndices();
|
||||
@ -983,6 +974,58 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||
return getValues().getValue(it->second);
|
||||
}
|
||||
|
||||
/// Get a zero APFloat for the given sparse attribute.
|
||||
APFloat SparseElementsAttr::getZeroAPFloat() const {
|
||||
auto eltType = getType().getElementType().cast<FloatType>();
|
||||
return APFloat(eltType.getFloatSemantics());
|
||||
}
|
||||
|
||||
/// Get a zero APInt for the given sparse attribute.
|
||||
APInt SparseElementsAttr::getZeroAPInt() const {
|
||||
auto eltType = getType().getElementType().cast<IntegerType>();
|
||||
return APInt::getNullValue(eltType.getWidth());
|
||||
}
|
||||
|
||||
/// Get a zero attribute for the given attribute type.
|
||||
Attribute SparseElementsAttr::getZeroAttr() const {
|
||||
auto eltType = getType().getElementType();
|
||||
|
||||
// Handle floating point elements.
|
||||
if (eltType.isa<FloatType>())
|
||||
return FloatAttr::get(eltType, 0);
|
||||
|
||||
// Otherwise, this is an integer.
|
||||
auto intEltTy = eltType.cast<IntegerType>();
|
||||
if (intEltTy.getWidth() == 1)
|
||||
return BoolAttr::get(false, eltType.getContext());
|
||||
return IntegerAttr::get(eltType, 0);
|
||||
}
|
||||
|
||||
/// Flatten, and return, all of the sparse indices in this attribute in
|
||||
/// row-major order.
|
||||
std::vector<intptr_t> SparseElementsAttr::getFlattenedSparseIndices() const {
|
||||
std::vector<intptr_t> flatSparseIndices;
|
||||
|
||||
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
|
||||
// as a 1-D index array.
|
||||
auto sparseIndices = getIndices();
|
||||
auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
|
||||
if (sparseIndices.isSplat()) {
|
||||
SmallVector<uint64_t, 8> indices(getType().getRank(),
|
||||
*sparseIndexValues.begin());
|
||||
flatSparseIndices.push_back(getFlattenedIndex(indices));
|
||||
return flatSparseIndices;
|
||||
}
|
||||
|
||||
// Otherwise, reinterpret each index as an ArrayRef when flattening.
|
||||
auto numSparseIndices = sparseIndices.getType().getDimSize(0);
|
||||
size_t rank = getType().getRank();
|
||||
for (size_t i = 0, e = numSparseIndices; i != e; ++i)
|
||||
flatSparseIndices.push_back(getFlattenedIndex(
|
||||
{&*std::next(sparseIndexValues.begin(), i * rank), rank}));
|
||||
return flatSparseIndices;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NamedAttributeList
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -90,12 +90,12 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
|
||||
splatAttr.getSplatValue(), loc);
|
||||
return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
|
||||
}
|
||||
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
|
||||
if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
|
||||
auto *vectorType = cast<llvm::VectorType>(llvmType);
|
||||
SmallVector<llvm::Constant *, 8> constants;
|
||||
uint64_t numElements = vectorType->getNumElements();
|
||||
constants.reserve(numElements);
|
||||
for (auto n : denseAttr.getAttributeValues()) {
|
||||
for (auto n : elementsAttr.getValues<Attribute>()) {
|
||||
constants.push_back(
|
||||
getLLVMConstant(vectorType->getElementType(), n, loc));
|
||||
if (!constants.back())
|
||||
|
Loading…
x
Reference in New Issue
Block a user