From 8a960ef4b3e85bf442ce8aa4c7e164f4fe55414a Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 21 Aug 2019 10:23:14 -0700 Subject: [PATCH] Add iterator support to ElementsAttr and SparseElementsAttr. This will allow iterating the values of a non-opaque ElementsAttr, with all of the types currently supported by DenseElementsAttr. This should help reduce the amount of specialization on DenseElementsAttr. PiperOrigin-RevId: 264637293 --- tensorflow/compiler/mlir/xla/ir/xla_ops.cc | 2 +- third_party/mlir/include/mlir/IR/Attributes.h | 223 +++++++++++++++++- third_party/mlir/lib/IR/Attributes.cpp | 99 +++++--- .../lib/Target/LLVMIR/ModuleTranslation.cpp | 4 +- 4 files changed, 293 insertions(+), 35 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.cc b/tensorflow/compiler/mlir/xla/ir/xla_ops.cc index 36a21bdf5eb..cbc6ab475dd 100644 --- a/tensorflow/compiler/mlir/xla/ir/xla_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/xla_ops.cc @@ -229,7 +229,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// OpFoldResult TransposeOp::fold(ArrayRef operands) { - for (auto it : llvm::enumerate(permutation().cast())) { + for (auto it : llvm::enumerate(permutation().getValues())) { if (it.index() != it.value()) { return {}; } diff --git a/third_party/mlir/include/mlir/IR/Attributes.h b/third_party/mlir/include/mlir/IR/Attributes.h index 824ec7afa0e..2d5f689a89f 100644 --- a/third_party/mlir/include/mlir/IR/Attributes.h +++ b/third_party/mlir/include/mlir/IR/Attributes.h @@ -20,6 +20,7 @@ #include "mlir/IR/AttributeSupport.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/Sequence.h" namespace mlir { class AffineMap; @@ -447,11 +448,18 @@ public: // Elements Attributes //===----------------------------------------------------------------------===// +namespace detail { +template class ElementsAttrIterator; +template class ElementsAttrRange; +} // namespace detail + /// A base attribute that represents a reference to a static shaped tensor or /// vector constant. class ElementsAttr : public Attribute { public: using Attribute::Attribute; + template using iterator = detail::ElementsAttrIterator; + template using iterator_range = detail::ElementsAttrRange; /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor /// with static shape. @@ -467,6 +475,11 @@ public: return getValue(index).template cast(); } + /// Return the elements of this attribute as a value of type 'T'. Note: + /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support + /// iteration. + template iterator_range getValues() const; + /// Return if the given 'index' refers to a valid element in this attribute. bool isValidIndex(ArrayRef index) const; @@ -492,6 +505,11 @@ public: return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR && attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR; } + +protected: + /// Returns the 1 dimenional flattened row-major index from the given + /// multi-dimensional index. + uint64_t getFlattenedIndex(ArrayRef index) const; }; namespace detail { @@ -853,10 +871,6 @@ protected: /// the current attribute. This method is used to verify specific type /// invariants that the templatized 'getValues' method cannot. bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const; - - /// Returns the 1 dimenional flattened index from the given multi-dimensional - /// index. - uint64_t getFlattenedIndex(ArrayRef index) const; }; /// An attribute that represents a reference to a dense float vector or tensor @@ -964,6 +978,11 @@ class SparseElementsAttr public: using Base::Base; + template + using iterator = + llvm::mapped_iterator, + std::function>; + /// 'type' must be a vector or tensor with static shape. static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices, DenseElementsAttr values); @@ -972,6 +991,25 @@ public: DenseElementsAttr getValues() const; + /// Return the values of this attribute in the form of the given type 'T'. 'T' + /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc. + template llvm::iterator_range> getValues() const { + auto zeroValue = getZeroValue(); + auto valueIt = getValues().getValues().begin(); + std::vector flatSparseIndices = getFlattenedSparseIndices(); + // TODO(riverriddle): Move-capture flatSparseIndices when c++14 is + // available. + std::function mapFn = [=](ptrdiff_t index) { + // Try to map the current index to one of the sparse indices. + for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i) + if (flatSparseIndices[i] == index) + return *std::next(valueIt, i); + // Otherwise, return the zero value. + return zeroValue; + }; + return llvm::map_range(llvm::seq(0, getNumElements()), mapFn); + } + /// Return the value of the element at the given index. The 'index' is /// expected to refer to a valid element. Attribute getValue(ArrayRef index) const; @@ -980,6 +1018,49 @@ public: static bool kindof(unsigned kind) { return kind == StandardAttributes::SparseElements; } + +private: + /// Get a zero APFloat for the given sparse attribute. + APFloat getZeroAPFloat() const; + + /// Get a zero APInt for the given sparse attribute. + APInt getZeroAPInt() const; + + /// Get a zero attribute for the given sparse attribute. + Attribute getZeroAttr() const; + + /// Utility methods to generate a zero value of some type 'T'. This is used by + /// the 'iterator' class. + /// Get a zero for a given attribute type. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return getZeroAttr().template cast(); + } + /// Get a zero for an APInt. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return getZeroAPInt(); + } + /// Get a zero for an APFloat. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return getZeroAPFloat(); + } + /// Get a zero for an C++ integer or float type. + template + typename std::enable_if::is_integer || + llvm::is_one_of::value, + T>::type + getZeroValue() const { + return T(0); + } + + /// Flatten, and return, all of the sparse indices in this attribute in + /// row-major order. + std::vector getFlattenedSparseIndices() const; }; /// An attribute that represents a reference to a splat vector or tensor @@ -995,6 +1076,136 @@ public: } }; +namespace detail { +/// This class represents a general iterator over the values of an ElementsAttr. +/// It supports all subclasses aside from OpaqueElementsAttr. +template +class ElementsAttrIterator + : public llvm::iterator_facade_base, + std::random_access_iterator_tag, T, + std::ptrdiff_t, T, T> { + using DenseIteratorT = + decltype(std::declval().getValues().begin()); + using SparseIteratorT = SparseElementsAttr::iterator; + + /// A union containing the specific iterators for each derived attribute kind. + union Iterator { + explicit Iterator(DenseIteratorT it) : denseIt(it) {} + explicit Iterator(SparseIteratorT it) : sparseIt(it) {} + ~Iterator() {} + + operator const DenseIteratorT &() const { return denseIt; } + operator const SparseIteratorT &() const { return sparseIt; } + operator DenseIteratorT &() { return denseIt; } + operator SparseIteratorT &() { return sparseIt; } + + /// An instance of a dense elements iterator. + DenseIteratorT denseIt; + /// An instance of a sparse elements iterator. + SparseIteratorT sparseIt; + }; + + /// Utility method to process a functor on each of the internal iterator + /// types. + template class ProcessFn, + typename... Args> + RetT process(Args &... args) const { + switch (attrKind) { + case StandardAttributes::DenseElements: + return ProcessFn()(args...); + case StandardAttributes::SparseElements: + return ProcessFn()(args...); + } + llvm_unreachable("unexpected attribute kind"); + } + + /// Utility functors used to generically implement the iterators methods. + template struct PlusAssign { + void operator()(ItT &it, ptrdiff_t offset) { it += offset; } + }; + template struct Minus { + ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; } + }; + template struct MinusAssign { + void operator()(ItT &it, ptrdiff_t offset) { it -= offset; } + }; + template struct Dereference { + T operator()(ItT &it) { return *it; } + }; + template struct ConstructIter { + Iterator operator()(const ItT &it) { return Iterator(it); } + }; + +public: + ElementsAttrIterator(const ElementsAttrIterator &rhs) + : attrKind(rhs.attrKind), + it(rhs.process(rhs.it)) {} + + /// Methods necessary to support random access iteration. + ptrdiff_t operator-(const ElementsAttrIterator &rhs) const { + assert(attrKind == rhs.attrKind && "incompatible iterators"); + return process(it, rhs.it); + } + bool operator==(const ElementsAttrIterator &rhs) const { + return rhs.attrKind == attrKind && process(it, rhs.it); + } + bool operator<(const ElementsAttrIterator &rhs) const { + assert(attrKind == rhs.attrKind && "incompatible iterators"); + return process(it, rhs.it); + } + ElementsAttrIterator &operator+=(ptrdiff_t offset) { + process(it, offset); + return *this; + } + ElementsAttrIterator &operator-=(ptrdiff_t offset) { + process(it, offset); + return *this; + } + + /// Dereference the iterator at the current index. + T operator*() { return process(it); } + +private: + template + ElementsAttrIterator(unsigned attrKind, IteratorT it) + : attrKind(attrKind), it(it) {} + + /// Allow accessing the constructor. + friend ElementsAttr; + + /// The kind of derived elements attribute. + unsigned attrKind; + + /// A union containing the specific iterators for each derived kind. + Iterator it; +}; + +template +class ElementsAttrRange : public llvm::iterator_range> { + using llvm::iterator_range>::iterator_range; +}; +} // namespace detail + +/// Return the elements of this attribute as a value of type 'T'. +template +auto ElementsAttr::getValues() const -> iterator_range { + if (DenseElementsAttr denseAttr = dyn_cast()) { + auto values = denseAttr.getValues(); + return {iterator(getKind(), values.begin()), + iterator(getKind(), values.end())}; + } + if (SparseElementsAttr sparseAttr = dyn_cast()) { + auto values = sparseAttr.getValues(); + return {iterator(getKind(), values.begin()), + iterator(getKind(), values.end())}; + } + llvm_unreachable("unexpected attribute kind"); +} + +//===----------------------------------------------------------------------===// +// Attributes Utils +//===----------------------------------------------------------------------===// + template bool Attribute::isa() const { assert(impl && "isa<> used on a null attribute."); return U::classof(*this); @@ -1015,6 +1226,10 @@ inline ::llvm::hash_code hash_value(Attribute arg) { return ::llvm::hash_value(arg.impl); } +//===----------------------------------------------------------------------===// +// NamedAttributeList +//===----------------------------------------------------------------------===// + /// A NamedAttributeList is used to manage a list of named attributes. This /// provides simple interfaces for adding/removing/finding attributes from /// within a DictionaryAttr. diff --git a/third_party/mlir/lib/IR/Attributes.cpp b/third_party/mlir/lib/IR/Attributes.cpp index a8101a28990..82df80bde4f 100644 --- a/third_party/mlir/lib/IR/Attributes.cpp +++ b/third_party/mlir/lib/IR/Attributes.cpp @@ -415,6 +415,25 @@ ElementsAttr ElementsAttr::mapValues( } } +/// Returns the 1 dimenional flattened row-major index from the given +/// multi-dimensional index. +uint64_t ElementsAttr::getFlattenedIndex(ArrayRef index) const { + assert(isValidIndex(index) && "expected valid multi-dimensional index"); + auto type = getType(); + + // Reduce the provided multidimensional index into a flattended 1D row-major + // index. + auto rank = type.getRank(); + auto shape = type.getShape(); + uint64_t valueIndex = 0; + uint64_t dimMultiplier = 1; + for (int i = rank - 1; i >= 0; --i) { + valueIndex += index[i] * dimMultiplier; + dimMultiplier *= shape[i]; + } + return valueIndex; +} + //===----------------------------------------------------------------------===// // DenseElementAttr Utilities //===----------------------------------------------------------------------===// @@ -779,25 +798,6 @@ DenseElementsAttr DenseElementsAttr::mapValues( return cast().mapValues(newElementType, mapping); } -/// Returns the 1 dimenional flattened index from the given multi-dimensional -/// index. -uint64_t DenseElementsAttr::getFlattenedIndex(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - auto type = getType(); - - // Reduce the provided multidimensional index into a flattended 1D row-major - // index. - auto rank = type.getRank(); - auto shape = type.getShape(); - uint64_t valueIndex = 0; - uint64_t dimMultiplier = 1; - for (int i = rank - 1; i >= 0; --i) { - valueIndex += index[i] * dimMultiplier; - dimMultiplier *= shape[i]; - } - return valueIndex; -} - //===----------------------------------------------------------------------===// // DenseFPElementsAttr //===----------------------------------------------------------------------===// @@ -938,15 +938,6 @@ Attribute SparseElementsAttr::getValue(ArrayRef index) const { assert(isValidIndex(index) && "expected valid multi-dimensional index"); auto type = getType(); - /// Return an attribute corresponding to '0' for the element type. - auto getZeroAttr = [=]() -> Attribute { - auto eltType = type.getElementType(); - if (eltType.isa()) - return FloatAttr::get(eltType, 0); - assert(eltType.isa() && "unexpected element type"); - return IntegerAttr::get(eltType, 0); - }; - // The sparse indices are 64-bit integers, so we can reinterpret the raw data // as a 1-D index array. auto sparseIndices = getIndices(); @@ -983,6 +974,58 @@ Attribute SparseElementsAttr::getValue(ArrayRef index) const { return getValues().getValue(it->second); } +/// Get a zero APFloat for the given sparse attribute. +APFloat SparseElementsAttr::getZeroAPFloat() const { + auto eltType = getType().getElementType().cast(); + return APFloat(eltType.getFloatSemantics()); +} + +/// Get a zero APInt for the given sparse attribute. +APInt SparseElementsAttr::getZeroAPInt() const { + auto eltType = getType().getElementType().cast(); + return APInt::getNullValue(eltType.getWidth()); +} + +/// Get a zero attribute for the given attribute type. +Attribute SparseElementsAttr::getZeroAttr() const { + auto eltType = getType().getElementType(); + + // Handle floating point elements. + if (eltType.isa()) + return FloatAttr::get(eltType, 0); + + // Otherwise, this is an integer. + auto intEltTy = eltType.cast(); + if (intEltTy.getWidth() == 1) + return BoolAttr::get(false, eltType.getContext()); + return IntegerAttr::get(eltType, 0); +} + +/// Flatten, and return, all of the sparse indices in this attribute in +/// row-major order. +std::vector SparseElementsAttr::getFlattenedSparseIndices() const { + std::vector flatSparseIndices; + + // The sparse indices are 64-bit integers, so we can reinterpret the raw data + // as a 1-D index array. + auto sparseIndices = getIndices(); + auto sparseIndexValues = sparseIndices.getValues(); + if (sparseIndices.isSplat()) { + SmallVector indices(getType().getRank(), + *sparseIndexValues.begin()); + flatSparseIndices.push_back(getFlattenedIndex(indices)); + return flatSparseIndices; + } + + // Otherwise, reinterpret each index as an ArrayRef when flattening. + auto numSparseIndices = sparseIndices.getType().getDimSize(0); + size_t rank = getType().getRank(); + for (size_t i = 0, e = numSparseIndices; i != e; ++i) + flatSparseIndices.push_back(getFlattenedIndex( + {&*std::next(sparseIndexValues.begin(), i * rank), rank})); + return flatSparseIndices; +} + //===----------------------------------------------------------------------===// // NamedAttributeList //===----------------------------------------------------------------------===// diff --git a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index bea22c9753c..e872794d426 100644 --- a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -90,12 +90,12 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, splatAttr.getSplatValue(), loc); return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child); } - if (auto denseAttr = attr.dyn_cast()) { + if (auto elementsAttr = attr.dyn_cast()) { auto *vectorType = cast(llvmType); SmallVector constants; uint64_t numElements = vectorType->getNumElements(); constants.reserve(numElements); - for (auto n : denseAttr.getAttributeValues()) { + for (auto n : elementsAttr.getValues()) { constants.push_back( getLLVMConstant(vectorType->getElementType(), n, loc)); if (!constants.back())