Add read-only copy of MLIR core to TensorFlow.

PiperOrigin-RevId: 260575790
This commit is contained in:
Christian Sigg 2019-07-29 14:06:20 -07:00 committed by TensorFlower Gardener
parent 7191d4f3b6
commit 366ddc8948
442 changed files with 103585 additions and 46 deletions

View File

@ -1,3 +1,4 @@
llvm/llvm/projects/google_mlir/WORKSPACE
tensorflow/contrib/mpi/BUILD
tensorflow/stream_executor/build_defs.bzl
tensorflow/python/autograph/core/config.py
@ -189,11 +190,6 @@ tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/mlir/BUILD
tensorflow/third_party/mlir/mlir_configure.bzl
tensorflow/third_party/mlir/bindings/python/BUILD
tensorflow/third_party/mlir/test/BUILD
tensorflow/third_party/mlir/tblgen.bzl
tensorflow/third_party/gast.BUILD
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py

View File

@ -7,7 +7,6 @@ load("//third_party/nccl:nccl_configure.bzl", "nccl_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure")
load("//third_party/mlir:mlir_configure.bzl", "mlir_configure")
load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure")
load("//third_party/toolchains/remote:configure.bzl", "remote_execution_configure")
@ -74,7 +73,10 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
syslibs_configure(name = "local_config_syslibs")
python_configure(name = "local_config_python")
rocm_configure(name = "local_config_rocm")
mlir_configure(name = "local_config_mlir")
native.local_repository(
name = "local_config_mlir",
path = "third_party/mlir",
)
remote_execution_configure(name = "local_config_remote_execution")
initialize_third_party()

2
third_party/mlir/.clang-format vendored Normal file
View File

@ -0,0 +1,2 @@
BasedOnStyle: LLVM
AlwaysBreakTemplateDeclarations: Yes

View File

@ -42,6 +42,7 @@ cc_library(
"lib/IR/Diagnostics.cpp",
"lib/IR/Dialect.cpp",
"lib/IR/Function.cpp",
"lib/IR/FunctionSupport.cpp",
"lib/IR/IntegerSet.cpp",
"lib/IR/IntegerSetDetail.h",
"lib/IR/Location.cpp",
@ -1341,9 +1342,9 @@ cc_binary(
":StandardDialectRegistration",
":Transforms",
":VectorDialectRegistration",
"//test:TestDialect",
"//test:TestTransforms",
"@llvm//:support",
"@local_config_mlir//test:TestDialect",
"@local_config_mlir//test:TestTransforms",
],
)

63
third_party/mlir/CMakeLists.txt vendored Normal file
View File

@ -0,0 +1,63 @@
# MLIR project.
set(MLIR_MAIN_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include ) # --src-root
set(MLIR_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/include ) # --includedir
set(MLIR_TABLEGEN_EXE mlir-tblgen)
set(MLIR_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(MLIR_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
# TODO: Temporary, remove when no longer needed.
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
function(mlir_tablegen ofn)
tablegen(MLIR ${ARGV} "-I${MLIR_MAIN_SRC_DIR}" "-I${MLIR_INCLUDE_DIR}")
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
PARENT_SCOPE)
endfunction()
# TODO: This is to handle the current static registration, but should be
# factored out a bit.
function(whole_archive_link target)
if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
set(link_flags "-L${CMAKE_BINARY_DIR}/lib ")
FOREACH(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "-Wl,-force_load ${CMAKE_BINARY_DIR}/lib/lib${LIB}.a ")
ENDFOREACH(LIB)
elseif(MSVC)
FOREACH(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "/WHOLEARCHIVE:${LIB} ")
ENDFOREACH(LIB)
else()
set(link_flags "-L${CMAKE_BINARY_DIR}/lib -Wl,--whole-archive,")
FOREACH(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "-l${LIB},")
ENDFOREACH(LIB)
string(CONCAT link_flags ${link_flags} "--no-whole-archive")
endif()
set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags})
endfunction(whole_archive_link)
# Build the CUDA conversions and run according tests if the NVPTX backend
# is available
if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
set(MLIR_CUDA_CONVERSIONS_ENABLED 1)
else()
set(MLIR_CUDA_CONVERSIONS_ENABLED 0)
endif()
set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner")
include_directories( "include")
include_directories( ${MLIR_INCLUDE_DIR})
add_subdirectory(include/mlir)
add_subdirectory(lib)
add_subdirectory(tools)
add_subdirectory(unittests)
add_subdirectory(test)
if( LLVM_INCLUDE_EXAMPLES )
add_subdirectory(examples)
endif()

49
third_party/mlir/CONTRIBUTING.md vendored Normal file
View File

@ -0,0 +1,49 @@
# How to Contribute
Everyone is welcome to contribute to MLIR. There are several ways of getting involved and contributing including reporting bugs, improving documentation, writing models or tutorials.
Please read our [Code of Conduct](https://github.com/tensorflow/tensorflow/blob/master/CODE_OF_CONDUCT.md) before participating.
## Community Guidelines
This project follows [Google's Open Source Community
Guidelines](https://opensource.google.com/conduct/).
## How to become a contributor and submit your own code
### Contributor License Agreements
We'd love to accept your patches! Before we can take them, please fill out either the individual or corporate Contributor License Agreement (CLA).
* If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://code.google.com/legal/individual-cla-v1.0.html).
* If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://code.google.com/legal/corporate-cla-v1.0.html).
Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests.
***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository.
### Contributing code
If you have improvements to MLIR, send us your pull requests! For those
just getting started, GitHub has a [howto](https://help.github.com/articles/using-pull-requests/).
MLIR team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, a team member will merge your pull request submitted to our internal repository. After the change has been submitted internally, your pull request will be merged automatically on GitHub.
If you want to contribute, start working through the MLIR codebase, navigate to [Github "issues" tab](https://github.com/tensorflow/mlir/issues) and start looking through interesting issues. If you decide to start on an issue, leave a comment so that other people know that you're working on it. If you want to help out, but not alone, use the issue comment thread to coordinate.
### Contribution guidelines and standards
* Read the [developer guide](g3doc/DeveloperGuide.md).
* Ensure that you use the correct license. Examples are provided below.
* Include tests when you contribute new features, as they help to a)
prove that your code works correctly, and b) guard against future breaking
changes to lower the maintenance cost.
* Bug fixes also generally require tests, because the presence of bugs
usually indicates insufficient test coverage.
#### License
Include a license at the top of new files.
* [C/C++ license example](https://github.com/tensorflow/mlir/blob/master/examples/toy/Ch1/toyc.cpp)
* [Python license example](https://github.com/tensorflow/mlir/blob/master/bindings/python/test/test_py2and3.py)

205
third_party/mlir/LICENSE.TXT vendored Normal file
View File

@ -0,0 +1,205 @@
Copyright 2019 The MLIR Authors.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

132
third_party/mlir/README.md vendored Normal file
View File

@ -0,0 +1,132 @@
# Multi-Level Intermediate Representation Overview
The MLIR project aims to define a common intermediate representation (IR) that
will unify the infrastructure required to execute high performance machine
learning models in TensorFlow and similar ML frameworks. This project will
include the application of HPC techniques, along with integration of search
algorithms like reinforcement learning. This project aims to reduce the cost to
bring up new hardware, and improve usability for existing TensorFlow users.
Note that this repository contains the core of the MLIR framework. The
TensorFlow compilers we are building on top of MLIR will be part of the
main TensorFlow repository soon.
# How to Contribute
Thank you for your interest in contributing to MLIR! If you want to contribute
to MLIR, be sure to review the [contribution guidelines](CONTRIBUTING.md).
## More resources
For more information on MLIR, please see:
* [The MLIR draft specification](g3doc/LangRef.md), which describes the IR
itself.
* [The MLIR rationale document](g3doc/Rationale.md), covering motivation
behind some decisions.
* Previous external [talks](#mlir-talks).
Join the [MLIR mailing list](https://groups.google.com/a/tensorflow.org/forum/#!forum/mlir)
to hear about announcements and discussions.
Please be mindful of the [TensorFlow Code of Conduct](https://github.com/tensorflow/tensorflow/blob/master/CODE_OF_CONDUCT.md),
which pledges to foster an open and welcoming environment.
## What is MLIR for?
MLIR is intended to be a hybrid IR which can support multiple different
requirements in a unified infrastructure. For example, this includes:
* The ability to represent all TensorFlow graphs, including dynamic shapes,
the user-extensible op ecosystem, TensorFlow variables, etc.
* Optimizations and transformations typically done on a TensorFlow graph, e.g.
in Grappler.
* Quantization and other graph transformations done on a TensorFlow graph or
the TF Lite representation.
* Representation of kernels for ML operations in a form suitable for
optimization.
* Ability to host high-performance-computing-style loop optimizations across
kernels (fusion, loop interchange, tiling, etc) and to transform memory
layouts of data.
* Code generation "lowering" transformations such as DMA insertion, explicit
cache management, memory tiling, and vectorization for 1D and 2D register
architectures.
* Ability to represent target-specific operations, e.g. the MXU on TPUs.
MLIR is a common IR that also supports hardware specific operations. Thus,
any investment into the infrastructure surrounding MLIR (e.g. the compiler
passes that work on it) should yield good returns; many targets can use that
infrastructure and will benefit from it.
MLIR is a powerful representation, but it also has non-goals. We do not try to
support low level machine code generation algorithms (like register allocation
and instruction scheduling). They are a better fit for lower level optimizers
(such as LLVM). Also, we do not intend MLIR to be a source language that
end-users would themselves write kernels in (analogous to CUDA C++). While we
would love to see a kernel language happen someday, that will be an independent
project that compiles down to MLIR.
## Compiler infrastructure
We benefited from experience gained from building other IRs (HLO, LLVM and SIL)
when building MLIR. We will directly adopt existing best practices, e.g. writing
and maintaining an IR spec, building an IR verifier, providing the ability to
dump and parse MLIR files to text, writing extensive unit tests with the
[FileCheck](https://llvm.org/docs/CommandGuide/FileCheck.html) tool, and
building the infrastructure as a set of modular libraries that can be combined
in new ways. We plan to use the infrastructure developed by the XLA team for
performance analysis and benchmarking.
Other lessons have been incorporated and integrated into the design in subtle
ways. For example, LLVM has non-obvious design mistakes that prevent a
multithreaded compiler from working on multiple functions in an LLVM module at
the same time. MLIR solves these problems by having per-function constant pools
and by making references explicit with `function_ref`.
# Getting started with MLIR
The following instructions for compiling and testing MLIR assume that you have
`git`, [`ninja`](https://ninja-build.org/), and a working C++ toolchain. In the
future, we aim to align on the same level of platform support as
[LLVM](https://llvm.org/docs/GettingStarted.html#requirements). For now, MLIR
has been tested on Linux and macOS, with recent versions of clang and with
gcc 7.
```sh
git clone https://github.com/llvm/llvm-project.git
git clone https://github.com/tensorflow/mlir llvm-project/llvm/projects/mlir
mkdir llvm-project/build
cd llvm-project/build
cmake -G Ninja ../llvm -DLLVM_BUILD_EXAMPLES=ON -DLLVM_ENABLE_CXX1Y=Y -DLLVM_TARGETS_TO_BUILD="host"
cmake --build . --target check-mlir
```
To compile and test on Windows using Visual Studio 2017:
```bat
REM In shell with Visual Studio environment set up, e.g., with command such as
REM <visual-studio-install>\Auxiliary\Build\vcvarsall.bat" x64
REM invoked.
git clone https://github.com/llvm/llvm-project.git
git clone https://github.com/tensorflow/mlir llvm-project\llvm\projects\mlir
mkdir llvm-project\build
cd llvm-project\build
cmake ..\llvm -G "Visual Studio 15 2017 Win64" -DLLVM_BUILD_EXAMPLES=ON -DLLVM_ENABLE_CXX1Y=Y -DLLVM_TARGETS_TO_BUILD="host" -DCMAKE_BUILD_TYPE=Release -Thost=x64
cmake --build . --target check-mlir
```
As a starter, you may try [the tutorial](g3doc/Tutorials/Toy/Ch-1.md) on
building a compiler for a Toy language.
# MLIR talks
* "[MLIR Primer: A Compiler Infrastructure for the End of Moores Law](https://ai.google/research/pubs/pub48035.pdf)"
* Chris Lattner & Jacques Pienaar, Google at
[Compilers for Machine Learning](https://www.c4ml.org/) workshop at
[CGO 2019](http://cgo.org/cgo2019/)
* "[MLIR: Multi-Level Intermediate Representation for Compiler
Infrastructure](https://llvm.org/devmtg/2019-04/talks.html#Keynote_1)"
* Tatiana Shpeisman & Chris Lattner, Google at
[EuroLLVM 2019](https://llvm.org/devmtg/2019-04)
* "[Tutorial: Building a Compiler with MLIR](https://llvm.org/devmtg/2019-04/talks.html#Tutorial_1)"
* Mehdi Amini, Jacques Pienaar, Nicolas Vasilache, Google at
[EuroLLVM 2019](https://llvm.org/devmtg/2019-04)

0
third_party/mlir/WORKSPACE vendored Normal file
View File

View File

@ -27,8 +27,9 @@ py_extension(
features = ["-use_header_modules"],
module_name = "pybind",
deps = [
"@llvm//:ir",
"@llvm//:support",
"//third_party/llvm/llvm:ir",
"//third_party/llvm/llvm:support",
"//third_party/pybind11",
"@local_config_mlir//:EDSC",
"@local_config_mlir//:ExecutionEngine",
"@local_config_mlir//:IR",
@ -37,6 +38,5 @@ py_extension(
"@local_config_mlir//:StandardDialectRegistration",
"@local_config_mlir//:TargetLLVMIR",
"@local_config_mlir//:Transforms",
"@pybind11",
],
)

View File

@ -0,0 +1,932 @@
//===- pybind.cpp - MLIR Python bindings ----------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <cstddef>
#include <unordered_map>
#include "mlir-c/Core.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/Passes.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
static bool inited = [] {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
return true;
}();
namespace mlir {
namespace edsc {
namespace python {
namespace py = pybind11;
struct PythonAttribute;
struct PythonAttributedType;
struct PythonBindable;
struct PythonExpr;
struct PythonFunctionContext;
struct PythonStmt;
struct PythonBlock;
struct PythonType {
PythonType() : type{nullptr} {}
PythonType(mlir_type_t t) : type{t} {}
operator mlir_type_t() const { return type; }
PythonAttributedType attachAttributeDict(
const std::unordered_map<std::string, PythonAttribute> &attrs) const;
std::string str() {
mlir::Type f = mlir::Type::getFromOpaquePointer(type);
std::string res;
llvm::raw_string_ostream os(res);
f.print(os);
return res;
}
mlir_type_t type;
};
struct PythonValueHandle {
PythonValueHandle(PythonType type)
: value(mlir::Type::getFromOpaquePointer(type.type)) {}
PythonValueHandle(const PythonValueHandle &other) = default;
PythonValueHandle(const mlir::edsc::ValueHandle &other) : value(other) {}
operator ValueHandle() const { return value; }
operator ValueHandle &() { return value; }
std::string str() const {
return std::to_string(reinterpret_cast<intptr_t>(value.getValue()));
}
PythonValueHandle call(const std::vector<PythonValueHandle> &args) {
assert(value.hasType() && value.getType().isa<FunctionType>() &&
"can only call function-typed values");
std::vector<Value *> argValues;
argValues.reserve(args.size());
for (auto arg : args)
argValues.push_back(arg.value.getValue());
return ValueHandle::create<CallIndirectOp>(value, argValues);
}
mlir::edsc::ValueHandle value;
};
struct PythonFunction {
PythonFunction() : function{nullptr} {}
PythonFunction(mlir_func_t f) : function{f} {}
PythonFunction(mlir::FuncOp f)
: function(const_cast<void *>(f.getAsOpaquePointer())) {}
operator mlir_func_t() { return function; }
std::string str() {
mlir::FuncOp f = mlir::FuncOp::getFromOpaquePointer(function);
std::string res;
llvm::raw_string_ostream os(res);
f.print(os);
return res;
}
// If the function does not yet have an entry block, i.e. if it is a function
// declaration, add the entry block, transforming the declaration into a
// definition. Return true if the block was added, false otherwise.
bool define() {
auto f = mlir::FuncOp::getFromOpaquePointer(function);
if (!f.getBlocks().empty())
return false;
f.addEntryBlock();
return true;
}
PythonValueHandle arg(unsigned index) {
auto f = mlir::FuncOp::getFromOpaquePointer(function);
assert(index < f.getNumArguments() && "argument index out of bounds");
return PythonValueHandle(ValueHandle(f.getArgument(index)));
}
mlir_func_t function;
};
/// Trivial C++ wrappers make use of the EDSC C API.
struct PythonMLIRModule {
PythonMLIRModule()
: mlirContext(),
module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&mlirContext))),
moduleManager(*module) {}
PythonType makeScalarType(const std::string &mlirElemType,
unsigned bitwidth) {
return ::makeScalarType(mlir_context_t{&mlirContext}, mlirElemType.c_str(),
bitwidth);
}
PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) {
return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
int64_list_t{sizes.data(), sizes.size()});
}
PythonType makeIndexType() {
return ::makeIndexType(mlir_context_t{&mlirContext});
}
// Declare a function with the given name, input types and their attributes,
// output types, and function attributes, but do not define it.
PythonFunction declareFunction(const std::string &name,
const py::list &inputs,
const std::vector<PythonType> &outputTypes,
const py::kwargs &funcAttributes);
// Declare a function with the given name, input types and their attributes,
// output types, and function attributes.
PythonFunction makeFunction(const std::string &name, const py::list &inputs,
const std::vector<PythonType> &outputTypes,
const py::kwargs &funcAttributes) {
auto declaration =
declareFunction(name, inputs, outputTypes, funcAttributes);
declaration.define();
return declaration;
}
// Create a custom op given its name and arguments.
PythonExpr op(const std::string &name, PythonType type,
const py::list &arguments, const py::list &successors,
py::kwargs attributes);
// Create an integer attribute.
PythonAttribute integerAttr(PythonType type, int64_t value);
// Create a boolean attribute.
PythonAttribute boolAttr(bool value);
void compile() {
PassManager manager;
manager.addPass(mlir::createCanonicalizerPass());
manager.addPass(mlir::createCSEPass());
manager.addPass(mlir::createLowerAffinePass());
manager.addPass(mlir::createConvertToLLVMIRPass());
if (failed(manager.run(*module))) {
llvm::errs() << "conversion to the LLVM IR dialect failed\n";
return;
}
auto created = mlir::ExecutionEngine::create(*module);
llvm::handleAllErrors(created.takeError(),
[](const llvm::ErrorInfoBase &b) {
b.log(llvm::errs());
assert(false);
});
engine = std::move(*created);
}
std::string getIR() {
std::string res;
llvm::raw_string_ostream os(res);
module->print(os);
return res;
}
uint64_t getEngineAddress() {
assert(engine && "module must be compiled into engine first");
return reinterpret_cast<uint64_t>(reinterpret_cast<void *>(engine.get()));
}
PythonFunction getNamedFunction(const std::string &name) {
return moduleManager.lookupSymbol<FuncOp>(name);
}
PythonFunctionContext
makeFunctionContext(const std::string &name, const py::list &inputs,
const std::vector<PythonType> &outputs,
const py::kwargs &attributes);
private:
mlir::MLIRContext mlirContext;
// One single module in a python-exposed MLIRContext for now.
mlir::OwningModuleRef module;
mlir::ModuleManager moduleManager;
std::unique_ptr<mlir::ExecutionEngine> engine;
};
struct PythonFunctionContext {
PythonFunctionContext(PythonFunction f) : function(f) {}
PythonFunctionContext(PythonMLIRModule &module, const std::string &name,
const py::list &inputs,
const std::vector<PythonType> &outputs,
const py::kwargs &attributes) {
auto function = module.declareFunction(name, inputs, outputs, attributes);
function.define();
}
PythonFunction enter() {
assert(function.function && "function is not set up");
auto mlirFunc = mlir::FuncOp::getFromOpaquePointer(function.function);
contextBuilder.emplace(mlirFunc.getBody());
context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc());
return function;
}
void exit(py::object, py::object, py::object) {
delete context;
context = nullptr;
contextBuilder.reset();
}
PythonFunction function;
mlir::edsc::ScopedContext *context;
llvm::Optional<OpBuilder> contextBuilder;
};
PythonFunctionContext PythonMLIRModule::makeFunctionContext(
const std::string &name, const py::list &inputs,
const std::vector<PythonType> &outputs, const py::kwargs &attributes) {
auto func = declareFunction(name, inputs, outputs, attributes);
func.define();
return PythonFunctionContext(func);
}
struct PythonBlockHandle {
PythonBlockHandle() : value(nullptr) {}
PythonBlockHandle(const PythonBlockHandle &other) = default;
PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
operator mlir::edsc::BlockHandle() const { return value; }
PythonValueHandle arg(int index) { return arguments[index]; }
std::string str() {
std::string s;
llvm::raw_string_ostream os(s);
value.getBlock()->print(os);
return os.str();
}
mlir::edsc::BlockHandle value;
std::vector<mlir::edsc::ValueHandle> arguments;
};
struct PythonLoopContext {
PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step)
: lb(lb), ub(ub), step(step) {}
PythonLoopContext(const PythonLoopContext &) = delete;
PythonLoopContext(PythonLoopContext &&) = default;
PythonLoopContext &operator=(const PythonLoopContext &) = delete;
PythonLoopContext &operator=(PythonLoopContext &&) = default;
~PythonLoopContext() { assert(!builder && "did not exit from the context"); }
PythonValueHandle enter() {
ValueHandle iv(lb.value.getType());
builder = new LoopBuilder(&iv, lb.value, ub.value, step);
return iv;
}
void exit(py::object, py::object, py::object) {
(*builder)({}); // exit from the builder's scope.
delete builder;
builder = nullptr;
}
PythonValueHandle lb, ub;
int64_t step;
LoopBuilder *builder = nullptr;
};
struct PythonLoopNestContext {
PythonLoopNestContext(const std::vector<PythonValueHandle> &lbs,
const std::vector<PythonValueHandle> &ubs,
const std::vector<int64_t> steps)
: lbs(lbs), ubs(ubs), steps(steps) {
assert(lbs.size() == ubs.size() && lbs.size() == steps.size() &&
"expected the same number of lower, upper bounds, and steps");
}
PythonLoopNestContext(const PythonLoopNestContext &) = delete;
PythonLoopNestContext(PythonLoopNestContext &&) = default;
PythonLoopNestContext &operator=(const PythonLoopNestContext &) = delete;
PythonLoopNestContext &operator=(PythonLoopNestContext &&) = default;
~PythonLoopNestContext() {
assert(!builder && "did not exit from the context");
}
std::vector<PythonValueHandle> enter() {
if (steps.empty())
return {};
auto type = mlir_type_t(lbs.front().value.getType().getAsOpaquePointer());
std::vector<PythonValueHandle> handles(steps.size(),
PythonValueHandle(type));
std::vector<ValueHandle *> handlePtrs;
handlePtrs.reserve(steps.size());
for (auto &h : handles)
handlePtrs.push_back(&h.value);
builder = new LoopNestBuilder(
handlePtrs, std::vector<ValueHandle>(lbs.begin(), lbs.end()),
std::vector<ValueHandle>(ubs.begin(), ubs.end()), steps);
return handles;
}
void exit(py::object, py::object, py::object) {
(*builder)({}); // exit from the builder's scope.
delete builder;
builder = nullptr;
}
std::vector<PythonValueHandle> lbs;
std::vector<PythonValueHandle> ubs;
std::vector<int64_t> steps;
LoopNestBuilder *builder = nullptr;
};
struct PythonBlockAppender {
PythonBlockAppender(const PythonBlockHandle &handle) : handle(handle) {}
PythonBlockHandle handle;
};
struct PythonBlockContext {
public:
PythonBlockContext() {
createBlockBuilder();
clearBuilder();
}
PythonBlockContext(const std::vector<PythonType> &argTypes) {
handle.arguments.reserve(argTypes.size());
for (const auto &t : argTypes) {
auto type =
Type::getFromOpaquePointer(reinterpret_cast<const void *>(t.type));
handle.arguments.emplace_back(type);
}
createBlockBuilder();
clearBuilder();
}
PythonBlockContext(const PythonBlockAppender &a) : handle(a.handle) {}
PythonBlockContext(const PythonBlockContext &) = delete;
PythonBlockContext(PythonBlockContext &&) = default;
PythonBlockContext &operator=(const PythonBlockContext &) = delete;
PythonBlockContext &operator=(PythonBlockContext &&) = default;
~PythonBlockContext() {
assert(!builder && "did not exit from the block context");
}
// EDSC maintain an implicit stack of builders (mostly for keeping track of
// insretion points); every operation gets inserted using the top-of-the-stack
// builder. Creating a new EDSC Builder automatically puts it on the stack,
// effectively entering the block for it.
void createBlockBuilder() {
if (handle.value.getBlock()) {
builder = new BlockBuilder(handle.value, mlir::edsc::Append());
} else {
std::vector<ValueHandle *> args;
args.reserve(handle.arguments.size());
for (auto &a : handle.arguments)
args.push_back(&a);
builder = new BlockBuilder(&handle.value, args);
}
}
PythonBlockHandle enter() {
createBlockBuilder();
return handle;
}
void exit(py::object, py::object, py::object) { clearBuilder(); }
PythonBlockHandle getHandle() { return handle; }
// EDSC maintain an implicit stack of builders (mostly for keeping track of
// insretion points); every operation gets inserted using the top-of-the-stack
// builder. Calling operator() on a builder pops the builder from the stack,
// effectively resetting the insertion point to its position before we entered
// the block.
void clearBuilder() {
(*builder)({}); // exit from the builder's scope.
delete builder;
builder = nullptr;
}
PythonBlockHandle handle;
BlockBuilder *builder = nullptr;
};
struct PythonAttribute {
PythonAttribute() : attr(nullptr) {}
PythonAttribute(const mlir_attr_t &a) : attr(a) {}
PythonAttribute(const PythonAttribute &other) = default;
operator mlir_attr_t() { return attr; }
std::string str() const {
if (!attr)
return "##null attr##";
std::string res;
llvm::raw_string_ostream os(res);
Attribute::getFromOpaquePointer(reinterpret_cast<const void *>(attr))
.print(os);
return res;
}
mlir_attr_t attr;
};
struct PythonAttributedType {
PythonAttributedType() : type(nullptr) {}
PythonAttributedType(mlir_type_t t) : type(t) {}
PythonAttributedType(
PythonType t,
const std::unordered_map<std::string, PythonAttribute> &attributes =
std::unordered_map<std::string, PythonAttribute>())
: type(t), attrs(attributes) {}
operator mlir_type_t() const { return type.type; }
operator PythonType() const { return type; }
// Return a vector of named attribute descriptors. The vector owns the
// mlir_named_attr_t objects it contains, but not the names and attributes
// those objects point to (names and opaque pointers to attributes are owned
// by `this`).
std::vector<mlir_named_attr_t> getNamedAttrs() const {
std::vector<mlir_named_attr_t> result;
result.reserve(attrs.size());
for (const auto &namedAttr : attrs)
result.push_back({namedAttr.first.c_str(), namedAttr.second.attr});
return result;
}
std::string str() {
mlir::Type t = mlir::Type::getFromOpaquePointer(type);
std::string res;
llvm::raw_string_ostream os(res);
t.print(os);
if (attrs.empty())
return os.str();
os << '{';
bool first = true;
for (const auto &namedAttr : attrs) {
if (first)
first = false;
else
os << ", ";
os << namedAttr.first << ": " << namedAttr.second.str();
}
os << '}';
return os.str();
}
private:
PythonType type;
std::unordered_map<std::string, PythonAttribute> attrs;
};
struct PythonIndexedValue {
explicit PythonIndexedValue(PythonType type)
: indexed(Type::getFromOpaquePointer(type.type)) {}
explicit PythonIndexedValue(const IndexedValue &other) : indexed(other) {}
PythonIndexedValue(PythonValueHandle handle) : indexed(handle.value) {}
PythonIndexedValue(const PythonIndexedValue &other) = default;
// Create a new indexed value with the same base as this one but with indices
// provided as arguments.
PythonIndexedValue index(const std::vector<PythonValueHandle> &indices) {
std::vector<ValueHandle> handles(indices.begin(), indices.end());
return PythonIndexedValue(IndexedValue(indexed(handles)));
}
void store(const std::vector<PythonValueHandle> &indices,
PythonValueHandle value) {
// Uses the overloaded `opreator=` to emit a store.
index(indices).indexed = value.value;
}
PythonValueHandle load(const std::vector<PythonValueHandle> &indices) {
// Uses the overloaded cast to `ValueHandle` to emit a load.
return static_cast<ValueHandle>(index(indices).indexed);
}
IndexedValue indexed;
};
template <typename ListTy, typename PythonTy, typename Ty>
ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) {
for (auto &inp : list) {
owning.push_back(Ty{inp.cast<PythonTy>()});
}
return ListTy{owning.data(), owning.size()};
}
static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl<mlir_type_t> &owning,
const py::list &types) {
return makeCList<mlir_type_list_t, PythonType>(owning, types);
}
PythonFunction
PythonMLIRModule::declareFunction(const std::string &name,
const py::list &inputs,
const std::vector<PythonType> &outputTypes,
const py::kwargs &funcAttributes) {
std::vector<PythonAttributedType> attributedInputs;
attributedInputs.reserve(inputs.size());
for (const auto &in : inputs) {
std::string className = in.get_type().str();
if (className.find(".Type'") != std::string::npos)
attributedInputs.emplace_back(in.cast<PythonType>());
else
attributedInputs.push_back(in.cast<PythonAttributedType>());
}
// Create the function type.
std::vector<mlir_type_t> ins(attributedInputs.begin(),
attributedInputs.end());
std::vector<mlir_type_t> outs(outputTypes.begin(), outputTypes.end());
auto funcType = ::makeFunctionType(
mlir_context_t{&mlirContext}, mlir_type_list_t{ins.data(), ins.size()},
mlir_type_list_t{outs.data(), outs.size()});
// Build the list of function attributes.
std::vector<mlir::NamedAttribute> attrs;
attrs.reserve(funcAttributes.size());
for (const auto &named : funcAttributes)
attrs.emplace_back(
Identifier::get(std::string(named.first.str()), &mlirContext),
mlir::Attribute::getFromOpaquePointer(reinterpret_cast<const void *>(
named.second.cast<PythonAttribute>().attr)));
// Build the list of lists of function argument attributes.
std::vector<mlir::NamedAttributeList> inputAttrs;
inputAttrs.reserve(attributedInputs.size());
for (const auto &in : attributedInputs) {
std::vector<mlir::NamedAttribute> inAttrs;
for (const auto &named : in.getNamedAttrs())
inAttrs.emplace_back(Identifier::get(named.name, &mlirContext),
mlir::Attribute::getFromOpaquePointer(
reinterpret_cast<const void *>(named.value)));
inputAttrs.emplace_back(inAttrs);
}
// Create the function itself.
auto func = mlir::FuncOp::create(
UnknownLoc::get(&mlirContext), name,
mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
inputAttrs);
moduleManager.insert(func);
return func;
}
PythonAttributedType PythonType::attachAttributeDict(
const std::unordered_map<std::string, PythonAttribute> &attrs) const {
return PythonAttributedType(*this, attrs);
}
PythonAttribute PythonMLIRModule::integerAttr(PythonType type, int64_t value) {
return PythonAttribute(::makeIntegerAttr(type, value));
}
PythonAttribute PythonMLIRModule::boolAttr(bool value) {
return PythonAttribute(::makeBoolAttr(&mlirContext, value));
}
PYBIND11_MODULE(pybind, m) {
m.doc() =
"Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)";
m.def("version", []() { return "EDSC Python extensions v1.0"; });
py::class_<PythonLoopContext>(
m, "LoopContext", "A context for building the body of a 'for' loop")
.def(py::init<PythonValueHandle, PythonValueHandle, int64_t>())
.def("__enter__", &PythonLoopContext::enter)
.def("__exit__", &PythonLoopContext::exit);
py::class_<PythonLoopNestContext>(m, "LoopNestContext",
"A context for building the body of a the "
"innermost loop in a nest of 'for' loops")
.def(py::init<const std::vector<PythonValueHandle> &,
const std::vector<PythonValueHandle> &,
const std::vector<int64_t> &>())
.def("__enter__", &PythonLoopNestContext::enter)
.def("__exit__", &PythonLoopNestContext::exit);
m.def("constant_index", [](int64_t val) -> PythonValueHandle {
return ValueHandle(index_t(val));
});
m.def("constant_int", [](int64_t val, int width) -> PythonValueHandle {
return ValueHandle::create<ConstantIntOp>(val, width);
});
m.def("constant_float", [](double val, PythonType type) -> PythonValueHandle {
FloatType floatType =
Type::getFromOpaquePointer(type.type).cast<FloatType>();
assert(floatType);
auto value = APFloat(val);
bool lostPrecision;
value.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&lostPrecision);
return ValueHandle::create<ConstantFloatOp>(value, floatType);
});
m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
auto function = FuncOp::getFromOpaquePointer(func.function);
auto attr = SymbolRefAttr::get(function.getName(), function.getContext());
return ValueHandle::create<ConstantOp>(function.getType(), attr);
});
m.def("appendTo", [](const PythonBlockHandle &handle) {
return PythonBlockAppender(handle);
});
m.def(
"ret",
[](const std::vector<PythonValueHandle> &args) {
std::vector<ValueHandle> values(args.begin(), args.end());
(intrinsics::ret(ArrayRef<ValueHandle>{values})); // vexing parse
return PythonValueHandle(nullptr);
},
py::arg("args") = std::vector<PythonValueHandle>());
m.def(
"br",
[](const PythonBlockHandle &dest,
const std::vector<PythonValueHandle> &args) {
std::vector<ValueHandle> values(args.begin(), args.end());
intrinsics::br(dest, values);
return PythonValueHandle(nullptr);
},
py::arg("dest"), py::arg("args") = std::vector<PythonValueHandle>());
m.def(
"cond_br",
[](PythonValueHandle condition, const PythonBlockHandle &trueDest,
const std::vector<PythonValueHandle> &trueArgs,
const PythonBlockHandle &falseDest,
const std::vector<PythonValueHandle> &falseArgs) -> PythonValueHandle {
std::vector<ValueHandle> trueArguments(trueArgs.begin(),
trueArgs.end());
std::vector<ValueHandle> falseArguments(falseArgs.begin(),
falseArgs.end());
intrinsics::cond_br(condition, trueDest, trueArguments, falseDest,
falseArguments);
return PythonValueHandle(nullptr);
});
m.def("select",
[](PythonValueHandle condition, PythonValueHandle trueValue,
PythonValueHandle falseValue) -> PythonValueHandle {
return ValueHandle::create<SelectOp>(condition.value, trueValue.value,
falseValue.value);
});
m.def("op",
[](const std::string &name,
const std::vector<PythonValueHandle> &operands,
const std::vector<PythonType> &resultTypes,
const py::kwargs &attributes) -> PythonValueHandle {
std::vector<ValueHandle> operandHandles(operands.begin(),
operands.end());
std::vector<Type> types;
types.reserve(resultTypes.size());
for (auto t : resultTypes)
types.push_back(Type::getFromOpaquePointer(t.type));
std::vector<NamedAttribute> attrs;
attrs.reserve(attributes.size());
for (const auto &a : attributes) {
std::string name = a.first.str();
auto pyAttr = a.second.cast<PythonAttribute>();
auto cppAttr = Attribute::getFromOpaquePointer(pyAttr.attr);
auto identifier =
Identifier::get(name, ScopedContext::getContext());
attrs.emplace_back(identifier, cppAttr);
}
return ValueHandle::create(name, operandHandles, types, attrs);
});
py::class_<PythonFunction>(m, "Function", "Wrapping class for mlir::FuncOp.")
.def(py::init<PythonFunction>())
.def("__str__", &PythonFunction::str)
.def("define", &PythonFunction::define,
"Adds a body to the function if it does not already have one. "
"Returns true if the body was added")
.def("arg", &PythonFunction::arg,
"Get the ValueHandle to the indexed argument of the function");
py::class_<PythonAttribute>(m, "Attribute",
"Wrapping class for mlir::Attribute")
.def(py::init<PythonAttribute>())
.def("__str__", &PythonAttribute::str);
py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.")
.def(py::init<PythonType>())
.def("__call__", &PythonType::attachAttributeDict,
"Attach the attributes to these type, making it suitable for "
"constructing functions with argument attributes")
.def("__str__", &PythonType::str);
py::class_<PythonAttributedType>(
m, "AttributedType",
"A class containing a wrapped mlir::Type and a wrapped "
"mlir::NamedAttributeList that are used together, e.g. in function "
"argument declaration")
.def(py::init<PythonAttributedType>())
.def("__str__", &PythonAttributedType::str);
py::class_<PythonMLIRModule>(
m, "MLIRModule",
"An MLIRModule is the abstraction that owns the allocations to support "
"compilation of a single mlir::ModuleOp into an ExecutionEngine backed "
"by "
"the LLVM ORC JIT. A typical flow consists in creating an MLIRModule, "
"adding functions, compiling the module to obtain an ExecutionEngine on "
"which named functions may be called. For now the only means to retrieve "
"the ExecutionEngine is by calling `get_engine_address`. This mode of "
"execution is limited to passing the pointer to C++ where the function "
"is called. Extending the API to allow calling JIT compiled functions "
"directly require integration with a tensor library (e.g. numpy). This "
"is left as the prerogative of libraries and frameworks for now.")
.def(py::init<>())
.def("boolAttr", &PythonMLIRModule::boolAttr,
"Creates an mlir::BoolAttr with the given value")
.def(
"integerAttr", &PythonMLIRModule::integerAttr,
"Creates an mlir::IntegerAttr of the given type with the given value "
"in the context associated with this MLIR module.")
.def("declare_function", &PythonMLIRModule::declareFunction,
"Declares a new mlir::FuncOp in the current mlir::ModuleOp. The "
"function arguments can have attributes. The function has no "
"definition and can be linked to an external library.")
.def("make_function", &PythonMLIRModule::makeFunction,
"Defines a new mlir::FuncOp in the current mlir::ModuleOp.")
.def("function_context", &PythonMLIRModule::makeFunctionContext,
"Defines a new mlir::FuncOp in the mlir::ModuleOp and creates the "
"function context for building the body of the function.")
.def("get_function", &PythonMLIRModule::getNamedFunction,
"Looks up the function with the given name in the module.")
.def(
"make_scalar_type",
[](PythonMLIRModule &instance, const std::string &type,
unsigned bitwidth) {
return instance.makeScalarType(type, bitwidth);
},
py::arg("type"), py::arg("bitwidth") = 0,
"Returns a scalar mlir::Type using the following convention:\n"
" - makeScalarType(c, \"bf16\") return an "
"`mlir::FloatType::getBF16`\n"
" - makeScalarType(c, \"f16\") return an `mlir::FloatType::getF16`\n"
" - makeScalarType(c, \"f32\") return an `mlir::FloatType::getF32`\n"
" - makeScalarType(c, \"f64\") return an `mlir::FloatType::getF64`\n"
" - makeScalarType(c, \"index\") return an `mlir::IndexType::get`\n"
" - makeScalarType(c, \"i\", bitwidth) return an "
"`mlir::IntegerType::get(bitwidth)`\n\n"
" No other combinations are currently supported.")
.def("make_memref_type", &PythonMLIRModule::makeMemRefType,
"Returns an mlir::MemRefType of an elemental scalar. -1 is used to "
"denote symbolic dimensions in the resulting memref shape.")
.def("make_index_type", &PythonMLIRModule::makeIndexType,
"Returns an mlir::IndexType")
.def("compile", &PythonMLIRModule::compile,
"Compiles the mlir::ModuleOp to LLVMIR a creates new opaque "
"ExecutionEngine backed by the ORC JIT.")
.def("get_ir", &PythonMLIRModule::getIR,
"Returns a dump of the MLIR representation of the module. This is "
"used for serde to support out-of-process execution as well as "
"debugging purposes.")
.def("get_engine_address", &PythonMLIRModule::getEngineAddress,
"Returns the address of the compiled ExecutionEngine. This is used "
"for in-process execution.")
.def("__str__", &PythonMLIRModule::getIR,
"Get the string representation of the module");
py::class_<PythonFunctionContext>(
m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext")
.def(py::init<PythonFunction>())
.def("__enter__", &PythonFunctionContext::enter)
.def("__exit__", &PythonFunctionContext::exit);
{
using namespace mlir::edsc::op;
py::class_<PythonValueHandle>(m, "ValueHandle",
"A wrapper around mlir::edsc::ValueHandle")
.def(py::init<PythonType>())
.def(py::init<PythonValueHandle>())
.def("__add__",
[](PythonValueHandle lhs, PythonValueHandle rhs)
-> PythonValueHandle { return lhs.value + rhs.value; })
.def("__sub__",
[](PythonValueHandle lhs, PythonValueHandle rhs)
-> PythonValueHandle { return lhs.value - rhs.value; })
.def("__mul__",
[](PythonValueHandle lhs, PythonValueHandle rhs)
-> PythonValueHandle { return lhs.value * rhs.value; })
.def("__div__",
[](PythonValueHandle lhs, PythonValueHandle rhs)
-> PythonValueHandle { return lhs.value / rhs.value; })
.def("__truediv__",
[](PythonValueHandle lhs, PythonValueHandle rhs)
-> PythonValueHandle { return lhs.value / rhs.value; })
.def("__floordiv__",
[](PythonValueHandle lhs, PythonValueHandle rhs)
-> PythonValueHandle { return floorDiv(lhs, rhs); })
.def("__mod__",
[](PythonValueHandle lhs, PythonValueHandle rhs)
-> PythonValueHandle { return lhs.value % rhs.value; })
.def("__lt__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::SLT, lhs.value,
rhs.value);
})
.def("__le__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::SLE, lhs.value,
rhs.value);
})
.def("__gt__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::SGT, lhs.value,
rhs.value);
})
.def("__ge__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::SGE, lhs.value,
rhs.value);
})
.def("__eq__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::EQ, lhs.value,
rhs.value);
})
.def("__ne__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
return ValueHandle::create<CmpIOp>(CmpIPredicate::NE, lhs.value,
rhs.value);
})
.def("__invert__",
[](PythonValueHandle handle) -> PythonValueHandle {
return !handle.value;
})
.def("__and__",
[](PythonValueHandle lhs, PythonValueHandle rhs)
-> PythonValueHandle { return lhs.value && rhs.value; })
.def("__or__",
[](PythonValueHandle lhs, PythonValueHandle rhs)
-> PythonValueHandle { return lhs.value || rhs.value; })
.def("__call__", &PythonValueHandle::call);
}
py::class_<PythonBlockAppender>(
m, "BlockAppender",
"A dummy class signaling BlockContext to append IR to the given block "
"instead of creating a new block")
.def(py::init<const PythonBlockHandle &>());
py::class_<PythonBlockHandle>(m, "BlockHandle",
"A wrapper around mlir::edsc::BlockHandle")
.def(py::init<PythonBlockHandle>())
.def("arg", &PythonBlockHandle::arg);
py::class_<PythonBlockContext>(m, "BlockContext",
"A wrapper around mlir::edsc::BlockBuilder")
.def(py::init<>())
.def(py::init<const std::vector<PythonType> &>())
.def(py::init<const PythonBlockAppender &>())
.def("__enter__", &PythonBlockContext::enter)
.def("__exit__", &PythonBlockContext::exit)
.def("handle", &PythonBlockContext::getHandle);
py::class_<PythonIndexedValue>(m, "IndexedValue",
"A wrapper around mlir::edsc::IndexedValue")
.def(py::init<PythonValueHandle>())
.def("load", &PythonIndexedValue::load)
.def("store", &PythonIndexedValue::store);
}
} // namespace python
} // namespace edsc
} // namespace mlir

View File

@ -0,0 +1,36 @@
# Description:
# BUILD file for the Python wrappers for EDSCs
licenses(["notice"]) # Apache 2.0
# Export the BUILD file so automated tooling can check licenses
exports_files(["BUILD"])
load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests")
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
test_file_exts = ["py"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
":test_edsc",
"//third_party/llvm/llvm:FileCheck",
],
)
py_binary(
name = "test_edsc",
srcs = ["test_py2and3.py"],
main = "test_py2and3.py",
python_version = "PY2",
deps = [
"//testing/pybase",
"@local_config_mlir//bindings/python:_pybind",
],
)

View File

@ -0,0 +1,486 @@
# Copyright 2019 The MLIR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/test_edsc %s | FileCheck %s
"""Python2 and 3 test for the MLIR EDSC Python bindings"""
import google_mlir.bindings.python.pybind as E
import inspect
# Prints `str` prefixed by the current test function name so we can use it in
# Filecheck label directives.
# This is achieved by inspecting the stack and getting the parent name.
def printWithCurrentFunctionName(str):
print(inspect.stack()[1][3])
print(str)
class EdscTest:
def setUp(self):
self.module = E.MLIRModule()
self.boolType = self.module.make_scalar_type("i", 1)
self.i32Type = self.module.make_scalar_type("i", 32)
self.f32Type = self.module.make_scalar_type("f32")
self.indexType = self.module.make_index_type()
def testBlockArguments(self):
self.setUp()
with self.module.function_context("foo", [], []) as fun:
E.constant_index(42)
with E.BlockContext([self.f32Type, self.f32Type]) as b:
b.arg(0) + b.arg(1)
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testBlockArguments
# CHECK: %{{.*}} = constant 42 : index
# CHECK: ^bb{{.*}}(%{{.*}}: f32, %{{.*}}: f32):
# CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
def testBlockContext(self):
self.setUp()
with self.module.function_context("foo", [], []) as fun:
cst = E.constant_index(42)
with E.BlockContext():
cst + cst
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testBlockContext
# CHECK: %{{.*}} = constant 42 : index
# CHECK: ^bb
# CHECK: %{{.*}} = "affine.apply"() {map = () -> (84)} : () -> index
def testBlockContextAppend(self):
self.setUp()
with self.module.function_context("foo", [], []) as fun:
E.constant_index(41)
with E.BlockContext() as b:
blk = b # save block handle for later
E.constant_index(0)
E.constant_index(42)
with E.BlockContext(E.appendTo(blk)):
E.constant_index(1)
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testBlockContextAppend
# CHECK: %{{.*}} = constant 41 : index
# CHECK: %{{.*}} = constant 42 : index
# CHECK: ^bb
# CHECK: %{{.*}} = constant 0 : index
# CHECK: %{{.*}} = constant 1 : index
def testBlockContextStandalone(self):
self.setUp()
with self.module.function_context("foo", [], []) as fun:
blk1 = E.BlockContext()
blk2 = E.BlockContext()
with blk1:
E.constant_index(0)
with blk2:
E.constant_index(56)
E.constant_index(57)
E.constant_index(41)
with blk1:
E.constant_index(1)
E.constant_index(42)
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testBlockContextStandalone
# CHECK: %{{.*}} = constant 41 : index
# CHECK: %{{.*}} = constant 42 : index
# CHECK: ^bb
# CHECK: %{{.*}} = constant 0 : index
# CHECK: %{{.*}} = constant 1 : index
# CHECK: ^bb
# CHECK: %{{.*}} = constant 56 : index
# CHECK: %{{.*}} = constant 57 : index
def testBooleanOps(self):
self.setUp()
with self.module.function_context(
"booleans", [self.boolType for _ in range(4)], []) as fun:
i, j, k, l = (fun.arg(x) for x in range(4))
stmt1 = (i < j) & (j >= k)
stmt2 = ~(stmt1 | (k == l))
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testBooleanOps
# CHECK: %{{.*}} = cmpi "slt", %{{.*}}, %{{.*}} : i1
# CHECK: %{{.*}} = cmpi "sge", %{{.*}}, %{{.*}} : i1
# CHECK: %{{.*}} = muli %{{.*}}, %{{.*}} : i1
# CHECK: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : i1
# CHECK: %{{.*}} = constant 1 : i1
# CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
# CHECK: %{{.*}} = constant 1 : i1
# CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
# CHECK: %{{.*}} = muli %{{.*}}, %{{.*}} : i1
# CHECK: %{{.*}} = constant 1 : i1
# CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
# CHECK: %{{.*}} = constant 1 : i1
# CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
def testBr(self):
self.setUp()
with self.module.function_context("foo", [], []) as fun:
with E.BlockContext() as b:
blk = b
E.ret()
E.br(blk)
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testBr
# CHECK: br ^bb
# CHECK: ^bb
# CHECK: return
def testBrArgs(self):
self.setUp()
with self.module.function_context("foo", [], []) as fun:
# Create an infinite loop.
with E.BlockContext([self.indexType, self.indexType]) as b:
E.br(b, [b.arg(1), b.arg(0)])
E.br(b, [E.constant_index(0), E.constant_index(1)])
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testBrArgs
# CHECK: %{{.*}} = constant 0 : index
# CHECK: %{{.*}} = constant 1 : index
# CHECK: br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
# CHECK: ^bb{{.*}}(%{{.*}}: index, %{{.*}}: index):
# CHECK: br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
def testBrDeclaration(self):
self.setUp()
with self.module.function_context("foo", [], []) as fun:
blk = E.BlockContext()
E.br(blk.handle())
with blk:
E.ret()
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testBrDeclaration
# CHECK: br ^bb
# CHECK: ^bb
# CHECK: return
def testCallOp(self):
self.setUp()
callee = self.module.declare_function("sqrtf", [self.f32Type],
[self.f32Type])
with self.module.function_context("call", [self.f32Type], []) as fun:
funCst = E.constant_function(callee)
funCst([fun.arg(0)]) + E.constant_float(42., self.f32Type)
printWithCurrentFunctionName(str(self.module))
# CHECK-LABEL: testCallOp
# CHECK: func @sqrtf(f32) -> f32
# CHECK: %{{.*}} = constant @sqrtf : (f32) -> f32
# CHECK: %{{.*}} = call_indirect %{{.*}}(%{{.*}}) : (f32) -> f32
def testCondBr(self):
self.setUp()
with self.module.function_context("foo", [self.boolType], []) as fun:
with E.BlockContext() as blk1:
E.ret([])
with E.BlockContext([self.indexType]) as blk2:
E.ret([])
cst = E.constant_index(0)
E.cond_br(fun.arg(0), blk1, [], blk2, [cst])
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testCondBr
# CHECK: cond_br %{{.*}}, ^bb{{.*}}, ^bb{{.*}}(%{{.*}} : index)
def testConstants(self):
self.setUp()
with self.module.function_context("constants", [], []) as fun:
E.constant_float(1.23, self.module.make_scalar_type("bf16"))
E.constant_float(1.23, self.module.make_scalar_type("f16"))
E.constant_float(1.23, self.module.make_scalar_type("f32"))
E.constant_float(1.23, self.module.make_scalar_type("f64"))
E.constant_int(1, 1)
E.constant_int(123, 8)
E.constant_int(123, 16)
E.constant_int(123, 32)
E.constant_int(123, 64)
E.constant_index(123)
E.constant_function(fun)
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testConstants
# CHECK: constant 1.230000e+00 : bf16
# CHECK: constant 1.230470e+00 : f16
# CHECK: constant 1.230000e+00 : f32
# CHECK: constant 1.230000e+00 : f64
# CHECK: constant 1 : i1
# CHECK: constant 123 : i8
# CHECK: constant 123 : i16
# CHECK: constant 123 : i32
# CHECK: constant 123 : index
# CHECK: constant @constants : () -> ()
def testCustom(self):
self.setUp()
with self.module.function_context("custom", [self.indexType, self.f32Type],
[]) as fun:
E.op("foo", [fun.arg(0)], [self.f32Type]) + fun.arg(1)
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testCustom
# CHECK: %{{.*}} = "foo"(%{{.*}}) : (index) -> f32
# CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
# Create 'addi' using the generic Op interface. We need an operation known
# to the execution engine so that the engine can compile it.
def testCustomOpCompilation(self):
self.setUp()
with self.module.function_context("adder", [self.i32Type], []) as f:
c1 = E.op(
"std.constant", [], [self.i32Type],
value=self.module.integerAttr(self.i32Type, 42))
E.op("std.addi", [c1, f.arg(0)], [self.i32Type])
E.ret([])
self.module.compile()
printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
# CHECK-LABEL: testCustomOpCompilation
# CHECK: False
def testDivisions(self):
self.setUp()
with self.module.function_context(
"division", [self.indexType, self.i32Type, self.i32Type], []) as fun:
# indices only support floor division
fun.arg(0) // E.constant_index(42)
# regular values only support regular division
fun.arg(1) / fun.arg(2)
printWithCurrentFunctionName(str(self.module))
# CHECK-LABEL: testDivisions
# CHECK: floordiv 42
# CHECK: divis %{{.*}}, %{{.*}} : i32
def testFunctionArgs(self):
self.setUp()
with self.module.function_context("foo", [self.f32Type, self.f32Type],
[self.indexType]) as fun:
pass
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testFunctionArgs
# CHECK: func @foo(%{{.*}}: f32, %{{.*}}: f32) -> index
def testFunctionContext(self):
self.setUp()
with self.module.function_context("foo", [], []):
pass
printWithCurrentFunctionName(self.module.get_function("foo"))
# CHECK-LABEL: testFunctionContext
# CHECK: func @foo() {
def testFunctionDeclaration(self):
self.setUp()
boolAttr = self.module.boolAttr(True)
t = self.module.make_memref_type(self.f32Type, [10])
t_llvm_noalias = t({"llvm.noalias": boolAttr})
t_readonly = t({"readonly": boolAttr})
f = self.module.declare_function("foo", [t, t_llvm_noalias, t_readonly], [])
printWithCurrentFunctionName(str(self.module))
# CHECK-LABEL: testFunctionDeclaration
# CHECK: func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias = true}, memref<10xf32> {readonly = true})
def testFunctionMultiple(self):
self.setUp()
with self.module.function_context("foo", [], []):
pass
with self.module.function_context("foo", [], []):
E.constant_index(0)
printWithCurrentFunctionName(str(self.module))
# CHECK-LABEL: testFunctionMultiple
# CHECK: func @foo()
# CHECK: func @foo_0()
# CHECK: %{{.*}} = constant 0 : index
def testIndexedValue(self):
self.setUp()
memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
with self.module.function_context("indexed", [memrefType],
[memrefType]) as fun:
A = E.IndexedValue(fun.arg(0))
cst = E.constant_float(1., self.f32Type)
with E.LoopNestContext(
[E.constant_index(0), E.constant_index(0)],
[E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j):
A.store([i, j], A.load([i, j]) + cst)
E.ret([fun.arg(0)])
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testIndexedValue
# CHECK: "affine.for"()
# CHECK: "affine.for"()
# CHECK: "affine.load"
# CHECK-SAME: memref<10x42xf32>
# CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
# CHECK: "affine.store"
# CHECK-SAME: memref<10x42xf32>
# CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (42)}
# CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (10)}
def testLoopContext(self):
self.setUp()
with self.module.function_context("foo", [], []) as fun:
lhs = E.constant_index(0)
rhs = E.constant_index(42)
with E.LoopContext(lhs, rhs, 1) as i:
lhs + rhs + i
with E.LoopContext(rhs, rhs + rhs, 2) as j:
x = i + j
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testLoopContext
# CHECK: "affine.for"() (
# CHECK: ^bb{{.*}}(%{{.*}}: index):
# CHECK: "affine.for"(%{{.*}}, %{{.*}}) (
# CHECK: ^bb{{.*}}(%{{.*}}: index):
# CHECK: "affine.apply"(%{{.*}}, %{{.*}}) {map = (d0, d1) -> (d0 + d1)} : (index, index) -> index
# CHECK: {lower_bound = (d0) -> (d0), step = 2 : index, upper_bound = (d0) -> (d0)} : (index, index) -> ()
# CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (42)}
def testLoopNestContext(self):
self.setUp()
with self.module.function_context("foo", [], []) as fun:
lbs = [E.constant_index(i) for i in range(4)]
ubs = [E.constant_index(10 * i + 5) for i in range(4)]
with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
i + j + k + l
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testLoopNestContext
# CHECK: "affine.for"() (
# CHECK: ^bb{{.*}}(%{{.*}}: index):
# CHECK: "affine.for"() (
# CHECK: ^bb{{.*}}(%{{.*}}: index):
# CHECK: "affine.for"() (
# CHECK: ^bb{{.*}}(%{{.*}}: index):
# CHECK: "affine.for"() (
# CHECK: ^bb{{.*}}(%{{.*}}: index):
# CHECK: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {map = (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index
def testMLIRBooleanCompilation(self):
self.setUp()
m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
with self.module.function_context("mkbooltensor", [m, m], []) as f:
input = E.IndexedValue(f.arg(0))
output = E.IndexedValue(f.arg(1))
zero = E.constant_index(0)
ten = E.constant_index(10)
with E.LoopNestContext([zero] * 3, [ten] * 3, [1] * 3) as (i, j, k):
b1 = (i < j) & (j < k)
b2 = ~b1
b3 = b2 | (k < j)
output.store([i], input.load([i]) & b3)
E.ret([])
self.module.compile()
printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
# CHECK-LABEL: testMLIRBooleanCompilation
# CHECK: False
def testMLIRFunctionCreation(self):
self.setUp()
module = E.MLIRModule()
t = module.make_scalar_type("f32")
m = module.make_memref_type(t, [3, 4, -1, 5])
printWithCurrentFunctionName(str(t))
print(str(m))
print(str(module.make_function("copy", [m, m], [])))
print(str(module.make_function("sqrtf", [t], [t])))
# CHECK-LABEL: testMLIRFunctionCreation
# CHECK: f32
# CHECK: memref<3x4x?x5xf32>
# CHECK: func @copy(%{{.*}}: memref<3x4x?x5xf32>, %{{.*}}: memref<3x4x?x5xf32>) {
# CHECK: func @sqrtf(%{{.*}}: f32) -> f32
def testMLIRScalarTypes(self):
self.setUp()
module = E.MLIRModule()
printWithCurrentFunctionName(str(module.make_scalar_type("bf16")))
print(str(module.make_scalar_type("f16")))
print(str(module.make_scalar_type("f32")))
print(str(module.make_scalar_type("f64")))
print(str(module.make_scalar_type("i", 1)))
print(str(module.make_scalar_type("i", 8)))
print(str(module.make_scalar_type("i", 32)))
print(str(module.make_scalar_type("i", 123)))
print(str(module.make_scalar_type("index")))
# CHECK-LABEL: testMLIRScalarTypes
# CHECK: bf16
# CHECK: f16
# CHECK: f32
# CHECK: f64
# CHECK: i1
# CHECK: i8
# CHECK: i32
# CHECK: i123
# CHECK: index
def testMatrixMultiply(self):
self.setUp()
memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
with self.module.function_context(
"matmul", [memrefType, memrefType, memrefType], []) as fun:
A = E.IndexedValue(fun.arg(0))
B = E.IndexedValue(fun.arg(1))
C = E.IndexedValue(fun.arg(2))
c0 = E.constant_index(0)
c32 = E.constant_index(32)
with E.LoopNestContext([c0, c0, c0], [c32, c32, c32], [1, 1, 1]) as (i, j,
k):
C.store([i, j], A.load([i, k]) * B.load([k, j]))
E.ret([])
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testMatrixMultiply
# CHECK: "affine.for"()
# CHECK: "affine.for"()
# CHECK: "affine.for"()
# CHECK-DAG: %{{.*}} = "affine.load"
# CHECK-DAG: %{{.*}} = "affine.load"
# CHECK: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
# CHECK: "affine.store"
# CHECK-SAME: memref<32x32xf32>
# CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
# CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
# CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
def testRet(self):
self.setUp()
with self.module.function_context("foo", [],
[self.indexType, self.indexType]) as fun:
c42 = E.constant_index(42)
c0 = E.constant_index(0)
E.ret([c42, c0])
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testRet
# CHECK: %{{.*}} = constant 42 : index
# CHECK: %{{.*}} = constant 0 : index
# CHECK: return %{{.*}}, %{{.*}} : index, index
def testSelectOp(self):
self.setUp()
with self.module.function_context("foo", [self.boolType],
[self.i32Type]) as fun:
a = E.constant_int(42, 32)
b = E.constant_int(0, 32)
E.ret([E.select(fun.arg(0), a, b)])
printWithCurrentFunctionName(str(fun))
# CHECK-LABEL: testSelectOp
# CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : i32
# Until python 3.6 this cannot be used because the order in the dict is not the
# order of method declaration.
def runTests():
def isTest(attr):
return inspect.ismethod(attr) and "EdscTest.setUp " not in str(attr)
edscTest = EdscTest()
tests = sorted(filter(isTest,
(getattr(edscTest, attr) for attr in dir(edscTest))),
key = lambda x : str(x))
for test in tests:
test()
if __name__ == '__main__':
runTests()

119
third_party/mlir/include/mlir-c/Core.h vendored Normal file
View File

@ -0,0 +1,119 @@
/*===-- mlir-c/Core.h - Core Library C Interface ------------------*- C -*-===*\
|* *|
|* Copyright 2019 The MLIR Authors. *|
|* *|
|* Licensed under the Apache License, Version 2.0 (the "License"); *|
|* you may not use this file except in compliance with the License. *|
|* You may obtain a copy of the License at *|
|* *|
|* http://www.apache.org/licenses/LICENSE-2.0 *|
|* *|
|* Unless required by applicable law or agreed to in writing, software *|
|* distributed under the License is distributed on an "AS IS" BASIS, *|
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *|
|* See the License for the specific language governing permissions and *|
|* limitations under the License. *|
|* *|
|*===----------------------------------------------------------------------===*|
|* *|
|* This header declares the C interface to MLIR. *|
|* *|
\*===----------------------------------------------------------------------===*/
#ifndef MLIR_C_CORE_H
#define MLIR_C_CORE_H
#ifdef __cplusplus
#include <cstdint>
extern "C" {
#else
#include <stdint.h>
#endif
/// Opaque MLIR types.
/// Opaque C type for mlir::MLIRContext*.
typedef void *mlir_context_t;
/// Opaque C type for mlir::Type.
typedef const void *mlir_type_t;
/// Opaque C type for mlir::FuncOp.
typedef void *mlir_func_t;
/// Opaque C type for mlir::Attribute.
typedef const void *mlir_attr_t;
/// Simple C lists for non-owning mlir Opaque C types.
/// Recommended usage is construction from the `data()` and `size()` of a scoped
/// owning SmallVectorImpl<...> and passing to one of the C functions declared
/// later in this file.
/// Once the function returns and the proper EDSC has been constructed,
/// resources are freed by exiting the scope.
typedef struct {
int64_t *values;
uint64_t n;
} int64_list_t;
typedef struct {
mlir_type_t *types;
uint64_t n;
} mlir_type_list_t;
typedef struct {
const char *name;
mlir_attr_t value;
} mlir_named_attr_t;
typedef struct {
mlir_named_attr_t *list;
uint64_t n;
} mlir_named_attr_list_t;
/// Minimal C API for exposing EDSCs to Swift, Python and other languages.
/// Returns a simple scalar mlir::Type using the following convention:
/// - makeScalarType(c, "bf16") return an `mlir::FloatType::getBF16`
/// - makeScalarType(c, "f16") return an `mlir::FloatType::getF16`
/// - makeScalarType(c, "f32") return an `mlir::FloatType::getF32`
/// - makeScalarType(c, "f64") return an `mlir::FloatType::getF64`
/// - makeScalarType(c, "index") return an `mlir::IndexType::get`
/// - makeScalarType(c, "i", bitwidth) return an
/// `mlir::IntegerType::get(bitwidth)`
///
/// No other combinations are currently supported.
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
unsigned bitwidth);
/// Returns an `mlir::MemRefType` of the element type `elemType` and shape
/// `sizes`.
mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
int64_list_t sizes);
/// Returns an `mlir::FunctionType` of the element type `elemType` and shape
/// `sizes`.
mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
mlir_type_list_t outputs);
/// Returns an `mlir::IndexType`.
mlir_type_t makeIndexType(mlir_context_t context);
/// Returns an `mlir::IntegerAttr` of the specified type that contains the given
/// value.
mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value);
/// Returns an `mlir::BoolAttr` with the given value.
mlir_attr_t makeBoolAttr(mlir_context_t context, bool value);
/// Returns the arity of `function`.
unsigned getFunctionArity(mlir_func_t function);
/// Returns the rank of the `function` argument at position `pos`.
/// If the argument is of MemRefType, this returns the rank of the MemRef.
/// Otherwise returns `0`.
/// TODO(ntv): support more than MemRefType and scalar Type.
unsigned getRankOfFunctionArgument(mlir_func_t function, unsigned pos);
/// Returns an opaque mlir::Type of the `function` argument at position `pos`.
mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos);
#ifdef __cplusplus
} // end extern "C"
#endif
#endif // MLIR_C_CORE_H

View File

@ -0,0 +1,598 @@
//===- AffineOps.h - MLIR Affine Operations -------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines convenience types for working with Affine operations
// in the MLIR operation set.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_AFFINEOPS_AFFINEOPS_H
#define MLIR_AFFINEOPS_AFFINEOPS_H
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class AffineBound;
class AffineValueMap;
class AffineTerminatorOp;
class FlatAffineConstraints;
class OpBuilder;
/// A utility function to check if a value is defined at the top level of a
/// function. A value defined at the top level is always a valid symbol.
bool isTopLevelSymbol(Value *value);
class AffineOpsDialect : public Dialect {
public:
AffineOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "affine"; }
};
/// The "affine.apply" operation applies an affine map to a list of operands,
/// yielding a single result. The operand list must be the same size as the
/// number of arguments to the affine mapping. All operands and the result are
/// of type 'Index'. This operation requires a single affine map attribute named
/// "map". For example:
///
/// %y = "affine.apply" (%x) { map: (d0) -> (d0 + 1) } :
/// (index) -> (index)
///
/// equivalently:
///
/// #map42 = (d0)->(d0+1)
/// %y = affine.apply #map42(%x)
///
class AffineApplyOp : public Op<AffineApplyOp, OpTrait::VariadicOperands,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
using Op::Op;
/// Builds an affine apply op with the specified map and operands.
static void build(Builder *builder, OperationState *result, AffineMap map,
ArrayRef<Value *> operands);
/// Returns the affine map to be applied by this operation.
AffineMap getAffineMap() {
return getAttrOfType<AffineMapAttr>("map").getValue();
}
/// Returns true if the result of this operation can be used as dimension id.
bool isValidDim();
/// Returns true if the result of this operation is a symbol.
bool isValidSymbol();
static StringRef getOperationName() { return "affine.apply"; }
// Hooks to customize behavior of this op.
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
OpFoldResult fold(ArrayRef<Attribute> operands);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
/// from a source memref to a destination memref. The source and destination
/// memref need not be of the same dimensionality, but need to have the same
/// elemental type. The operands include the source and destination memref's
/// each followed by its indices, size of the data transfer in terms of the
/// number of elements (of the elemental type of the memref), a tag memref with
/// its indices, and optionally at the end, a stride and a
/// number_of_elements_per_stride arguments. The tag location is used by an
/// AffineDmaWaitOp to check for completion. The indices of the source memref,
/// destination memref, and the tag memref have the same restrictions as any
/// affine.load/store. In particular, index for each memref dimension must be an
/// affine expression of loop induction variables and symbols.
/// The optional stride arguments should be of 'index' type, and specify a
/// stride for the slower memory space (memory space with a lower memory space
/// id), tranferring chunks of number_of_elements_per_stride every stride until
/// %num_elements are transferred. Either both or no stride arguments should be
/// specified. The value of 'num_elements' must be a multiple of
/// 'number_of_elements_per_stride'.
//
// For example, a DmaStartOp operation that transfers 256 elements of a memref
// '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory
// space 1 at indices [%k + 7, %l], would be specified as follows:
//
// %num_elements = constant 256
// %idx = constant 0 : index
// %tag = alloc() : memref<1xi32, 4>
// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
// %num_elements :
// memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2>
//
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
// transfer %num_elt_per_stride elements every %stride elements apart from
// memory space 0 until %num_elements are transferred.
//
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements,
// %stride, %num_elt_per_stride : ...
//
// TODO(mlir-team): add additional operands to allow source and destination
// striding, and multiple stride levels (possibly using AffineMaps to specify
// multiple levels of striding).
// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result, Value *srcMemRef,
AffineMap srcMap, ArrayRef<Value *> srcIndices,
Value *destMemRef, AffineMap dstMap,
ArrayRef<Value *> destIndices, Value *tagMemRef,
AffineMap tagMap, ArrayRef<Value *> tagIndices,
Value *numElements, Value *stride = nullptr,
Value *elementsPerStride = nullptr);
/// Returns the operand index of the src memref.
unsigned getSrcMemRefOperandIndex() { return 0; }
/// Returns the source MemRefType for this DMA operation.
Value *getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
MemRefType getSrcMemRefType() {
return getSrcMemRef()->getType().cast<MemRefType>();
}
/// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); }
/// Returns the affine map used to access the src memref.
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
AffineMapAttr getSrcMapAttr() {
return getAttr(getSrcMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the source memref affine map indices for this DMA operation.
operand_range getSrcIndices() {
return {operand_begin() + getSrcMemRefOperandIndex() + 1,
operand_begin() + getSrcMemRefOperandIndex() + 1 +
getSrcMap().getNumInputs()};
}
/// Returns the memory space of the src memref.
unsigned getSrcMemorySpace() {
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
}
/// Returns the operand index of the dst memref.
unsigned getDstMemRefOperandIndex() {
return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs();
}
/// Returns the destination MemRefType for this DMA operations.
Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
MemRefType getDstMemRefType() {
return getDstMemRef()->getType().cast<MemRefType>();
}
/// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return getDstMemRef()->getType().cast<MemRefType>().getRank();
}
/// Returns the memory space of the src memref.
unsigned getDstMemorySpace() {
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
}
/// Returns the affine map used to access the dst memref.
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
AffineMapAttr getDstMapAttr() {
return getAttr(getDstMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the destination memref indices for this DMA operation.
operand_range getDstIndices() {
return {operand_begin() + getDstMemRefOperandIndex() + 1,
operand_begin() + getDstMemRefOperandIndex() + 1 +
getDstMap().getNumInputs()};
}
/// Returns the operand index of the tag memref.
unsigned getTagMemRefOperandIndex() {
return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs();
}
/// Returns the Tag MemRef for this DMA operation.
Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
MemRefType getTagMemRefType() {
return getTagMemRef()->getType().cast<MemRefType>();
}
/// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
return getTagMemRef()->getType().cast<MemRefType>().getRank();
}
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the tag memref indices for this DMA operation.
operand_range getTagIndices() {
return {operand_begin() + getTagMemRefOperandIndex() + 1,
operand_begin() + getTagMemRefOperandIndex() + 1 +
getTagMap().getNumInputs()};
}
/// Returns the number of elements being transferred by this DMA operation.
Value *getNumElements() {
return getOperand(getTagMemRefOperandIndex() + 1 +
getTagMap().getNumInputs());
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
if (memref == getSrcMemRef())
return {Identifier::get(getSrcMapAttrName(), getContext()),
getSrcMapAttr()};
else if (memref == getDstMemRef())
return {Identifier::get(getDstMapAttrName(), getContext()),
getDstMapAttr()};
assert(memref == getTagMemRef() &&
"DmaStartOp expected source, destination or tag memref");
return {Identifier::get(getTagMapAttrName(), getContext()),
getTagMapAttr()};
}
/// Returns true if this is a DMA from a faster memory space to a slower one.
bool isDestMemorySpaceFaster() {
return (getSrcMemorySpace() < getDstMemorySpace());
}
/// Returns true if this is a DMA from a slower memory space to a faster one.
bool isSrcMemorySpaceFaster() {
// Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace());
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex();
}
static StringRef getSrcMapAttrName() { return "src_map"; }
static StringRef getDstMapAttrName() { return "dst_map"; }
static StringRef getTagMapAttrName() { return "tag_map"; }
static StringRef getOperationName() { return "affine.dma_start"; }
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
/// Returns true if this DMA operation is strided, returns false otherwise.
bool isStrided() {
return getNumOperands() !=
getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1;
}
/// Returns the stride value for this DMA operation.
Value *getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
/// Returns the number of elements to transfer per stride for this DMA op.
Value *getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
}
};
/// AffineDmaWaitOp blocks until the completion of a DMA operation associated
/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be
/// an index with the same restrictions as any load/store index. In particular,
/// index for each memref dimension must be an affine expression of loop
/// induction variables and symbols. %num_elements is the number of elements
/// associated with the DMA operation. For example:
//
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements :
// memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2>
// ...
// ...
// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
//
class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result, Value *tagMemRef,
AffineMap tagMap, ArrayRef<Value *> tagIndices,
Value *numElements);
static StringRef getOperationName() { return "affine.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
Value *getTagMemRef() { return getOperand(0); }
MemRefType getTagMemRefType() {
return getTagMemRef()->getType().cast<MemRefType>();
}
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
}
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
return {operand_begin() + 1,
operand_begin() + 1 + getTagMap().getNumInputs()};
}
// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
return getTagMemRef()->getType().cast<MemRefType>().getRank();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
assert(memref == getTagMemRef());
return {Identifier::get(getTagMapAttrName(), getContext()),
getTagMapAttr()};
}
/// Returns the number of elements transferred in the associated DMA op.
Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
static StringRef getTagMapAttrName() { return "tag_map"; }
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// The "affine.load" op reads an element from a memref, where the index
/// for each memref dimension is an affine expression of loop induction
/// variables and symbols. The output of 'affine.load' is a new value with the
/// same type as the elements of the memref. An affine expression of loop IVs
/// and symbols must be specified for each dimension of the memref. The keyword
/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
//
// Example 1:
//
// %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
//
// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
//
// %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)]
// : memref<100x100xf32>
//
class AffineLoadOp : public Op<AffineLoadOp, OpTrait::OneResult,
OpTrait::AtLeastNOperands<1>::Impl> {
public:
using Op::Op;
/// Builds an affine load op with the specified map and operands.
static void build(Builder *builder, OperationState *result, AffineMap map,
ArrayRef<Value *> operands);
/// Builds an affine load op an identify map and operands.
static void build(Builder *builder, OperationState *result, Value *memref,
ArrayRef<Value *> indices = {});
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 0; }
/// Get memref operand.
Value *getMemRef() { return getOperand(getMemRefOperandIndex()); }
void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); }
MemRefType getMemRefType() {
return getMemRef()->getType().cast<MemRefType>();
}
/// Get affine map operands.
operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
static StringRef getMapAttrName() { return "map"; }
static StringRef getOperationName() { return "affine.load"; }
// Hooks to customize behavior of this op.
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// The "affine.store" op writes an element to a memref, where the index
/// for each memref dimension is an affine expression of loop induction
/// variables and symbols. The 'affine.store' op stores a new value which is the
/// same type as the elements of the memref. An affine expression of loop IVs
/// and symbols must be specified for each dimension of the memref. The keyword
/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
//
// Example 1:
//
// affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
//
// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
//
// affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)]
// : memref<100x100xf32>
//
class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult,
OpTrait::AtLeastNOperands<1>::Impl> {
public:
using Op::Op;
/// Builds an affine store operation with the specified map and operands.
static void build(Builder *builder, OperationState *result,
Value *valueToStore, AffineMap map,
ArrayRef<Value *> operands);
/// Builds an affine store operation with an identity map and operands.
static void build(Builder *builder, OperationState *result,
Value *valueToStore, Value *memref,
ArrayRef<Value *> operands);
/// Get value to be stored by store operation.
Value *getValueToStore() { return getOperand(0); }
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 1; }
/// Get memref operand.
Value *getMemRef() { return getOperand(getMemRefOperandIndex()); }
void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); }
MemRefType getMemRefType() {
return getMemRef()->getType().cast<MemRefType>();
}
/// Get affine map operands.
operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
static StringRef getMapAttrName() { return "map"; }
static StringRef getOperationName() { return "affine.store"; }
// Hooks to customize behavior of this op.
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// Returns true if the given Value can be used as a dimension id.
bool isValidDim(Value *value);
/// Returns true if the given Value can be used as a symbol.
bool isValidSymbol(Value *value);
/// Modifies both `map` and `operands` in-place so as to:
/// 1. drop duplicate operands
/// 2. drop unused dims and symbols from map
void canonicalizeMapAndOperands(AffineMap *map,
llvm::SmallVectorImpl<Value *> *operands);
/// Returns a composed AffineApplyOp by composing `map` and `operands` with
/// other AffineApplyOps supplying those operands. The operands of the resulting
/// AffineApplyOp do not change the length of AffineApplyOp chains.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
llvm::ArrayRef<Value *> operands);
/// Given an affine map `map` and its input `operands`, this method composes
/// into `map`, maps of AffineApplyOps whose results are the values in
/// `operands`, iteratively until no more of `operands` are the result of an
/// AffineApplyOp. When this function returns, `map` becomes the composed affine
/// map, and each Value in `operands` is guaranteed to be either a loop IV or a
/// terminal symbol, i.e., a symbol defined at the top level or a block/function
/// argument.
void fullyComposeAffineMapAndOperands(AffineMap *map,
llvm::SmallVectorImpl<Value *> *operands);
#define GET_OP_CLASSES
#include "mlir/AffineOps/AffineOps.h.inc"
/// Returns if the provided value is the induction variable of a AffineForOp.
bool isForInductionVar(Value *val);
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
AffineForOp getForInductionVarOwner(Value *val);
/// Extracts the induction variables from a list of AffineForOps and places them
/// in the output argument `ivs`.
void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
SmallVectorImpl<Value *> *ivs);
/// AffineBound represents a lower or upper bound in the for operation.
/// This class does not own the underlying operands. Instead, it refers
/// to the operands stored in the AffineForOp. Its life span should not exceed
/// that of the for operation it refers to.
class AffineBound {
public:
AffineForOp getAffineForOp() { return op; }
AffineMap getMap() { return map; }
/// Returns an AffineValueMap representing this bound.
AffineValueMap getAsAffineValueMap();
unsigned getNumOperands() { return opEnd - opStart; }
Value *getOperand(unsigned idx) {
return op.getOperation()->getOperand(opStart + idx);
}
using operand_iterator = AffineForOp::operand_iterator;
using operand_range = AffineForOp::operand_range;
operand_iterator operand_begin() { return op.operand_begin() + opStart; }
operand_iterator operand_end() { return op.operand_begin() + opEnd; }
operand_range getOperands() { return {operand_begin(), operand_end()}; }
private:
// 'affine.for' operation that contains this bound.
AffineForOp op;
// Start and end positions of this affine bound operands in the list of
// the containing 'affine.for' operation operands.
unsigned opStart, opEnd;
// Affine map for this bound.
AffineMap map;
AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map)
: op(op), opStart(opStart), opEnd(opEnd), map(map) {}
friend class AffineForOp;
};
} // end namespace mlir
#endif

View File

@ -0,0 +1,257 @@
//===- Ops.td - Affine operation definitions ---------------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines MLIR affine operations.
//
//===----------------------------------------------------------------------===//
#ifdef AFFINE_OPS
#else
#define AFFINE_OPS
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def Affine_Dialect : Dialect {
let name = "affine";
let cppNamespace = "";
}
// Base class for Affine dialect ops.
class Affine_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Affine_Dialect, mnemonic, traits> {
// For every affine op, there needs to be a:
// * void print(OpAsmPrinter *p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
// * ParseResult parse${C++ class of Op}(OpAsmParser *parser,
// OperationState *result)
// functions.
let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
// Require regions to have affine terminator.
def ImplicitAffineTerminator
: SingleBlockImplicitTerminator<"AffineTerminatorOp">;
def AffineForOp : Affine_Op<"for", [ImplicitAffineTerminator]> {
let summary = "for operation";
let description = [{
The "affine.for" operation represents an affine loop nest, defining an SSA
value for its induction variable. It has one region capturing the loop body.
The induction variable is represented as a argument of this region. This SSA
value always has type index, which is the size of the machine word. The
stride, represented by step, is a positive constant integer which defaults
to "1" if not present. The lower and upper bounds specify a half-open range:
the range includes the lower bound but does not include the upper bound.
The body region must contain exactly one block that terminates with
"affine.terminator". Calling AffineForOp::build will create such region
and insert the terminator, so will the parsing even in cases if it is absent
from the custom format.
The lower and upper bounds of a for operation are represented as an
application of an affine mapping to a list of SSA values passed to the map.
The same restrictions hold for these SSA values as for all bindings of SSA
values to dimensions and symbols. The affine mappings for the bounds may
return multiple results, in which case the max/min keywords are required
(for the lower/upper bound respectively), and the bound is the
maximum/minimum of the returned values.
Example:
affine.for %i = 1 to 10 {
...
}
}];
let arguments = (ins Variadic<AnyType>);
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState *result, "
"int64_t lowerBound, int64_t upperBound, int64_t step = 1">,
OpBuilder<"Builder *builder, OperationState *result, "
"ArrayRef<Value *> lbOperands, AffineMap lbMap, "
"ArrayRef<Value *> ubOperands, AffineMap ubMap, "
"int64_t step = 1">
];
let extraClassDeclaration = [{
static StringRef getStepAttrName() { return "step"; }
static StringRef getLowerBoundAttrName() { return "lower_bound"; }
static StringRef getUpperBoundAttrName() { return "upper_bound"; }
Block *getBody() { return &region().front(); }
Value *getInductionVar() { return getBody()->getArgument(0); }
OpBuilder getBodyBuilder() {
return OpBuilder(getBody(), std::prev(getBody()->end()));
}
// TODO: provide iterators for the lower and upper bound operands
// if the current access via getLowerBound(), getUpperBound() is too slow.
/// Returns operands for the lower bound map.
operand_range getLowerBoundOperands();
/// Returns operands for the upper bound map.
operand_range getUpperBoundOperands();
/// Returns information about the lower bound as a single object.
AffineBound getLowerBound();
/// Returns information about the upper bound as a single object.
AffineBound getUpperBound();
/// Returns loop step.
int64_t getStep() {
return getAttr(getStepAttrName()).cast<IntegerAttr>().getInt();
}
/// Returns affine map for the lower bound.
AffineMap getLowerBoundMap() { return getLowerBoundMapAttr().getValue(); }
AffineMapAttr getLowerBoundMapAttr() {
return getAttr(getLowerBoundAttrName()).cast<AffineMapAttr>();
}
/// Returns affine map for the upper bound. The upper bound is exclusive.
AffineMap getUpperBoundMap() { return getUpperBoundMapAttr().getValue(); }
AffineMapAttr getUpperBoundMapAttr() {
return getAttr(getUpperBoundAttrName()).cast<AffineMapAttr>();
}
/// Set lower bound. The new bound must have the same number of operands as
/// the current bound map. Otherwise, 'replaceForLowerBound' should be used.
void setLowerBound(ArrayRef<Value *> operands, AffineMap map);
/// Set upper bound. The new bound must not have more operands than the
/// current bound map. Otherwise, 'replaceForUpperBound' should be used.
void setUpperBound(ArrayRef<Value *> operands, AffineMap map);
/// Set the lower bound map without changing operands.
void setLowerBoundMap(AffineMap map);
/// Set the upper bound map without changing operands.
void setUpperBoundMap(AffineMap map);
/// Set loop step.
void setStep(int64_t step) {
assert(step > 0 && "step has to be a positive integer constant");
auto *context = getLowerBoundMap().getContext();
setAttr(Identifier::get(getStepAttrName(), context),
IntegerAttr::get(IndexType::get(context), step));
}
/// Returns true if the lower bound is constant.
bool hasConstantLowerBound();
/// Returns true if the upper bound is constant.
bool hasConstantUpperBound();
/// Returns true if both bounds are constant.
bool hasConstantBounds() {
return hasConstantLowerBound() && hasConstantUpperBound();
}
/// Returns the value of the constant lower bound.
/// Fails assertion if the bound is non-constant.
int64_t getConstantLowerBound();
/// Returns the value of the constant upper bound. The upper bound is
/// exclusive. Fails assertion if the bound is non-constant.
int64_t getConstantUpperBound();
/// Sets the lower bound to the given constant value.
void setConstantLowerBound(int64_t value);
/// Sets the upper bound to the given constant value.
void setConstantUpperBound(int64_t value);
/// Returns true if both the lower and upper bound have the same operand
/// lists (same operands in the same order).
bool matchingBoundOperandList();
}];
let hasCanonicalizer = 1;
}
def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
let summary = "if-then-else operation";
let description = [{
The "if" operation represents an if-then-else construct for conditionally
executing two regions of code. The operands to an if operation are an
IntegerSet condition and a set of symbol/dimension operands to the
condition set. The operation produces no results. For example:
affine.if #set(%i) {
...
} else {
...
}
The 'else' blocks to the if operation are optional, and may be omitted. For
example:
affine.if #set(%i) {
...
}
}];
let arguments = (ins Variadic<AnyType>);
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState *result, "
"Value *cond, bool withElseRegion">
];
let extraClassDeclaration = [{
static StringRef getConditionAttrName() { return "condition"; }
IntegerSet getIntegerSet();
void setIntegerSet(IntegerSet newSet);
OpBuilder getThenBodyBuilder() {
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
Block &body = thenRegion().front();
return OpBuilder(&body, std::prev(body.end()));
}
OpBuilder getElseBodyBuilder() {
assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
Block &body = elseRegion().front();
return OpBuilder(&body, std::prev(body.end()));
}
}];
}
def AffineTerminatorOp :
Affine_Op<"terminator", [Terminator]> {
let summary = "affine terminator operation";
let description = [{
Affine terminator is a special terminator operation for blocks inside affine
loops and branches. It unconditionally transmits the control flow to the
successor of the operation enclosing the region.
This operation does _not_ have a custom syntax. However, affine control
operations omit the terminator in their custom syntax for brevity.
}];
// No custom parsing/printing form.
let parser = ?;
let printer = ?;
// Fully specified by traits.
let verifier = ?;
}
#endif // AFFINE_OPS

View File

@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS AffineOps.td)
mlir_tablegen(AffineOps.h.inc -gen-op-decls)
mlir_tablegen(AffineOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRAffineOpsIncGen)

View File

@ -0,0 +1,134 @@
//===- AffineAnalysis.h - analyses for affine structures --------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes for methods that perform analysis
// involving affine structures (AffineExprStorage, AffineMap, IntegerSet, etc.)
// and other IR structures that in turn use these.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
#define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
class AffineApplyOp;
class AffineForOp;
class AffineValueMap;
class FlatAffineConstraints;
class Operation;
class Value;
/// Returns in `affineApplyOps`, the sequence of those AffineApplyOp
/// Operations that are reachable via a search starting from `operands` and
/// ending at those operands that are not the result of an AffineApplyOp.
void getReachableAffineApplyOps(
llvm::ArrayRef<Value *> operands,
llvm::SmallVectorImpl<Operation *> &affineApplyOps);
/// Builds a system of constraints with dimensional identifiers corresponding to
/// the loop IVs of the forOps appearing in that order. Bounds of the loop are
/// used to add appropriate inequalities. Any symbols founds in the bound
/// operands are added as symbols in the system. Returns failure for the yet
/// unimplemented cases.
// TODO(bondhugula): handle non-unit strides.
LogicalResult getIndexSet(llvm::MutableArrayRef<AffineForOp> forOps,
FlatAffineConstraints *domain);
/// Encapsulates a memref load or store access information.
struct MemRefAccess {
Value *memref;
Operation *opInst;
llvm::SmallVector<Value *, 4> indices;
/// Constructs a MemRefAccess from a load or store operation.
// TODO(b/119949820): add accessors to standard op's load, store, DMA op's to
// return MemRefAccess, i.e., loadOp->getAccess(), dmaOp->getRead/WriteAccess.
explicit MemRefAccess(Operation *opInst);
// Returns the rank of the memref associated with this access.
unsigned getRank() const;
// Returns true if this access is of a store op.
bool isStore() const;
/// Populates 'accessMap' with composition of AffineApplyOps reachable from
// 'indices'.
void getAccessMap(AffineValueMap *accessMap) const;
};
// DependenceComponent contains state about the direction of a dependence as an
// interval [lb, ub] for an AffineForOp.
// Distance vectors components are represented by the interval [lb, ub] with
// lb == ub.
// Direction vectors components are represented by the interval [lb, ub] with
// lb < ub. Note that ub/lb == None means unbounded.
struct DependenceComponent {
// The AffineForOp Operation associated with this dependence component.
Operation *op;
// The lower bound of the dependence distance.
llvm::Optional<int64_t> lb;
// The upper bound of the dependence distance (inclusive).
llvm::Optional<int64_t> ub;
DependenceComponent() : lb(llvm::None), ub(llvm::None) {}
};
/// Checks whether two accesses to the same memref access the same element.
/// Each access is specified using the MemRefAccess structure, which contains
/// the operation, indices and memref associated with the access. Returns
/// 'NoDependence' if it can be determined conclusively that the accesses do not
/// access the same memref element. If 'allowRAR' is true, will consider
/// read-after-read dependences (typically used by applications trying to
/// optimize input reuse).
// TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into
// a single struct.
// TODO(andydavis) Make 'dependenceConstraints' optional arg.
struct DependenceResult {
enum ResultEnum {
HasDependence, // A dependence exists between 'srcAccess' and 'dstAccess'.
NoDependence, // No dependence exists between 'srcAccess' and 'dstAccess'.
Failure, // Dependence check failed due to unsupported cases.
} value;
DependenceResult(ResultEnum v) : value(v) {}
};
DependenceResult checkMemrefAccessDependence(
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
bool allowRAR = false);
/// Utility function that returns true if the provided DependenceResult
/// corresponds to a dependence result.
inline bool hasDependence(DependenceResult result) {
return result.value == DependenceResult::HasDependence;
}
/// Returns in 'depCompsVec', dependence components for dependences between all
/// load and store ops in loop nest rooted at 'forOp', at loop depths in range
/// [1, maxLoopDepth].
void getDependenceComponents(
AffineForOp forOp, unsigned maxLoopDepth,
std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec);
} // end namespace mlir
#endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H

View File

@ -0,0 +1,813 @@
//===- AffineStructures.h - MLIR Affine Structures Class --------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Structures for affine/polyhedral analysis of ML functions.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_AFFINE_STRUCTURES_H
#define MLIR_ANALYSIS_AFFINE_STRUCTURES_H
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
class AffineApplyOp;
class AffineBound;
class AffineCondition;
class AffineMap;
class AffineForOp;
class IntegerSet;
class MLIRContext;
class Value;
class HyperRectangularSet;
class MemRefType;
/// A mutable affine map. Its affine expressions are however unique.
struct MutableAffineMap {
public:
MutableAffineMap() {}
MutableAffineMap(AffineMap map);
ArrayRef<AffineExpr> getResults() const { return results; }
AffineExpr getResult(unsigned idx) const { return results[idx]; }
void setResult(unsigned idx, AffineExpr result) { results[idx] = result; }
unsigned getNumResults() const { return results.size(); }
unsigned getNumDims() const { return numDims; }
void setNumDims(unsigned d) { numDims = d; }
unsigned getNumSymbols() const { return numSymbols; }
void setNumSymbols(unsigned d) { numSymbols = d; }
MLIRContext *getContext() const { return context; }
/// Returns true if the idx'th result expression is a multiple of factor.
bool isMultipleOf(unsigned idx, int64_t factor) const;
/// Resets this MutableAffineMap with 'map'.
void reset(AffineMap map);
/// Simplify the (result) expressions in this map using analysis (used by
//-simplify-affine-expr pass).
void simplify();
/// Get the AffineMap corresponding to this MutableAffineMap. Note that an
/// AffineMap will be uniqued and stored in context, while a mutable one
/// isn't.
AffineMap getAffineMap() const;
private:
// Same meaning as AffineMap's fields.
SmallVector<AffineExpr, 8> results;
unsigned numDims;
unsigned numSymbols;
/// A pointer to the IR's context to store all newly created
/// AffineExprStorage's.
MLIRContext *context;
};
/// A mutable integer set. Its affine expressions are however unique.
struct MutableIntegerSet {
public:
MutableIntegerSet(IntegerSet set, MLIRContext *context);
/// Create a universal set (no constraints).
MutableIntegerSet(unsigned numDims, unsigned numSymbols,
MLIRContext *context);
unsigned getNumDims() const { return numDims; }
unsigned getNumSymbols() const { return numSymbols; }
unsigned getNumConstraints() const { return constraints.size(); }
void clear() {
constraints.clear();
eqFlags.clear();
}
private:
unsigned numDims;
unsigned numSymbols;
SmallVector<AffineExpr, 8> constraints;
SmallVector<bool, 8> eqFlags;
};
/// An AffineValueMap is an affine map plus its ML value operands and
/// results for analysis purposes. The structure is still a tree form that is
/// same as that of an affine map or an AffineApplyOp. However, its operands,
/// results, and its map can themselves change as a result of
/// substitutions, simplifications, and other analysis.
// An affine value map can readily be constructed from an AffineApplyOp, or an
// AffineBound of a AffineForOp. It can be further transformed, substituted
// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and
// destroyed during analysis. Only the AffineMap expressions that are pointed by
// them are unique'd. An affine value map, and the operations on it, maintain
// the invariant that operands are always positionally aligned with the
// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap.
// TODO(bondhugula): Some of these classes could go into separate files.
class AffineValueMap {
public:
// Creates an empty AffineValueMap (users should call 'reset' to reset map
// and operands).
AffineValueMap() {}
AffineValueMap(AffineMap map);
AffineValueMap(AffineMap map, ArrayRef<Value *> operands,
ArrayRef<Value *> results = llvm::None);
explicit AffineValueMap(AffineApplyOp applyOp);
explicit AffineValueMap(AffineBound bound);
~AffineValueMap();
// Resets this AffineValueMap with 'map', 'operands', and 'results'.
void reset(AffineMap map, ArrayRef<Value *> operands,
ArrayRef<Value *> results = llvm::None);
/// Return true if the idx^th result can be proved to be a multiple of
/// 'factor', false otherwise.
inline bool isMultipleOf(unsigned idx, int64_t factor) const;
/// Return true if the idx^th result depends on 'value', false otherwise.
bool isFunctionOf(unsigned idx, Value *value) const;
/// Return true if the result at 'idx' is a constant, false
/// otherwise.
bool isConstant(unsigned idx) const;
/// Return true if this is an identity map.
bool isIdentity() const;
inline unsigned getNumOperands() const { return operands.size(); }
inline unsigned getNumDims() const { return map.getNumDims(); }
inline unsigned getNumSymbols() const { return map.getNumSymbols(); }
inline unsigned getNumResults() const { return map.getNumResults(); }
Value *getOperand(unsigned i) const;
ArrayRef<Value *> getOperands() const;
AffineMap getAffineMap() const;
private:
// A mutable affine map.
MutableAffineMap map;
// TODO: make these trailing objects?
/// The SSA operands binding to the dim's and symbols of 'map'.
SmallVector<Value *, 4> operands;
/// The SSA results binding to the results of 'map'.
SmallVector<Value *, 4> results;
};
/// An IntegerValueSet is an integer set plus its operands.
// Both, the integer set being pointed to and the operands can change during
// analysis, simplification, and transformation.
class IntegerValueSet {
/// Constructs an integer value set from an affine value map.
// This will lead to a single equality in 'set'.
explicit IntegerValueSet(const AffineValueMap &avm);
/// Returns true if this integer set is determined to be empty. Emptiness is
/// checked by by eliminating identifiers successively (through either
/// Gaussian or Fourier-Motzkin) while using the GCD test and a trivial
/// invalid constraint check. Returns 'true' if the constaint system is found
/// to be empty; false otherwise. This method is exact for rational spaces but
/// not integer spaces - thus, if it returns true, the set is provably integer
/// empty as well, but if it returns false, it doesn't necessarily mean an
/// integer point exists in it. This method also returns false where an
/// explosion of constraints is detected - due to the super-exponential
/// worse-case complexity of Fourier-Motzkin elimination (rare for realistic
/// problem cases but possible for artificial adversarial or improperly
// constructed ones), this method returns false conservatively.
bool isEmpty() const;
bool getNumDims() const { return set.getNumDims(); }
bool getNumSymbols() const { return set.getNumSymbols(); }
private:
// The set pointed to may itself change unlike in IR structures like
// 'AffineCondition'.
MutableIntegerSet set;
/// The SSA operands binding to the dim's and symbols of 'set'.
SmallVector<Value *, 4> operands;
};
/// A flat list of affine equalities and inequalities in the form.
/// Inequality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} == 0
/// Equality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} >= 0
///
/// FlatAffineConstraints stores coefficients in a contiguous buffer (one buffer
/// for equalities and one for inequalities). The size of each buffer is
/// numReservedCols * number of inequalities (or equalities). The reserved size
/// is numReservedCols * numReservedInequalities (or numReservedEqualities). A
/// coefficient (r, c) lives at the location numReservedCols * r + c in the
/// buffer. The extra space between getNumCols() and numReservedCols exists to
/// prevent frequent movement of data when adding columns, especially at the
/// end.
///
/// The identifiers x_0, x_1, ... appear in the order: dimensional identifiers,
/// symbolic identifiers, and local identifiers. The local identifiers
/// correspond to local/internal variables created when converting from
/// AffineExpr's containing mod's and div's; they are thus needed to increase
/// representational power. Each local identifier is always (by construction) a
/// floordiv of a pure add/mul affine function of dimensional, symbolic, and
/// other local identifiers, in a non-mutually recursive way. Hence, every local
/// identifier can ultimately always be recovered as an affine function of
/// dimensional and symbolic identifiers (involving floordiv's); note however
/// that some floordiv combinations are converted to mod's by AffineExpr
/// construction.
///
class FlatAffineConstraints {
public:
enum IdKind { Dimension, Symbol, Local };
/// Constructs a constraint system reserving memory for the specified number
/// of constraints and identifiers..
FlatAffineConstraints(unsigned numReservedInequalities,
unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims = 0,
unsigned numSymbols = 0, unsigned numLocals = 0,
ArrayRef<Optional<Value *>> idArgs = {})
: numReservedCols(numReservedCols), numDims(numDims),
numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1);
assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals);
equalities.reserve(numReservedCols * numReservedEqualities);
inequalities.reserve(numReservedCols * numReservedInequalities);
numIds = numDims + numSymbols + numLocals;
ids.reserve(numReservedCols);
if (idArgs.empty())
ids.resize(numIds, None);
else
ids.append(idArgs.begin(), idArgs.end());
}
/// Constructs a constraint system with the specified number of
/// dimensions and symbols.
FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0,
ArrayRef<Optional<Value *>> idArgs = {})
: numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims),
numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1);
assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals);
numIds = numDims + numSymbols + numLocals;
ids.reserve(numIds);
if (idArgs.empty())
ids.resize(numIds, None);
else
ids.append(idArgs.begin(), idArgs.end());
}
explicit FlatAffineConstraints(const HyperRectangularSet &set);
/// Create a flat affine constraint system from an AffineValueMap or a list of
/// these. The constructed system will only include equalities.
// TODO(bondhugula)
explicit FlatAffineConstraints(const AffineValueMap &avm);
explicit FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef);
/// Creates an affine constraint system from an IntegerSet.
explicit FlatAffineConstraints(IntegerSet set);
/// Create an affine constraint system from an IntegerValueSet.
// TODO(bondhugula)
explicit FlatAffineConstraints(const IntegerValueSet &set);
FlatAffineConstraints(const FlatAffineConstraints &other);
FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef,
IntegerSet set);
FlatAffineConstraints(const MutableAffineMap &map);
~FlatAffineConstraints() {}
// Clears any existing data and reserves memory for the specified constraints.
void reset(unsigned numReservedInequalities, unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims, unsigned numSymbols,
unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
void reset(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
/// Appends constraints from 'other' into this. This is equivalent to an
/// intersection with no simplification of any sort attempted.
void append(const FlatAffineConstraints &other);
// Checks for emptiness by performing variable elimination on all identifiers,
// running the GCD test on each equality constraint, and checking for invalid
// constraints.
// Returns true if the GCD test fails for any equality, or if any invalid
// constraints are discovered on any row. Returns false otherwise.
bool isEmpty() const;
// Runs the GCD test on all equality constraints. Returns 'true' if this test
// fails on any equality. Returns 'false' otherwise.
// This test can be used to disprove the existence of a solution. If it
// returns true, no integer solution to the equality constraints can exist.
bool isEmptyByGCDTest() const;
// Clones this object.
std::unique_ptr<FlatAffineConstraints> clone() const;
/// Returns the value at the specified equality row and column.
inline int64_t atEq(unsigned i, unsigned j) const {
return equalities[i * numReservedCols + j];
}
inline int64_t &atEq(unsigned i, unsigned j) {
return equalities[i * numReservedCols + j];
}
inline int64_t atIneq(unsigned i, unsigned j) const {
return inequalities[i * numReservedCols + j];
}
inline int64_t &atIneq(unsigned i, unsigned j) {
return inequalities[i * numReservedCols + j];
}
/// Returns the number of columns in the constraint system.
inline unsigned getNumCols() const { return numIds + 1; }
inline unsigned getNumEqualities() const {
assert(equalities.size() % numReservedCols == 0 &&
"inconsistent equality buffer size");
return equalities.size() / numReservedCols;
}
inline unsigned getNumInequalities() const {
assert(inequalities.size() % numReservedCols == 0 &&
"inconsistent inequality buffer size");
return inequalities.size() / numReservedCols;
}
inline unsigned getNumReservedEqualities() const {
return equalities.capacity() / numReservedCols;
}
inline unsigned getNumReservedInequalities() const {
return inequalities.capacity() / numReservedCols;
}
inline ArrayRef<int64_t> getEquality(unsigned idx) const {
return ArrayRef<int64_t>(&equalities[idx * numReservedCols], getNumCols());
}
inline ArrayRef<int64_t> getInequality(unsigned idx) const {
return ArrayRef<int64_t>(&inequalities[idx * numReservedCols],
getNumCols());
}
AffineExpr toAffineExpr(unsigned idx, MLIRContext *context);
/// Adds constraints (lower and upper bounds) for the specified 'affine.for'
/// operation's Value using IR information stored in its bound maps. The
/// right identifier is first looked up using forOp's Value. Asserts if the
/// Value corresponding to the 'affine.for' operation isn't found in the
/// constraint system. Returns failure for the yet unimplemented/unsupported
/// cases. Any new identifiers that are found in the bound operands of the
/// 'affine.for' operation are added as trailing identifiers (either
/// dimensional or symbolic depending on whether the operand is a valid
/// symbol).
// TODO(bondhugula): add support for non-unit strides.
LogicalResult addAffineForOpDomain(AffineForOp forOp);
/// Adds a lower or an upper bound for the identifier at the specified
/// position with constraints being drawn from the specified bound map and
/// operands. If `eq` is true, add a single equality equal to the bound map's
/// first result expr.
LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
ArrayRef<Value *> operands, bool eq,
bool lower = true);
/// Computes the lower and upper bounds of the first 'num' dimensional
/// identifiers (starting at 'offset') as an affine map of the remaining
/// identifiers (dimensional and symbolic). This method is able to detect
/// identifiers as floordiv's and mod's of affine expressions of other
/// identifiers with respect to (positive) constants. Sets bound map to a
/// null AffineMap if such a bound can't be found (or yet unimplemented).
void getSliceBounds(unsigned offset, unsigned num, MLIRContext *context,
SmallVectorImpl<AffineMap> *lbMaps,
SmallVectorImpl<AffineMap> *ubMaps);
/// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
/// bounds in 'ubMaps' to each identifier in the constraint system which has
/// a value in 'values'. Note that both lower/upper bounds share the same
/// operand list 'operands'.
/// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size'.
/// Note that both lower/upper bounds use operands from 'operands'.
LogicalResult addSliceBounds(ArrayRef<Value *> values,
ArrayRef<AffineMap> lbMaps,
ArrayRef<AffineMap> ubMaps,
ArrayRef<Value *> operands);
// Adds an inequality (>= 0) from the coefficients specified in inEq.
void addInequality(ArrayRef<int64_t> inEq);
// Adds an equality from the coefficients specified in eq.
void addEquality(ArrayRef<int64_t> eq);
/// Adds a constant lower bound constraint for the specified identifier.
void addConstantLowerBound(unsigned pos, int64_t lb);
/// Adds a constant upper bound constraint for the specified identifier.
void addConstantUpperBound(unsigned pos, int64_t ub);
/// Adds a new local identifier as the floordiv of an affine function of other
/// identifiers, the coefficients of which are provided in 'dividend' and with
/// respect to a positive constant 'divisor'. Two constraints are added to the
/// system to capture equivalence with the floordiv:
/// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1.
void addLocalFloorDiv(ArrayRef<int64_t> dividend, int64_t divisor);
/// Adds a constant lower bound constraint for the specified expression.
void addConstantLowerBound(ArrayRef<int64_t> expr, int64_t lb);
/// Adds a constant upper bound constraint for the specified expression.
void addConstantUpperBound(ArrayRef<int64_t> expr, int64_t ub);
/// Sets the identifier at the specified position to a constant.
void setIdToConstant(unsigned pos, int64_t val);
/// Sets the identifier corresponding to the specified Value id to a
/// constant. Asserts if the 'id' is not found.
void setIdToConstant(Value &id, int64_t val);
/// Looks up the position of the identifier with the specified Value. Returns
/// true if found (false otherwise). `pos' is set to the (column) position of
/// the identifier.
bool findId(Value &id, unsigned *pos) const;
/// Returns true if an identifier with the specified Value exists, false
/// otherwise.
bool containsId(Value &id) const;
// Add identifiers of the specified kind - specified positions are relative to
// the kind of identifier. The coefficient column corresponding to the added
// identifier is initialized to zero. 'id' is the Value corresponding to the
// identifier that can optionally be provided.
void addDimId(unsigned pos, Value *id = nullptr);
void addSymbolId(unsigned pos, Value *id = nullptr);
void addLocalId(unsigned pos);
void addId(IdKind kind, unsigned pos, Value *id = nullptr);
/// Add the specified values as a dim or symbol id depending on its nature, if
/// it already doesn't exist in the system. `id' has to be either a terminal
/// symbol or a loop IV, i.e., it cannot be the result affine.apply of any
/// symbols or loop IVs. The identifier is added to the end of the existing
/// dims or symbols. Additional information on the identifier is extracted
/// from the IR and added to the constraint system.
void addInductionVarOrTerminalSymbol(Value *id);
/// Composes the affine value map with this FlatAffineConstrains, adding the
/// results of the map as dimensions at the front [0, vMap->getNumResults())
/// and with the dimensions set to the equalities specified by the value map.
/// Returns failure if the composition fails (when vMap is a semi-affine map).
/// The vMap's operand Value's are used to look up the right positions in
/// the FlatAffineConstraints with which to associate. The dimensional and
/// symbolic operands of vMap should match 1:1 (in the same order) with those
/// of this constraint system, but the latter could have additional trailing
/// operands.
LogicalResult composeMap(AffineValueMap *vMap);
/// Projects out (aka eliminates) 'num' identifiers starting at position
/// 'pos'. The resulting constraint system is the shadow along the dimensions
/// that still exist. This method may not always be integer exact.
// TODO(bondhugula): deal with integer exactness when necessary - can return a
// value to mark exactness for example.
void projectOut(unsigned pos, unsigned num);
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
/// Projects out the identifier that is associate with Value *.
void projectOut(Value *id);
void removeId(IdKind idKind, unsigned pos);
void removeId(unsigned pos);
void removeDim(unsigned pos);
void removeEquality(unsigned pos);
void removeInequality(unsigned pos);
/// Changes the partition between dimensions and symbols. Depending on the new
/// symbol count, either a chunk of trailing dimensional identifiers becomes
/// symbols, or some of the leading symbols become dimensions.
void setDimSymbolSeparation(unsigned newSymbolCount);
/// Changes all symbol identifiers which are loop IVs to dim identifiers.
void convertLoopIVSymbolsToDims();
/// Sets the specified identifier to a constant and removes it.
void setAndEliminate(unsigned pos, int64_t constVal);
/// Tries to fold the specified identifier to a constant using a trivial
/// equality detection; if successful, the constant is substituted for the
/// identifier everywhere in the constraint system and then removed from the
/// system.
LogicalResult constantFoldId(unsigned pos);
/// This method calls constantFoldId for the specified range of identifiers,
/// 'num' identifiers starting at position 'pos'.
void constantFoldIdRange(unsigned pos, unsigned num);
/// Returns true if all the identifiers in the specified range [start, limit)
/// can only take a single value each if the remaining identifiers are treated
/// as symbols/parameters, i.e., for given values of the latter, there only
/// exists a unique value for each of the dimensions in the specified range.
bool isRangeOneToOne(unsigned start, unsigned limit) const;
/// Updates the constraints to be the smallest bounding (enclosing) box that
/// contains the points of 'this' set and that of 'other', with the symbols
/// being treated specially. For each of the dimensions, the min of the lower
/// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
/// to determine such a bounding box. `other' is expected to have the same
/// dimensional identifiers as this constraint system (in the same order).
///
/// Eg: if 'this' is {0 <= d0 <= 127}, 'other' is {16 <= d0 <= 192}, the
/// output is {0 <= d0 <= 192}.
/// 2) 'this' = {s0 + 5 <= d0 <= s0 + 20}, 'other' is {s0 + 1 <= d0 <= s0 +
/// 9}, output = {s0 + 1 <= d0 <= s0 + 20}.
/// 3) 'this' = {0 <= d0 <= 5, 1 <= d1 <= 9}, 'other' = {2 <= d0 <= 6, 5 <= d1
/// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}.
LogicalResult unionBoundingBox(const FlatAffineConstraints &other);
/// Returns 'true' if this constraint system and 'other' are in the same
/// space, i.e., if they are associated with the same set of identifiers,
/// appearing in the same order. Returns 'false' otherwise.
bool areIdsAlignedWithOther(const FlatAffineConstraints &other);
/// Merge and align the identifiers of 'this' and 'other' starting at
/// 'offset', so that both constraint systems get the union of the contained
/// identifiers that is dimension-wise and symbol-wise unique; both
/// constraint systems are updated so that they have the union of all
/// identifiers, with this's original identifiers appearing first followed by
/// any of other's identifiers that didn't appear in 'this'. Local
/// identifiers of each system are by design separate/local and are placed
/// one after other (this's followed by other's).
// Eg: Input: 'this' has ((%i %j) [%M %N])
// 'other' has (%k, %j) [%P, %N, %M])
// Output: both 'this', 'other' have (%i, %j, %k) [%M, %N, %P]
//
void mergeAndAlignIdsWithOther(unsigned offset, FlatAffineConstraints *other);
unsigned getNumConstraints() const {
return getNumInequalities() + getNumEqualities();
}
inline unsigned getNumIds() const { return numIds; }
inline unsigned getNumDimIds() const { return numDims; }
inline unsigned getNumSymbolIds() const { return numSymbols; }
inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; }
inline unsigned getNumLocalIds() const {
return numIds - numDims - numSymbols;
}
inline ArrayRef<Optional<Value *>> getIds() const {
return {ids.data(), ids.size()};
}
inline MutableArrayRef<Optional<Value *>> getIds() {
return {ids.data(), ids.size()};
}
/// Returns the optional Value corresponding to the pos^th identifier.
inline Optional<Value *> getId(unsigned pos) const { return ids[pos]; }
inline Optional<Value *> &getId(unsigned pos) { return ids[pos]; }
/// Returns the Value associated with the pos^th identifier. Asserts if
/// no Value identifier was associated.
inline Value *getIdValue(unsigned pos) const {
assert(ids[pos].hasValue() && "identifier's Value not set");
return ids[pos].getValue();
}
/// Returns the Values associated with identifiers in range [start, end).
/// Asserts if no Value was associated with one of these identifiers.
void getIdValues(unsigned start, unsigned end,
SmallVectorImpl<Value *> *values) const {
assert((start < numIds || start == end) && "invalid start position");
assert(end <= numIds && "invalid end position");
values->clear();
values->reserve(end - start);
for (unsigned i = start; i < end; i++) {
values->push_back(getIdValue(i));
}
}
inline void getAllIdValues(SmallVectorImpl<Value *> *values) const {
getIdValues(0, numIds, values);
}
/// Sets Value associated with the pos^th identifier.
inline void setIdValue(unsigned pos, Value *val) {
assert(pos < numIds && "invalid id position");
ids[pos] = val;
}
/// Sets Values associated with identifiers in the range [start, end).
void setIdValues(unsigned start, unsigned end, ArrayRef<Value *> values) {
assert((start < numIds || end == start) && "invalid start position");
assert(end <= numIds && "invalid end position");
assert(values.size() == end - start);
for (unsigned i = start; i < end; ++i)
ids[i] = values[i - start];
}
/// Clears this list of constraints and copies other into it.
void clearAndCopyFrom(const FlatAffineConstraints &other);
/// Returns the smallest known constant bound for the extent of the specified
/// identifier (pos^th), i.e., the smallest known constant that is greater
/// than or equal to 'exclusive upper bound' - 'lower bound' of the
/// identifier. Returns None if it's not a constant. This method employs
/// trivial (low complexity / cost) checks and detection. Symbolic identifiers
/// are treated specially, i.e., it looks for constant differences between
/// affine expressions involving only the symbolic identifiers. See comments
/// at function definition for examples. 'lb' and 'lbDivisor', if provided,
/// are used to express the lower bound associated with the constant
/// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg.,
/// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three
/// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32.
Optional<int64_t>
getConstantBoundOnDimSize(unsigned pos,
SmallVectorImpl<int64_t> *lb = nullptr,
int64_t *lbFloorDivisor = nullptr,
SmallVectorImpl<int64_t> *ub = nullptr) const;
/// Returns the constant lower bound for the pos^th identifier if there is
/// one; None otherwise.
Optional<int64_t> getConstantLowerBound(unsigned pos) const;
/// Returns the constant upper bound for the pos^th identifier if there is
/// one; None otherwise.
Optional<int64_t> getConstantUpperBound(unsigned pos) const;
/// Gets the lower and upper bound of the pos^th identifier treating
/// [0, offset) U [offset + num, symbStartPos) as dimensions and
/// [symStartPos, getNumDimAndSymbolIds) as symbols. The returned
/// multi-dimensional maps in the pair represent the max and min of
/// potentially multiple affine expressions. The upper bound is exclusive.
/// 'localExprs' holds pre-computed AffineExpr's for all local identifiers in
/// the system.
std::pair<AffineMap, AffineMap>
getLowerAndUpperBound(unsigned pos, unsigned offset, unsigned num,
unsigned symStartPos, ArrayRef<AffineExpr> localExprs,
MLIRContext *context);
/// Returns true if the set can be trivially detected as being
/// hyper-rectangular on the specified contiguous set of identifiers.
bool isHyperRectangular(unsigned pos, unsigned num) const;
/// Removes duplicate constraints, trivially true constraints, and constraints
/// that can be detected as redundant as a result of differing only in their
/// constant term part. A constraint of the form <non-negative constant> >= 0
/// is considered trivially true. This method is a linear time method on the
/// constraints, does a single scan, and updates in place.
void removeTrivialRedundancy();
/// A more expensive check to detect redundant inequalities thatn
/// removeTrivialRedundancy.
void removeRedundantInequalities();
// Removes all equalities and inequalities.
void clearConstraints();
void print(raw_ostream &os) const;
void dump() const;
private:
/// Returns false if the fields corresponding to various identifier counts, or
/// equality/inequality buffer sizes aren't consistent; true otherwise. This
/// is meant to be used within an assert internally.
bool hasConsistentState() const;
/// Checks all rows of equality/inequality constraints for trivial
/// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
/// after elimination. Returns 'true' if an invalid constraint is found;
/// 'false'otherwise.
bool hasInvalidConstraint() const;
/// Returns the constant lower bound bound if isLower is true, and the upper
/// bound if isLower is false.
template <bool isLower>
Optional<int64_t> computeConstantLowerOrUpperBound(unsigned pos);
// Eliminates a single identifier at 'position' from equality and inequality
// constraints. Returns 'success' if the identifier was eliminated, and
// 'failure' otherwise.
inline LogicalResult gaussianEliminateId(unsigned position) {
return success(gaussianEliminateIds(position, position + 1) == 1);
}
// Eliminates identifiers from equality and inequality constraints
// in column range [posStart, posLimit).
// Returns the number of variables eliminated.
unsigned gaussianEliminateIds(unsigned posStart, unsigned posLimit);
/// Eliminates identifier at the specified position using Fourier-Motzkin
/// variable elimination, but uses Gaussian elimination if there is an
/// equality involving that identifier. If the result of the elimination is
/// integer exact, *isResultIntegerExact is set to true. If 'darkShadow' is
/// set to true, a potential under approximation (subset) of the rational
/// shadow / exact integer shadow is computed.
// See implementation comments for more details.
void FourierMotzkinEliminate(unsigned pos, bool darkShadow = false,
bool *isResultIntegerExact = nullptr);
/// Tightens inequalities given that we are dealing with integer spaces. This
/// is similar to the GCD test but applied to inequalities. The constant term
/// can be reduced to the preceding multiple of the GCD of the coefficients,
/// i.e.,
/// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a
/// fast method (linear in the number of coefficients).
void GCDTightenInequalities();
/// Normalized each constraints by the GCD of its coefficients.
void normalizeConstraintsByGCD();
/// Removes identifiers in column range [idStart, idLimit), and copies any
/// remaining valid data into place, updates member variables, and resizes
/// arrays as needed.
void removeIdRange(unsigned idStart, unsigned idLimit);
/// Coefficients of affine equalities (in == 0 form).
SmallVector<int64_t, 64> equalities;
/// Coefficients of affine inequalities (in >= 0 form).
SmallVector<int64_t, 64> inequalities;
/// Number of columns reserved. Actual ones in used are returned by
/// getNumCols().
unsigned numReservedCols;
/// Total number of identifiers.
unsigned numIds;
/// Number of identifiers corresponding to real dimensions.
unsigned numDims;
/// Number of identifiers corresponding to symbols (unknown but constant for
/// analysis).
unsigned numSymbols;
/// Values corresponding to the (column) identifiers of this constraint
/// system appearing in the order the identifiers correspond to columns.
/// Temporary ones or those that aren't associated to any Value are set to
/// None.
SmallVector<Optional<Value *>, 8> ids;
/// A parameter that controls detection of an unrealistic number of
/// constraints. If the number of constraints is this many times the number of
/// variables, we consider such a system out of line with the intended use
/// case of FlatAffineConstraints.
// The rationale for 32 is that in the typical simplest of cases, an
// identifier is expected to have one lower bound and one upper bound
// constraint. With a level of tiling or a connection to another identifier
// through a div or mod, an extra pair of bounds gets added. As a limit, we
// don't expect an identifier to have more than 32 lower/upper/equality
// constraints. This is conservatively set low and can be raised if needed.
constexpr static unsigned kExplosionFactor = 32;
};
/// Simplify an affine expression by flattening and some amount of
/// simple analysis. This has complexity linear in the number of nodes in
/// 'expr'. Returns the simplified expression, which is the same as the input
/// expression if it can't be simplified.
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols);
/// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' could not be
/// flattened (i.e., semi-affine is not yet handled). 'cst' contains constraints
/// that connect newly introduced local identifiers to existing dimensional and
/// symbolic identifiers. See documentation for AffineExprFlattener on how
/// mod's and div's are flattened.
LogicalResult
getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
llvm::SmallVectorImpl<int64_t> *flattenedExpr,
FlatAffineConstraints *cst = nullptr);
/// Flattens the result expressions of the map to their corresponding flattened
/// forms and set in 'flattenedExprs'. Returns failure if any expression in the
/// map could not be flattened (i.e., semi-affine is not yet handled). 'cst'
/// contains constraints that connect newly introduced local identifiers to
/// existing dimensional and / symbolic identifiers. See documentation for
/// AffineExprFlattener on how mod's and div's are flattened. For all affine
/// expressions that share the same operands (like those of an affine map), this
/// method should be used instead of repeatedly calling getFlattenedAffineExpr
/// since local variables added to deal with div's and mod's will be reused
/// across expressions.
LogicalResult getFlattenedAffineExprs(
AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
FlatAffineConstraints *cst = nullptr);
LogicalResult getFlattenedAffineExprs(
IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
FlatAffineConstraints *cst = nullptr);
} // end namespace mlir.
#endif // MLIR_ANALYSIS_AFFINE_STRUCTURES_H

View File

@ -0,0 +1,144 @@
//===- Dominance.h - Dominator analysis for CFGs ----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_ANALYSIS_DOMINANCE_H
#define MLIR_ANALYSIS_DOMINANCE_H
#include "mlir/IR/RegionGraphTraits.h"
#include "llvm/Support/GenericDomTree.h"
extern template class llvm::DominatorTreeBase<mlir::Block, false>;
extern template class llvm::DominatorTreeBase<mlir::Block, true>;
namespace mlir {
using DominanceInfoNode = llvm::DomTreeNodeBase<Block>;
class Operation;
namespace detail {
template <bool IsPostDom> class DominanceInfoBase {
using base = llvm::DominatorTreeBase<Block, IsPostDom>;
public:
DominanceInfoBase(Operation *op) { recalculate(op); }
DominanceInfoBase(DominanceInfoBase &&) = default;
DominanceInfoBase &operator=(DominanceInfoBase &&) = default;
DominanceInfoBase(const DominanceInfoBase &) = delete;
DominanceInfoBase &operator=(const DominanceInfoBase &) = delete;
/// Recalculate the dominance info.
void recalculate(Operation *op);
/// Get the root dominance node of the given region.
DominanceInfoNode *getRootNode(Region *region) {
assert(dominanceInfos.count(region) != 0);
return dominanceInfos[region]->getRootNode();
}
protected:
using super = DominanceInfoBase<IsPostDom>;
/// Return true if the specified block A properly dominates block B.
bool properlyDominates(Block *a, Block *b);
/// A mapping of regions to their base dominator tree.
llvm::DenseMap<Region *, std::unique_ptr<base>> dominanceInfos;
};
} // end namespace detail
/// A class for computing basic dominance information.
class DominanceInfo : public detail::DominanceInfoBase</*IsPostDom=*/false> {
public:
using super::super;
/// Return true if operation A properly dominates operation B.
bool properlyDominates(Operation *a, Operation *b);
/// Return true if operation A dominates operation B.
bool dominates(Operation *a, Operation *b) {
return a == b || properlyDominates(a, b);
}
/// Return true if value A properly dominates operation B.
bool properlyDominates(Value *a, Operation *b);
/// Return true if operation A dominates operation B.
bool dominates(Value *a, Operation *b) {
return (Operation *)a->getDefiningOp() == b || properlyDominates(a, b);
}
/// Return true if the specified block A dominates block B.
bool dominates(Block *a, Block *b) {
return a == b || properlyDominates(a, b);
}
/// Return true if the specified block A properly dominates block B.
bool properlyDominates(Block *a, Block *b) {
return super::properlyDominates(a, b);
}
};
/// A class for computing basic postdominance information.
class PostDominanceInfo : public detail::DominanceInfoBase</*IsPostDom=*/true> {
public:
using super::super;
/// Return true if operation A properly postdominates operation B.
bool properlyPostDominates(Operation *a, Operation *b);
/// Return true if operation A postdominates operation B.
bool postDominates(Operation *a, Operation *b) {
return a == b || properlyPostDominates(a, b);
}
/// Return true if the specified block A properly postdominates block B.
bool properlyPostDominates(Block *a, Block *b) {
return super::properlyDominates(a, b);
}
/// Return true if the specified block A postdominates block B.
bool postDominates(Block *a, Block *b) {
return a == b || properlyPostDominates(a, b);
}
};
} // end namespace mlir
namespace llvm {
/// DominatorTree GraphTraits specialization so the DominatorTree can be
/// iterated by generic graph iterators.
template <> struct GraphTraits<mlir::DominanceInfoNode *> {
using ChildIteratorType = mlir::DominanceInfoNode::iterator;
using NodeRef = mlir::DominanceInfoNode *;
static NodeRef getEntryNode(NodeRef N) { return N; }
static inline ChildIteratorType child_begin(NodeRef N) { return N->begin(); }
static inline ChildIteratorType child_end(NodeRef N) { return N->end(); }
};
template <> struct GraphTraits<const mlir::DominanceInfoNode *> {
using ChildIteratorType = mlir::DominanceInfoNode::const_iterator;
using NodeRef = const mlir::DominanceInfoNode *;
static NodeRef getEntryNode(NodeRef N) { return N; }
static inline ChildIteratorType child_begin(NodeRef N) { return N->begin(); }
static inline ChildIteratorType child_end(NodeRef N) { return N->end(); }
};
} // end namespace llvm
#endif

View File

@ -0,0 +1,111 @@
//===- LoopAnalysis.h - loop analysis methods -------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes for methods to analyze loops.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_LOOP_ANALYSIS_H
#define MLIR_ANALYSIS_LOOP_ANALYSIS_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
namespace mlir {
class AffineExpr;
class AffineForOp;
class AffineMap;
class Operation;
class MemRefType;
class Value;
/// Returns the trip count of the loop as an affine map with its corresponding
/// operands if the latter is expressible as an affine expression, and nullptr
/// otherwise. This method always succeeds as long as the lower bound is not a
/// multi-result map. The trip count expression is simplified before returning.
/// This method only utilizes map composition to construct lower and upper
/// bounds before computing the trip count expressions
// TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a
// pure analysis method relying on FlatAffineConstraints
void buildTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
SmallVectorImpl<Value *> *operands);
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// uses affine expression analysis and is able to determine constant trip count
/// in non-trivial cases.
llvm::Optional<uint64_t> getConstantTripCount(AffineForOp forOp);
/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t getLargestDivisorOfTripCount(AffineForOp forOp);
/// Given an induction variable `iv` of type AffineForOp and an `index` of type
/// IndexType, returns `true` if `index` is independent of `iv` and false
/// otherwise.
/// The determination supports composition with at most one AffineApplyOp.
/// The at most one AffineApplyOp comes from the fact that composition of
/// AffineApplyOp need to be canonicalized by construction to avoid writing code
/// that composes arbitrary numbers of AffineApplyOps everywhere. To achieve
/// this, at the very least, the compose-affine-apply pass must have been run.
///
/// Prerequisites:
/// 1. `iv` and `index` of the proper type;
/// 2. at most one reachable AffineApplyOp from index;
///
/// Returns false in cases with more than one AffineApplyOp, this is
/// conservative.
bool isAccessInvariant(Value *iv, Value *index);
/// Given an induction variable `iv` of type AffineForOp and `indices` of type
/// IndexType, returns the set of `indices` that are independent of `iv`.
///
/// Prerequisites (inherited from `isAccessInvariant` above):
/// 1. `iv` and `indices` of the proper type;
/// 2. at most one affine.apply is reachable from each index in `indices`;
///
/// Emits a note if it encounters a chain of affine.apply and conservatively
/// those cases.
llvm::DenseSet<Value *, llvm::DenseMapInfo<Value *>>
getInvariantAccesses(Value *iv, llvm::ArrayRef<Value *> indices);
using VectorizableLoopFun = std::function<bool(AffineForOp)>;
/// Checks whether the loop is structurally vectorizable; i.e.:
/// 1. no conditionals are nested under the loop;
/// 2. all nested load/stores are to scalar MemRefs.
/// TODO(ntv): relax the no-conditionals restriction
bool isVectorizableLoopBody(AffineForOp loop);
/// Checks whether the loop is structurally vectorizable and that all the LoadOp
/// and StoreOp matched have access indexing functions that are are either:
/// 1. invariant along the loop induction variable created by 'loop';
/// 2. varying along at most one memory dimension. If such a unique dimension
/// is found, it is written into `memRefDim`.
bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim);
/// Checks where SSA dominance would be violated if a for op's body
/// operations are shifted by the specified shifts. This method checks if a
/// 'def' and all its uses have the same shift factor.
// TODO(mlir-team): extend this to check for memory-based dependence
// violation when we have the support.
bool isInstwiseShiftValid(AffineForOp forOp, llvm::ArrayRef<uint64_t> shifts);
} // end namespace mlir
#endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H

View File

@ -0,0 +1,193 @@
//===- NestedMacher.h - Nested matcher for MLFunction -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
#define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/Allocator.h"
namespace mlir {
struct NestedPattern;
class Operation;
/// An NestedPattern captures nested patterns in the IR.
/// It is used in conjunction with a scoped NestedPatternContext which is an
/// llvm::BumpPtrAllocator that handles memory allocations efficiently and
/// avoids ownership issues.
///
/// In order to use NestedPatterns, first create a scoped context.
/// When the context goes out of scope, everything is freed.
/// This design simplifies the API by avoiding references to the context and
/// makes it clear that references to matchers must not escape.
///
/// Example:
/// {
/// NestedPatternContext context;
/// auto gemmLike = Doall(Doall(Red(LoadStores())));
/// auto matches = gemmLike.match(f);
/// // do work on matches
/// } // everything is freed
///
///
/// Nested abstraction for matching results.
/// Provides access to the nested Operation* captured by a Matcher.
///
/// A NestedMatch contains an Operation* and the children NestedMatch and is
/// thus cheap to copy. NestedMatch is stored in a scoped bumper allocator whose
/// lifetime is managed by an RAII NestedPatternContext.
struct NestedMatch {
static NestedMatch build(Operation *operation,
ArrayRef<NestedMatch> nestedMatches);
NestedMatch(const NestedMatch &) = default;
NestedMatch &operator=(const NestedMatch &) = default;
explicit operator bool() { return matchedOperation != nullptr; }
Operation *getMatchedOperation() { return matchedOperation; }
ArrayRef<NestedMatch> getMatchedChildren() { return matchedChildren; }
private:
friend struct NestedPattern;
friend struct NestedPatternContext;
/// Underlying global bump allocator managed by a NestedPatternContext.
static llvm::BumpPtrAllocator *&allocator();
NestedMatch() = default;
/// Payload, holds a NestedMatch and all its children along this branch.
Operation *matchedOperation;
ArrayRef<NestedMatch> matchedChildren;
};
/// A NestedPattern is a nested operation walker that:
/// 1. recursively matches a substructure in the tree;
/// 2. uses a filter function to refine matches with extra semantic
/// constraints (passed via a lambda of type FilterFunctionType);
/// 3. TODO(ntv) optionally applies actions (lambda).
///
/// Nested patterns are meant to capture imperfectly nested loops while matching
/// properties over the whole loop nest. For instance, in vectorization we are
/// interested in capturing all the imperfectly nested loops of a certain type
/// and such that all the load and stores have certain access patterns along the
/// loops' induction variables). Such NestedMatches are first captured using the
/// `match` function and are later processed to analyze properties and apply
/// transformations in a non-greedy way.
///
/// The NestedMatches captured in the IR can grow large, especially after
/// aggressive unrolling. As experience has shown, it is generally better to use
/// a plain walk over operations to match flat patterns but the current
/// implementation is competitive nonetheless.
using FilterFunctionType = std::function<bool(Operation &)>;
inline bool defaultFilterFunction(Operation &) { return true; }
struct NestedPattern {
NestedPattern(ArrayRef<NestedPattern> nested,
FilterFunctionType filter = defaultFilterFunction);
NestedPattern(const NestedPattern &) = default;
NestedPattern &operator=(const NestedPattern &) = default;
/// Returns all the top-level matches in `func`.
void match(FuncOp func, SmallVectorImpl<NestedMatch> *matches) {
func.walk([&](Operation *op) { matchOne(op, matches); });
}
/// Returns all the top-level matches in `op`.
void match(Operation *op, SmallVectorImpl<NestedMatch> *matches) {
op->walk([&](Operation *child) { matchOne(child, matches); });
}
/// Returns the depth of the pattern.
unsigned getDepth() const;
private:
friend struct NestedPatternContext;
friend struct NestedMatch;
friend struct State;
/// Underlying global bump allocator managed by a NestedPatternContext.
static llvm::BumpPtrAllocator *&allocator();
/// Matches this pattern against a single `op` and fills matches with the
/// result.
void matchOne(Operation *op, SmallVectorImpl<NestedMatch> *matches);
/// Nested patterns to be matched.
ArrayRef<NestedPattern> nestedPatterns;
/// Extra filter function to apply to prune patterns as the IR is walked.
FilterFunctionType filter;
/// skip is an implementation detail needed so that we can implement match
/// without switching on the type of the Operation. The idea is that a
/// NestedPattern first checks if it matches locally and then recursively
/// applies its nested matchers to its elem->nested. Since we want to rely on
/// the existing operation walking functionality rather than duplicate
/// it, we allow an off-by-one traversal to account for the fact that we
/// write:
///
/// void match(Operation *elem) {
/// for (auto &c : getNestedPatterns()) {
/// NestedPattern childPattern(...);
/// ^~~~ Needs off-by-one skip.
///
Operation *skip;
};
/// RAII structure to transparently manage the bump allocator for
/// NestedPattern and NestedMatch classes. This avoids passing a context to
/// all the API functions.
struct NestedPatternContext {
NestedPatternContext() {
assert(NestedMatch::allocator() == nullptr &&
"Only a single NestedPatternContext is supported");
assert(NestedPattern::allocator() == nullptr &&
"Only a single NestedPatternContext is supported");
NestedMatch::allocator() = &allocator;
NestedPattern::allocator() = &allocator;
}
~NestedPatternContext() {
NestedMatch::allocator() = nullptr;
NestedPattern::allocator() = nullptr;
}
llvm::BumpPtrAllocator allocator;
};
namespace matcher {
// Syntactic sugar NestedPattern builder functions.
NestedPattern Op(FilterFunctionType filter = defaultFilterFunction);
NestedPattern If(NestedPattern child);
NestedPattern If(FilterFunctionType filter, NestedPattern child);
NestedPattern If(ArrayRef<NestedPattern> nested = {});
NestedPattern If(FilterFunctionType filter,
ArrayRef<NestedPattern> nested = {});
NestedPattern For(NestedPattern child);
NestedPattern For(FilterFunctionType filter, NestedPattern child);
NestedPattern For(ArrayRef<NestedPattern> nested = {});
NestedPattern For(FilterFunctionType filter,
ArrayRef<NestedPattern> nested = {});
bool isParallelLoop(Operation &op);
bool isReductionLoop(Operation &op);
bool isLoadOrStore(Operation &op);
} // end namespace matcher
} // end namespace mlir
#endif // MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_

View File

@ -0,0 +1,43 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes that expose pass constructors in the
// analysis library.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_PASSES_H
#define MLIR_ANALYSIS_PASSES_H
#include "mlir/Support/LLVM.h"
namespace mlir {
class FunctionPassBase;
/// Creates a pass to check memref accesses in a Function.
FunctionPassBase *createMemRefBoundCheckPass();
/// Creates a pass to check memref access dependences in a Function.
FunctionPassBase *createTestMemRefDependenceCheckPass();
/// Creates a pass to test parallelism detection; emits note for parallel loops.
FunctionPassBase *createParallelismDetectionTestPass();
} // end namespace mlir
#endif // MLIR_ANALYSIS_PASSES_H

View File

@ -0,0 +1,215 @@
//===- SliceAnalysis.h - Analysis for Transitive UseDef chains --*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_ANALYSIS_SLICEANALYSIS_H_
#define MLIR_ANALYSIS_SLICEANALYSIS_H_
#include <functional>
#include <vector>
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SetVector.h"
namespace mlir {
class Operation;
/// Type of the condition to limit the propagation of transitive use-defs.
/// This can be used in particular to limit the propagation to a given Scope or
/// to avoid passing through certain types of operation in a configurable
/// manner.
using TransitiveFilter = std::function<bool(Operation *)>;
/// Fills `forwardSlice` with the computed forward slice (i.e. all
/// the transitive uses of op), **without** including that operation.
///
/// This additionally takes a TransitiveFilter which acts as a frontier:
/// when looking at uses transitively, a operation that does not pass the
/// filter is never propagated through. This allows in particular to carve out
/// the scope within a ForInst or the scope within an IfInst.
///
/// The implementation traverses the use chains in postorder traversal for
/// efficiency reasons: if a operation is already in `forwardSlice`, no
/// need to traverse its uses again. Since use-def chains form a DAG, this
/// terminates.
///
/// Upon return to the root call, `forwardSlice` is filled with a
/// postorder list of uses (i.e. a reverse topological order). To get a proper
/// topological order, we just just reverse the order in `forwardSlice` before
/// returning.
///
/// Example starting from node 0
/// ============================
///
/// 0
/// ___________|___________
/// 1 2 3 4
/// |_______| |______|
/// | | |
/// | 5 6
/// |___|_____________|
/// | |
/// 7 8
/// |_______________|
/// |
/// 9
///
/// Assuming all local orders match the numbering order:
/// 1. after getting back to the root getForwardSlice, `forwardSlice` may
/// contain:
/// {9, 7, 8, 5, 1, 2, 6, 3, 4}
/// 2. reversing the result of 1. gives:
/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
///
void getForwardSlice(
Operation *op, llvm::SetVector<Operation *> *forwardSlice,
TransitiveFilter filter = /* pass-through*/
[](Operation *) { return true; });
/// Fills `backwardSlice` with the computed backward slice (i.e.
/// all the transitive defs of op), **without** including that operation.
///
/// This additionally takes a TransitiveFilter which acts as a frontier:
/// when looking at defs transitively, a operation that does not pass the
/// filter is never propagated through. This allows in particular to carve out
/// the scope within a ForInst or the scope within an IfInst.
///
/// The implementation traverses the def chains in postorder traversal for
/// efficiency reasons: if a operation is already in `backwardSlice`, no
/// need to traverse its definitions again. Since useuse-def chains form a DAG,
/// this terminates.
///
/// Upon return to the root call, `backwardSlice` is filled with a
/// postorder list of defs. This happens to be a topological order, from the
/// point of view of the use-def chains.
///
/// Example starting from node 8
/// ============================
///
/// 1 2 3 4
/// |_______| |______|
/// | | |
/// | 5 6
/// |___|_____________|
/// | |
/// 7 8
/// |_______________|
/// |
/// 9
///
/// Assuming all local orders match the numbering order:
/// {1, 2, 5, 3, 4, 6}
///
void getBackwardSlice(
Operation *op, llvm::SetVector<Operation *> *backwardSlice,
TransitiveFilter filter = /* pass-through*/
[](Operation *) { return true; });
/// Iteratively computes backward slices and forward slices until
/// a fixed point is reached. Returns an `llvm::SetVector<Operation *>` which
/// **includes** the original operation.
///
/// This allows building a slice (i.e. multi-root DAG where everything
/// that is reachable from an Value in forward and backward direction is
/// contained in the slice).
/// This is the abstraction we need to materialize all the operations for
/// supervectorization without worrying about orderings and Value
/// replacements.
///
/// Example starting from any node
/// ==============================
///
/// 1 2 3 4
/// |_______| |______|
/// | | | |
/// | 5 6___|
/// |___|_____________| |
/// | | |
/// 7 8 |
/// |_______________| |
/// | |
/// 9 10
///
/// Return the whole DAG in some topological order.
///
/// The implementation works by just filling up a worklist with iterative
/// alternate calls to `getBackwardSlice` and `getForwardSlice`.
///
/// The following section describes some additional implementation
/// considerations for a potentially more efficient implementation but they are
/// just an intuition without proof, we still use a worklist for now.
///
/// Additional implementation considerations
/// ========================================
/// Consider the defs-op-uses hourglass.
/// ____
/// \ / defs (in some topological order)
/// \/
/// op
/// /\
/// / \ uses (in some topological order)
/// /____\
///
/// We want to iteratively apply `getSlice` to construct the whole
/// list of Operation that are reachable by (use|def)+ from op.
/// We want the resulting slice in topological order.
/// Ideally we would like the ordering to be maintained in-place to avoid
/// copying Operation at each step. Keeping this ordering by construction
/// seems very unclear, so we list invariants in the hope of seeing whether
/// useful properties pop up.
///
/// In the following:
/// we use |= for set inclusion;
/// we use << for set topological ordering (i.e. each pair is ordered).
///
/// Assumption:
/// ===========
/// We wish to maintain the following property by a recursive argument:
/// """
/// defs << {op} <<uses are in topological order.
/// """
/// The property clearly holds for 0 and 1-sized uses and defs;
///
/// Invariants:
/// 2. defs and uses are in topological order internally, by construction;
/// 3. for any {x} |= defs, defs(x) |= defs; because all go through op
/// 4. for any {x} |= uses, defs |= defs(x); because all go through op
/// 5. for any {x} |= defs, uses |= uses(x); because all go through op
/// 6. for any {x} |= uses, uses(x) |= uses; because all go through op
///
/// Intuitively, we should be able to recurse like:
/// preorder(defs) - op - postorder(uses)
/// and keep things ordered but this is still hand-wavy and not worth the
/// trouble for now: punt to a simple worklist-based solution.
///
llvm::SetVector<Operation *> getSlice(
Operation *op,
TransitiveFilter backwardFilter = /* pass-through*/
[](Operation *) { return true; },
TransitiveFilter forwardFilter = /* pass-through*/
[](Operation *) { return true; });
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
/// Returns a topologically sorted SetVector.
llvm::SetVector<Operation *>
topologicalSort(const llvm::SetVector<Operation *> &toSort);
} // end namespace mlir
#endif // MLIR_ANALYSIS_SLICEANALYSIS_H_

View File

@ -0,0 +1,304 @@
//===- Utils.h - General analysis utilities ---------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes for various transformation utilities for
// memref's and non-loop IR structures. These are not passes by themselves but
// are used either by passes, optimization sequences, or in turn by other
// transformation utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_UTILS_H
#define MLIR_ANALYSIS_UTILS_H
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
namespace mlir {
class AffineForOp;
class Block;
class FlatAffineConstraints;
class Location;
struct MemRefAccess;
class Operation;
class Value;
/// Populates 'loops' with IVs of the loops surrounding 'op' ordered from
/// the outermost 'affine.for' operation to the innermost one.
// TODO(bondhugula): handle 'affine.if' ops.
void getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);
/// Returns the nesting depth of this operation, i.e., the number of loops
/// surrounding this operation.
unsigned getNestingDepth(Operation &op);
/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
/// at 'forOp'.
void getSequentialLoops(AffineForOp forOp,
llvm::SmallDenseSet<Value *, 8> *sequentialLoops);
/// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their
/// associated operands for a set of loops within a loop nest (typically the
/// set of loops surrounding a store operation). Loop bound AffineMaps which
/// are non-null represent slices of that loop's iteration space.
struct ComputationSliceState {
// List of sliced loop IVs (ordered from outermost to innermost).
// EX: 'ivs[i]' has lower bound 'lbs[i]' and upper bound 'ubs[i]'.
SmallVector<Value *, 4> ivs;
// List of lower bound AffineMaps.
SmallVector<AffineMap, 4> lbs;
// List of upper bound AffineMaps.
SmallVector<AffineMap, 4> ubs;
// List of lower bound operands (lbOperands[i] are used by 'lbs[i]').
std::vector<SmallVector<Value *, 4>> lbOperands;
// List of upper bound operands (ubOperands[i] are used by 'ubs[i]').
std::vector<SmallVector<Value *, 4>> ubOperands;
// Slice loop nest insertion point in target loop nest.
Block::iterator insertPoint;
// Adds to 'cst' with constraints which represent the slice bounds on 'ivs'
// in 'this'. Specifically, the values in 'ivs' are added to 'cst' as dim
// identifiers and the values in 'lb/ubOperands' are added as symbols.
// Constraints are added for all loop IV bounds (dim or symbol), and
// constraints are added for slice bounds in 'lbs'/'ubs'.
// Returns failure if we cannot add loop bounds because of unsupported cases.
LogicalResult getAsConstraints(FlatAffineConstraints *cst);
// Clears all bounds and operands in slice state.
void clearBounds();
};
/// Computes the computation slice loop bounds for one loop nest as affine maps
/// of the other loop nest's IVs and symbols, using 'dependenceConstraints'
/// computed between 'depSourceAccess' and 'depSinkAccess'.
/// If 'isBackwardSlice' is true, a backwards slice is computed in which the
/// slice bounds of loop nest surrounding 'depSourceAccess' are computed in
/// terms of loop IVs and symbols of the loop nest surrounding 'depSinkAccess'
/// at 'loopDepth'.
/// If 'isBackwardSlice' is false, a forward slice is computed in which the
/// slice bounds of loop nest surrounding 'depSinkAccess' are computed in terms
/// of loop IVs and symbols of the loop nest surrounding 'depSourceAccess' at
/// 'loopDepth'.
/// The slice loop bounds and associated operands are returned in 'sliceState'.
//
// Backward slice example:
//
// affine.for %i0 = 0 to 10 {
// affine.store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess'
// }
// affine.for %i1 = 0 to 10 {
// %v = affine.load %0[%i1] : memref<100xf32> // 'depSinkAccess'
// }
//
// // Backward computation slice of loop nest '%i0'.
// affine.for %i0 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 1)(%i1) {
// affine.store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess'
// }
//
// Forward slice example:
//
// affine.for %i0 = 0 to 10 {
// affine.store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess'
// }
// affine.for %i1 = 0 to 10 {
// %v = affine.load %0[%i1] : memref<100xf32> // 'depSinkAccess'
// }
//
// // Forward computation slice of loop nest '%i1'.
// affine.for %i1 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 1)(%i0) {
// %v = affine.load %0[%i1] : memref<100xf32> // 'depSinkAccess'
// }
//
void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp,
FlatAffineConstraints *dependenceConstraints,
unsigned loopDepth, bool isBackwardSlice,
ComputationSliceState *sliceState);
/// Computes in 'sliceUnion' the union of all slice bounds computed at
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
/// The parameter 'numCommonLoops' is the number of loops common to the
/// operations in 'opsA' and 'opsB'.
/// If 'isBackwardSlice' is true, computes slice bounds for loop nest
/// surrounding ops in 'opsA', as a function of IVs and symbols of loop nest
/// surrounding ops in 'opsB' at 'loopDepth'.
/// If 'isBackwardSlice' is false, computes slice bounds for loop nest
/// surrounding ops in 'opsB', as a function of IVs and symbols of loop nest
/// surrounding ops in 'opsA' at 'loopDepth'.
/// Returns 'success' if union was computed, 'failure' otherwise.
// TODO(andydavis) Change this API to take 'forOpA'/'forOpB'.
LogicalResult computeSliceUnion(ArrayRef<Operation *> opsA,
ArrayRef<Operation *> opsB, unsigned loopDepth,
unsigned numCommonLoops, bool isBackwardSlice,
ComputationSliceState *sliceUnion);
/// Creates a clone of the computation contained in the loop nest surrounding
/// 'srcOpInst', slices the iteration space of src loop based on slice bounds
/// in 'sliceState', and inserts the computation slice at the beginning of the
/// operation block of the loop at 'dstLoopDepth' in the loop nest surrounding
/// 'dstOpInst'. Returns the top-level loop of the computation slice on
/// success, returns nullptr otherwise.
// Loop depth is a crucial optimization choice that determines where to
// materialize the results of the backward slice - presenting a trade-off b/w
// storage and redundant computation in several cases.
// TODO(andydavis) Support computation slices with common surrounding loops.
AffineForOp insertBackwardComputationSlice(Operation *srcOpInst,
Operation *dstOpInst,
unsigned dstLoopDepth,
ComputationSliceState *sliceState);
/// A region of a memref's data space; this is typically constructed by
/// analyzing load/store op's on this memref and the index space of loops
/// surrounding such op's.
// For example, the memref region for a load operation at loop depth = 1:
//
// affine.for %i = 0 to 32 {
// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
// affine.load %A[%ii]
// }
// }
//
// Region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
// The last field is a 2-d FlatAffineConstraints symbolic in %i.
//
struct MemRefRegion {
explicit MemRefRegion(Location loc) : loc(loc) {}
/// Computes the memory region accessed by this memref with the region
/// represented as constraints symbolic/parameteric in 'loopDepth' loops
/// surrounding opInst. The computed region's 'cst' field has exactly as many
/// dimensional identifiers as the rank of the memref, and *potentially*
/// additional symbolic identifiers which could include any of the loop IVs
/// surrounding opInst up until 'loopDepth' and another additional Function
/// symbols involved with the access (for eg., those appear in affine.apply's,
/// loop bounds, etc.). If 'sliceState' is non-null, operands from
/// 'sliceState' are added as symbols, and the following constraints are added
/// to the system:
/// *) Inequality constraints which represent loop bounds for 'sliceState'
/// operands which are loop IVS (these represent the destination loop IVs
/// of the slice, and are added as symbols to MemRefRegion's constraint
/// system).
/// *) Inequality constraints for the slice bounds in 'sliceState', which
/// represent the bounds on the loop IVs in this constraint system w.r.t
/// to slice operands (which correspond to symbols).
/// If 'addMemRefDimBounds' is true, constant upper/lower bounds
/// [0, memref.getDimSize(i)) are added for each MemRef dimension 'i'.
///
/// For example, the memref region for this operation at loopDepth = 1 will
/// be:
///
/// affine.for %i = 0 to 32 {
/// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
/// load %A[%ii]
/// }
/// }
///
/// {memref = %A, write = false, {%i <= m0 <= %i + 7} }
/// The last field is a 2-d FlatAffineConstraints symbolic in %i.
///
LogicalResult compute(Operation *op, unsigned loopDepth,
ComputationSliceState *sliceState = nullptr,
bool addMemRefDimBounds = true);
FlatAffineConstraints *getConstraints() { return &cst; }
const FlatAffineConstraints *getConstraints() const { return &cst; }
bool isWrite() const { return write; }
void setWrite(bool flag) { write = flag; }
/// Returns a constant upper bound on the number of elements in this region if
/// bounded by a known constant (always possible for static shapes), None
/// otherwise. Note that the symbols of the region are treated specially,
/// i.e., the returned bounding constant holds for *any given* value of the
/// symbol identifiers. The 'shape' vector is set to the corresponding
/// dimension-wise bounds major to minor. We use int64_t instead of uint64_t
/// since index types can be at most int64_t.
Optional<int64_t> getConstantBoundingSizeAndShape(
SmallVectorImpl<int64_t> *shape = nullptr,
std::vector<SmallVector<int64_t, 4>> *lbs = nullptr,
SmallVectorImpl<int64_t> *lbDivisors = nullptr) const;
/// A wrapper around FlatAffineConstraints::getConstantBoundOnDimSize(). 'pos'
/// corresponds to the position of the memref shape's dimension (major to
/// minor) which matches 1:1 with the dimensional identifier positions in
//'cst'.
Optional<int64_t>
getConstantBoundOnDimSize(unsigned pos,
SmallVectorImpl<int64_t> *lb = nullptr,
int64_t *lbFloorDivisor = nullptr) const {
assert(pos < getRank() && "invalid position");
return cst.getConstantBoundOnDimSize(pos, lb);
}
/// Returns the size of this MemRefRegion in bytes.
Optional<int64_t> getRegionSize();
// Wrapper around FlatAffineConstraints::unionBoundingBox.
LogicalResult unionBoundingBox(const MemRefRegion &other);
/// Returns the rank of the memref that this region corresponds to.
unsigned getRank() const;
/// Memref that this region corresponds to.
Value *memref;
/// Read or write.
bool write;
/// If there is more than one load/store op associated with the region, the
/// location information would correspond to one of those op's.
Location loc;
/// Region (data space) of the memref accessed. This set will thus have at
/// least as many dimensional identifiers as the shape dimensionality of the
/// memref, and these are the leading dimensions of the set appearing in that
/// order (major to minor / outermost to innermost). There may be additional
/// identifiers since getMemRefRegion() is called with a specific loop depth,
/// and thus the region is symbolic in the outer surrounding loops at that
/// depth.
// TODO(bondhugula): Replace this to exploit HyperRectangularSet.
FlatAffineConstraints cst;
};
/// Returns the size of memref data in bytes if it's statically shaped, None
/// otherwise.
Optional<uint64_t> getMemRefSizeInBytes(MemRefType memRefType);
/// Checks a load or store op for an out of bound access; returns failure if the
/// access is out of bounds along any of the dimensions, success otherwise.
/// Emits a diagnostic error (with location information) if emitError is true.
template <typename LoadOrStoreOpPointer>
LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
bool emitError = true);
/// Returns the number of surrounding loops common to both A and B.
unsigned getNumCommonSurroundingLoops(Operation &A, Operation &B);
/// Gets the memory footprint of all data touched in the specified memory space
/// in bytes; if the memory space is unspecified, considers all memory spaces.
Optional<int64_t> getMemoryFootprintBytes(AffineForOp forOp,
int memorySpace = -1);
/// Returns true if `forOp' is a parallel loop.
bool isLoopParallel(AffineForOp forOp);
} // end namespace mlir
#endif // MLIR_ANALYSIS_UTILS_H

View File

@ -0,0 +1,143 @@
//===- VectorAnalysis.h - Analysis for Vectorization -------*- C++ -*-=======//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_ANALYSIS_VECTORANALYSIS_H_
#define MLIR_ANALYSIS_VECTORANALYSIS_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
class AffineApplyOp;
class AffineForOp;
class AffineMap;
class Location;
class MemRefType;
class OpBuilder;
class Operation;
class Value;
class VectorType;
/// Computes and returns the multi-dimensional ratio of `superShape` to
/// `subShape`. This is calculated by performing a traversal from minor to major
/// dimensions (i.e. in reverse shape order). If integral division is not
/// possible, returns None.
/// The ArrayRefs are assumed (and enforced) to only contain > 1 values.
/// This constraint comes from the fact that they are meant to be used with
/// VectorTypes, for which the property holds by construction.
///
/// Examples:
/// - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4}
/// - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None
/// - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16}
llvm::Optional<llvm::SmallVector<unsigned, 4>>
shapeRatio(ArrayRef<int64_t> superShape, ArrayRef<int64_t> subShape);
/// Computes and returns the multi-dimensional ratio of the shapes of
/// `superVector` to `subVector`. If integral division is not possible, returns
/// None.
/// Assumes and enforces that the VectorTypes have the same elemental type.
llvm::Optional<llvm::SmallVector<unsigned, 4>>
shapeRatio(VectorType superVectorType, VectorType subVectorType);
/// Constructs a permutation map of invariant memref indices to vector
/// dimension.
///
/// If no index is found to be invariant, 0 is added to the permutation_map and
/// corresponds to a vector broadcast along that dimension.
///
/// The implementation uses the knowledge of the mapping of loops to
/// vector dimension. `loopToVectorDim` carries this information as a map with:
/// - keys representing "vectorized enclosing loops";
/// - values representing the corresponding vector dimension.
/// Note that loopToVectorDim is a whole function map from which only enclosing
/// loop information is extracted.
///
/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at
/// most one invariant index along each AffineForOp of `loopToVectorDim`).
///
/// Example 1:
/// The following MLIR snippet:
///
/// ```mlir
/// affine.for %i3 = 0 to %0 {
/// affine.for %i4 = 0 to %1 {
/// affine.for %i5 = 0 to %2 {
/// %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32>
/// }}}
/// ```
///
/// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into:
///
/// ```mlir
/// affine.for %i3 = 0 to %0 step 32 {
/// affine.for %i4 = 0 to %1 {
/// affine.for %i5 = 0 to %2 step 256 {
/// %4 = vector.transfer_read %arg0, %i4, %i5, %i3
/// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
/// (memref<?x?x?xf32>, index, index) -> vector<32x256xf32>
/// }}}
/// ```
///
/// Meaning that vector.transfer_read will be responsible for reading the slice:
/// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>.
///
/// Example 2:
/// The following MLIR snippet:
///
/// ```mlir
/// %cst0 = constant 0 : index
/// affine.for %i0 = 0 to %0 {
/// %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32>
/// }
/// ```
///
/// may vectorize with {permutation_map: (d0) -> (0)} into:
///
/// ```mlir
/// affine.for %i0 = 0 to %0 step 128 {
/// %3 = vector.transfer_read %arg0, %c0_0, %c0_0
/// {permutation_map: (d0, d1) -> (0)} :
/// (memref<?x?xf32>, index, index) -> vector<128xf32>
/// }
/// ````
///
/// Meaning that vector.transfer_read will be responsible of reading the slice
/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
///
AffineMap makePermutationMap(
Operation *op, ArrayRef<Value *> indices,
const llvm::DenseMap<Operation *, unsigned> &loopToVectorDim);
namespace matcher {
/// Matches vector.transfer_read, vector.transfer_write and ops that return a
/// vector type that is a multiple of the sub-vector type. This allows passing
/// over other smaller vector types in the function and avoids interfering with
/// operations on those.
/// This is a first approximation, it can easily be extended in the future.
/// TODO(ntv): this could all be much simpler if we added a bit that a vector
/// type to mark that a vector is a strict super-vector but it still does not
/// warrant adding even 1 extra bit in the IR for now.
bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
} // end namespace matcher
} // end namespace mlir
#endif // MLIR_ANALYSIS_VECTORANALYSIS_H_

View File

@ -0,0 +1,31 @@
//===- Verifier.h - Verifier analysis for MLIR structures -------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_ANALYSIS_VERIFIER_H
#define MLIR_ANALYSIS_VERIFIER_H
namespace mlir {
struct LogicalResult;
class Operation;
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs, on this operation and any nested operations. On error, this
/// reports the error through the MLIRContext and returns failure.
LogicalResult verify(Operation *op);
} // end namespace mlir
#endif

View File

@ -0,0 +1,6 @@
add_subdirectory(AffineOps)
add_subdirectory(Dialect)
add_subdirectory(EDSC)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(StandardOps)

View File

@ -0,0 +1,45 @@
//===- ConvertControlFlowToCFG.h - Pass entrypoint --------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
#define MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
#include <memory>
#include <vector>
namespace mlir {
class FuncOp;
class FunctionPassBase;
struct LogicalResult;
class MLIRContext;
class RewritePattern;
// Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
/// Collect a set of patterns to lower from loop.for, loop.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular
/// convert structured control flow into CFG branch-based control flow.
void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx);
/// Creates a pass to convert loop.for, loop.if and loop.terminator ops to CFG.
FunctionPassBase *createConvertToCFGPass();
} // namespace mlir
#endif // MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_

View File

@ -0,0 +1,58 @@
//===- GPUToCUDAPass.h - MLIR CUDA runtime support --------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
#define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
#include <functional>
#include <memory>
#include <string>
#include <vector>
namespace mlir {
class ModulePassBase;
class FuncOp;
using OwnedCubin = std::unique_ptr<std::vector<char>>;
using CubinGenerator = std::function<OwnedCubin(const std::string &, FuncOp &)>;
/// Creates a pass to convert kernel functions into CUBIN blobs.
///
/// This transformation takes the body of each function that is annotated with
/// the 'nvvm.kernel' attribute, copies it to a new LLVM module, compiles the
/// module with help of the nvptx backend to PTX and then invokes the provided
/// cubinGenerator to produce a binary blob (the cubin). Such blob is then
/// attached as a string attribute named 'nvvm.cubin' to the kernel function.
/// After the transformation, the body of the kernel function is removed (i.e.,
/// it is turned into a declaration).
ModulePassBase *
createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator);
/// Creates a pass to convert a gpu.launch_func operation into a sequence of
/// CUDA calls.
///
/// This pass does not generate code to call CUDA directly but instead uses a
/// small wrapper library that exports a stable and conveniently typed ABI
/// ontop of CUDA.
ModulePassBase *createConvertGpuLaunchFuncToCudaCallsPass();
/// Creates a pass to augment a module with getter functions for all contained
/// cubins as encoded via the 'nvvm.cubin' attribute.
ModulePassBase *createGenerateCubinAccessorPass();
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_

View File

@ -0,0 +1,28 @@
//===- GPUToNVMMPass.h - Convert GPU kernel to NVVM dialect -----*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
namespace mlir {
struct FunctionPassBase;
/// Creates a pass that lowers GPU dialect operations to NVVM counterparts.
FunctionPassBase *createLowerGpuOpsToNVVMOpsPass();
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_

View File

@ -0,0 +1,57 @@
//===- LoopsToGPU.h - Convert loop nests to GPU kernels ---------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_
#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_
namespace mlir {
class AffineForOp;
struct LogicalResult;
namespace loop {
class ForOp;
} // end namespace loop
/// Convert a perfect affine loop nest with the outermost loop identified by
/// `forOp` into a gpu::Launch operation. Map `numBlockDims` outer loops to
/// GPU blocks and `numThreadDims` to GPU threads. The bounds of the loops that
/// are mapped should be independent of the induction variables of the other
/// mapped loops.
///
/// No check on the size of the block or grid, or on the validity of
/// parallelization is performed, it is under the responsibility of the caller
/// to strip-mine the loops and to perform the dependence analysis before
/// calling the conversion.
LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp,
unsigned numBlockDims,
unsigned numThreadDims);
/// Convert a perfect linalg loop nest with the outermost loop identified by
/// `forOp` into a gpu::Launch operation. Map `numBlockDims` outer loops to
/// GPU blocks and `numThreadDims` to GPU threads. The bounds of the loops that
/// are mapped should be independent of the induction variables of the other
/// mapped loops.
///
/// No check on the size of the block or grid, or on the validity of
/// parallelization is performed, it is under the responsibility of the caller
/// to strip-mine the loops and to perform the dependence analysis before
/// calling the conversion.
LogicalResult convertLoopNestToGPULaunch(loop::ForOp forOp,
unsigned numBlockDims,
unsigned numThreadDims);
} // namespace mlir
#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_

View File

@ -0,0 +1,35 @@
//===- LoopsToGPUPass.h - Pass converting loops to GPU kernels --*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_
#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_
namespace mlir {
class FunctionPassBase;
/// Create a pass that converts loop nests into GPU kernels. It considers
/// top-level affine.for and linalg.for operations as roots of loop nests and
/// converts them to the gpu.launch operations if possible.
///
/// No check on the size of the block or grid, or on the validity of
/// parallelization is performed, it is under the responsibility of the caller
/// to strip-mine the loops and to perform the dependence analysis before
/// calling the conversion.
FunctionPassBase *createSimpleLoopsToGPUPass(unsigned numBlockDims,
unsigned numThreadDims);
} // namespace mlir
#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_

View File

@ -0,0 +1,129 @@
//===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Provides a dialect conversion targeting the LLVM IR dialect. By default, it
// converts Standard ops and types and provides hooks for dialect-specific
// extensions to the conversion.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
#include "mlir/Transforms/DialectConversion.h"
namespace llvm {
class IntegerType;
class LLVMContext;
class Module;
class Type;
} // namespace llvm
namespace mlir {
namespace LLVM {
class LLVMDialect;
class LLVMType;
} // namespace LLVM
/// Conversion from types in the Standard dialect to the LLVM IR dialect.
class LLVMTypeConverter : public TypeConverter {
public:
using TypeConverter::convertType;
LLVMTypeConverter(MLIRContext *ctx);
/// Convert types to LLVM IR. This calls `convertAdditionalType` to convert
/// non-standard or non-builtin types.
Type convertType(Type t) override;
/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one values is
/// returned, create an LLVM IR structure type with elements that correspond
/// to each of the MLIR types converted with `convertType`.
Type packFunctionResults(ArrayRef<Type> types);
/// Returns the LLVM context.
llvm::LLVMContext &getLLVMContext();
/// Returns the LLVM dialect.
LLVM::LLVMDialect *getDialect() { return llvmDialect; }
protected:
/// LLVM IR module used to parse/create types.
llvm::Module *module;
LLVM::LLVMDialect *llvmDialect;
private:
Type convertStandardType(Type type);
// Convert a function type. The arguments and results are converted one by
// one. Additionally, if the function returns more than one value, pack the
// results into an LLVM IR structure type so that the converted function type
// returns at most one result.
Type convertFunctionType(FunctionType type);
// Convert the index type. Uses llvmModule data layout to create an integer
// of the pointer bitwidth.
Type convertIndexType(IndexType type);
// Convert an integer type `i*` to `!llvm<"i*">`.
Type convertIntegerType(IntegerType type);
// Convert a floating point type: `f16` to `!llvm.half`, `f32` to
// `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported
// by LLVM.
Type convertFloatType(FloatType type);
// Convert a memref type into an LLVM type that captures the relevant data.
// For statically-shaped memrefs, the resulting type is a pointer to the
// (converted) memref element type. For dynamically-shaped memrefs, the
// resulting type is an LLVM structure type that contains:
// 1. a pointer to the (converted) memref element type
// 2. as many index types as memref has dynamic dimensions.
Type convertMemRefType(MemRefType type);
// Convert a 1D vector type into an LLVM vector type.
Type convertVectorType(VectorType type);
// Get the LLVM representation of the index type based on the bitwidth of the
// pointer as defined by the data layout of the module.
LLVM::LLVMType getIndexType();
// Wrap the given LLVM IR type into an LLVM IR dialect type.
Type wrap(llvm::Type *llvmType);
// Extract an LLVM IR dialect type.
LLVM::LLVMType unwrap(Type type);
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
/// conversion patterns with an access to the containing LLVMLowering for the
/// purpose of type conversions.
class LLVMOpLowering : public ConversionPattern {
public:
LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
LLVMTypeConverter &lowering);
protected:
// Back-reference to the lowering class, used to call type and function
// conversions accounting for potential extensions.
LLVMTypeConverter &lowering;
};
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H

View File

@ -0,0 +1,92 @@
//===- ConvertStandardToLLVMPass.h - Pass entrypoint ------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_
#include "llvm/ADT/STLExtras.h"
#include <memory>
#include <vector>
namespace llvm {
class Module;
} // namespace llvm
namespace mlir {
class DialectConversion;
class FuncOp;
class LLVMTypeConverter;
struct LogicalResult;
class MLIRContext;
class ModuleOp;
class ModulePassBase;
class RewritePattern;
class Type;
// Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
/// Type for a callback constructing the owning list of patterns for the
/// conversion to the LLVMIR dialect. The callback is expected to append
/// patterns to the owning list provided as the second argument.
using LLVMPatternListFiller =
std::function<void(LLVMTypeConverter &, OwningRewritePatternList &)>;
/// Type for a callback constructing the type converter for the conversion to
/// the LLVMIR dialect. The callback is expected to return an instance of the
/// converter.
using LLVMTypeConverterMaker =
std::function<std::unique_ptr<LLVMTypeConverter>(MLIRContext *)>;
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
ModulePassBase *createConvertToLLVMIRPass();
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
/// is defined by a list of patterns and a type converter that will be obtained
/// during the pass using the provided callbacks.
ModulePassBase *
createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller,
LLVMTypeConverterMaker typeConverterMaker);
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
/// is defined by a list of patterns obtained during the pass using the provided
/// callback and an optional type conversion class, an instance is created
/// during the pass.
template <typename TypeConverter = LLVMTypeConverter>
ModulePassBase *
createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller) {
return createConvertToLLVMIRPass(patternListFiller, [](MLIRContext *context) {
return llvm::make_unique<TypeConverter>(context);
});
}
namespace LLVM {
/// Make argument-taking successors of each block distinct. PHI nodes in LLVM
/// IR use the predecessor ID to identify which value to take. They do not
/// support different values coming from the same predecessor. If a block has
/// another block as a successor more than once with different values, insert
/// a new dummy block for LLVM PHI nodes to tell the sources apart.
void ensureDistinctSuccessors(ModuleOp m);
} // namespace LLVM
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_

View File

@ -0,0 +1,35 @@
//===- StdOpsToSPIRVConversion.h - Convert StandardOps to SPIR-V *- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines utility function to import patterns to convert StandardOps
// to SPIR-V ops
//
//===----------------------------------------------------------------------===//
#ifndef STANDARD_OPS_TO_SPIRV_H_
#define STANDARD_OPS_TO_SPIRV_H_
#include "mlir/IR/PatternMatch.h"
namespace mlir {
/// Method to append to a pattern list additional patterns for translating
/// StandardOps to SPIR-V ops.
void populateStdOpsToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
} // namespace mlir
#endif // STANDARD_OPS_TO_SPIRV_H_

View File

@ -0,0 +1,5 @@
add_subdirectory(FxpMathOps)
add_subdirectory(GPU)
add_subdirectory(LoopOps)
add_subdirectory(QuantOps)
add_subdirectory(SPIRV)

View File

@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS FxpMathOps.td)
mlir_tablegen(FxpMathOps.h.inc -gen-op-decls)
mlir_tablegen(FxpMathOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRFxpMathOpsIncGen)

View File

@ -0,0 +1,40 @@
//===- FxpMathOps.h - Fixed point ops ---------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
#define MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
namespace fxpmath {
/// Defines the 'FxpMathOps' dialect.
class FxpMathOpsDialect : public Dialect {
public:
FxpMathOpsDialect(MLIRContext *context);
};
#define GET_OP_CLASSES
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc"
} // namespace fxpmath
} // namespace mlir
#endif // MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_

View File

@ -0,0 +1,290 @@
//===- FxpMathOps.td - Fixed point ops --------------------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the operation definition file for fixed point ops (and real
// equivalents).
//
//===----------------------------------------------------------------------===//
#ifdef DIALECT_FXPMATHOPS_FXPMATH_OPS_
#else
#define DIALECT_FXPMATHOPS_FXPMATH_OPS_
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
include "mlir/Dialect/QuantOps/QuantPredicates.td"
def fxpmath_Dialect : Dialect {
let name = "fxpmath";
}
//===----------------------------------------------------------------------===//
// Attributes
//===----------------------------------------------------------------------===//
// Real value for an (inclusive) min/max clamp limit.
def fxpmath_ClampValueAttr : OptionalAttr<F64Attr>;
// Element-wise activation function to apply.
// Note that RELU activations are not here: they are expressed as clamps.
def fxpmath_EwUnaryFnAttr :
StringBasedAttr<CPred<"true">, "element-wise unary function"> {
let returnType = [{ StringRef }];
let defaultValue = "IDENTITY";
}
class fxpmath_ConstEwUnaryFn<string val> : ConstantAttr<fxpmath_EwUnaryFnAttr, val>;
def fxpmath_EwUnaryFn_Abs : fxpmath_ConstEwUnaryFn<"ABS">;
def fxpmath_EwUnaryFn_Exp : fxpmath_ConstEwUnaryFn<"EXP">;
def fxpmath_EwUnaryFn_Identity: fxpmath_ConstEwUnaryFn<"IDENTITY">;
def fxpmath_EwUnaryFn_Log : fxpmath_ConstEwUnaryFn<"LOG">;
def fxpmath_EwUnaryFn_Neg : fxpmath_ConstEwUnaryFn<"NEG">;
def fxpmath_EwUnaryFn_Rsqrt : fxpmath_ConstEwUnaryFn<"RSQRT">;
def fxpmath_EwUnaryFn_Sigmoid : fxpmath_ConstEwUnaryFn<"SIGMOID">;
def fxpmath_EwUnaryFn_Sign : fxpmath_ConstEwUnaryFn<"SIGN">;
def fxpmath_EwUnaryFn_Sin : fxpmath_ConstEwUnaryFn<"SIN">;
def fxpmath_EwUnaryFn_Sqrt : fxpmath_ConstEwUnaryFn<"SQRT">;
def fxpmath_EwUnaryFn_Square : fxpmath_ConstEwUnaryFn<"SQUARE">;
def fxpmath_EwUnaryFn_Tanh : fxpmath_ConstEwUnaryFn<"TANH">;
//===----------------------------------------------------------------------===//
// Comparison functions (compares relative to zero on a subtraction result).
//===----------------------------------------------------------------------===//
def fxpmath_CompareZ : StrEnumAttrCase<"CMPZ">;
def fxpmath_CompareNZ : StrEnumAttrCase<"CMPNZ">;
def fxpmath_CompareLZ : StrEnumAttrCase<"CMPLZ">;
def fxpmath_CompareLZE : StrEnumAttrCase<"CMPLZE">;
def fxpmath_CompareGZ : StrEnumAttrCase<"CMPGZ">;
def fxpmath_CompareGZE : StrEnumAttrCase<"CMPGZE">;
def fxpmath_CompareFnAttr : StrEnumAttr<"ComparisonFn",
"Type of subtraction-result comparison to perform.",
[
fxpmath_CompareZ,
fxpmath_CompareNZ,
fxpmath_CompareLZ,
fxpmath_CompareLZE,
fxpmath_CompareGZ,
fxpmath_CompareGZE
]>;
//===----------------------------------------------------------------------===//
// Base classes
//===----------------------------------------------------------------------===//
class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
Op<fxpmath_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Fixed-point (fxp) arithmetic ops used by kernels.
// Some of these are temporary pending inclusion into a more core dialect.
//===----------------------------------------------------------------------===//
def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameOperandsAndResultType]> {
let summary =
"Clamps a signed-integer like argument to a min/max range.";
let description = [{
Element-wise equivalent to:
r = std::min(clamp_max, std::max(e, clamp_min))
}];
let arguments = (ins IntegerLike:$operand,
APIntAttr:$clamp_min,
APIntAttr:$clamp_max);
let results = (outs IntegerLike);
}
def fxpmath_ConvertISOp :
fxpmath_Op<"convertis",
[NoSideEffect, SameOperandsAndResultShape]> {
let summary =
"Does an element-wise conversion from a signed integer to signed integer";
let description = [{
Similar to an element-wise static_cast in C++, from a one signed integer
element type to another.
}];
let arguments = (ins IntegerLike:$operand);
let results = (outs IntegerLike);
}
def fxpmath_ConvertISToFOp :
fxpmath_Op<"convertistof",
[NoSideEffect, SameOperandsAndResultShape]> {
let summary =
"Does an element-wise conversion from a signed integer to a float";
let description = [{
Similar to an element-wise static_cast in C++, from a signed integer
element type to a floating point element type, rounding to the nearest
floating point value.
}];
let arguments = (ins IntegerLike:$operand);
let results = (outs FloatLike);
}
def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp :
fxpmath_Op<"vs_saturating_rounding_doubling_high_mulis",
[NoSideEffect, SameOperandsAndResultType]> {
let summary = "Implements equivalent functionality to ARMv7 NEON VQRDMULH";
let description = [{
Equivalent to the ARMv7 NEON VQRDMULH instruction.
See gemmlowp::SaturatingRoundingDoublingHighMul for a reference
implementation.
}];
let arguments = (ins IntegerLike:$a, APIntAttr:$b);
let results = (outs IntegerLike);
}
def fxpmath_RoundingDivideByPotISOp :
fxpmath_Op<"rounding_divide_by_potis", [NoSideEffect, SameOperandsAndResultType]> {
let summary = [{
Computes a rounding arithmetic right shift.
}];
let description = [{
Computes integer division by a power-of-two, correctly rounded-to-nearest.
Also known as a rounding arithmetic right shift. See
gemmlowp::RoundingDivideByPOT for a reference implementation.
}];
let arguments = (ins IntegerLike:$operand, APIntAttr:$exponent);
let results = (outs IntegerLike:$res);
let verifier = [{
auto verifyExponent = exponent().getSExtValue();
if (verifyExponent < 0 || verifyExponent > 31) {
return emitOpError("exponent must be in range [0..31]");
}
return success();
}];
}
//===----------------------------------------------------------------------===//
// Real math ops.
//
// Math ops on real numbers which may have a representation in quantized
// arithmetic. It is expected that eligible ops are lowered from a source
// dialect to this set of ops prior to the process of converting a compuation
// to a quantized form. It is a non-goal of these ops to preserve enough
// information to convert back to the higher level, source dialect.
//
// These ops support either real/floating point or QuantizedTypes as operands
// and results. Since not all transformations are supported (globally or
// sometimes for specific targets), a computation may end up with
// untransformable RealMathOps, in which case they need to be lowered as is
// (using floating point math).
//
// This op set takes advantage of the fact that it is typically trivial to
// combine a math function with a compatible bias addition and real-valued
// clamp (which can be done at a higher accumulation bit depth).
//
// In addition, all element-wise unary functions are collapsed into a single
// fxpmath_RealUnaryEwOp and selected via an enum-like attribute. Especially at
// low bit depths, this makes matching simpler and allows the construction of
// generic LUT-based implementations. It also allows specific lowering rules
// to consolidate runs of chained unary ops and fuse them to preceding math
// ops, potentially allowing them to operate directly on higher precision
// intermediates without resorting to lots of custom kernels for common
// formulas that can suffer from insufficient precision at low bit depths.
//
// Comparison operators are modeled as element-wise unary functions (i.e.
// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit
// quantized value. It is expected that lowering rules can fuse them with
// the preceding sub.
//===----------------------------------------------------------------------===//
class fxpmath_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
fxpmath_Op<mnemonic, traits>,
Arguments<!con(args, (ins
fxpmath_ClampValueAttr:$clamp_min, fxpmath_ClampValueAttr:$clamp_max))>;
//===----------------------------------------------------------------------===//
// Element wise binary real math ops.
//===----------------------------------------------------------------------===//
class fxpmath_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
fxpmath_RealMathOp<mnemonic, traits,
(ins quant_RealValueType:$lhs,
quant_RealValueType:$rhs)>,
Results<(outs quant_RealValueType:$res)>;
class fxpmath_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
fxpmath_RealMathOp<mnemonic, traits,
(ins quant_RealValueType:$lhs, quant_RealValueType:$rhs,
quant_RealValueType:$bias)>,
Results<(outs quant_RealValueType:$res)>;
def fxpmath_RealAddEwOp :
fxpmath_RealBinaryOp<"real_add_ew", [NoSideEffect]>;
def fxpmath_RealSubEwOp :
fxpmath_RealBinaryOp<"real_sub_ew", [NoSideEffect]>;
def fxpmath_RealMulEwOp :
fxpmath_RealBinaryOp<"real_mul_ew", [NoSideEffect]>;
def fxpmath_RealDivEwOp :
fxpmath_RealBinaryOp<"real_div_ew", [NoSideEffect]>;
//===----------------------------------------------------------------------===//
// Element wise unary real math op.
//===----------------------------------------------------------------------===//
def fxpmath_RealUnaryEwOp :
fxpmath_RealMathOp<"real_unary_ew", [NoSideEffect],
(ins quant_RealValueType:$operand, fxpmath_EwUnaryFnAttr:$fn)>,
Results<(outs quant_RealValueType:$res)>;
def fxpmath_RealCompareZeroEwOp : fxpmath_Op<"compare", [NoSideEffect]>,
Arguments<(ins quant_RealValueType:$operand, fxpmath_CompareFnAttr:$fn)>,
Results<(outs I1Tensor:$res)> {
let description = [{
Compares a real value to zero, returning an I1 (boolean) tensor with the
result of applying the comparison function.
}];
}
//===----------------------------------------------------------------------===//
// Dot op with fused bias addition.
//===----------------------------------------------------------------------===//
def fxpmath_RealMatMulOp :
fxpmath_RealBinaryOp<"real_matmul", [NoSideEffect]> {
let summary = "Matmul";
let description = [{
A matrix multiply of [m, k] and [k, n] -> [m, n] where the bias vector is
of shape [n]. Also accepts rank 3 or more input tensors, in which case
the leading dimensions are batch dims.
Many real systems have specific library calls optimized for this precise
operation, which is why it is handled explicitly versus purely as a
generalized tensor contraction.
}];
}
def fxpmath_RealMatMulBiasOp :
fxpmath_RealBinaryBiasOp<"real_matmul_bias", [NoSideEffect]> {
let summary = "Matmul with bias";
let description = [{
A specialization of a RealMatMulOp that also accepts an [n] dimension
bias vector.
In addition, there is often special support for a fused bias and clamp,
which is why they are included.
}];
}
#endif // DIALECT_FXPMATHOPS_FXPMATH_OPS_

View File

@ -0,0 +1,43 @@
//===- Passes.h - Fixed point math passes -----------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines all of the passes owned by the FxpMathOps dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_FXPMATHOPS_PASSES_H
#define MLIR_DIALECT_FXPMATHOPS_PASSES_H
namespace mlir {
class FunctionPassBase;
namespace fxpmath {
/// Creates a pass that lowers uniform-quantized real math ops to integer
/// arithmetic. This will leave unrecognized real math ops as-is and is
/// typically followed by a pass that lowers any unrecognized ops to a pure
/// floating point form.
FunctionPassBase *createLowerUniformRealMathPass();
/// Creates a pass that lowers uniform-quantized qcast/dcast ops to equivalent
/// operations that perform quantize/dequantize.
FunctionPassBase *createLowerUniformCastsPass();
} // namespace fxpmath
} // namespace mlir
#endif // MLIR_DIALECT_FXPMATHOPS_PASSES_H

View File

@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS GPUOps.td)
mlir_tablegen(GPUOps.h.inc -gen-op-decls)
mlir_tablegen(GPUOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRGPUOpsIncGen)

View File

@ -0,0 +1,174 @@
//===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the GPU kernel-related operations and puts them in the
// corresponding dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_GPU_GPUDIALECT_H
#define MLIR_DIALECT_GPU_GPUDIALECT_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
class FuncOp;
namespace gpu {
/// The dialect containing GPU kernel launching operations and related
/// facilities.
class GPUDialect : public Dialect {
public:
/// Create the dialect in the given `context`.
GPUDialect(MLIRContext *context);
/// Get the canonical string name of the dialect.
static StringRef getDialectName();
/// Get the name of the attribute used to annotate outlined kernel functions.
static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
/// Returns whether the given function is a kernel function, i.e., has the
/// 'gpu.kernel' attribute.
static bool isKernel(FuncOp function);
};
/// Utility class for the GPU dialect to represent triples of `Value`s
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
struct KernelDim3 {
Value *x;
Value *y;
Value *z;
};
/// GPU kernel launch operation. Takes a 3D grid of thread blocks as leading
/// operands, followed by kernel data operands. Has one region representing
/// the kernel to be executed. This region is not allowed to use values defined
/// outside it.
class LaunchOp : public Op<LaunchOp, OpTrait::AtLeastNOperands<6>::Impl,
OpTrait::ZeroResult, OpTrait::IsIsolatedFromAbove> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result, Value *gridSizeX,
Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
Value *blockSizeY, Value *blockSizeZ,
ArrayRef<Value *> operands);
/// Get the kernel region.
Region &getBody();
/// Get the SSA values corresponding to kernel block identifiers.
KernelDim3 getBlockIds();
/// Get the SSA values corresponding to kernel thread identifiers.
KernelDim3 getThreadIds();
/// Get the SSA values corresponding to kernel grid size.
KernelDim3 getGridSize();
/// Get the SSA values corresponding to kernel block size.
KernelDim3 getBlockSize();
/// Get the operand values passed as kernel arguments.
operand_range getKernelOperandValues();
/// Get the operand types passed as kernel arguments.
operand_type_range getKernelOperandTypes();
/// Get the SSA values passed as operands to specify the grid size.
KernelDim3 getGridSizeOperandValues();
/// Get the SSA values passed as operands to specify the block size.
KernelDim3 getBlockSizeOperandValues();
/// Get the SSA values of the kernel arguments.
llvm::iterator_range<Block::args_iterator> getKernelArguments();
LogicalResult verify();
/// Custom syntax support.
void print(OpAsmPrinter *p);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
static StringRef getOperationName() { return "gpu.launch"; }
/// Erase the `index`-th kernel argument. Both the entry block argument and
/// the operand will be dropped. The block argument must not have any uses.
void eraseKernelArgument(unsigned index);
/// Append canonicalization patterns to `results`.
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
private:
static StringRef getBlocksKeyword() { return "blocks"; }
static StringRef getThreadsKeyword() { return "threads"; }
static StringRef getArgsKeyword() { return "args"; }
/// The number of launch configuration operands, placed at the leading
/// positions of the operand list.
static constexpr unsigned kNumConfigOperands = 6;
/// The number of region attributes containing the launch configuration,
/// placed in the leading positions of the argument list.
static constexpr unsigned kNumConfigRegionAttributes = 12;
};
/// Operation to launch a kernel given as outlined function.
class LaunchFuncOp : public Op<LaunchFuncOp, OpTrait::AtLeastNOperands<6>::Impl,
OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result, FuncOp kernelFunc,
Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ,
Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ,
ArrayRef<Value *> kernelOperands);
static void build(Builder *builder, OperationState *result, FuncOp kernelFunc,
KernelDim3 gridSize, KernelDim3 blockSize,
ArrayRef<Value *> kernelOperands);
/// The kernel function specified by the operation's `kernel` attribute.
StringRef kernel();
/// The number of operands passed to the kernel function.
unsigned getNumKernelOperands();
/// The i-th operand passed to the kernel function.
Value *getKernelOperand(unsigned i);
/// Get the SSA values passed as operands to specify the grid size.
KernelDim3 getGridSizeOperandValues();
/// Get the SSA values passed as operands to specify the block size.
KernelDim3 getBlockSizeOperandValues();
LogicalResult verify();
static StringRef getOperationName() { return "gpu.launch_func"; }
/// The number of launch configuration operands, placed at the leading
/// positions of the operand list.
static constexpr unsigned kNumConfigOperands = 6;
private:
/// The name of the function attribute specifying the kernel to launch.
static StringRef getKernelAttrName() { return "kernel"; }
};
#define GET_OP_CLASSES
#include "mlir/Dialect/GPU/GPUOps.h.inc"
} // end namespace gpu
} // end namespace mlir
#endif // MLIR_DIALECT_GPU_GPUDIALECT_H

View File

@ -0,0 +1,60 @@
//===-- GPUOps.td - GPU dialect operation definitions ------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines some operations of the GPU dialect.
//
//===----------------------------------------------------------------------===//
#ifdef GPU_OPS
#else
#define GPU_OPS
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def GPU_Dialect : Dialect {
let name = "gpu";
}
class GPU_Op<string mnemonic, list<OpTrait> traits = []> :
Op<GPU_Dialect, mnemonic, traits>;
class GPU_IndexOp<string mnemonic, list<OpTrait> traits = []> :
GPU_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
Arguments<(ins StrAttr:$dimension)>, Results<(outs Index)>;
def gpu_BlockDim : GPU_IndexOp<"block_dim">;
def gpu_BlockId : GPU_IndexOp<"block_id">;
def gpu_GridDim : GPU_IndexOp<"grid_dim">;
def gpu_ThreadId : GPU_IndexOp<"thread_id">;
def gpu_Return : GPU_Op<"return", [Terminator]>, Arguments<(ins)>,
Results<(outs)> {
let summary = "Terminator for GPU launch regions.";
let description = [{
A terminator operation for regions that appear in the body of `gpu.launch`
operation. These regions are not expected to return any value so the
terminator takes no operands.
}];
let parser = [{ return success(); }];
let printer = [{ *p << getOperationName(); }];
}
#endif // GPU_OPS

View File

@ -0,0 +1,33 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes that expose pass constructors.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_GPU_PASSES_H_
#define MLIR_DIALECT_GPU_PASSES_H_
namespace mlir {
class ModulePassBase;
ModulePassBase *createGpuKernelOutliningPass();
} // namespace mlir
#endif // MLIR_DIALECT_GPU_PASSES_H_

View File

@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS LoopOps.td)
mlir_tablegen(LoopOps.h.inc -gen-op-decls)
mlir_tablegen(LoopOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRLoopOpsIncGen)

View File

@ -0,0 +1,56 @@
//===- Ops.h - Loop MLIR Operations -----------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines convenience types for working with loop operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_LOOPOPS_OPS_H_
#define MLIR_LOOPOPS_OPS_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace loop {
class TerminatorOp;
class LoopOpsDialect : public Dialect {
public:
LoopOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "loop"; }
};
#define GET_OP_CLASSES
#include "mlir/Dialect/LoopOps/LoopOps.h.inc"
// Insert `loop.terminator` at the end of the only region's only block if it
// does not have a terminator already. If a new `loop.terminator` is inserted,
// the location is specified by `loc`. If the region is empty, insert a new
// block first.
void ensureLoopTerminator(Region &region, Builder &builder, Location loc);
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
ForOp getForInductionVarOwner(Value *val);
} // end namespace loop
} // end namespace mlir
#endif // MLIR_LOOPOPS_OPS_H_

View File

@ -0,0 +1,158 @@
//===- Ops.td - Loop operation definitions ---------------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines MLIR loop operations.
//
//===----------------------------------------------------------------------===//
#ifdef LOOP_OPS
#else
#define LOOP_OPS
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def Loop_Dialect : Dialect {
let name = "loop";
let cppNamespace = "";
}
// Base class for Loop dialect ops.
class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Loop_Dialect, mnemonic, traits> {
// For every standard op, there needs to be a:
// * void print(OpAsmPrinter *p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
// * ParseResult parse${C++ class of Op}(OpAsmParser *parser,
// OperationState *result)
// functions.
let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def ForOp : Loop_Op<"for",
[SingleBlockImplicitTerminator<"TerminatorOp">]> {
let summary = "for operation";
let description = [{
The "loop.for" operation represents a loop nest taking 3 SSA value as
operands that represent the lower bound, upper bound and step respectively.
The operation defines an SSA value for its induction variable. It has one
region capturing the loop body. The induction variable is represented as an
argument of this region. This SSA value always has type index, which is the
size of the machine word. The step is a value of type index, required to be
positive.
The lower and upper bounds specify a half-open range: the range includes the
lower bound but does not include the upper bound.
The body region must contain exactly one block that terminates with
"loop.terminator". Calling ForOp::build will create such region and insert
the terminator, so will the parsing even in cases when it is absent from the
custom format. For example:
loop.for %iv = %lb to %ub step %step {
... // body
}
}];
let arguments = (ins Index:$lowerBound, Index:$upperBound, Index:$step);
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState *result, "
"Value *lowerBound, Value *upperBound, Value *step">
];
let extraClassDeclaration = [{
Block *getBody() { return &region().front(); }
Value *getInductionVar() { return getBody()->getArgument(0); }
OpBuilder getBodyBuilder() {
return OpBuilder(getBody(), std::prev(getBody()->end()));
}
void setLowerBound(Value *bound) { getOperation()->setOperand(0, bound); }
void setUpperBound(Value *bound) { getOperation()->setOperand(1, bound); }
void setStep(Value *step) { getOperation()->setOperand(2, step); }
}];
}
def IfOp : Loop_Op<"if",
[SingleBlockImplicitTerminator<"TerminatorOp">]> {
let summary = "if-then-else operation";
let description = [{
The "loop.if" operation represents an if-then-else construct for
conditionally executing two regions of code. The operand to an if operation
is a boolean value. The operation produces no results. For example:
loop.if %b {
...
} else {
...
}
The 'else' block is optional, and may be omitted. For
example:
loop.if %b {
...
}
}];
let arguments = (ins I1:$condition);
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState *result, "
"Value *cond, bool withElseRegion">
];
let extraClassDeclaration = [{
OpBuilder getThenBodyBuilder() {
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
Block &body = thenRegion().front();
return OpBuilder(&body, std::prev(body.end()));
}
OpBuilder getElseBodyBuilder() {
assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
Block &body = elseRegion().front();
return OpBuilder(&body, std::prev(body.end()));
}
}];
}
def TerminatorOp :
Loop_Op<"terminator", [NativeOpTrait<"IsTerminator">]> {
let summary = "cf terminator operation";
let description = [{
"loop.terminator" is a special terminator operation for blocks inside
loops. It terminates the region. This operation does _not_ have a custom
syntax. However, `std` control operations omit the terminator in their
custom syntax for brevity.
loop.terminator
}];
// No custom parsing/printing form.
let parser = ?;
let printer = ?;
// Fully specified by traits.
let verifier = ?;
}
#endif // LOOP_OPS

View File

@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS QuantOps.td)
mlir_tablegen(QuantOps.h.inc -gen-op-decls)
mlir_tablegen(QuantOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRQuantOpsIncGen)

View File

@ -0,0 +1,68 @@
//===- FakeQuantSupport.h - Support utilities for FakeQuant ops -*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines support utilities for interoperating with FakeQuant* based
// QAT (Quantized Aware Training) computations, as implemented by TFLite. Note
// that FakeQuant* operators mix multiple concerns specific to how TFLite
// originally implemented quantization. As such, utilities here enforce
// opinions taken by that codebase (vs providing any amount of genericity).
//
// Specifically, it combines the following concerns, each of which would be
// independent variables in a more generic setup:
// - numBits and isSigned imply storage data type (uint8, int8, int16)
// - numBits < 8 is promoted to uint8 or int8
// - "narrow_range" narrows the lower bound of the storage type's range by
// 1
// - the specified min/max values are "nudged" so that the result has a zero
// that can be exactly expressed
// - min=max=0 implies scale=0 and zero_point=0
//
// With the above assumptions applied, every conforming specified FakeQuant op
// can be represented by a UniformQuantizedType. This scheme is not expected to
// be generalized further in the future and should be considered to be a
// legacy set of rules.
//
// As canonically used in TensorFlow graphs, the presence of a FakeQuant node
// is a hint that the specific math represented here has been simulated at
// training time. As such, it is usually not advised to arbitrarily change
// quantization parameters derived from FakeQuant.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_QUANTOPS_FAKEQUANTSUPPORT_H_
#define MLIR_DIALECT_QUANTOPS_FAKEQUANTSUPPORT_H_
#include "mlir/Dialect/QuantOps/QuantTypes.h"
namespace mlir {
namespace quant {
/// Converts per-layer FakeQuant attributes to the corresponding type.
/// In the event that the parameters cannot be converted, returns a nullptr
/// convertible Type and issues an appropriate error.
/// Note that there are multiple variants of a per-layer FakeQuant op, so
/// this function takes the attributes discretely vs taking a reference to the
/// originating op.
UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
double rmin, double rmax,
bool narrowRange, Type expressedType,
bool isSigned = false);
} // namespace quant
} // namespace mlir
#endif // MLIR_DIALECT_QUANTOPS_FAKEQUANTSUPPORT_H_

View File

@ -0,0 +1,47 @@
//===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines all of the passes owned by the quantization dialect. As
// things mature, it is expected that passes specific to certain frontend or
// backend dialects will move to those dialects directly. For now, they are
// incubated here.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_QUANTOPS_PASSES_H
#define MLIR_DIALECT_QUANTOPS_PASSES_H
namespace mlir {
class FunctionPassBase;
namespace quant {
/// Creates a pass that converts quantization simulation operations (i.e.
/// FakeQuant and those like it) to casts into/out of supported QuantizedTypes.
FunctionPassBase *createConvertSimulatedQuantPass();
/// Creates a pass that converts constants followed by a qbarrier to a
/// constant whose value is quantized. This is typically one of the last
/// passes done when lowering to express actual quantized arithmetic in a
/// low level representation. Because it modifies the constant, it is
/// destructive and cannot be undone.
FunctionPassBase *createConvertConstPass();
} // namespace quant
} // namespace mlir
#endif // MLIR_DIALECT_QUANTOPS_PASSES_H

View File

@ -0,0 +1,50 @@
//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_DIALECT_QUANTOPS_QUANTOPS_H_
#define MLIR_DIALECT_QUANTOPS_QUANTOPS_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
namespace quant {
/// Defines the 'Quantization' dialect
class QuantizationDialect : public Dialect {
public:
QuantizationDialect(MLIRContext *context);
/// Parse a type registered to this dialect.
Type parseType(StringRef spec, Location loc) const override;
/// Print a type registered to this dialect.
void printType(Type type, raw_ostream &os) const override;
};
#define GET_OP_CLASSES
#include "mlir/Dialect/QuantOps/QuantOps.h.inc"
} // namespace quant
} // namespace mlir
#endif // MLIR_DIALECT_QUANTOPS_QUANTOPS_H_

View File

@ -0,0 +1,227 @@
//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the operation definition file for Quantization.
//
//===----------------------------------------------------------------------===//
#ifdef DIALECT_QUANTOPS_QUANT_OPS_
#else
#define DIALECT_QUANTOPS_QUANT_OPS_
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
include "mlir/Dialect/QuantOps/QuantPredicates.td"
#endif // OP_BASE
def quant_Dialect : Dialect {
let name = "quant";
}
//===----------------------------------------------------------------------===//
// Base classes
//===----------------------------------------------------------------------===//
class quant_Op<string mnemonic, list<OpTrait> traits> :
Op<quant_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Quantization casts
//===----------------------------------------------------------------------===//
// A QuantizeCast (qcast) represents a potential type shift from a quantizable
// type to a quantized type.
//
// At runtime, a qcast will apply the transformation expressed by its
// operand and result type. For flexibility during transformation, it is also
// possible to have a qcast that performs no transformation (both its
// operand and result type are quantizable).
//
// A qcast will typically originate from either:
// a) An expressed or implied constraint in the source dialect which signals
// that a certain level of quantization is possible or required.
// b) An inference made by a quantization algorithm indicating that a
// quantized representation may be acceptable.
//
// Especially early in transformation, it is common to have pairs of
// qcast/dcast at points where a transition to a quantized type is
// required. In addition, it is also common to have an identity qcast
// (where the operand and result type are not quantized) at all points where
// it is legal to use a quantized representation (but is not known to be
// acceptable).
def quant_QuantizeCastOp : quant_Op<"qcast", [NoSideEffect]> {
let arguments = (ins quant_RealValueType:$arg);
let results = (outs quant_RealValueType);
}
// A DequantizeCast op (dcast) represents the inverse of a qcast,
// converting back from a quantized to quantizable (expressed) type.
//
// Like qcasts, a dcast is allowed to have both its operand and result
// as non quantized types. This facilitates transformations and marks edges
// where the computation must be carried out in the expressed type.
//
// Especially early in transformation, it is common to have dcasts on
// all operands to ops that must operate with the expressed type (typically
// math ops prior to lowering to target-specific, quantized kernels).
def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> {
let arguments = (ins quant_RealValueType:$arg);
let results = (outs quant_RealValueType);
}
// A StorageCast (scast) represents a cast from or to a type based on the
// storage type and a type based on a corresponding quantized type.
//
// This op exists to ensure type coherency for between parts of the computation
// which are operating directly on an underlying storage type and those which
// operate on quantized values.
//
// Examples from storage to quantized type:
// i8 -> !quant<"uniform[i8:f32]{1.0}">
// tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
// vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
let arguments = (ins quant_RealOrStorageValueType:$arg);
let results = (outs quant_RealOrStorageValueType);
let hasCanonicalizer = 0b1;
}
//===----------------------------------------------------------------------===//
// Training integration and instrumentation ops
//===----------------------------------------------------------------------===//
def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
[SameOperandsAndResultType, NoSideEffect]> {
let summary =
"Simulates the effect of uniform quantization with const range.";
let description = [{
Given a const min, max, num_bits and narrow_range attribute, applies the
same uniform quantization simulation as is done by the TensorFlow
fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility
method and the quant-convert-simulated-quantization pass for futher details.
}];
let arguments = (ins
F32Tensor:$inputs,
F32Attr:$min,
F32Attr:$max,
// The bitwidth of the quantization; between 2 and 16, inclusive.
I64Attr:$num_bits,
// Quantization range starts from 0 or 1; starts from 1 if true.
DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
// The sign of the quantization.
DefaultValuedAttr<BoolAttr, "false">:$is_signed
);
let results = (outs
F32Tensor:$outputs
);
}
def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> {
let summary =
"Indicates that statistics are resolved by reference.";
let description = [{
This op acts as an identity that, when encountered at runtime, should result
in statistics being collected about about the value of its operand/result.
Such statistics will be stored with the provided key, allowing this node
to later be converted to a 'stats' op if statistics with that key have been
encountered.
}];
let arguments = (ins
quant_RealValueType:$arg,
StrAttr:$statsKey
);
let results = (outs quant_RealValueType);
}
def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> {
let summary =
"Identity op which associates statistics with the value.";
let description = [{
Associates statistics about the runtime ranges of values observed for
evaluations of this node.
Statistics about the entire type are reported in the 'layerStats' attribute
and those for each axis, in the (optional) `axisStats` attribute. The
interpretation of each is determined by the last dimension of its shape.
Currently, only dim=2 is supported, which is interpreted as [min, max].
`layerStats` must be a rank 1 tensor: [2]
`axisStats` must be a rank 2 tensor: [N, 2], where N=the rank of `arg`.
}];
let arguments = (ins
quant_RealValueType:$arg,
ElementsAttr:$layerStats,
OptionalAttr<ElementsAttr>:$axisStats);
let results = (outs quant_RealValueType);
let verifier = [{
auto tensorArg = arg()->getType().dyn_cast<TensorType>();
auto argRank = tensorArg ? tensorArg.getRank() : 0;
// Verify layerStats attribute.
{
auto layerStatsType = layerStats().getType();
if (!layerStatsType.getElementType().isa<FloatType>()) {
return emitOpError(
"layerStats must have a floating point element type");
}
if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) {
return emitOpError("layerStats must have shape [2]");
}
}
// Verify axisStats (optional) attribute.
if (axisStats()) {
auto axisStatsType = axisStats()->getType();
if (!axisStatsType.getElementType().isa<FloatType>()) {
return emitOpError("axisStats must have a floating point element type");
}
if (axisStatsType.getRank() != 2 ||
axisStatsType.getDimSize(1) != 2 ||
axisStatsType.getDimSize(0) != argRank) {
return emitOpError("axisStats must have shape [N,2] "
"where N = the argument rank");
}
}
return success();
}];
}
def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> {
let summary =
"Indicates that one point of the computation is coupled to another.";
let description = [{
Ordinarily, relationships between ops for the purposes of determining
compatible quantized types is explicit based on the use-def chain. However,
in some situations, a use may be separated from its def by arbitrary
external connections. In such a case, during analysis, all coupled_ref
nodes in a module which share a coupledKey will be considered to be
directly connected as via an identity op for the purpose of type inference.
}];
let arguments = (ins
quant_RealValueType:$arg,
StrAttr:$coupledKey);
let results = (outs quant_RealValueType);
}
#endif // DIALECT_QUANTOPS_QUANT_OPS_

View File

@ -0,0 +1,72 @@
//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Predicates for types in the Quantization dialect.
//
//===----------------------------------------------------------------------===//
#ifdef DIALECT_QUANTOPS_QUANT_PREDICATES_
#else
//===----------------------------------------------------------------------===//
// Quantization type definitions
//===----------------------------------------------------------------------===//
class quant_TypedPrimitiveOrContainer<Type etype> :
Type<Or<[etype.predicate,
TensorOf<[etype]>.predicate,
VectorOf<[etype]>.predicate]>,
"primitive/tensor/vector of " # etype.description>;
// An implementation of QuantizedType.
def quant_QuantizedType :
Type<CPred<"$_self.isa<mlir::quant::QuantizedType>()">, "QuantizedType">;
// A primitive type that can represent a real value. This is either a
// floating point value or a quantized type.
def quant_RealPrimitiveType :
Type<Or<[AnyFloat.predicate, quant_QuantizedType.predicate]>,
"real valued primitive (float or quantized type)">;
// A primitive type that can represent a storage value. This is either an
// integer or quantized type.
def quant_StoragePrimitiveType :
Type<Or<[AnyInteger.predicate, quant_QuantizedType.predicate]>,
"quantized storage primitive (integer or quantized type)">;
// A primitive or container of RealPrimitiveType.
def quant_RealValueType :
quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
// A primitive or container of StoragePrimitiveType.
def quant_StorageValueType :
quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
// Either a real valued or storage primitive or container type.
def quant_RealOrStorageValueType :
Type<Or<[quant_RealValueType.predicate,
quant_StorageValueType.predicate]>>;
// An implementation of UniformQuantizedType.
def quant_UniformQuantizedType :
Type<CPred<"$_self.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
// Predicate for detecting a container or primitive of UniformQuantizedType.
def quant_UniformQuantizedValueType :
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
#endif // DIALECT_QUANTOPS_QUANT_PREDICATES_

View File

@ -0,0 +1,411 @@
//===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
#define MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
namespace quant {
class QuantizedIntegerType;
namespace detail {
struct QuantizedTypeStorage;
struct AnyQuantizedTypeStorage;
struct UniformQuantizedTypeStorage;
struct UniformQuantizedPerAxisTypeStorage;
} // namespace detail
namespace QuantizationTypes {
enum Kind {
Any = Type::FIRST_QUANTIZATION_TYPE,
UniformQuantized,
UniformQuantizedPerAxis,
LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
};
} // namespace QuantizationTypes
/// Enumeration of bit-mapped flags related to quantized types.
namespace QuantizationFlags {
enum FlagValue {
// Indicates that the storage type should be interpreted as a signed
// integer. The default is to interpret it as an unsigned value.
Signed = 1,
};
} // namespace QuantizationFlags
/// Base class for all quantized types known to this dialect.
/// All quantized types have:
/// - storageType: The (narrower) numeric type that is being used to
/// approximate some expressed type.
/// - expressedType: The type that is being approximated.
///
/// The base class provides generic support for manipulating the types based
/// on these fields.
class QuantizedType : public Type {
public:
using ImplType = detail::QuantizedTypeStorage;
using Type::Type;
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, unsigned flags,
Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool classof(Type type) {
return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE &&
type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE;
}
/// Gets the minimum possible stored by a storageType. storageTypeMin must
/// be greater than or equal to this value.
static int64_t getDefaultMininumForInteger(bool isSigned,
unsigned integralWidth) {
if (isSigned) {
return llvm::minIntN(integralWidth);
}
return 0;
}
/// Gets the maximum possible stored by a storageType. storageTypeMax must
/// be less than or equal to this value.
static int64_t getDefaultMaxinumForInteger(bool isSigned,
unsigned integralWidth) {
if (isSigned) {
return llvm::maxIntN(integralWidth);
}
return llvm::maxUIntN(integralWidth);
}
/// Gets the original expressed type that this quantized type approximates.
/// Note that this presumes that the quantized type was always derived from
/// a floating point type, which in the broadest definition, is not true (i.e.
/// it could be some form of integral, fixed type or affine type in its own
/// right); however, at the high level, no examples of such usage are
/// presently known and the restriction serves some useful purposes (such as
/// always being able to reverse a transformation or measure error). In most
/// cases, this will be f32.
Type getExpressedType() const;
/// Gets the flags associated with this type. Typically a more specific
/// accessor is appropriate.
unsigned getFlags() const;
// Convenience helpers.
/// Whether the storage type should be interpreted as a signed quantity
/// (true) or an unsigned value (false).
bool isSigned() const {
return (getFlags() & QuantizationFlags::Signed) ==
QuantizationFlags::Signed;
}
/// Gets the underlying type used for to store values. Note that this may
/// be signed or unsigned. Use the isSigned() accessor to differentiate.
Type getStorageType() const;
/// The minimum value that storageType can take.
int64_t getStorageTypeMin() const;
/// The maximum value that storageType can take.
int64_t getStorageTypeMax() const;
/// Gets the integral bit width that the underlying storage type can exactly
/// represent. For integral storage types, this will just be their width.
unsigned getStorageTypeIntegralWidth() const;
/// Returns whether the candidateExpressedType is a match for this
/// QuantizedType. This will be true if the candidate type is either a
/// primitive type or a container type whose element type equals this
/// QuantizedType's expressed type.
/// Examples of compatible candidateExpressedType:
/// !quant.uniform<i8:f32, 1.0> =~ f32
/// !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
bool isCompatibleExpressedType(Type candidateExpressedType);
/// Returns the element type as a QuantizedType or nullptr if it is not
/// a quantized type. If the type is primitive, returns that. If it is a
/// container (vector/tensor), return the element type.
/// Examples:
/// !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
/// tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
/// Casts from a type based on the storageType to a corresponding type based
/// on this type (returns nullptr if the cast is not valid).
/// Examples:
/// i8 -> !quant.uniform<i8:f32, 1.0>
/// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
/// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
Type castFromStorageType(Type candidateType);
/// Casts from a type based on a QuantizedType to a corresponding type based
/// on the storageType (returns nullptr if the cast is not valid).
/// This is the inverse of castFromStorageType().
static Type castToStorageType(Type quantizedType);
/// Casts from a type based on the expressedType to a corresponding type based
/// on this type (returns nullptr if the cast is not valid).
/// Examples:
/// f32 -> !quant.uniform<i8:f32, 1.0>
/// tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
/// vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
Type castFromExpressedType(Type candidateType);
/// Casts from a type based on QuantizedType to a corresponding type based
/// on the expressedType (returns nullptr if the cast is not valid).
/// This is the inverse of castFromExpressedType.
static Type castToExpressedType(Type quantizedType);
/// Casts from a type based on the expressedType to the equivalent type
/// based on storageType by way of this QuantizedType. Equivalent to:
/// QuantizedType::castToStorageType(castFromExpressedType(candidateType))
/// (but with validity checks).
/// Example (for this = !quant.uniform<i8:f32, 1.0>):
/// tensor<4xf32> -> tensor<4xi8>
Type castExpressedToStorageType(Type candidateType);
private:
/// Hide the following methods inherited from `Type`. It is almost certainly
/// a bug to call them from a `QuantizedType` object. Users should call
/// `getStorageType` or `getExpressedType` to get the underlying types
/// they want to inspect.
using Type::isBF16;
using Type::isF16;
using Type::isF32;
using Type::isF64;
using Type::isIndex;
using Type::isInteger;
};
/// A quantized type that maps storage to/from expressed types in an
/// unspecified way.
///
/// Typical syntax:
/// quant.any<i8:f32>
/// quant.any<i8>
/// quant.any<i8<-16,15>>
///
/// Note that for the any type, the expressed type is optional.
class AnyQuantizedType
: public Type::TypeBase<AnyQuantizedType, QuantizedType,
detail::AnyQuantizedTypeStorage> {
public:
using Base::Base;
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == QuantizationTypes::Any; }
/// Gets an instance of the type with all parameters specified but not
/// checked.
static AnyQuantizedType get(unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static AnyQuantizedType getChecked(unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax, Location location);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, unsigned flags,
Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax);
};
/// Represents a family of uniform, quantized types.
///
/// Each instance of this type expresses a mapping between real values (most
/// often expressed in floating point f32) and quantized values (either fixed
/// point or affine).
///
/// The relationship is:
/// real_value = scale * (quantized_value - zero_point)
///
/// It is used as part of high level graph transformations that have the goal
/// of re-expressing parts of a computation in terms of this common form for
/// more efficient execution at runtime. In addition, it is designed to be
/// expressive enough to facilitate lowering to precise types and operations
/// in target hardware.
///
/// As a high-level type, focused on intermediate passes, this type holds
/// opinions consistent with high-level usage. If lowering math kernels below
/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
/// instruction sets), it is expected that the information expressed here
/// will be used to drive low level codegen and target specific type selection,
/// but this type will likely be erased in the process.
///
/// Syntax synopsis:
/// Per-layer, all parameters expressed:
/// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
/// Per-layer, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Scale: A legal double value
/// ZeroPoint: An integer value
class UniformQuantizedType
: public Type::TypeBase<UniformQuantizedType, QuantizedType,
detail::UniformQuantizedTypeStorage> {
public:
using Base::Base;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedType get(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedType
getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
Location location);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == QuantizationTypes::UniformQuantized;
}
/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
double getScale() const;
/// Gets the storage value corresponding to the real value 0 in the affine
/// equation.
int64_t getZeroPoint() const;
// Fixed point values are real numbers divided by a scale.
// Currently, only signed storage types are treated as fixed point.
// A fixed point value can be obtained from an affine value by subtracting
// the zeroPoint.
// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
};
/// Represents per-axis (also known as per-channel quantization).
///
/// Syntax synopsis:
/// Per-axis, all parameters expressed:
/// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
/// Per-axis, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// QuantParams: (Scale ':' ZeroPoint)+
/// Scale: A legal double value
/// ZeroPoint: An integer value
class UniformQuantizedPerAxisType
: public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
detail::UniformQuantizedPerAxisTypeStorage> {
public:
using Base::Base;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedPerAxisType
get(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedPerAxisType
getChecked(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax, Location location);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == QuantizationTypes::UniformQuantizedPerAxis;
}
/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
/// by 1. The ith scale corresponds to the ith slice in the
/// quantized_dimension.
ArrayRef<double> getScales() const;
/// Gets the storage values corresponding to the real value 0 in the affine
/// equation. The ith zero point corresponds to the ith slice in the
/// quantized_dimension.
ArrayRef<int64_t> getZeroPoints() const;
/// Specifies the dimension of the Tensor's shape that the scales and
/// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
/// with quantization params:
/// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
/// will be quantized across the second dimension of t.
/// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
/// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
/// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
int32_t getQuantizedDimension() const;
/// Fixed point values are real numbers divided by a scale.
/// Currently, only signed storage types are treated as fixed point.
/// A fixed point value can be obtained from an affine value by subtracting
/// the zeroPoint.
/// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const {
if (!isSigned())
return false;
return llvm::all_of(getZeroPoints(),
[](int64_t zeroPoint) { return zeroPoint != 0; });
}
};
} // namespace quant
} // namespace mlir
#endif // MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_

View File

@ -0,0 +1,70 @@
//===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_
#define MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_
namespace mlir {
class Attribute;
class Type;
namespace quant {
class QuantizedType;
class UniformQuantizedType;
class UniformQuantizedValueConverter;
/// Converts an attribute from a type based on
/// quantizedElementType.getExpressedType() to one based on
/// quantizedElementType.getStorageType(), where quantizedElementType is as from
/// QuantizedType::getQuantizedElementType().
/// Returns nullptr if the conversion is not supported. On success, stores the
/// converted type in outConvertedType.
///
/// Examples:
/// 1. realValue is a primitive value attribute:
/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
/// -> (IntegerAttr, outConvertedType: i8)
/// 2. realValue is an elements attribute:
/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
/// quantizedElementType: UniformQuantizedType[i8:f32])
/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
Type &outConvertedType);
/// Converts an attribute from a type based on
/// quantizedElementType.getExpressedType() to one based on
/// quantizedElementType.getStorageType(), where quantizedElementType is as from
/// QuantizedType::getQuantizedElementType() and casted to an
/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On
/// success, stores the converted type in outConvertedType.
///
/// Examples:
/// 1. realValue is a primitive value attribute:
/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
/// -> (IntegerAttr, outConvertedType: i8)
/// 2. realValue is an elements attribute:
/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
/// quantizedElementType: UniformQuantizedType[i8:f32])
/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
Attribute quantizeAttrUniform(Attribute realValue,
UniformQuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter,
Type &outConvertedType);
} // namespace quant
} // namespace mlir
#endif // MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_

View File

@ -0,0 +1,119 @@
//===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
#define MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
#include "mlir/Dialect/QuantOps/QuantTypes.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/APSInt.h"
namespace mlir {
namespace quant {
/// Performs type conversion from an arbitrary input type to a type
/// that is expressed by a UniformQuantizedType.
///
/// This handles cases where the inputType is a supported primitive type
/// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported
/// elemental type.
///
/// Since conversion often involves introspecting some attributes of the
/// input type in order to determine how to represent it, this is a two step
/// process.
struct ExpressedToUniformQuantizedConverter {
/// Creates a converter for the given input type.
static const ExpressedToUniformQuantizedConverter
forInputType(Type inputType);
/// Converts the inputType to be based on the given elemental type,
/// returning the new type (or nullptr and emit an error on failure).
Type convert(UniformQuantizedType elementalType) const;
/// Whether the conversion is legal.
explicit operator bool() const { return (bool)expressedType; }
/// The input type that is being converted from.
/// This may be an elemental or composite type.
const Type inputType;
/// Supported, elemental expressed type (i.e. f32).
/// Will be nullptr if conversion is not supported.
const Type expressedType;
};
/// Reference implementation of converting between real numbers and values
/// represented by a UniformQuantizedType.
/// Note that this is not expected to be speedy and may be superceded eventually
/// by a more optimal implementation.
/// Also, the interface assumes that quantization is done per-layer and will
/// need to be wider for various per-channel schemes. As such, this is a
/// placeholder.
class UniformQuantizedValueConverter {
public:
UniformQuantizedValueConverter(UniformQuantizedType uniformType)
: scale(uniformType.getScale()),
zeroPoint(static_cast<double>(uniformType.getZeroPoint())),
clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
isSigned(uniformType.isSigned()) {
assert(uniformType.getExpressedType().isa<FloatType>());
assert(uniformType.getStorageType().isa<IntegerType>());
}
virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
bool lossy;
expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven,
&lossy);
// fixedpoint = clamp(clampMin, clampMax, (
// roundHalfToEven(expressed / scale) + zeroPoint))
APFloat scaled = (expressedValue / scale);
scaled.roundToIntegral(APFloat::rmNearestTiesToEven);
scaled.add(zeroPoint, APFloat::rmNearestTiesToEven);
APFloat fixedpoint = llvm::minimum(scaled, clampMax);
fixedpoint = llvm::maximum(fixedpoint, clampMin);
llvm::APSInt result(storageBitWidth, !isSigned);
fixedpoint.convertToInteger(result, APFloat::rmNearestTiesToEven, &lossy);
return result;
}
int64_t quantizeFloatToInt64(APFloat expressedValue) const {
APInt qValue = quantizeFloatToInt(expressedValue);
return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
}
virtual ~UniformQuantizedValueConverter() {}
private:
const APFloat scale;
const APFloat zeroPoint;
const APFloat clampMin;
const APFloat clampMax;
const uint32_t storageBitWidth;
const bool isSigned;
};
} // namespace quant
} // namespace mlir
#endif // MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_

View File

@ -0,0 +1,17 @@
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVOps.h.inc -gen-op-decls)
mlir_tablegen(SPIRVOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRSPIRVOpsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
add_public_tablegen_target(MLIRSPIRVSerializationGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVOpUtils.inc -gen-spirv-op-utils)
add_public_tablegen_target(MLIRSPIRVOpUtilsGen)

View File

@ -0,0 +1,35 @@
//===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes that expose pass constructors.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_PASSES_H_
#define MLIR_DIALECT_SPIRV_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace spirv {
FunctionPassBase *createStdOpsToSPIRVConversionPass();
} // namespace spirv
} // namespace mlir
#endif // MLIR_DIALECT_SPIRV_PASSES_H_

View File

@ -0,0 +1,580 @@
//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the base file for SPIR-V operation definition specification.
// This file defines the SPIR-V dialect, common SPIR-V types, and utilities
// for facilitating defining SPIR-V ops.
//
//===----------------------------------------------------------------------===//
#ifdef SPIRV_BASE
#else
#define SPIRV_BASE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
//===----------------------------------------------------------------------===//
// SPIR-V dialect definitions
//===----------------------------------------------------------------------===//
def SPV_Dialect : Dialect {
let name = "spv";
let description = [{
The SPIR-V dialect in MLIR.
SPIR-V is the Khronos Group's binary intermediate language for representing
graphical-shader stages and compute kernels for multiple Khronos APIs,
including OpenCL, OpenGL, and Vulkan.
See https://www.khronos.org/registry/spir-v for more details.
This dialect aims to be a simple proxy for the SPIR-V binary format to
enable straightforward and lightweight conversion from/to the binary
format. Ops in this dialect should stay at the same semantic level and
try to be a mechanical mapping to the corresponding SPIR-V instructions;
but they may deviate representationally to allow using MLIR mechanisms.
As a convention, if such deviation happens, the op name follows "snake_case"
style; otherwise, the op name just follows the SPIR-V mnemonic (by removing
the leading `Op` prefix) to use "CamelCase" style.
}];
let cppNamespace = "spirv";
}
//===----------------------------------------------------------------------===//
// SPIR-V opcode specification
//===----------------------------------------------------------------------===//
class SPV_OpCode<string name, int val> {
// Name used as reference to retrieve the opcode
string opname = name;
// Opcode associated with the name
int opcode = val;
}
// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>;
def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>;
def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>;
def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>;
def SPV_OC_OpConstantComposite : I32EnumAttrCase<"OpConstantComposite", 44>;
def SPV_OC_OpConstantNull : I32EnumAttrCase<"OpConstantNull", 46>;
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
SPV_OC_OpExecutionMode, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
SPV_OC_OpFMul, SPV_OC_OpReturn
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
//===----------------------------------------------------------------------===//
// SPIR-V type definitions
//===----------------------------------------------------------------------===//
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
def SPV_Void : TypeAlias<NoneType, "void type">;
def SPV_Bool : IntOfWidths<[1]>;
def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
def SPV_Float : FloatOfWidths<[16, 32, 64]>;
def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">;
def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>;
def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyStruct]>;
def SPV_Type : AnyTypeOf<[
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
]>;
class SPV_ScalarOrVectorOf<Type type> :
Type<Or<[type.predicate, VectorOf<[type]>.predicate]>,
"scalar/vector of " # type.description>;
// TODO(antiagainst): Use a more appropriate way to model optional operands
class SPV_Optional<Type type> : Variadic<type>;
def SPV_IsEntryPointType :
CPred<"$_self.isa<::mlir::spirv::EntryPointType>()">;
def SPV_EntryPoint : Type<SPV_IsEntryPointType, "SPIR-V entry point type">;
//===----------------------------------------------------------------------===//
// SPIR-V enum definitions
//===----------------------------------------------------------------------===//
// Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY!
def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
def SPV_AM_Physical32 : I32EnumAttrCase<"Physical32", 1>;
def SPV_AM_Physical64 : I32EnumAttrCase<"Physical64", 2>;
def SPV_AM_PhysicalStorageBuffer64EXT : I32EnumAttrCase<"PhysicalStorageBuffer64EXT", 5348>;
def SPV_AddressingModelAttr :
I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
SPV_AM_PhysicalStorageBuffer64EXT
]> {
let returnType = "::mlir::spirv::AddressingModel";
let convertFromStorage = "static_cast<::mlir::spirv::AddressingModel>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_D_1D : I32EnumAttrCase<"1D", 0>;
def SPV_D_2D : I32EnumAttrCase<"2D", 1>;
def SPV_D_3D : I32EnumAttrCase<"3D", 2>;
def SPV_D_Cube : I32EnumAttrCase<"Cube", 3>;
def SPV_D_Rect : I32EnumAttrCase<"Rect", 4>;
def SPV_D_Buffer : I32EnumAttrCase<"Buffer", 5>;
def SPV_D_SubpassData : I32EnumAttrCase<"SubpassData", 6>;
def SPV_DimAttr :
I32EnumAttr<"Dim", "valid SPIR-V Dim", [
SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer,
SPV_D_SubpassData
]> {
let returnType = "::mlir::spirv::Dim";
let convertFromStorage = "static_cast<::mlir::spirv::Dim>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_EM_Invocations : I32EnumAttrCase<"Invocations", 0>;
def SPV_EM_SpacingEqual : I32EnumAttrCase<"SpacingEqual", 1>;
def SPV_EM_SpacingFractionalEven : I32EnumAttrCase<"SpacingFractionalEven", 2>;
def SPV_EM_SpacingFractionalOdd : I32EnumAttrCase<"SpacingFractionalOdd", 3>;
def SPV_EM_VertexOrderCw : I32EnumAttrCase<"VertexOrderCw", 4>;
def SPV_EM_VertexOrderCcw : I32EnumAttrCase<"VertexOrderCcw", 5>;
def SPV_EM_PixelCenterInteger : I32EnumAttrCase<"PixelCenterInteger", 6>;
def SPV_EM_OriginUpperLeft : I32EnumAttrCase<"OriginUpperLeft", 7>;
def SPV_EM_OriginLowerLeft : I32EnumAttrCase<"OriginLowerLeft", 8>;
def SPV_EM_EarlyFragmentTests : I32EnumAttrCase<"EarlyFragmentTests", 9>;
def SPV_EM_PointMode : I32EnumAttrCase<"PointMode", 10>;
def SPV_EM_Xfb : I32EnumAttrCase<"Xfb", 11>;
def SPV_EM_DepthReplacing : I32EnumAttrCase<"DepthReplacing", 12>;
def SPV_EM_DepthGreater : I32EnumAttrCase<"DepthGreater", 14>;
def SPV_EM_DepthLess : I32EnumAttrCase<"DepthLess", 15>;
def SPV_EM_DepthUnchanged : I32EnumAttrCase<"DepthUnchanged", 16>;
def SPV_EM_LocalSize : I32EnumAttrCase<"LocalSize", 17>;
def SPV_EM_LocalSizeHint : I32EnumAttrCase<"LocalSizeHint", 18>;
def SPV_EM_InputPoints : I32EnumAttrCase<"InputPoints", 19>;
def SPV_EM_InputLines : I32EnumAttrCase<"InputLines", 20>;
def SPV_EM_InputLinesAdjacency : I32EnumAttrCase<"InputLinesAdjacency", 21>;
def SPV_EM_Triangles : I32EnumAttrCase<"Triangles", 22>;
def SPV_EM_InputTrianglesAdjacency : I32EnumAttrCase<"InputTrianglesAdjacency", 23>;
def SPV_EM_Quads : I32EnumAttrCase<"Quads", 24>;
def SPV_EM_Isolines : I32EnumAttrCase<"Isolines", 25>;
def SPV_EM_OutputVertices : I32EnumAttrCase<"OutputVertices", 26>;
def SPV_EM_OutputPoints : I32EnumAttrCase<"OutputPoints", 27>;
def SPV_EM_OutputLineStrip : I32EnumAttrCase<"OutputLineStrip", 28>;
def SPV_EM_OutputTriangleStrip : I32EnumAttrCase<"OutputTriangleStrip", 29>;
def SPV_EM_VecTypeHint : I32EnumAttrCase<"VecTypeHint", 30>;
def SPV_EM_ContractionOff : I32EnumAttrCase<"ContractionOff", 31>;
def SPV_EM_Initializer : I32EnumAttrCase<"Initializer", 33>;
def SPV_EM_Finalizer : I32EnumAttrCase<"Finalizer", 34>;
def SPV_EM_SubgroupSize : I32EnumAttrCase<"SubgroupSize", 35>;
def SPV_EM_SubgroupsPerWorkgroup : I32EnumAttrCase<"SubgroupsPerWorkgroup", 36>;
def SPV_EM_SubgroupsPerWorkgroupId : I32EnumAttrCase<"SubgroupsPerWorkgroupId", 37>;
def SPV_EM_LocalSizeId : I32EnumAttrCase<"LocalSizeId", 38>;
def SPV_EM_LocalSizeHintId : I32EnumAttrCase<"LocalSizeHintId", 39>;
def SPV_EM_PostDepthCoverage : I32EnumAttrCase<"PostDepthCoverage", 4446>;
def SPV_EM_DenormPreserve : I32EnumAttrCase<"DenormPreserve", 4459>;
def SPV_EM_DenormFlushToZero : I32EnumAttrCase<"DenormFlushToZero", 4460>;
def SPV_EM_SignedZeroInfNanPreserve : I32EnumAttrCase<"SignedZeroInfNanPreserve", 4461>;
def SPV_EM_RoundingModeRTE : I32EnumAttrCase<"RoundingModeRTE", 4462>;
def SPV_EM_RoundingModeRTZ : I32EnumAttrCase<"RoundingModeRTZ", 4463>;
def SPV_EM_StencilRefReplacingEXT : I32EnumAttrCase<"StencilRefReplacingEXT", 5027>;
def SPV_EM_OutputLinesNV : I32EnumAttrCase<"OutputLinesNV", 5269>;
def SPV_EM_OutputPrimitivesNV : I32EnumAttrCase<"OutputPrimitivesNV", 5270>;
def SPV_EM_DerivativeGroupQuadsNV : I32EnumAttrCase<"DerivativeGroupQuadsNV", 5289>;
def SPV_EM_DerivativeGroupLinearNV : I32EnumAttrCase<"DerivativeGroupLinearNV", 5290>;
def SPV_EM_OutputTrianglesNV : I32EnumAttrCase<"OutputTrianglesNV", 5298>;
def SPV_EM_PixelInterlockOrderedEXT : I32EnumAttrCase<"PixelInterlockOrderedEXT", 5366>;
def SPV_EM_PixelInterlockUnorderedEXT : I32EnumAttrCase<"PixelInterlockUnorderedEXT", 5367>;
def SPV_EM_SampleInterlockOrderedEXT : I32EnumAttrCase<"SampleInterlockOrderedEXT", 5368>;
def SPV_EM_SampleInterlockUnorderedEXT : I32EnumAttrCase<"SampleInterlockUnorderedEXT", 5369>;
def SPV_EM_ShadingRateInterlockOrderedEXT : I32EnumAttrCase<"ShadingRateInterlockOrderedEXT", 5370>;
def SPV_EM_ShadingRateInterlockUnorderedEXT : I32EnumAttrCase<"ShadingRateInterlockUnorderedEXT", 5371>;
def SPV_ExecutionModeAttr :
I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", [
SPV_EM_Invocations, SPV_EM_SpacingEqual, SPV_EM_SpacingFractionalEven,
SPV_EM_SpacingFractionalOdd, SPV_EM_VertexOrderCw, SPV_EM_VertexOrderCcw,
SPV_EM_PixelCenterInteger, SPV_EM_OriginUpperLeft, SPV_EM_OriginLowerLeft,
SPV_EM_EarlyFragmentTests, SPV_EM_PointMode, SPV_EM_Xfb, SPV_EM_DepthReplacing,
SPV_EM_DepthGreater, SPV_EM_DepthLess, SPV_EM_DepthUnchanged, SPV_EM_LocalSize,
SPV_EM_LocalSizeHint, SPV_EM_InputPoints, SPV_EM_InputLines,
SPV_EM_InputLinesAdjacency, SPV_EM_Triangles, SPV_EM_InputTrianglesAdjacency,
SPV_EM_Quads, SPV_EM_Isolines, SPV_EM_OutputVertices, SPV_EM_OutputPoints,
SPV_EM_OutputLineStrip, SPV_EM_OutputTriangleStrip, SPV_EM_VecTypeHint,
SPV_EM_ContractionOff, SPV_EM_Initializer, SPV_EM_Finalizer,
SPV_EM_SubgroupSize, SPV_EM_SubgroupsPerWorkgroup,
SPV_EM_SubgroupsPerWorkgroupId, SPV_EM_LocalSizeId, SPV_EM_LocalSizeHintId,
SPV_EM_PostDepthCoverage, SPV_EM_DenormPreserve, SPV_EM_DenormFlushToZero,
SPV_EM_SignedZeroInfNanPreserve, SPV_EM_RoundingModeRTE,
SPV_EM_RoundingModeRTZ, SPV_EM_StencilRefReplacingEXT, SPV_EM_OutputLinesNV,
SPV_EM_OutputPrimitivesNV, SPV_EM_DerivativeGroupQuadsNV,
SPV_EM_DerivativeGroupLinearNV, SPV_EM_OutputTrianglesNV,
SPV_EM_PixelInterlockOrderedEXT, SPV_EM_PixelInterlockUnorderedEXT,
SPV_EM_SampleInterlockOrderedEXT, SPV_EM_SampleInterlockUnorderedEXT,
SPV_EM_ShadingRateInterlockOrderedEXT, SPV_EM_ShadingRateInterlockUnorderedEXT
]> {
let returnType = "::mlir::spirv::ExecutionMode";
let convertFromStorage = "static_cast<::mlir::spirv::ExecutionMode>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_EM_Vertex : I32EnumAttrCase<"Vertex", 0>;
def SPV_EM_TessellationControl : I32EnumAttrCase<"TessellationControl", 1>;
def SPV_EM_TessellationEvaluation : I32EnumAttrCase<"TessellationEvaluation", 2>;
def SPV_EM_Geometry : I32EnumAttrCase<"Geometry", 3>;
def SPV_EM_Fragment : I32EnumAttrCase<"Fragment", 4>;
def SPV_EM_GLCompute : I32EnumAttrCase<"GLCompute", 5>;
def SPV_EM_Kernel : I32EnumAttrCase<"Kernel", 6>;
def SPV_EM_TaskNV : I32EnumAttrCase<"TaskNV", 5267>;
def SPV_EM_MeshNV : I32EnumAttrCase<"MeshNV", 5268>;
def SPV_EM_RayGenerationNV : I32EnumAttrCase<"RayGenerationNV", 5313>;
def SPV_EM_IntersectionNV : I32EnumAttrCase<"IntersectionNV", 5314>;
def SPV_EM_AnyHitNV : I32EnumAttrCase<"AnyHitNV", 5315>;
def SPV_EM_ClosestHitNV : I32EnumAttrCase<"ClosestHitNV", 5316>;
def SPV_EM_MissNV : I32EnumAttrCase<"MissNV", 5317>;
def SPV_EM_CallableNV : I32EnumAttrCase<"CallableNV", 5318>;
def SPV_ExecutionModelAttr :
I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", [
SPV_EM_Vertex, SPV_EM_TessellationControl, SPV_EM_TessellationEvaluation,
SPV_EM_Geometry, SPV_EM_Fragment, SPV_EM_GLCompute, SPV_EM_Kernel,
SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationNV, SPV_EM_IntersectionNV,
SPV_EM_AnyHitNV, SPV_EM_ClosestHitNV, SPV_EM_MissNV, SPV_EM_CallableNV
]> {
let returnType = "::mlir::spirv::ExecutionModel";
let convertFromStorage = "static_cast<::mlir::spirv::ExecutionModel>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_FC_None : I32EnumAttrCase<"None", 0x0000>;
def SPV_FC_Inline : I32EnumAttrCase<"Inline", 0x0001>;
def SPV_FC_DontInline : I32EnumAttrCase<"DontInline", 0x0002>;
def SPV_FC_Pure : I32EnumAttrCase<"Pure", 0x0004>;
def SPV_FC_Const : I32EnumAttrCase<"Const", 0x0008>;
def SPV_FunctionControlAttr :
I32EnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const
]> {
let returnType = "::mlir::spirv::FunctionControl";
let convertFromStorage = "static_cast<::mlir::spirv::FunctionControl>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_IF_Unknown : I32EnumAttrCase<"Unknown", 0>;
def SPV_IF_Rgba32f : I32EnumAttrCase<"Rgba32f", 1>;
def SPV_IF_Rgba16f : I32EnumAttrCase<"Rgba16f", 2>;
def SPV_IF_R32f : I32EnumAttrCase<"R32f", 3>;
def SPV_IF_Rgba8 : I32EnumAttrCase<"Rgba8", 4>;
def SPV_IF_Rgba8Snorm : I32EnumAttrCase<"Rgba8Snorm", 5>;
def SPV_IF_Rg32f : I32EnumAttrCase<"Rg32f", 6>;
def SPV_IF_Rg16f : I32EnumAttrCase<"Rg16f", 7>;
def SPV_IF_R11fG11fB10f : I32EnumAttrCase<"R11fG11fB10f", 8>;
def SPV_IF_R16f : I32EnumAttrCase<"R16f", 9>;
def SPV_IF_Rgba16 : I32EnumAttrCase<"Rgba16", 10>;
def SPV_IF_Rgb10A2 : I32EnumAttrCase<"Rgb10A2", 11>;
def SPV_IF_Rg16 : I32EnumAttrCase<"Rg16", 12>;
def SPV_IF_Rg8 : I32EnumAttrCase<"Rg8", 13>;
def SPV_IF_R16 : I32EnumAttrCase<"R16", 14>;
def SPV_IF_R8 : I32EnumAttrCase<"R8", 15>;
def SPV_IF_Rgba16Snorm : I32EnumAttrCase<"Rgba16Snorm", 16>;
def SPV_IF_Rg16Snorm : I32EnumAttrCase<"Rg16Snorm", 17>;
def SPV_IF_Rg8Snorm : I32EnumAttrCase<"Rg8Snorm", 18>;
def SPV_IF_R16Snorm : I32EnumAttrCase<"R16Snorm", 19>;
def SPV_IF_R8Snorm : I32EnumAttrCase<"R8Snorm", 20>;
def SPV_IF_Rgba32i : I32EnumAttrCase<"Rgba32i", 21>;
def SPV_IF_Rgba16i : I32EnumAttrCase<"Rgba16i", 22>;
def SPV_IF_Rgba8i : I32EnumAttrCase<"Rgba8i", 23>;
def SPV_IF_R32i : I32EnumAttrCase<"R32i", 24>;
def SPV_IF_Rg32i : I32EnumAttrCase<"Rg32i", 25>;
def SPV_IF_Rg16i : I32EnumAttrCase<"Rg16i", 26>;
def SPV_IF_Rg8i : I32EnumAttrCase<"Rg8i", 27>;
def SPV_IF_R16i : I32EnumAttrCase<"R16i", 28>;
def SPV_IF_R8i : I32EnumAttrCase<"R8i", 29>;
def SPV_IF_Rgba32ui : I32EnumAttrCase<"Rgba32ui", 30>;
def SPV_IF_Rgba16ui : I32EnumAttrCase<"Rgba16ui", 31>;
def SPV_IF_Rgba8ui : I32EnumAttrCase<"Rgba8ui", 32>;
def SPV_IF_R32ui : I32EnumAttrCase<"R32ui", 33>;
def SPV_IF_Rgb10a2ui : I32EnumAttrCase<"Rgb10a2ui", 34>;
def SPV_IF_Rg32ui : I32EnumAttrCase<"Rg32ui", 35>;
def SPV_IF_Rg16ui : I32EnumAttrCase<"Rg16ui", 36>;
def SPV_IF_Rg8ui : I32EnumAttrCase<"Rg8ui", 37>;
def SPV_IF_R16ui : I32EnumAttrCase<"R16ui", 38>;
def SPV_IF_R8ui : I32EnumAttrCase<"R8ui", 39>;
def SPV_ImageFormatAttr :
I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [
SPV_IF_Unknown, SPV_IF_Rgba32f, SPV_IF_Rgba16f, SPV_IF_R32f, SPV_IF_Rgba8,
SPV_IF_Rgba8Snorm, SPV_IF_Rg32f, SPV_IF_Rg16f, SPV_IF_R11fG11fB10f,
SPV_IF_R16f, SPV_IF_Rgba16, SPV_IF_Rgb10A2, SPV_IF_Rg16, SPV_IF_Rg8,
SPV_IF_R16, SPV_IF_R8, SPV_IF_Rgba16Snorm, SPV_IF_Rg16Snorm, SPV_IF_Rg8Snorm,
SPV_IF_R16Snorm, SPV_IF_R8Snorm, SPV_IF_Rgba32i, SPV_IF_Rgba16i, SPV_IF_Rgba8i,
SPV_IF_R32i, SPV_IF_Rg32i, SPV_IF_Rg16i, SPV_IF_Rg8i, SPV_IF_R16i, SPV_IF_R8i,
SPV_IF_Rgba32ui, SPV_IF_Rgba16ui, SPV_IF_Rgba8ui, SPV_IF_R32ui,
SPV_IF_Rgb10a2ui, SPV_IF_Rg32ui, SPV_IF_Rg16ui, SPV_IF_Rg8ui, SPV_IF_R16ui,
SPV_IF_R8ui
]> {
let returnType = "::mlir::spirv::ImageFormat";
let convertFromStorage = "static_cast<::mlir::spirv::ImageFormat>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_LT_Export : I32EnumAttrCase<"Export", 0>;
def SPV_LT_Import : I32EnumAttrCase<"Import", 1>;
def SPV_LinkageTypeAttr :
I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [
SPV_LT_Export, SPV_LT_Import
]> {
let returnType = "::mlir::spirv::LinkageType";
let convertFromStorage = "static_cast<::mlir::spirv::LinkageType>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_MA_None : I32EnumAttrCase<"None", 0x0000>;
def SPV_MA_Volatile : I32EnumAttrCase<"Volatile", 0x0001>;
def SPV_MA_Aligned : I32EnumAttrCase<"Aligned", 0x0002>;
def SPV_MA_Nontemporal : I32EnumAttrCase<"Nontemporal", 0x0004>;
def SPV_MA_MakePointerAvailableKHR : I32EnumAttrCase<"MakePointerAvailableKHR", 0x0008>;
def SPV_MA_MakePointerVisibleKHR : I32EnumAttrCase<"MakePointerVisibleKHR", 0x0010>;
def SPV_MA_NonPrivatePointerKHR : I32EnumAttrCase<"NonPrivatePointerKHR", 0x0020>;
def SPV_MemoryAccessAttr :
I32EnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [
SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal,
SPV_MA_MakePointerAvailableKHR, SPV_MA_MakePointerVisibleKHR,
SPV_MA_NonPrivatePointerKHR
]> {
let returnType = "::mlir::spirv::MemoryAccess";
let convertFromStorage = "static_cast<::mlir::spirv::MemoryAccess>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_MM_Simple : I32EnumAttrCase<"Simple", 0>;
def SPV_MM_GLSL450 : I32EnumAttrCase<"GLSL450", 1>;
def SPV_MM_OpenCL : I32EnumAttrCase<"OpenCL", 2>;
def SPV_MM_VulkanKHR : I32EnumAttrCase<"VulkanKHR", 3>;
def SPV_MemoryModelAttr :
I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_VulkanKHR
]> {
let returnType = "::mlir::spirv::MemoryModel";
let convertFromStorage = "static_cast<::mlir::spirv::MemoryModel>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_SC_UniformConstant : I32EnumAttrCase<"UniformConstant", 0>;
def SPV_SC_Input : I32EnumAttrCase<"Input", 1>;
def SPV_SC_Uniform : I32EnumAttrCase<"Uniform", 2>;
def SPV_SC_Output : I32EnumAttrCase<"Output", 3>;
def SPV_SC_Workgroup : I32EnumAttrCase<"Workgroup", 4>;
def SPV_SC_CrossWorkgroup : I32EnumAttrCase<"CrossWorkgroup", 5>;
def SPV_SC_Private : I32EnumAttrCase<"Private", 6>;
def SPV_SC_Function : I32EnumAttrCase<"Function", 7>;
def SPV_SC_Generic : I32EnumAttrCase<"Generic", 8>;
def SPV_SC_PushConstant : I32EnumAttrCase<"PushConstant", 9>;
def SPV_SC_AtomicCounter : I32EnumAttrCase<"AtomicCounter", 10>;
def SPV_SC_Image : I32EnumAttrCase<"Image", 11>;
def SPV_SC_StorageBuffer : I32EnumAttrCase<"StorageBuffer", 12>;
def SPV_SC_CallableDataNV : I32EnumAttrCase<"CallableDataNV", 5328>;
def SPV_SC_IncomingCallableDataNV : I32EnumAttrCase<"IncomingCallableDataNV", 5329>;
def SPV_SC_RayPayloadNV : I32EnumAttrCase<"RayPayloadNV", 5338>;
def SPV_SC_HitAttributeNV : I32EnumAttrCase<"HitAttributeNV", 5339>;
def SPV_SC_IncomingRayPayloadNV : I32EnumAttrCase<"IncomingRayPayloadNV", 5342>;
def SPV_SC_ShaderRecordBufferNV : I32EnumAttrCase<"ShaderRecordBufferNV", 5343>;
def SPV_SC_PhysicalStorageBufferEXT : I32EnumAttrCase<"PhysicalStorageBufferEXT", 5349>;
def SPV_StorageClassAttr :
I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", [
SPV_SC_UniformConstant, SPV_SC_Input, SPV_SC_Uniform, SPV_SC_Output,
SPV_SC_Workgroup, SPV_SC_CrossWorkgroup, SPV_SC_Private, SPV_SC_Function,
SPV_SC_Generic, SPV_SC_PushConstant, SPV_SC_AtomicCounter, SPV_SC_Image,
SPV_SC_StorageBuffer, SPV_SC_CallableDataNV, SPV_SC_IncomingCallableDataNV,
SPV_SC_RayPayloadNV, SPV_SC_HitAttributeNV, SPV_SC_IncomingRayPayloadNV,
SPV_SC_ShaderRecordBufferNV, SPV_SC_PhysicalStorageBufferEXT
]> {
let returnType = "::mlir::spirv::StorageClass";
let convertFromStorage = "static_cast<::mlir::spirv::StorageClass>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
// End enum section. Generated from SPIR-V spec; DO NOT MODIFY!
// Enums added manually that are not part of SPIRV spec
def SPV_IDI_NoDepth : I32EnumAttrCase<"NoDepth", 0>;
def SPV_IDI_IsDepth : I32EnumAttrCase<"IsDepth", 1>;
def SPV_IDI_DepthUnknown : I32EnumAttrCase<"DepthUnknown", 2>;
def SPV_DepthAttr :
I32EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification",
[SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]> {
let cppNamespace = "::mlir::spirv";
}
def SPV_IAI_NonArrayed : I32EnumAttrCase<"NonArrayed", 0>;
def SPV_IAI_Arrayed : I32EnumAttrCase<"Arrayed", 1>;
def SPV_ArrayedAttr :
I32EnumAttr<"ImageArrayedInfo", "valid SPIR-V Image Arrayed specification",
[SPV_IAI_NonArrayed, SPV_IAI_Arrayed]> {
let cppNamespace = "::mlir::spirv";
}
def SPV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>;
def SPV_ISI_MultiSampled : I32EnumAttrCase<"MultiSampled", 1>;
def SPV_SamplingAttr:
I32EnumAttr<"ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
[SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]> {
let cppNamespace = "::mlir::spirv";
}
def SPV_ISUI_SamplerUnknown : I32EnumAttrCase<"SamplerUnknown", 0>;
def SPV_ISUI_NeedSampler : I32EnumAttrCase<"NeedSampler", 1>;
def SPV_ISUI_NoSampler : I32EnumAttrCase<"NoSampler", 2>;
def SPV_SamplerUseAttr:
I32EnumAttr<"ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification",
[SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]> {
let cppNamespace = "::mlir::spirv";
}
//===----------------------------------------------------------------------===//
// SPIR-V OpTrait definitions
//===----------------------------------------------------------------------===//
// Check that an op can only be used with SPIR-V ModuleOp
def IsModuleOnlyPred :
CPred<"llvm::isa_and_nonnull<spirv::ModuleOp>($_op.getParentOp())">;
def ModuleOnly :
PredOpTrait<"op can only be used in a 'spv.module' block", IsModuleOnlyPred>;
//===----------------------------------------------------------------------===//
// SPIR-V op definitions
//===----------------------------------------------------------------------===//
// Base class for all SPIR-V ops.
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
Op<SPV_Dialect, mnemonic, traits> {
// For each SPIR-V op, the following static functions need to be defined
// in SPVOps.cpp:
//
// * static ParseResult parse<op-c++-class-name>(OpAsmParser *parser,
// OperationState *result)
// * static void print(OpAsmPrinter *p, <op-c++-class-name> op)
// * static LogicalResult verify(<op-c++-class-name> op)
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(*this, p); }];
let verifier = [{ return ::verify(*this); }];
// Specifies whether this op has a direct corresponding SPIR-V binary
// instruction opcode. The (de)serializer use this field to determine whether
// to auto-generate an entry in the (de)serialization dispatch table for this
// op. If set, this field also futher enables `autogenSerialization` (see
// below for details).
bit hasOpcode = 1;
// Name of the corresponding SPIR-V op. Only valid to use when hasOpcode is 1.
string spirvOpName = "Op" # mnemonic;
// Controls whether to auto-generate this op's (de)serialization method.
// If set, it results in generation of the following methods:
//
// ```c++
// template<typename OpTy> Serializer::processOp(OpTy op);
// template<typename OpTy> Deserializer::processOp(ArrayRef<uint32_t>);
// ```
//
// If this field is not set, then manual implementation of a specialization of
// these methods is required.
//
// Note:
//
// 1) If hasOpcode is set but autogenSerialization is not set, the
// (de)serializer dispatch method still calls the above method for
// (de)serializing this op.
//
// 2) If hasOpcode is not set, then this field is not interpreted; this op's
// (de)serialization method will not be auto-generated regardless. Neither
// does the handling in the (de)serialization dispatch table. Both
// (de)serializing this op and its dispatch should be handled manually.
bit autogenSerialization = 1;
}
#endif // SPIRV_BASE

View File

@ -0,0 +1,46 @@
//===- SPIRVDialect.h - MLIR SPIR-V dialect ---------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file declares the SPIR-V dialect in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_SPIRVDIALECT_H_
#define MLIR_DIALECT_SPIRV_SPIRVDIALECT_H_
#include "mlir/IR/Dialect.h"
namespace mlir {
namespace spirv {
class SPIRVDialect : public Dialect {
public:
explicit SPIRVDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "spv"; }
/// Parses a type registered to this dialect.
Type parseType(llvm::StringRef spec, Location loc) const override;
/// Prints a type registered to this dialect.
void printType(Type type, llvm::raw_ostream &os) const override;
};
} // end namespace spirv
} // end namespace mlir
#endif // MLIR_DIALECT_SPIRV_SPIRVDIALECT_H_

View File

@ -0,0 +1,37 @@
//===- SPIRVOps.h - MLIR SPIR-V operations ----------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file declares the operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_SPIRVOPS_H_
#define MLIR_DIALECT_SPIRV_SPIRVOPS_H_
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Function.h"
namespace mlir {
namespace spirv {
#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/SPIRVOps.h.inc"
} // end namespace spirv
} // end namespace mlir
#endif // MLIR_DIALECT_SPIRV_SPIRVOPS_H_

View File

@ -0,0 +1,468 @@
//===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the main operation definition specification file for SPIR-V
// operations.
//
//===----------------------------------------------------------------------===//
// Note that for each op in this file, we use a tool to automatically generate
// certain sections in its definition: basic structure, summary, description.
// So modifications to these sections will not be respected. Modifications to
// op traits, arguments, results, and sections after the results are retained.
// Besides, ops in this file must be separated via the '// -----' marker.
#ifdef SPIRV_OPS
#else
#define SPIRV_OPS
#ifdef SPIRV_BASE
#else
include "mlir/Dialect/SPIRV/SPIRVBase.td"
#endif // SPIRV_BASE
#ifdef SPIRV_STRUCTURE_OPS
#else
// Pull in ops for defining the SPIR-V module structure
include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
#endif // SPIRV_STRUCTURE_OPS
// -----
def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
let summary = [{
Create a pointer into a composite object that can be used with OpLoad
and OpStore.
}];
let description = [{
Result Type must be an OpTypePointer. Its Type operand must be the type
reached by walking the Bases type hierarchy down to the last provided
index in Indexes, and its Storage Class operand must be the same as the
Storage Class of Base.
Base must be a pointer, pointing to the base of a composite object.
Indexes walk the type hierarchy to the desired depth, potentially down
to scalar granularity. The first index in Indexes will select the top-
level member/element/component/element of the base composite. All
composite constituents use zero-based numbering, as described by their
OpType instruction. The second index will apply similarly to that
result, and so on. Once any non-composite type is reached, there must be
no remaining (unused) indexes.
Each index in Indexes
- must be a scalar integer type,
- is treated as a signed count, and
- must be an OpConstant when indexing into a structure.
### Custom assembly form
``` {.ebnf}
access-chain-op ::= ssa-id `=` `spv.AccessChain` ssa-use
`[` ssa-use (',' ssa-use)* `]`
`:` pointer-type
```
For example:
```
%0 = "spv.constant"() { value = 1: i32} : () -> i32
%1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
%2 = spv.AccessChain %1[%0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
%3 = spv.Load "Function" %2 ["Volatile"] : !spv.array<4xf32>
```
}];
let arguments = (ins
SPV_AnyPtr:$base_ptr,
Variadic<SPV_Integer>:$indices
);
let results = (outs
SPV_AnyPtr:$component_ptr
);
}
// -----
def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
let summary = "Extract a part of a composite object.";
let description = [{
Result Type must be the type of object selected by the last provided
index. The instruction result is the extracted object.
Composite is the composite to extract from.
Indexes walk the type hierarchy, potentially down to component
granularity, to select the part to extract. All indexes must be in
bounds. All composite constituents use zero-based numbering, as
described by their OpType instruction.
### Custom assembly form
``` {.ebnf}
composite-extract-op ::= ssa-id `=` `spv.CompositeExtract` ssa-use
`[` integer-literal (',' integer-literal)* `]`
`:` composite-type
```
For example:
```
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
%1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
%2 = spv.CompositeExtract %1[1 : i32] : !spv.array<4x!spv.array<4xf32>>
```
}];
let arguments = (ins
SPV_Composite:$composite,
I32ArrayAttr:$indices
);
let results = (outs
SPV_Type:$component
);
}
// -----
def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
let summary = [{
Declare an entry point, its execution model, and its interface.
}];
let description = [{
Execution Model is the execution model for the entry point and its
static call tree. See Execution Model.
Entry Point must be the Result <id> of an OpFunction instruction.
Name is a name string for the entry point. A module cannot have two
OpEntryPoint instructions with the same Execution Model and the same
Name string.
Interface is a list of <id> of global OpVariable instructions. These
declare the set of global variables from a module that form the
interface of this entry point. The set of Interface <id> must be equal
to or a superset of the global OpVariable Result <id> referenced by the
entry points static call tree, within the interfaces storage classes.
Before version 1.4, the interfaces storage classes are limited to the
Input and Output storage classes. Starting with version 1.4, the
interfaces storage classes are all storage classes used in declaring
all global variables referenced by the entry points call tree.
Interface <id> are forward references. Before version 1.4, duplication
of these <id> is tolerated. Starting with version 1.4, an <id> must not
appear more than once.
### Custom assembly form
``` {.ebnf}
execution-model ::= "Vertex" | "TesellationControl" |
<and other SPIR-V execution models...>
entry-point-op ::= ssa-id ` = spv.EntryPoint ` execution-model fn-name
(ssa-use ( `, ` ssa-use)* ` : `
pointer-type ( `, ` pointer-type)* )?
```
For example:
```
spv.EntryPoint "GLCompute" @foo
spv.EntryPoint "Kernel" @foo, %1, %2 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
```
}];
let arguments = (ins
SPV_ExecutionModelAttr:$execution_model,
SymbolRefAttr:$fn,
Variadic<SPV_AnyPtr>:$interface
);
let results = (outs);
let autogenSerialization = 0;
}
// -----
def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
let summary = "Declare an execution mode for an entry point.";
let description = [{
Entry Point must be the Entry Point <id> operand of an OpEntryPoint
instruction.
Mode is the execution mode. See Execution Mode.
This instruction is only valid when the Mode operand is an execution
mode that takes no Extra Operands, or takes Extra Operands that are not
<id> operands.
### Custom assembly form
``` {.ebnf}
execution-mode ::= "Invocations" | "SpacingEqual" |
<and other SPIR-V execution modes...>
execution-mode-op ::= `spv.ExecutionMode ` ssa-use execution-mode
(integer-literal (`, ` integer-literal)* )?
```
For example:
```
spv.ExecutionMode @foo "ContractionOff"
spv.ExecutionMode @bar "LocalSizeHint", 3, 4, 5
```
}];
let arguments = (ins
SymbolRefAttr:$fn,
SPV_ExecutionModeAttr:$execution_mode,
OptionalAttr<I32ArrayAttr>:$values
);
let results = (outs);
let verifier = [{ return success(); }];
let autogenSerialization = 0;
}
// -----
def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
let description = [{
Result Type must be a scalar or vector of floating-point type.
The types of Operand 1 and Operand 2 both must be the same as Result
Type.
Results are computed per component.
### Custom assembly form
``` {.ebnf}
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
execution-mode-op ::= `spv.FMul` ssa-use, ssa-use
`:` float-scalar-vector-type
```
For example:
```
spv.FMul %0, %1 : f32
spv.FMul %2, %3 : vector<4xf32>
```
}];
let arguments = (ins
SPV_ScalarOrVectorOf<SPV_Float>:$operand1,
SPV_ScalarOrVectorOf<SPV_Float>:$operand2
);
let results = (outs
SPV_ScalarOrVectorOf<AnyFloat>:$result
);
let parser = [{ return impl::parseBinaryOp(parser, result); }];
let printer = [{ return impl::printBinaryOp(getOperation(), p); }];
// No additional verification needed in addition to the ODS-generated ones.
let verifier = [{ return success(); }];
}
// -----
def SPV_LoadOp : SPV_Op<"Load", []> {
let summary = "Load through a pointer.";
let description = [{
Result Type is the type of the loaded object. It must be a type with
fixed size; i.e., it cannot be, nor include, any OpTypeRuntimeArray
types.
Pointer is the pointer to load through. Its type must be an
OpTypePointer whose Type operand is the same as Result Type.
If present, any Memory Operands must begin with a memory operand
literal. If not present, it is the same as specifying the memory operand
None.
### Custom assembly form
``` {.ebnf}
memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` integer-literal
| `"NonTemporal"`
load-op ::= ssa-id ` = spv.Load ` storage-class ssa-use
(`[` memory-access `]`)? ` : ` spirv-element-type
```
For example:
```
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Load "Function" %0 : f32
%2 = spv.Load "Function" %0 ["Volatile"] : f32
%3 = spv.Load "Function" %0 ["Aligned", 4] : f32
```
}];
let arguments = (ins
SPV_AnyPtr:$ptr,
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
OptionalAttr<I32Attr>:$alignment
);
let results = (outs
SPV_Type:$value
);
}
// -----
def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
let summary = "Return with no value from a function with void return type.";
let description = [{
This instruction must be the last instruction in a block.
### Custom assembly form
``` {.ebnf}
return-op ::= `spv.Return`
```
}];
let arguments = (ins);
let results = (outs);
let parser = [{ return parseNoIOOp(parser, result); }];
let printer = [{ printNoIOOp(getOperation(), p); }];
let verifier = [{ return verifyReturn(*this); }];
}
// -----
def SPV_StoreOp : SPV_Op<"Store", []> {
let summary = "Store through a pointer.";
let description = [{
Pointer is the pointer to store through. Its type must be an
OpTypePointer whose Type operand is the same as the type of Object.
Object is the object to store.
If present, any Memory Operands must begin with a memory operand
literal. If not present, it is the same as specifying the memory operand
None.
### Custom assembly form
``` {.ebnf}
store-op ::= `spv.Store ` storage-class ssa-use `, ` ssa-use `, `
(`[` memory-access `]`)? `:` spirv-element-type
```
For example:
```
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.FMul ... : f32
spv.Store "Function" %0, %1 : f32
spv.Store "Function" %0, %1 ["Volatile"] : f32
spv.Store "Function" %0, %1 ["Aligned", 4] : f32
}];
let arguments = (ins
SPV_AnyPtr:$ptr,
SPV_Type:$value,
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
OptionalAttr<I32Attr>:$alignment
);
let results = (outs);
}
// -----
def SPV_VariableOp : SPV_Op<"Variable", []> {
let summary = [{
Allocate an object in memory, resulting in a pointer to it, which can be
used with OpLoad and OpStore.
}];
let description = [{
Result Type must be an OpTypePointer. Its Type operand is the type of
object in memory.
Storage Class is the Storage Class of the memory holding the object. It
cannot be Generic. It must be the same as the Storage Class operand of
the Result Type.
Initializer is optional. If Initializer is present, it will be the
initial value of the variables memory content. Initializer must be an
<id> from a constant instruction or a global (module scope) OpVariable
instruction. Initializer must have the same type as the type pointed to
by Result Type.
### Custom assembly form
``` {.ebnf}
variable-op ::= ssa-id `=` `spv.Variable` (`init(` ssa-use `)`)?
(`bind(` integer-literal, integer-literal `)`)?
attribute-dict? `:` spirv-pointer-type
```
where `init` specifies initializer and `bind` specifies the descriptor set
and binding number.
For example:
```
%0 = spv.constant ...
%1 = spv.Variable : !spv.ptr<f32, Function>
%2 = spv.Variable init(%0): !spv.ptr<f32, Private>
%3 = spv.Variable init(%0) bind(1, 2): !spv.ptr<f32, Uniform>
```
}];
let arguments = (ins
SPV_StorageClassAttr:$storage_class,
SPV_Optional<AnyType>:$initializer
);
let results = (outs
SPV_AnyPtr:$pointer
);
}
// -----
#endif // SPIRV_OPS

View File

@ -0,0 +1,175 @@
//===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains ops for defining the SPIR-V structure: module, function,
// and module-level operations. The representational form of these ops deviate
// from the SPIR-V binary format in order to utilize MLIR mechanisms.
//
//===----------------------------------------------------------------------===//
#ifdef SPIRV_STRUCTURE_OPS
#else
#define SPIRV_STRUCTURE_OPS
#ifdef SPIRV_BASE
#else
include "mlir/SPIRV/SPIRVBase.td"
#endif // SPIRV_BASE
def SPV_ModuleOp : SPV_Op<"module", []> {
let summary = "The top-level op that defines a SPIR-V module";
let description = [{
This op defines a SPIR-V module using a MLIR region. The region contains
one block. Module-level operations, including functions definitions,
are all placed in this block.
Using an op with a region to define a SPIR-V module enables "embedding"
SPIR-V modules in other dialects in a clean manner: this op guarantees
the validaty and serializability of a SPIR-V module and thus serves as
a clear-cut boundary.
This op takes no operands and generates no results. This op should not
implicitly capture values from the enclosing environment.
This op has only one region, which only contains one block. The block
must be terminated via the `spv._module_end` op.
### Custom assembly form
``` {.ebnf}
addressing-model ::= `"Logical"` | `"Physical32"` | `"Physical64"`
memory-model ::= `"Simple"` | `"GLSL450"` | `"OpenCL"` | `"VulkanKHR"`
spv-module-op ::= `spv.module` addressing-model memory-model
region
(`attributes` attribute-dict)?
```
For example:
```
spv.module "Logical" "VulkanKHR" { }
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
} attributes {
capability = ["Shader"],
extension = ["SPV_KHR_16bit_storage"]
}
```
}];
let arguments = (ins
OptionalAttr<StrArrayAttr>:$capabilities,
OptionalAttr<StrArrayAttr>:$extensions,
OptionalAttr<StrArrayAttr>:$extended_instruction_sets,
SPV_AddressingModelAttr:$addressing_model,
SPV_MemoryModelAttr:$memory_model
);
let results = (outs);
let regions = (region SizedRegion<1>:$body);
let builders = [OpBuilder<"Builder *, OperationState *state">];
// We need to ensure the block inside the region is properly terminated;
// the auto-generated builders do not guarantee that.
let skipDefaultBuilders = 1;
let hasOpcode = 0;
let extraClassDeclaration = [{
Block& getBlock() {
return this->getOperation()->getRegion(0).front();
}
}];
}
def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> {
let summary = "The pseudo op that ends a SPIR-V module";
let description = [{
This op terminates the only block inside a `spv.module`'s only region.
This op does not have a corresponding SPIR-V instruction and thus will
not be serialized into the binary format; it is used solely to satisfy
the structual requirement that an block must be ended with a terminator.
}];
let arguments = (ins);
let results = (outs);
let parser = [{ return parseNoIOOp(parser, result); }];
let printer = [{ printNoIOOp(getOperation(), p); }];
let verifier = [{ return success(); }];
let hasOpcode = 0;
}
def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
let summary = "The op that declares a SPIR-V constant";
let description = [{
This op declares a SPIR-V constant. SPIR-V has multiple constant
instructions covering different constant types:
* `OpConstantTrue` and `OpConstantFalse` for boolean constants
* `OpConstant` for scalar constants
* `OpConstantComposite` for composite constants
* `OpConstantNull` for null constants
* ...
Having such a plethora of constant instructions renders IR transformations
more tedious. Therefore, we use a single `spv.constant` op to represent
them all. Note that conversion between those SPIR-V constant instructions
and this op is purely mechanical; so it can be scoped to the binary
(de)serialzation process.
### Custom assembly form
``` {.ebnf}
spv-constant-op ::= ssa-id `=` `spv.constant` attribute-value
(`:` spirv-type)?
```
For example:
```
%0 = spv.constant true
%1 = spv.constant dense<vector<2xf32>, [2, 3]>
%2 = spv.constant [dense<vector<2xf32>, 3.0>] : !spv.array<1xvector<2xf32>>
```
TODO(antiagainst): support constant structs
}];
let arguments = (ins
AnyAttr:$value
);
let results = (outs
SPV_Type:$constant
);
let hasOpcode = 0;
}
#endif // SPIRV_STRUCTURE_OPS

View File

@ -0,0 +1,185 @@
//===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file declares the types in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
#define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
// Pull in all enum type definitions and utility function declarations
#include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc"
#include <tuple>
namespace mlir {
namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct ImageTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
struct StructTypeStorage;
} // namespace detail
namespace TypeKind {
enum Kind {
Array = Type::FIRST_SPIRV_TYPE,
Image,
Pointer,
RuntimeArray,
Struct,
};
}
// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
class CompositeType : public Type {
public:
using Type::Type;
static bool classof(Type type) {
return (type.getKind() == TypeKind::Array ||
type.getKind() == TypeKind::Struct ||
type.getKind() == StandardTypes::Vector);
}
unsigned getNumElements() const;
Type getElementType(unsigned) const;
};
// SPIR-V array type
class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
detail::ArrayTypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
static ArrayType get(Type elementType, unsigned elementCount);
unsigned getNumElements() const;
Type getElementType() const;
};
// SPIR-V image type
class ImageType
: public Type::TypeBase<ImageType, Type, detail::ImageTypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == TypeKind::Image; }
static ImageType
get(Type elementType, Dim dim,
ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
ImageFormat format = ImageFormat::Unknown) {
return ImageType::get(
std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
elementType, dim, depth, arrayed, samplingInfo, samplerUse,
format));
}
static ImageType
get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
Type getElementType() const;
Dim getDim() const;
ImageDepthInfo getDepthInfo() const;
ImageArrayedInfo getArrayedInfo() const;
ImageSamplingInfo getSamplingInfo() const;
ImageSamplerUseInfo getSamplerUseInfo() const;
ImageFormat getImageFormat() const;
// TODO(ravishankarm): Add support for Access qualifier
};
// SPIR-V pointer type
class PointerType
: public Type::TypeBase<PointerType, Type, detail::PointerTypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == TypeKind::Pointer; }
static PointerType get(Type pointeeType, StorageClass storageClass);
Type getPointeeType() const;
StorageClass getStorageClass() const;
};
// SPIR-V run-time array type
class RuntimeArrayType
: public Type::TypeBase<RuntimeArrayType, Type,
detail::RuntimeArrayTypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == TypeKind::RuntimeArray; }
static RuntimeArrayType get(Type elementType);
Type getElementType() const;
};
// SPIR-V struct type
class StructType : public Type::TypeBase<StructType, CompositeType,
detail::StructTypeStorage> {
public:
using Base::Base;
// Layout information used for members in a struct in SPIR-V
//
// TODO(ravishankarm) : For now this only supports the offset type, so uses
// uint64_t value to represent the offset, with
// std::numeric_limit<uint64_t>::max indicating no offset. Change this to
// something that can hold all the information needed for different member
// types
using LayoutInfo = uint64_t;
static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
static StructType get(ArrayRef<Type> memberTypes);
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<LayoutInfo> layoutInfo);
unsigned getNumElements() const;
Type getElementType(unsigned) const;
bool hasLayout() const;
uint64_t getOffset(unsigned) const;
};
} // end namespace spirv
} // end namespace mlir
#endif // MLIR_DIALECT_SPIRV_SPIRVTYPES_H_

View File

@ -0,0 +1,49 @@
//===- Serialization.h - MLIR SPIR-V (De)serialization ----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file declares the entry points for serialize and deserialze SPIR-V
// binary modules.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_SERIALIZATION_H_
#define MLIR_DIALECT_SPIRV_SERIALIZATION_H_
#include "mlir/Support/LLVM.h"
namespace mlir {
struct LogicalResult;
class MLIRContext;
namespace spirv {
class ModuleOp;
/// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
/// reports errors to the error handler registered with the MLIR context for
/// `module`.
LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary);
/// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp
/// in the given `context`. Returns the ModuleOp on success; otherwise, reports
/// errors to the error handler registered with `context` and returns
/// llvm::None.
Optional<ModuleOp> deserialize(ArrayRef<uint32_t> binary, MLIRContext *context);
} // end namespace spirv
} // end namespace mlir
#endif // MLIR_DIALECT_SPIRV_SERIALIZATION_H_

View File

@ -0,0 +1,89 @@
//===- Traits.h - Common op traits shared by dialects -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file declares common op traits that are not core to MLIR but can be
// shared by multiple dialects.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRAITS
#define MLIR_DIALECT_TRAITS
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifyCompatibleOperandBroadcast(Operation *op);
} // namespace impl
namespace util {
/// Returns true and sets `resultShape` to the broadcasted shape from the two
/// given shapes if they are broadcast compatible. Returns false and clears
/// `resultShape` otherwise.
///
/// The rules for determing the result shape are:
///
/// Zip together the dimensions in the two given shapes by prepending the shape
/// with less dimensions with 1s. For each dimension pair, deduces the result
/// dimension according to the following order:
/// - If there are unknown dimensions, follows the TensorFlow behavior:
/// - If either dimension is greater than 1, we assume that the program is
/// correct, and the other dimension will be broadcast to match it.
/// - If either dimension is 1, the other dimension is the result.
/// - Otherwise, the result dimension is unknown dimension.
/// - If one of the dimension is 1, the other dimension is the result.
/// - If two dimensions are the same, that's the result.
/// - Otherwise, incompatible shape.
bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
SmallVectorImpl<int64_t> &resultShape);
/// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
Type getBroadcastedType(Type type1, Type type2);
} // namespace util
/// This class provides the API for ops that are known to have broadcast-
/// compatible operand and result types. Specifically, starting from the
/// most varying dimension, each dimension pair of the two operands' types
/// should either be the same or one of them is one. Also, the result type
/// should have the corresponding dimension equal to the larger one, if known.
/// Shapes are checked partially if ranks or dimensions are not known. For
/// example, an op with tensor<? x 2 x f32> and tensor <2 x f32> as operand
/// types and tensor<3 x 2 x f32> as the result type is broadcast-compatible.
///
/// Ths trait assumes the op has two operands and one result, and it asserts
/// if the pre-condition is not satisfied.
template <typename ConcreteType>
class BroadcastableTwoOperandsOneResult
: public TraitBase<ConcreteType, BroadcastableTwoOperandsOneResult> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyCompatibleOperandBroadcast(op);
}
};
} // end namespace OpTrait
} // end namespace mlir
#endif // MLIR_DIALECT_TRAITS

View File

@ -0,0 +1,500 @@
//===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Provides intuitive composable interfaces for building structured MLIR
// snippets in a declarative fashion.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EDSC_BUILDERS_H_
#define MLIR_EDSC_BUILDERS_H_
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/VectorOps/VectorOps.h"
namespace mlir {
namespace edsc {
struct index_t {
explicit index_t(int64_t v) : v(v) {}
explicit operator int64_t() { return v; }
int64_t v;
};
class BlockHandle;
class CapturableHandle;
class NestedBuilder;
class ValueHandle;
/// Helper class to transparently handle builder insertion points by RAII.
/// As its name indicates, a ScopedContext is means to be used locally in a
/// scoped fashion. This abstracts away all the boilerplate related to
/// checking proper usage of captures, NestedBuilders as well as handling the
/// setting and restoring of insertion points.
class ScopedContext {
public:
ScopedContext(OpBuilder &builder, Location location);
/// Sets the insertion point of the builder to 'newInsertPt' for the duration
/// of the scope. The existing insertion point of the builder is restored on
/// destruction.
ScopedContext(OpBuilder &builder, OpBuilder::InsertPoint newInsertPt,
Location location);
~ScopedContext();
static MLIRContext *getContext();
static OpBuilder &getBuilder();
static Location getLocation();
private:
/// Only NestedBuilder (which is used to create an operation with a body)
/// may access private members in order to implement scoping.
friend class NestedBuilder;
ScopedContext() = delete;
ScopedContext(const ScopedContext &) = delete;
ScopedContext &operator=(const ScopedContext &) = delete;
static ScopedContext *&getCurrentScopedContext();
/// Top level OpBuilder.
OpBuilder &builder;
/// The previous insertion point of the builder.
llvm::Optional<OpBuilder::InsertPoint> prevBuilderInsertPoint;
/// Current location.
Location location;
/// Parent context we return into.
ScopedContext *enclosingScopedContext;
/// Defensively keeps track of the current NestedBuilder to ensure proper
/// scoping usage.
NestedBuilder *nestedBuilder;
// TODO: Implement scoping of ValueHandles. To do this we need a proper data
// structure to hold ValueHandle objects. We can emulate one but there should
// already be something available in LLVM for this purpose.
};
/// A NestedBuilder is a scoping abstraction to create an idiomatic syntax
/// embedded in C++ that serves the purpose of building nested MLIR.
/// Nesting and compositionality is obtained by using the strict ordering that
/// exists between object construction and method invocation on said object (in
/// our case, the call to `operator()`).
/// This ordering allows implementing an abstraction that decouples definition
/// from declaration (in a PL sense) on placeholders of type ValueHandle and
/// BlockHandle.
class NestedBuilder {
protected:
NestedBuilder() = default;
NestedBuilder(const NestedBuilder &) = delete;
NestedBuilder(NestedBuilder &&other) : bodyScope(other.bodyScope) {
other.bodyScope = nullptr;
}
NestedBuilder &operator=(const NestedBuilder &) = delete;
NestedBuilder &operator=(NestedBuilder &&other) {
std::swap(bodyScope, other.bodyScope);
return *this;
}
/// Enter an mlir::Block and setup a ScopedContext to insert operations at
/// the end of it. Since we cannot use c++ language-level scoping to implement
/// scoping itself, we use enter/exit pairs of operations.
/// As a consequence we must allocate a new OpBuilder + ScopedContext and
/// let the escape.
/// Step back "prev" times from the end of the block to set up the insertion
/// point, which is useful for non-empty blocks.
void enter(mlir::Block *block, int prev = 0) {
bodyScope = new ScopedContext(
ScopedContext::getBuilder(),
OpBuilder::InsertPoint(block, std::prev(block->end(), prev)),
ScopedContext::getLocation());
bodyScope->nestedBuilder = this;
}
/// Exit the current mlir::Block by explicitly deleting the dynamically
/// allocated OpBuilder and ScopedContext.
void exit() {
// Reclaim now to exit the scope.
bodyScope->nestedBuilder = nullptr;
delete bodyScope;
bodyScope = nullptr;
}
/// Custom destructor does nothing because we already destroyed bodyScope
/// manually in `exit`. Insert an assertion to defensively guard against
/// improper usage of scoping.
~NestedBuilder() {
assert(!bodyScope &&
"Illegal use of NestedBuilder; must have called exit()");
}
private:
ScopedContext *bodyScope = nullptr;
};
/// A LoopBuilder is a generic NestedBuilder for loop-like MLIR operations.
/// More specifically it is meant to be used as a temporary object for
/// representing any nested MLIR construct that is "related to" an mlir::Value*
/// (for now an induction variable).
/// This is extensible and will evolve in the future as MLIR evolves, hence
/// the name LoopBuilder (as opposed to say ForBuilder or AffineForBuilder).
class LoopBuilder : public NestedBuilder {
public:
/// Constructs a new AffineForOp and captures the associated induction
/// variable. A ValueHandle pointer is passed as the first argument and is the
/// *only* way to capture the loop induction variable.
LoopBuilder(ValueHandle *iv, ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles, int64_t step);
LoopBuilder(const LoopBuilder &) = delete;
LoopBuilder(LoopBuilder &&) = default;
LoopBuilder &operator=(const LoopBuilder &) = delete;
LoopBuilder &operator=(LoopBuilder &&) = default;
/// The only purpose of this operator is to serve as a sequence point so that
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
/// scoped within a LoopBuilder.
ValueHandle operator()(llvm::function_ref<void(void)> fun = nullptr);
};
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
/// explicitly writing all the loops in a nest. This simple functionality is
/// also useful to write rank-agnostic custom ops.
///
/// Usage:
///
/// ```c++
/// LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})(
/// [&](){
/// ...
/// });
/// ```
///
/// ```c++
/// LoopNestBuilder({&i}, {lb}, {ub}, {1})([&](){
/// LoopNestBuilder({&j}, {lb}, {ub}, {1})([&](){
/// LoopNestBuilder({&k}, {lb}, {ub}, {1})([&](){
/// ...
/// }),
/// }),
/// });
/// ```
class LoopNestBuilder {
public:
LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
ValueHandle operator()(llvm::function_ref<void(void)> fun = nullptr);
private:
SmallVector<LoopBuilder, 4> loops;
};
// This class exists solely to handle the C++ vexing parse case when
// trying to enter a Block that has already been constructed.
class Append {};
/// A BlockBuilder is a NestedBuilder for mlir::Block*.
/// This exists by opposition to LoopBuilder which is not related to an
/// mlir::Block* but to a mlir::Value*.
/// It is meant to be used as a temporary object for representing any nested
/// MLIR construct that is "related to" an mlir::Block*.
class BlockBuilder : public NestedBuilder {
public:
/// Enters the mlir::Block* previously captured by `bh` and sets the insertion
/// point to its end.
BlockBuilder(BlockHandle bh, Append);
/// Constructs a new mlir::Block with argument types derived from `args`.
/// Captures the new block in `bh` and its arguments into `args`.
/// Enters the new mlir::Block* and sets the insertion point to its end.
///
/// Prerequisites:
/// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are
/// not yet bound to mlir::Value*.
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
/// The only purpose of this operator is to serve as a sequence point so that
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
/// scoped within a BlockBuilder.
void operator()(llvm::function_ref<void(void)> fun = nullptr);
private:
BlockBuilder(BlockBuilder &) = delete;
BlockBuilder &operator=(BlockBuilder &other) = delete;
};
/// Base class for ValueHandle, OperationHandle and BlockHandle.
/// Not meant to be used outside of these classes.
class CapturableHandle {
protected:
CapturableHandle() = default;
};
/// ValueHandle implements a (potentially "delayed") typed Value abstraction.
/// ValueHandle should be captured by pointer but otherwise passed by Value
/// everywhere.
/// A ValueHandle can have 3 states:
/// 1. null state (empty type and empty value), in which case it does not hold
/// a value and must never hold a Value (now or in the future). This is
/// used for MLIR operations with zero returns as well as the result of
/// calling a NestedBuilder::operator(). In both cases the objective is to
/// have an object that can be inserted in an ArrayRef<ValueHandle> to
/// implement nesting;
/// 2. delayed state (empty value), in which case it represents an eagerly
/// typed "delayed" value that can be hold a Value in the future;
/// 3. constructed state,in which case it holds a Value.
///
/// A ValueHandle is meant to capture a single Value* and should be used for
/// operations that have a single result. For convenience of use, we also
/// include AffineForOp in this category although it does not return a value.
/// In the case of AffineForOp, the captured Value* is the loop induction
/// variable.
class ValueHandle : public CapturableHandle {
public:
/// A ValueHandle in a null state can never be captured;
static ValueHandle null() { return ValueHandle(); }
/// A ValueHandle that is constructed from a Type represents a typed "delayed"
/// Value. A delayed Value can only capture Values of the specified type.
/// Such a delayed value represents the declaration (in the PL sense) of a
/// placeholder for an mlir::Value* that will be constructed and captured at
/// some later point in the program.
explicit ValueHandle(Type t) : t(t), v(nullptr) {}
/// A ValueHandle that is constructed from an mlir::Value* is an "eager"
/// Value. An eager Value represents both the declaration and the definition
/// (in the PL sense) of a placeholder for an mlir::Value* that has already
/// been constructed in the past and that is captured "now" in the program.
explicit ValueHandle(Value *v) : t(v->getType()), v(v) {}
/// Builds a ConstantIndexOp of value `cst`. The constant is created at the
/// current insertion point.
/// This implicit constructor is provided to each build an eager Value for a
/// constant at the current insertion point in the IR. An implicit constructor
/// allows idiomatic expressions mixing ValueHandle and literals.
ValueHandle(index_t cst);
/// ValueHandle is a value type, use the default copy constructor.
ValueHandle(const ValueHandle &other) = default;
/// ValueHandle is a value type, the assignment operator typechecks before
/// assigning.
ValueHandle &operator=(const ValueHandle &other);
/// Provide a swap operator.
void swap(ValueHandle &other) {
if (this == &other)
return;
std::swap(t, other.t);
std::swap(v, other.v);
}
/// Implicit conversion useful for automatic conversion to Container<Value*>.
operator Value *() const { return getValue(); }
/// Generic mlir::Op create. This is the key to being extensible to the whole
/// of MLIR without duplicating the type system or the op definitions.
template <typename Op, typename... Args>
static ValueHandle create(Args... args);
/// Generic mlir::Op create. This is the key to being extensible to the whole
/// of MLIR without duplicating the type system or the op definitions.
template <typename Op, typename... Args>
static ValueHandle create(OperationFolder &folder, Args... args);
/// Special case to build composed AffineApply operations.
// TODO: createOrFold when available and move inside of the `create` method.
static ValueHandle createComposedAffineApply(AffineMap map,
ArrayRef<Value *> operands);
/// Generic create for a named operation producing a single value.
static ValueHandle create(StringRef name, ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes = {});
bool hasValue() const { return v != nullptr; }
Value *getValue() const {
assert(hasValue() && "Unexpected null value;");
return v;
}
bool hasType() const { return t != Type(); }
Type getType() const { return t; }
Operation *getOperation() const {
if (!v)
return nullptr;
return v->getDefiningOp();
}
protected:
ValueHandle() : t(), v(nullptr) {}
Type t;
Value *v;
};
/// An OperationHandle can be used in lieu of ValueHandle to capture the
/// operation in cases when one does not care about, or cannot extract, a
/// unique Value* from the operation.
/// This can be used for capturing zero result operations as well as
/// multi-result operations that are not supported by ValueHandle.
/// We do not distinguish further between zero and multi-result operations at
/// this time.
struct OperationHandle : public CapturableHandle {
OperationHandle() : op(nullptr) {}
OperationHandle(Operation *op) : op(op) {}
OperationHandle(const OperationHandle &) = default;
OperationHandle &operator=(const OperationHandle &) = default;
/// Generic mlir::Op create. This is the key to being extensible to the whole
/// of MLIR without duplicating the type system or the op definitions.
template <typename Op, typename... Args>
static OperationHandle create(Args... args);
template <typename Op, typename... Args> static Op createOp(Args... args);
/// Generic create for a named operation.
static OperationHandle create(StringRef name, ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes = {});
operator Operation *() { return op; }
Operation *getOperation() const { return op; }
private:
Operation *op;
};
/// Simple wrapper to build a generic operation without successor blocks.
template <typename HandleType> struct CustomOperation {
CustomOperation(StringRef name) : name(name) {
static_assert(std::is_same<HandleType, ValueHandle>() ||
std::is_same<HandleType, OperationHandle>(),
"Only CustomOperation<ValueHandle> or "
"CustomOperation<OperationHandle> can be constructed.");
}
HandleType operator()(ArrayRef<ValueHandle> operands = {},
ArrayRef<Type> resultTypes = {},
ArrayRef<NamedAttribute> attributes = {}) {
return HandleType::create(name, operands, resultTypes, attributes);
}
std::string name;
};
/// A BlockHandle represents a (potentially "delayed") Block abstraction.
/// This extra abstraction is necessary because an mlir::Block is not an
/// mlir::Value.
/// A BlockHandle should be captured by pointer but otherwise passed by Value
/// everywhere.
class BlockHandle : public CapturableHandle {
public:
/// A BlockHandle constructed without an mlir::Block* represents a "delayed"
/// Block. A delayed Block represents the declaration (in the PL sense) of a
/// placeholder for an mlir::Block* that will be constructed and captured at
/// some later point in the program.
BlockHandle() : block(nullptr) {}
/// A BlockHandle constructed with an mlir::Block* represents an "eager"
/// Block. An eager Block represents both the declaration and the definition
/// (in the PL sense) of a placeholder for an mlir::Block* that has already
/// been constructed in the past and that is captured "now" in the program.
BlockHandle(mlir::Block *block) : block(block) {}
/// BlockHandle is a value type, use the default copy constructor and
/// assignment operator.
BlockHandle(const BlockHandle &) = default;
BlockHandle &operator=(const BlockHandle &) = default;
/// Delegates block creation to MLIR and wrap the resulting mlir::Block.
static BlockHandle create(ArrayRef<Type> argTypes);
operator bool() { return block != nullptr; }
operator mlir::Block *() { return block; }
mlir::Block *getBlock() { return block; }
private:
mlir::Block *block;
};
template <typename Op, typename... Args>
OperationHandle OperationHandle::create(Args... args) {
return OperationHandle(ScopedContext::getBuilder()
.create<Op>(ScopedContext::getLocation(), args...)
.getOperation());
}
template <typename Op, typename... Args>
Op OperationHandle::createOp(Args... args) {
return cast<Op>(
OperationHandle(ScopedContext::getBuilder()
.create<Op>(ScopedContext::getLocation(), args...)
.getOperation())
.getOperation());
}
template <typename Op, typename... Args>
ValueHandle ValueHandle::create(Args... args) {
Operation *op = ScopedContext::getBuilder()
.create<Op>(ScopedContext::getLocation(), args...)
.getOperation();
if (op->getNumResults() == 1) {
return ValueHandle(op->getResult(0));
} else if (op->getNumResults() == 0) {
if (auto f = dyn_cast<AffineForOp>(op)) {
return ValueHandle(f.getInductionVar());
}
}
llvm_unreachable("unsupported operation, use an OperationHandle instead");
}
template <typename Op, typename... Args>
ValueHandle ValueHandle::create(OperationFolder &folder, Args... args) {
return ValueHandle(folder.create<Op>(ScopedContext::getBuilder(),
ScopedContext::getLocation(), args...));
}
namespace op {
ValueHandle operator+(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator-(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator*(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator/(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator%(ValueHandle lhs, ValueHandle rhs);
ValueHandle floorDiv(ValueHandle lhs, ValueHandle rhs);
ValueHandle ceilDiv(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator!(ValueHandle value);
ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator||(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator^(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator==(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator<(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator>(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs);
} // namespace op
} // namespace edsc
} // namespace mlir
#endif // MLIR_EDSC_BUILDERS_H_

View File

@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS "${MLIR_SOURCE_DIR}/test/mlir-tblgen/reference-impl.td")
mlir_tablegen("reference-impl.inc" -gen-reference-implementations)
add_public_tablegen_target(MLIRReferenceImplementationTestGen)

View File

@ -0,0 +1,264 @@
//===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Provides helper classes and syntactic sugar for declarative builders.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EDSC_HELPERS_H_
#define MLIR_EDSC_HELPERS_H_
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Intrinsics.h"
namespace mlir {
namespace edsc {
// A TemplatedIndexedValue brings an index notation over the template Load and
// Store parameters.
template <typename Load, typename Store> class TemplatedIndexedValue;
// By default, edsc::IndexedValue provides an index notation around the affine
// load and stores.
using IndexedValue =
TemplatedIndexedValue<intrinsics::affine_load, intrinsics::affine_store>;
// Base class for MemRefView and VectorView.
class View {
public:
unsigned rank() const { return lbs.size(); }
ValueHandle lb(unsigned idx) { return lbs[idx]; }
ValueHandle ub(unsigned idx) { return ubs[idx]; }
int64_t step(unsigned idx) { return steps[idx]; }
std::tuple<ValueHandle, ValueHandle, int64_t> range(unsigned idx) {
return std::make_tuple(lbs[idx], ubs[idx], steps[idx]);
}
void swapRanges(unsigned i, unsigned j) {
if (i == j)
return;
lbs[i].swap(lbs[j]);
ubs[i].swap(ubs[j]);
std::swap(steps[i], steps[j]);
}
ArrayRef<ValueHandle> getLbs() { return lbs; }
ArrayRef<ValueHandle> getUbs() { return ubs; }
ArrayRef<int64_t> getSteps() { return steps; }
protected:
SmallVector<ValueHandle, 8> lbs;
SmallVector<ValueHandle, 8> ubs;
SmallVector<int64_t, 8> steps;
};
/// A MemRefView represents the information required to step through a
/// MemRef. It has placeholders for non-contiguous tensors that fit within the
/// Fortran subarray model.
/// At the moment it can only capture a MemRef with an identity layout map.
// TODO(ntv): Support MemRefs with layoutMaps.
class MemRefView : public View {
public:
explicit MemRefView(Value *v);
MemRefView(const MemRefView &) = default;
MemRefView &operator=(const MemRefView &) = default;
unsigned fastestVarying() const { return rank() - 1; }
private:
friend IndexedValue;
ValueHandle base;
};
/// A VectorView represents the information required to step through a
/// Vector accessing each scalar element at a time. It is the counterpart of
/// a MemRefView but for vectors. This exists purely for boilerplate avoidance.
class VectorView : public View {
public:
explicit VectorView(Value *v);
VectorView(const VectorView &) = default;
VectorView &operator=(const VectorView &) = default;
private:
friend IndexedValue;
ValueHandle base;
};
/// A TemplatedIndexedValue brings an index notation over the template Load and
/// Store parameters. This helper class is an abstraction purely for sugaring
/// purposes and allows writing compact expressions such as:
///
/// ```mlir
/// // `IndexedValue` provided by default in the mlir::edsc namespace.
/// using IndexedValue =
/// TemplatedIndexedValue<intrinsics::load, intrinsics::store>;
/// IndexedValue A(...), B(...), C(...);
/// For(ivs, zeros, shapeA, ones, {
/// C(ivs) = A(ivs) + B(ivs)
/// });
/// ```
///
/// Assigning to an IndexedValue emits an actual `Store` operation, while
/// converting an IndexedValue to a ValueHandle emits an actual `Load`
/// operation.
template <typename Load, typename Store> class TemplatedIndexedValue {
public:
explicit TemplatedIndexedValue(Type t) : base(t) {}
explicit TemplatedIndexedValue(Value *v)
: TemplatedIndexedValue(ValueHandle(v)) {}
explicit TemplatedIndexedValue(ValueHandle v) : base(v) {}
TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default;
TemplatedIndexedValue operator()() { return *this; }
/// Returns a new `TemplatedIndexedValue`.
TemplatedIndexedValue operator()(ValueHandle index) {
TemplatedIndexedValue res(base);
res.indices.push_back(index);
return res;
}
template <typename... Args>
TemplatedIndexedValue operator()(ValueHandle index, Args... indices) {
return TemplatedIndexedValue(base, index).append(indices...);
}
TemplatedIndexedValue operator()(llvm::ArrayRef<ValueHandle> indices) {
return TemplatedIndexedValue(base, indices);
}
TemplatedIndexedValue operator()(llvm::ArrayRef<IndexHandle> indices) {
return TemplatedIndexedValue(
base, llvm::ArrayRef<ValueHandle>(indices.begin(), indices.end()));
}
/// Emits a `store`.
// NOLINTNEXTLINE: unconventional-assign-operator
OperationHandle operator=(const TemplatedIndexedValue &rhs) {
ValueHandle rrhs(rhs);
return Store(rrhs, getBase(), {indices.begin(), indices.end()});
}
// NOLINTNEXTLINE: unconventional-assign-operator
OperationHandle operator=(ValueHandle rhs) {
return Store(rhs, getBase(), {indices.begin(), indices.end()});
}
/// Emits a `load` when converting to a ValueHandle.
operator ValueHandle() const {
return Load(getBase(), {indices.begin(), indices.end()});
}
/// Emits a `load` when converting to a Value*.
Value *operator*(void)const {
return Load(getBase(), {indices.begin(), indices.end()}).getValue();
}
ValueHandle getBase() const { return base; }
/// Operator overloadings.
ValueHandle operator+(ValueHandle e);
ValueHandle operator-(ValueHandle e);
ValueHandle operator*(ValueHandle e);
ValueHandle operator/(ValueHandle e);
OperationHandle operator+=(ValueHandle e);
OperationHandle operator-=(ValueHandle e);
OperationHandle operator*=(ValueHandle e);
OperationHandle operator/=(ValueHandle e);
ValueHandle operator+(TemplatedIndexedValue e) {
return *this + static_cast<ValueHandle>(e);
}
ValueHandle operator-(TemplatedIndexedValue e) {
return *this - static_cast<ValueHandle>(e);
}
ValueHandle operator*(TemplatedIndexedValue e) {
return *this * static_cast<ValueHandle>(e);
}
ValueHandle operator/(TemplatedIndexedValue e) {
return *this / static_cast<ValueHandle>(e);
}
OperationHandle operator+=(TemplatedIndexedValue e) {
return this->operator+=(static_cast<ValueHandle>(e));
}
OperationHandle operator-=(TemplatedIndexedValue e) {
return this->operator-=(static_cast<ValueHandle>(e));
}
OperationHandle operator*=(TemplatedIndexedValue e) {
return this->operator*=(static_cast<ValueHandle>(e));
}
OperationHandle operator/=(TemplatedIndexedValue e) {
return this->operator/=(static_cast<ValueHandle>(e));
}
private:
TemplatedIndexedValue(ValueHandle base, ArrayRef<ValueHandle> indices)
: base(base), indices(indices.begin(), indices.end()) {}
TemplatedIndexedValue &append() { return *this; }
template <typename T, typename... Args>
TemplatedIndexedValue &append(T index, Args... indices) {
this->indices.push_back(static_cast<ValueHandle>(index));
append(indices...);
return *this;
}
ValueHandle base;
llvm::SmallVector<ValueHandle, 8> indices;
};
/// Operator overloadings.
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator+(ValueHandle e) {
using op::operator+;
return static_cast<ValueHandle>(*this) + e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator-(ValueHandle e) {
using op::operator-;
return static_cast<ValueHandle>(*this) - e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator*(ValueHandle e) {
using op::operator*;
return static_cast<ValueHandle>(*this) * e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator/(ValueHandle e) {
using op::operator/;
return static_cast<ValueHandle>(*this) / e;
}
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator+=(ValueHandle e) {
using op::operator+;
return Store(*this + e, getBase(), {indices.begin(), indices.end()});
}
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator-=(ValueHandle e) {
using op::operator-;
return Store(*this - e, getBase(), {indices.begin(), indices.end()});
}
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator*=(ValueHandle e) {
using op::operator*;
return Store(*this * e, getBase(), {indices.begin(), indices.end()});
}
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator/=(ValueHandle e) {
using op::operator/;
return Store(*this / e, getBase(), {indices.begin(), indices.end()});
}
} // namespace edsc
} // namespace mlir
#endif // MLIR_EDSC_HELPERS_H_

View File

@ -0,0 +1,265 @@
//===- Intrinsics.h - MLIR Operations for Declarative Builders ---*- C++-*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Provides intuitive composable intrinsics for building snippets of MLIR
// declaratively
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EDSC_INTRINSICS_H_
#define MLIR_EDSC_INTRINSICS_H_
#include "mlir/EDSC/Builders.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class MemRefType;
class Type;
namespace edsc {
/// An IndexHandle is a simple wrapper around a ValueHandle.
/// IndexHandles are ubiquitous enough to justify a new type to allow simple
/// declarations without boilerplate such as:
///
/// ```c++
/// IndexHandle i, j, k;
/// ```
struct IndexHandle : public ValueHandle {
explicit IndexHandle()
: ValueHandle(ScopedContext::getBuilder().getIndexType()) {}
explicit IndexHandle(index_t v) : ValueHandle(v) {}
explicit IndexHandle(Value *v) : ValueHandle(v) {
assert(v->getType() == ScopedContext::getBuilder().getIndexType() &&
"Expected index type");
}
explicit IndexHandle(ValueHandle v) : ValueHandle(v) {
assert(v.getType() == ScopedContext::getBuilder().getIndexType() &&
"Expected index type");
}
IndexHandle &operator=(const ValueHandle &v) {
assert(v.getType() == ScopedContext::getBuilder().getIndexType() &&
"Expected index type");
/// Creating a new IndexHandle(v) and then std::swap rightly complains the
/// binding has already occurred and that we should use another name.
this->t = v.getType();
this->v = v.getValue();
return *this;
}
static SmallVector<IndexHandle, 8> makeIndexHandles(unsigned rank) {
return SmallVector<IndexHandle, 8>(rank);
}
static SmallVector<ValueHandle *, 8>
makeIndexHandlePointers(SmallVectorImpl<IndexHandle> &ivs) {
SmallVector<ValueHandle *, 8> pivs;
pivs.reserve(ivs.size());
for (auto &iv : ivs) {
pivs.push_back(&iv);
}
return pivs;
}
};
/// Provides a set of first class intrinsics.
/// In the future, most of intrinsics related to Operation that don't contain
/// other operations should be Tablegen'd.
namespace intrinsics {
namespace detail {
/// Helper structure to be used with ValueBuilder / OperationBuilder.
/// It serves the purpose of removing boilerplate specialization for the sole
/// purpose of implicitly converting ArrayRef<ValueHandle> -> ArrayRef<Value*>.
class ValueHandleArray {
public:
ValueHandleArray(ArrayRef<ValueHandle> vals) {
values.append(vals.begin(), vals.end());
}
ValueHandleArray(ArrayRef<IndexHandle> vals) {
values.append(vals.begin(), vals.end());
}
ValueHandleArray(ArrayRef<index_t> vals) {
llvm::SmallVector<IndexHandle, 8> tmp(vals.begin(), vals.end());
values.append(tmp.begin(), tmp.end());
}
operator ArrayRef<Value *>() { return values; }
private:
ValueHandleArray() = default;
llvm::SmallVector<Value *, 8> values;
};
template <typename T> inline T unpack(T value) { return value; }
inline detail::ValueHandleArray unpack(ArrayRef<ValueHandle> values) {
return detail::ValueHandleArray(values);
}
} // namespace detail
/// Helper variadic abstraction to allow extending to any MLIR op without
/// boilerplate or Tablegen.
/// Arguably a builder is not a ValueHandle but in practice it is only used as
/// an alias to a notional ValueHandle<Op>.
/// Implementing it as a subclass allows it to compose all the way to Value*.
/// Without subclassing, implicit conversion to Value* would fail when composing
/// in patterns such as: `select(a, b, select(c, d, e))`.
template <typename Op> struct ValueBuilder : public ValueHandle {
// Builder-based
template <typename... Args>
ValueBuilder(Args... args)
: ValueHandle(ValueHandle::create<Op>(detail::unpack(args)...)) {}
ValueBuilder(ArrayRef<ValueHandle> vs)
: ValueBuilder(ValueBuilder::create<Op>(detail::unpack(vs))) {}
template <typename... Args>
ValueBuilder(ArrayRef<ValueHandle> vs, Args... args)
: ValueHandle(ValueHandle::create<Op>(detail::unpack(vs),
detail::unpack(args)...)) {}
template <typename T, typename... Args>
ValueBuilder(T t, ArrayRef<ValueHandle> vs, Args... args)
: ValueHandle(ValueHandle::create<Op>(
detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {}
template <typename T1, typename T2, typename... Args>
ValueBuilder(T1 t1, T2 t2, ArrayRef<ValueHandle> vs, Args... args)
: ValueHandle(ValueHandle::create<Op>(
detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
detail::unpack(args)...)) {}
/// Folder-based
template <typename... Args>
ValueBuilder(OperationFolder &folder, Args... args)
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(args)...)) {}
ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs)
: ValueBuilder(ValueBuilder::create<Op>(folder, detail::unpack(vs))) {}
template <typename... Args>
ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs, Args... args)
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(vs),
detail::unpack(args)...)) {}
template <typename T, typename... Args>
ValueBuilder(OperationFolder &folder, T t, ArrayRef<ValueHandle> vs,
Args... args)
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(t),
detail::unpack(vs),
detail::unpack(args)...)) {}
template <typename T1, typename T2, typename... Args>
ValueBuilder(OperationFolder &folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
Args... args)
: ValueHandle(ValueHandle::create<Op>(
folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
detail::unpack(args)...)) {}
ValueBuilder() : ValueHandle(ValueHandle::create<Op>()) {}
};
template <typename Op> struct OperationBuilder : public OperationHandle {
template <typename... Args>
OperationBuilder(Args... args)
: OperationHandle(OperationHandle::create<Op>(detail::unpack(args)...)) {}
OperationBuilder(ArrayRef<ValueHandle> vs)
: OperationHandle(OperationHandle::create<Op>(detail::unpack(vs))) {}
template <typename... Args>
OperationBuilder(ArrayRef<ValueHandle> vs, Args... args)
: OperationHandle(OperationHandle::create<Op>(detail::unpack(vs),
detail::unpack(args)...)) {}
template <typename T, typename... Args>
OperationBuilder(T t, ArrayRef<ValueHandle> vs, Args... args)
: OperationHandle(OperationHandle::create<Op>(
detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {}
template <typename T1, typename T2, typename... Args>
OperationBuilder(T1 t1, T2 t2, ArrayRef<ValueHandle> vs, Args... args)
: OperationHandle(OperationHandle::create<Op>(
detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
detail::unpack(args)...)) {}
OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {}
};
using alloc = ValueBuilder<AllocOp>;
using affine_apply = ValueBuilder<AffineApplyOp>;
using affine_load = ValueBuilder<AffineLoadOp>;
using affine_store = OperationBuilder<AffineStoreOp>;
using constant_float = ValueBuilder<ConstantFloatOp>;
using constant_index = ValueBuilder<ConstantIndexOp>;
using constant_int = ValueBuilder<ConstantIntOp>;
using dealloc = OperationBuilder<DeallocOp>;
using dim = ValueBuilder<DimOp>;
using muli = ValueBuilder<MulIOp>;
using ret = OperationBuilder<ReturnOp>;
using select = ValueBuilder<SelectOp>;
using std_load = ValueBuilder<LoadOp>;
using std_store = OperationBuilder<StoreOp>;
using subi = ValueBuilder<SubIOp>;
using vector_type_cast = ValueBuilder<VectorTypeCastOp>;
/// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
///
/// Prerequisites:
/// All Handles have already captured previously constructed IR objects.
OperationHandle br(BlockHandle bh, ArrayRef<ValueHandle> operands);
/// Creates a new mlir::Block* and branches to it from the current block.
/// Argument types are specified by `operands`.
/// Captures the new block in `bh` and the actual `operands` in `captures`. To
/// insert the new mlir::Block*, a local ScopedContext is constructed and
/// released to the current block. The branch operation is then added to the
/// new block.
///
/// Prerequisites:
/// `b` has not yet captured an mlir::Block*.
/// No `captures` have captured any mlir::Value*.
/// All `operands` have already captured an mlir::Value*
/// captures.size() == operands.size()
/// captures and operands are pairwise of the same type.
OperationHandle br(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
ArrayRef<ValueHandle> operands);
/// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with
/// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and
/// `falseOperand` if `cond` evaluates to `false`).
///
/// Prerequisites:
/// All Handles have captured previouly constructed IR objects.
OperationHandle cond_br(ValueHandle cond, BlockHandle trueBranch,
ArrayRef<ValueHandle> trueOperands,
BlockHandle falseBranch,
ArrayRef<ValueHandle> falseOperands);
/// Eagerly creates new mlir::Block* with argument types specified by
/// `trueOperands`/`falseOperands`.
/// Captures the new blocks in `trueBranch`/`falseBranch` and the arguments in
/// `trueCaptures/falseCaptures`.
/// To insert the new mlir::Block*, a local ScopedContext is constructed and
/// released. The branch operation is then added in the original location and
/// targeting the eagerly constructed blocks.
///
/// Prerequisites:
/// `trueBranch`/`falseBranch` has not yet captured an mlir::Block*.
/// No `trueCaptures`/`falseCaptures` have captured any mlir::Value*.
/// All `trueOperands`/`trueOperands` have already captured an mlir::Value*
/// `trueCaptures`.size() == `trueOperands`.size()
/// `falseCaptures`.size() == `falseOperands`.size()
/// `trueCaptures` and `trueOperands` are pairwise of the same type
/// `falseCaptures` and `falseOperands` are pairwise of the same type.
OperationHandle cond_br(ValueHandle cond, BlockHandle *trueBranch,
ArrayRef<ValueHandle *> trueCaptures,
ArrayRef<ValueHandle> trueOperands,
BlockHandle *falseBranch,
ArrayRef<ValueHandle *> falseCaptures,
ArrayRef<ValueHandle> falseOperands);
} // namespace intrinsics
} // namespace edsc
} // namespace mlir
#endif // MLIR_EDSC_INTRINSICS_H_

View File

@ -0,0 +1,111 @@
//===- ExecutionEngine.h - MLIR Execution engine and utils -----*- C++ -*--===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file provides a JIT-backed execution engine for MLIR modules.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_
#define MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_
#include "mlir/Support/LLVM.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/Error.h"
#include <functional>
#include <memory>
namespace llvm {
template <typename T> class Expected;
class Module;
} // namespace llvm
namespace mlir {
class ModuleOp;
namespace impl {
class OrcJIT;
} // end namespace impl
/// JIT-backed execution engine for MLIR modules. Assumes the module can be
/// converted to LLVM IR. For each function, creates a wrapper function with
/// the fixed interface
///
/// void _mlir_funcName(void **)
///
/// where the only argument is interpreted as a list of pointers to the actual
/// arguments of the function, followed by a pointer to the result. This allows
/// the engine to provide the caller with a generic function pointer that can
/// be used to invoke the JIT-compiled function.
class ExecutionEngine {
public:
~ExecutionEngine();
/// Creates an execution engine for the given module. If `transformer` is
/// provided, it will be called on the LLVM module during JIT-compilation and
/// can be used, e.g., for reporting or optimization.
/// If `sharedLibPaths` are provided, the underlying JIT-compilation will open
/// and link the shared libraries for symbol resolution.
static llvm::Expected<std::unique_ptr<ExecutionEngine>>
create(ModuleOp m,
std::function<llvm::Error(llvm::Module *)> transformer = {},
ArrayRef<StringRef> sharedLibPaths = {});
/// Looks up a packed-argument function with the given name and returns a
/// pointer to it. Propagates errors in case of failure.
llvm::Expected<void (*)(void **)> lookup(StringRef name) const;
/// Invokes the function with the given name passing it the list of arguments.
/// The arguments are accepted by lvalue-reference since the packed function
/// interface expects a list of non-null pointers.
template <typename... Args>
llvm::Error invoke(StringRef name, Args &... args);
/// Invokes the function with the given name passing it the list of arguments
/// as a list of opaque pointers. This is the arity-agnostic equivalent of
/// the templated `invoke`.
llvm::Error invoke(StringRef name, MutableArrayRef<void *> args);
/// Set the target triple on the module. This is implicitly done when creating
/// the engine.
static bool setupTargetTriple(llvm::Module *llvmModule);
private:
// Ordering of llvmContext and jit is important for destruction purposes: the
// jit must be destroyed before the context.
llvm::LLVMContext llvmContext;
// Private implementation of the JIT (PIMPL)
std::unique_ptr<impl::OrcJIT> jit;
};
template <typename... Args>
llvm::Error ExecutionEngine::invoke(StringRef name, Args &... args) {
auto expectedFPtr = lookup(name);
if (!expectedFPtr)
return expectedFPtr.takeError();
auto fptr = *expectedFPtr;
llvm::SmallVector<void *, 8> packedArgs{static_cast<void *>(&args)...};
(*fptr)(packedArgs.data());
return llvm::Error::success();
}
} // end namespace mlir
#endif // MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_

View File

@ -0,0 +1,54 @@
//===- MemRefUtils.h - MLIR runtime utilities for memrefs -------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is a set of utilities to working with objects of memref type in an JIT
// context using the MLIR execution engine.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
#include "mlir/Support/LLVM.h"
namespace llvm {
template <typename T> class Expected;
}
namespace mlir {
class FuncOp;
/// Simple memref descriptor class compatible with the ABI of functions emitted
/// by MLIR to LLVM IR conversion for statically-shaped memrefs of float type.
struct StaticFloatMemRef {
float *data;
};
/// Given an MLIR function that takes only statically-shaped memrefs with
/// element type f32, allocate the memref descriptor and the data storage for
/// each of the arguments, initialize the storage with `initialValue`, and
/// return a list of type-erased descriptor pointers.
llvm::Expected<SmallVector<void *, 8>>
allocateMemRefArguments(FuncOp func, float initialValue = 0.0);
/// Free a list of type-erased descriptors to statically-shaped memrefs with
/// element type f32.
void freeMemRefArguments(ArrayRef<void *> args);
} // namespace mlir
#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_

View File

@ -0,0 +1,59 @@
//===- OptUtils.h - MLIR Execution Engine opt pass utilities ----*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file declares the utility functions to trigger LLVM optimizations from
// MLIR Execution Engine.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EXECUTIONENGINE_OPTUTILS_H_
#define MLIR_EXECUTIONENGINE_OPTUTILS_H_
#include "llvm/Pass.h"
#include <functional>
#include <string>
namespace llvm {
class Module;
class Error;
} // namespace llvm
namespace mlir {
/// Initialize LLVM passes that can be when running MLIR code using
/// ExecutionEngine.
void initializeLLVMPasses();
/// Create a module transformer function for MLIR ExecutionEngine that runs
/// LLVM IR passes corresponding to the given speed and size optimization
/// levels (e.g. -O2 or -Os).
std::function<llvm::Error(llvm::Module *)>
makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel);
/// Create a module transformer function for MLIR ExecutionEngine that runs
/// LLVM IR passes explicitly specified, plus an optional optimization level,
/// Any optimization passes, if present, will be inserted before the pass at
/// position optPassesInsertPos.
std::function<llvm::Error(llvm::Module *)>
makeLLVMPassesTransformer(llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
llvm::Optional<unsigned> mbOptLevel,
unsigned optPassesInsertPos = 0);
} // end namespace mlir
#endif // LIR_EXECUTIONENGINE_OPTUTILS_H_

View File

@ -0,0 +1,311 @@
//===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// An affine expression is an affine combination of dimension identifiers and
// symbols, including ceildiv/floordiv/mod by a constant integer.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_AFFINE_EXPR_H
#define MLIR_IR_AFFINE_EXPR_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/Support/Casting.h"
#include <type_traits>
namespace mlir {
class MLIRContext;
class AffineMap;
class IntegerSet;
namespace detail {
struct AffineExprStorage;
struct AffineBinaryOpExprStorage;
struct AffineDimExprStorage;
struct AffineSymbolExprStorage;
struct AffineConstantExprStorage;
} // namespace detail
enum class AffineExprKind {
Add,
/// RHS of mul is always a constant or a symbolic expression.
Mul,
/// RHS of mod is always a constant or a symbolic expression with a positive
/// value.
Mod,
/// RHS of floordiv is always a constant or a symbolic expression.
FloorDiv,
/// RHS of ceildiv is always a constant or a symbolic expression.
CeilDiv,
/// This is a marker for the last affine binary op. The range of binary
/// op's is expected to be this element and earlier.
LAST_AFFINE_BINARY_OP = CeilDiv,
/// Constant integer.
Constant,
/// Dimensional identifier.
DimId,
/// Symbolic identifier.
SymbolId,
};
/// Base type for affine expression.
/// AffineExpr's are immutable value types with intuitive operators to
/// operate on chainable, lightweight compositions.
/// An AffineExpr is an interface to the underlying storage type pointer.
class AffineExpr {
public:
using ImplType = detail::AffineExprStorage;
AffineExpr() : expr(nullptr) {}
/* implicit */ AffineExpr(const ImplType *expr)
: expr(const_cast<ImplType *>(expr)) {}
AffineExpr(const AffineExpr &other) : expr(other.expr) {}
AffineExpr &operator=(AffineExpr other) {
expr = other.expr;
return *this;
}
bool operator==(AffineExpr other) const { return expr == other.expr; }
bool operator!=(AffineExpr other) const { return !(*this == other); }
explicit operator bool() const { return expr; }
bool operator!() const { return expr == nullptr; }
template <typename U> bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U cast() const;
MLIRContext *getContext() const;
/// Return the classification for this type.
AffineExprKind getKind() const;
void print(raw_ostream &os) const;
void dump() const;
/// Returns true if this expression is made out of only symbols and
/// constants, i.e., it does not involve dimensional identifiers.
bool isSymbolicOrConstant() const;
/// Returns true if this is a pure affine expression, i.e., multiplication,
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
bool isPureAffine() const;
/// Returns the greatest known integral divisor of this affine expression.
uint64_t getLargestKnownDivisor() const;
/// Return true if the affine expression is a multiple of 'factor'.
bool isMultipleOf(int64_t factor) const;
/// Return true if the affine expression involves AffineDimExpr `position`.
bool isFunctionOfDim(unsigned position) const;
/// Walk all of the AffineExpr's in this expression in postorder.
void walk(std::function<void(AffineExpr)> callback) const;
/// This method substitutes any uses of dimensions and symbols (e.g.
/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
ArrayRef<AffineExpr> symReplacements) const;
AffineExpr operator+(int64_t v) const;
AffineExpr operator+(AffineExpr other) const;
AffineExpr operator-() const;
AffineExpr operator-(int64_t v) const;
AffineExpr operator-(AffineExpr other) const;
AffineExpr operator*(int64_t v) const;
AffineExpr operator*(AffineExpr other) const;
AffineExpr floorDiv(uint64_t v) const;
AffineExpr floorDiv(AffineExpr other) const;
AffineExpr ceilDiv(uint64_t v) const;
AffineExpr ceilDiv(AffineExpr other) const;
AffineExpr operator%(uint64_t v) const;
AffineExpr operator%(AffineExpr other) const;
/// Compose with an AffineMap.
/// Returns the composition of this AffineExpr with `map`.
///
/// Prerequisites:
/// `this` and `map` are composable, i.e. that the number of AffineDimExpr of
/// `this` is smaller than the number of results of `map`. If a result of a
/// map does not have a corresponding AffineDimExpr, that result simply does
/// not appear in the produced AffineExpr.
///
/// Example:
/// expr: `d0 + d2`
/// map: `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)`
/// returned expr: `d0 * 2 + d1 + d2 + s1`
AffineExpr compose(AffineMap map) const;
friend ::llvm::hash_code hash_value(AffineExpr arg);
protected:
ImplType *expr;
};
/// Affine binary operation expression. An affine binary operation could be an
/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
/// represented through a multiply by -1 and add.) These expressions are always
/// constructed in a simplified form. For eg., the LHS and RHS operands can't
/// both be constants. There are additional canonicalizing rules depending on
/// the op type: see checks in the constructor.
class AffineBinaryOpExpr : public AffineExpr {
public:
using ImplType = detail::AffineBinaryOpExprStorage;
/* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
AffineExpr getLHS() const;
AffineExpr getRHS() const;
};
/// A dimensional identifier appearing in an affine expression.
class AffineDimExpr : public AffineExpr {
public:
using ImplType = detail::AffineDimExprStorage;
/* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
unsigned getPosition() const;
};
/// A symbolic identifier appearing in an affine expression.
class AffineSymbolExpr : public AffineExpr {
public:
using ImplType = detail::AffineDimExprStorage;
/* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
unsigned getPosition() const;
};
/// An integer constant appearing in affine expression.
class AffineConstantExpr : public AffineExpr {
public:
using ImplType = detail::AffineConstantExprStorage;
/* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr);
int64_t getValue() const;
};
/// Make AffineExpr hashable.
inline ::llvm::hash_code hash_value(AffineExpr arg) {
return ::llvm::hash_value(arg.expr);
}
inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; }
inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; }
inline AffineExpr operator-(int64_t val, AffineExpr expr) {
return expr * (-1) + val;
}
/// These free functions allow clients of the API to not use classes in detail.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
AffineExpr rhs);
/// Constructs an affine expression from a flat ArrayRef. If there are local
/// identifiers (neither dimensional nor symbolic) that appear in the sum of
/// products expression, 'localExprs' is expected to have the AffineExpr
/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
/// format [dims, symbols, locals, constant term].
AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
unsigned numSymbols, ArrayRef<AffineExpr> localExprs,
MLIRContext *context);
raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr);
template <typename U> bool AffineExpr::isa() const {
if (std::is_same<U, AffineBinaryOpExpr>::value) {
return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
}
if (std::is_same<U, AffineDimExpr>::value) {
return getKind() == AffineExprKind::DimId;
}
if (std::is_same<U, AffineSymbolExpr>::value) {
return getKind() == AffineExprKind::SymbolId;
}
if (std::is_same<U, AffineConstantExpr>::value) {
return getKind() == AffineExprKind::Constant;
}
}
template <typename U> U AffineExpr::dyn_cast() const {
if (isa<U>()) {
return U(expr);
}
return U(nullptr);
}
template <typename U> U AffineExpr::cast() const {
assert(isa<U>());
return U(expr);
}
/// Simplify an affine expression by flattening and some amount of
/// simple analysis. This has complexity linear in the number of nodes in
/// 'expr'. Returns the simplified expression, which is the same as the input
/// expression if it can't be simplified.
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols);
/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false
/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled).
/// 'cst' contains constraints that connect newly introduced local identifiers
/// to existing dimensional and / symbolic identifiers. See documentation for
/// AffineExprFlattener on how mod's and div's are flattened.
bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols,
llvm::SmallVectorImpl<int64_t> *flattenedExpr);
/// Flattens the result expressions of the map to their corresponding flattened
/// forms and set in 'flattenedExprs'. Returns true on success or false
/// if any expression in the map could not be flattened (i.e., semi-affine is
/// not yet handled). For all affine expressions that share the same operands
/// (like those of an affine map), this method should be used instead of
/// repeatedly calling getFlattenedAffineExpr since local variables added to
/// deal with div's and mod's will be reused across expressions.
bool getFlattenedAffineExprs(
AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs);
bool getFlattenedAffineExprs(
IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs);
} // namespace mlir
namespace llvm {
// AffineExpr hash just like pointers
template <> struct DenseMapInfo<mlir::AffineExpr> {
static mlir::AffineExpr getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
}
static mlir::AffineExpr getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::AffineExpr val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) {
return LHS == RHS;
}
};
} // namespace llvm
#endif // MLIR_IR_AFFINE_EXPR_H

View File

@ -0,0 +1,334 @@
//===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the AffineExpr visitor class.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_AFFINE_EXPR_VISITOR_H
#define MLIR_IR_AFFINE_EXPR_VISITOR_H
#include "mlir/IR/AffineExpr.h"
namespace mlir {
/// Base class for AffineExpr visitors/walkers.
///
/// AffineExpr visitors are used when you want to perform different actions
/// for different kinds of AffineExprs without having to use lots of casts
/// and a big switch instruction.
///
/// To define your own visitor, inherit from this class, specifying your
/// new type for the 'SubClass' template parameter, and "override" visitXXX
/// functions in your class. This class is defined in terms of statically
/// resolved overloading, not virtual functions.
///
/// For example, here is a visitor that counts the number of for AffineDimExprs
/// in an AffineExpr.
///
/// /// Declare the class. Note that we derive from AffineExprVisitor
/// /// instantiated with our new subclasses_ type.
///
/// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
/// unsigned numDimExprs;
/// DimExprCounter() : numDimExprs(0) {}
/// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
/// };
///
/// And this class would be used like this:
/// DimExprCounter dec;
/// dec.visit(affineExpr);
/// numDimExprs = dec.numDimExprs;
///
/// AffineExprVisitor provides visit methods for the following binary affine
/// op expressions:
/// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
/// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
/// AffineBinaryCeilDivOpExpr. Note that default implementations of these
/// methods will call the general AffineBinaryOpExpr method.
///
/// In addition, visit methods are provided for the following affine
// expressions: AffineConstantExpr, AffineDimExpr, and
// AffineSymbolExpr.
///
/// Note that if you don't implement visitXXX for some affine expression type,
/// the visitXXX method for Instruction superclass will be invoked.
///
/// Note that this class is specifically designed as a template to avoid
/// virtual function call overhead. Defining and using a AffineExprVisitor is
/// just as efficient as having your own switch instruction over the instruction
/// opcode.
template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the AffineExprVisitor
// that you use to visit affine expressions...
public:
// Function to walk an AffineExpr (in post order).
RetTy walkPostOrder(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
expr.cast<AffineConstantExpr>());
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
expr.cast<AffineDimExpr>());
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
expr.cast<AffineSymbolExpr>());
}
}
// Function to visit an AffineExpr.
RetTy visit(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
expr.cast<AffineConstantExpr>());
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
expr.cast<AffineDimExpr>());
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
expr.cast<AffineSymbolExpr>());
}
llvm_unreachable("Unknown AffineExpr");
}
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
//
// Default visit methods. Note that the default op-specific binary op visit
// methods call the general visitAffineBinaryOpExpr visit method.
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {}
void visitAddExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitMulExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitModExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitFloorDivExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitCeilDivExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitConstantExpr(AffineConstantExpr expr) {}
void visitDimExpr(AffineDimExpr expr) {}
void visitSymbolExpr(AffineSymbolExpr expr) {}
private:
// Walk the operands - each operand is itself walked in post order.
void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
walkPostOrder(expr.getLHS());
walkPostOrder(expr.getRHS());
}
};
// This class is used to flatten a pure affine expression (AffineExpr,
// which is in a tree form) into a sum of products (w.r.t constants) when
// possible, and in that process simplifying the expression. For a modulo,
// floordiv, or a ceildiv expression, an additional identifier, called a local
// identifier, is introduced to rewrite the expression as a sum of product
// affine expression. Each local identifier is always and by construction a
// floordiv of a pure add/mul affine function of dimensional, symbolic, and
// other local identifiers, in a non-mutually recursive way. Hence, every local
// identifier can ultimately always be recovered as an affine function of
// dimensional and symbolic identifiers (involving floordiv's); note however
// that by AffineExpr construction, some floordiv combinations are converted to
// mod's. The result of the flattening is a flattened expression and a set of
// constraints involving just the local variables.
//
// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local
// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
//
// The simplification performed includes the accumulation of contributions for
// each dimensional and symbolic identifier together, the simplification of
// floordiv/ceildiv/mod expressions and other simplifications that in turn
// happen as a result. A simplification that this flattening naturally performs
// is of simplifying the numerator and denominator of floordiv/ceildiv, and
// folding a modulo expression to a zero, if possible. Three examples are below:
//
// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
// (d0 - d0 mod 4 + 4) mod 4 simplified to 0
// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
//
// The way the flattening works for the second example is as follows: d0 % 4 is
// replaced by d0 - 4*q with q being introduced: the expression then simplifies
// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
// zero. Note that an affine expression may not always be expressible purely as
// a sum of products involving just the original dimensional and symbolic
// identifiers due to the presence of modulo/floordiv/ceildiv expressions that
// may not be eliminated after simplification; in such cases, the final
// expression can be reconstructed by replacing the local identifiers with their
// corresponding explicit form stored in 'localExprs' (note that each of the
// explicit forms itself would have been simplified).
//
// The expression walk method here performs a linear time post order walk that
// performs the above simplifications through visit methods, with partial
// results being stored in 'operandExprStack'. When a parent expr is visited,
// the flattened expressions corresponding to its two operands would already be
// on the stack - the parent expression looks at the two flattened expressions
// and combines the two. It pops off the operand expressions and pushes the
// combined result (although this is done in-place on its LHS operand expr).
// When the walk is completed, the flattened form of the top-level expression
// would be left on the stack.
//
// A flattener can be repeatedly used for multiple affine expressions that bind
// to the same operands, for example, for all result expressions of an
// AffineMap or AffineValueMap. In such cases, using it for multiple expressions
// is more efficient than creating a new flattener for each expression since
// common idenical div and mod expressions appearing across different
// expressions are mapped to the same local identifier (same column position in
// 'localVarCst').
class SimpleAffineExprFlattener
: public AffineExprVisitor<SimpleAffineExprFlattener> {
public:
// Flattend expression layout: [dims, symbols, locals, constant]
// Stack that holds the LHS and RHS operands while visiting a binary op expr.
// In future, consider adding a prepass to determine how big the SmallVector's
// will be, and linearize this to std::vector<int64_t> to prevent
// SmallVector moves on re-allocation.
std::vector<SmallVector<int64_t, 8>> operandExprStack;
unsigned numDims;
unsigned numSymbols;
// Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
unsigned numLocals;
// AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
// which new identifiers were introduced; if the latter do not get canceled
// out, these expressions can be readily used to reconstruct the AffineExpr
// (tree) form. Note that these expressions themselves would have been
// simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
// will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
// ceildiv 2 would be the local expression stored for q.
SmallVector<AffineExpr, 4> localExprs;
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
virtual ~SimpleAffineExprFlattener() = default;
// Visitor method overrides.
void visitMulExpr(AffineBinaryOpExpr expr);
void visitAddExpr(AffineBinaryOpExpr expr);
void visitDimExpr(AffineDimExpr expr);
void visitSymbolExpr(AffineSymbolExpr expr);
void visitConstantExpr(AffineConstantExpr expr);
void visitCeilDivExpr(AffineBinaryOpExpr expr);
void visitFloorDivExpr(AffineBinaryOpExpr expr);
//
// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
//
// A mod expression "expr mod c" is thus flattened by introducing a new local
// variable q (= expr floordiv c), such that expr mod c is replaced with
// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
void visitModExpr(AffineBinaryOpExpr expr);
protected:
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
// The local identifier added is always a floordiv of a pure add/mul affine
// function of other identifiers, coefficients of which are specified in
// dividend and with respect to a positive constant divisor. localExpr is the
// simplified tree expression (AffineExpr) corresponding to the quantifier.
virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
AffineExpr localExpr);
private:
// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
// A floordiv is thus flattened by introducing a new local variable q, and
// replacing that expression with 'q' while adding the constraints
// c * q <= expr <= c * q + c - 1 to localVarCst (done by
// FlatAffineConstraints::addLocalFloorDiv).
//
// A ceildiv is similarly flattened:
// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
int findLocalId(AffineExpr localExpr);
inline unsigned getNumCols() const {
return numDims + numSymbols + numLocals + 1;
}
inline unsigned getConstantIndex() const { return getNumCols() - 1; }
inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
inline unsigned getSymbolStartIndex() const { return numDims; }
inline unsigned getDimStartIndex() const { return 0; }
};
} // end namespace mlir
#endif // MLIR_IR_AFFINE_EXPR_VISITOR_H

View File

@ -0,0 +1,241 @@
//===- AffineMap.h - MLIR Affine Map Class ----------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Affine maps are mathematical functions which map a list of dimension
// identifiers and symbols, to multidimensional affine expressions.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_AFFINE_MAP_H
#define MLIR_IR_AFFINE_MAP_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
namespace mlir {
namespace detail {
struct AffineMapStorage;
} // end namespace detail
class AffineExpr;
class Attribute;
struct LogicalResult;
class MLIRContext;
/// A multi-dimensional affine map
/// Affine map's are immutable like Type's, and they are uniqued.
/// Eg: (d0, d1) -> (d0/128, d0 mod 128, d1)
/// The names used (d0, d1) don't matter - it's the mathematical function that
/// is unique to this affine map.
class AffineMap {
public:
using ImplType = detail::AffineMapStorage;
AffineMap() : map(nullptr) {}
explicit AffineMap(ImplType *map) : map(map) {}
AffineMap(const AffineMap &other) : map(other.map) {}
AffineMap &operator=(const AffineMap &other) = default;
static AffineMap get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results);
/// Returns a single constant result affine map.
static AffineMap getConstantMap(int64_t val, MLIRContext *context);
/// Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap getMultiDimIdentityMap(unsigned numDims,
MLIRContext *context);
MLIRContext *getContext() const;
explicit operator bool() { return map != nullptr; }
bool operator==(AffineMap other) const { return other.map == map; }
bool operator!=(AffineMap other) const { return !(other.map == map); }
/// Returns true if this affine map is an identity affine map.
/// An identity affine map corresponds to an identity affine function on the
/// dimensional identifiers.
bool isIdentity() const;
/// Returns true if this affine map is a single result constant function.
bool isSingleConstant() const;
/// Returns the constant result of this map. This methods asserts that the map
/// has a single constant result.
int64_t getSingleConstantResult() const;
// Prints affine map to 'os'.
void print(raw_ostream &os) const;
void dump() const;
unsigned getNumDims() const;
unsigned getNumSymbols() const;
unsigned getNumResults() const;
unsigned getNumInputs() const;
ArrayRef<AffineExpr> getResults() const;
AffineExpr getResult(unsigned idx) const;
/// Walk all of the AffineExpr's in this mapping. Each node in an expression
/// tree is visited in postorder.
void walkExprs(std::function<void(AffineExpr)> callback) const;
/// This method substitutes any uses of dimensions and symbols (e.g.
/// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
/// expression mapping. Because this can be used to eliminate dims and
/// symbols, the client needs to specify the number of dims and symbols in
/// the result. The returned map always has the same number of results.
AffineMap replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
ArrayRef<AffineExpr> symReplacements,
unsigned numResultDims,
unsigned numResultSyms);
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible.
LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute> &results) const;
/// Returns the AffineMap resulting from composing `this` with `map`.
/// The resulting AffineMap has as many AffineDimExpr as `map` and as many
/// AffineSymbolExpr as the concatenation of `this` and `map` (in which case
/// the symbols of `this` map come first).
///
/// Prerequisites:
/// The maps are composable, i.e. that the number of AffineDimExpr of `this`
/// matches the number of results of `map`.
///
/// Example:
/// map1: `(d0, d1)[s0, s1] -> (d0 + 1 + s1, d1 - 1 - s0)`
/// map2: `(d0)[s0] -> (d0 + s0, d0 - s0))`
/// map1.compose(map2):
/// `(d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)`
AffineMap compose(AffineMap map);
/// Returns true if the AffineMap represents a subset (i.e. a projection) of a
/// symbol-less permutation map.
bool isProjectedPermutation();
/// Returns true if the AffineMap represents a symbol-less permutation map.
bool isPermutation();
/// Returns the map consisting of the `resultPos` subset.
AffineMap getSubMap(ArrayRef<unsigned> resultPos);
friend ::llvm::hash_code hash_value(AffineMap arg);
private:
ImplType *map;
};
// Make AffineExpr hashable.
inline ::llvm::hash_code hash_value(AffineMap arg) {
return ::llvm::hash_value(arg.map);
}
/// Simplify an affine map by simplifying its underlying AffineExpr results.
AffineMap simplifyAffineMap(AffineMap map);
/// Returns a map of codomain to domain dimensions such that the first codomain
/// dimension for a particular domain dimension is selected.
/// Returns an empty map if the input map is empty.
///
/// Prerequisites:
/// 1. `map` must contain a subset that is a permutation of full domain rank.
/// 2. `map` has no symbols.
///
/// Example 1:
///
/// ```{.mlir}
/// (d0, d1, d2) -> (d1, d1, d0, d2, d1, d2, d1, d0)
/// 0 2 3
/// ```
///
/// returns:
///
/// ```{.mlir}
/// (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3)
/// ```
///
/// Example 2:
///
/// ```{.mlir}
/// (d0, d1, d2) -> (d1, d0 + d1, d0, d2, d1, d2, d1, d0)
/// 0 2 3
/// ```
///
/// returns:
///
/// ```{.mlir}
/// (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3)
/// ```
AffineMap inversePermutation(AffineMap map);
/// Concatenates a list of `maps` into a single AffineMap, stepping over
/// potentially empty maps. Assumes each of the underlying map has 0 symbols.
/// The resulting map has a number of dims equal to the max of `maps`' dims and
/// the concatenated results as its results.
/// Returns an empty map if all input `maps` are empty.
///
/// Example:
/// When applied to the following list of 3 affine maps,
///
/// ```{.mlir}
/// {
/// (i, j, k) -> (i, k),
/// (i, j, k) -> (k, j),
/// (i, j, k) -> (i, j)
/// }
/// ```
///
/// Returns the map:
///
/// ```{.mlir}
/// (i, j, k) -> (i, k, k, j, i, j)
/// ```
AffineMap concatAffineMaps(llvm::ArrayRef<AffineMap> maps);
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
map.print(os);
return os;
}
} // end namespace mlir
namespace llvm {
// AffineExpr hash just like pointers
template <> struct DenseMapInfo<mlir::AffineMap> {
static mlir::AffineMap getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
}
static mlir::AffineMap getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::AffineMap val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::AffineMap LHS, mlir::AffineMap RHS) {
return LHS == RHS;
}
};
} // namespace llvm
#endif // MLIR_IR_AFFINE_MAP_H

View File

@ -0,0 +1,116 @@
//===- AttributeSupport.h ---------------------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines support types for registering dialect extended attributes.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_ATTRIBUTESUPPORT_H
#define MLIR_IR_ATTRIBUTESUPPORT_H
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StorageUniquerSupport.h"
#include "llvm/ADT/PointerIntPair.h"
namespace mlir {
class MLIRContext;
class Type;
//===----------------------------------------------------------------------===//
// AttributeStorage
//===----------------------------------------------------------------------===//
namespace detail {
class AttributeUniquer;
} // end namespace detail
/// Base storage class appearing in an attribute. Derived storage classes should
/// only be constructed within the context of the AttributeUniquer.
class AttributeStorage : public StorageUniquer::BaseStorage {
friend detail::AttributeUniquer;
friend StorageUniquer;
public:
/// Get the type of this attribute.
Type getType() const;
/// Get the dialect of this attribute.
Dialect &getDialect() const {
assert(dialect && "Malformed attribute storage object.");
return const_cast<Dialect &>(*dialect);
}
protected:
/// Construct a new attribute storage instance with the given type.
/// Note: All attributes require a valid type. If no type is provided here,
/// the type of the attribute will automatically default to NoneType
/// upon initialization in the uniquer.
AttributeStorage(Type type);
AttributeStorage();
/// Set the type of this attribute.
void setType(Type type);
// Set the dialect for this storage instance. This is used by the
// AttributeUniquer when initializing a newly constructed storage object.
void initializeDialect(Dialect &newDialect) { dialect = &newDialect; }
private:
/// The dialect for this attribute.
Dialect *dialect;
/// The opaque type of the attribute value.
const void *type;
};
/// Default storage type for attributes that require no additional
/// initialization or storage.
using DefaultAttributeStorage = AttributeStorage;
//===----------------------------------------------------------------------===//
// AttributeStorageAllocator
//===----------------------------------------------------------------------===//
// This is a utility allocator used to allocate memory for instances of derived
// Attributes.
using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
//===----------------------------------------------------------------------===//
// AttributeUniquer
//===----------------------------------------------------------------------===//
namespace detail {
// A utility class to get, or create, unique instances of attributes within an
// MLIRContext. This class manages all creation and uniquing of attributes.
class AttributeUniquer {
public:
/// Get an uniqued instance of attribute T.
template <typename T, typename... Args>
static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
return ctx->getAttributeUniquer().get<typename T::ImplType>(
getInitFn(ctx, T::getClassID()), kind, std::forward<Args>(args)...);
}
private:
/// Returns a functor used to initialize new attribute storage instances.
static std::function<void(AttributeStorage *)>
getInitFn(MLIRContext *ctx, const ClassID *const attrID);
};
} // namespace detail
} // end namespace mlir
#endif

View File

@ -0,0 +1,954 @@
//===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_IR_ATTRIBUTES_H
#define MLIR_IR_ATTRIBUTES_H
#include "mlir/IR/AttributeSupport.h"
#include "llvm/ADT/APFloat.h"
namespace mlir {
class AffineMap;
class Dialect;
class FunctionType;
class Identifier;
class IntegerSet;
class Location;
class MLIRContext;
class ShapedType;
class Type;
namespace detail {
struct AffineMapAttributeStorage;
struct ArrayAttributeStorage;
struct BoolAttributeStorage;
struct DictionaryAttributeStorage;
struct IntegerAttributeStorage;
struct IntegerSetAttributeStorage;
struct FloatAttributeStorage;
struct OpaqueAttributeStorage;
struct StringAttributeStorage;
struct TypeAttributeStorage;
/// Elements Attributes.
struct DenseElementsAttributeStorage;
struct OpaqueElementsAttributeStorage;
struct SparseElementsAttributeStorage;
} // namespace detail
/// Attributes are known-constant values of operations and functions.
///
/// Instances of the Attribute class are references to immutable, uniqued,
/// and immortal values owned by MLIRContext. As such, an Attribute is a thin
/// wrapper around an underlying storage pointer. Attributes are usually passed
/// by value.
class Attribute {
public:
/// Integer identifier for all the concrete attribute kinds.
enum Kind {
// Reserve attribute kinds for dialect specific extensions.
#define DEFINE_SYM_KIND_RANGE(Dialect) \
FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff,
#include "DialectSymbolRegistry.def"
};
/// Utility class for implementing attributes.
template <typename ConcreteType, typename BaseType = Attribute,
typename StorageType = AttributeStorage>
using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
detail::AttributeUniquer>;
using ImplType = AttributeStorage;
using ValueType = void;
Attribute() : impl(nullptr) {}
/* implicit */ Attribute(const ImplType *impl)
: impl(const_cast<ImplType *>(impl)) {}
Attribute(const Attribute &other) : impl(other.impl) {}
Attribute &operator=(Attribute other) {
impl = other.impl;
return *this;
}
bool operator==(Attribute other) const { return impl == other.impl; }
bool operator!=(Attribute other) const { return !(*this == other); }
explicit operator bool() const { return impl; }
bool operator!() const { return impl == nullptr; }
template <typename U> bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
// Support dyn_cast'ing Attribute to itself.
static bool classof(Attribute) { return true; }
/// Return the classification for this attribute.
unsigned getKind() const { return impl->getKind(); }
/// Return the type of this attribute.
Type getType() const;
/// Return the context this attribute belongs to.
MLIRContext *getContext() const;
/// Get the dialect this attribute is registered to.
Dialect &getDialect() const;
/// Print the attribute.
void print(raw_ostream &os) const;
void dump() const;
/// Get an opaque pointer to the attribute.
const void *getAsOpaquePointer() const { return impl; }
/// Construct an attribute from the opaque pointer representation.
static Attribute getFromOpaquePointer(const void *ptr) {
return Attribute(reinterpret_cast<const ImplType *>(ptr));
}
friend ::llvm::hash_code hash_value(Attribute arg);
protected:
ImplType *impl;
};
inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
attr.print(os);
return os;
}
namespace StandardAttributes {
enum Kind {
AffineMap = Attribute::FIRST_STANDARD_ATTR,
Array,
Bool,
Dictionary,
Float,
Integer,
IntegerSet,
Opaque,
String,
SymbolRef,
Type,
Unit,
/// Elements Attributes.
DenseElements,
OpaqueElements,
SparseElements,
FIRST_ELEMENTS_ATTR = DenseElements,
LAST_ELEMENTS_ATTR = SparseElements,
/// Locations.
CallSiteLocation,
FileLineColLocation,
FusedLocation,
NameLocation,
UnknownLocation,
// Represents a location as a 'void*' pointer to a front-end's opaque
// location information, which must live longer than the MLIR objects that
// refer to it. OpaqueLocation's are never serialized.
//
// TODO: OpaqueLocation,
// Represents a value inlined through a function call.
// TODO: InlinedLocation,
FIRST_LOCATION_ATTR = CallSiteLocation,
LAST_LOCATION_ATTR = UnknownLocation,
};
} // namespace StandardAttributes
class AffineMapAttr
: public Attribute::AttrBase<AffineMapAttr, Attribute,
detail::AffineMapAttributeStorage> {
public:
using Base::Base;
using ValueType = AffineMap;
static AffineMapAttr get(AffineMap value);
AffineMap getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::AffineMap;
}
};
/// Array attributes are lists of other attributes. They are not necessarily
/// type homogenous given that attributes don't, in general, carry types.
class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
detail::ArrayAttributeStorage> {
public:
using Base::Base;
using ValueType = ArrayRef<Attribute>;
static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
ArrayRef<Attribute> getValue() const;
/// Support range iteration.
using iterator = llvm::ArrayRef<Attribute>::iterator;
iterator begin() const { return getValue().begin(); }
iterator end() const { return getValue().end(); }
size_t size() const { return getValue().size(); }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Array;
}
};
class BoolAttr : public Attribute::AttrBase<BoolAttr, Attribute,
detail::BoolAttributeStorage> {
public:
using Base::Base;
using ValueType = bool;
static BoolAttr get(bool value, MLIRContext *context);
bool getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == StandardAttributes::Bool; }
};
/// NamedAttribute is used for dictionary attributes, it holds an identifier for
/// the name and a value for the attribute. The attribute pointer should always
/// be non-null.
using NamedAttribute = std::pair<Identifier, Attribute>;
/// Dictionary attribute is an attribute that represents a sorted collection of
/// named attribute values. The elements are sorted by name, and each name must
/// be unique within the collection.
class DictionaryAttr
: public Attribute::AttrBase<DictionaryAttr, Attribute,
detail::DictionaryAttributeStorage> {
public:
using Base::Base;
using ValueType = ArrayRef<NamedAttribute>;
static DictionaryAttr get(ArrayRef<NamedAttribute> value,
MLIRContext *context);
ArrayRef<NamedAttribute> getValue() const;
/// Return the specified attribute if present, null otherwise.
Attribute get(StringRef name) const;
Attribute get(Identifier name) const;
/// Support range iteration.
using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
iterator begin() const;
iterator end() const;
bool empty() const { return size() == 0; }
size_t size() const;
/// Methods for supporting type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Dictionary;
}
};
class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
detail::FloatAttributeStorage> {
public:
using Base::Base;
using ValueType = APFloat;
/// Return a float attribute for the specified value in the specified type.
/// These methods should only be used for simple constant values, e.g 1.0/2.0,
/// that are known-valid both as host double and the 'type' format.
static FloatAttr get(Type type, double value);
static FloatAttr getChecked(Type type, double value, Location loc);
/// Return a float attribute for the specified value in the specified type.
static FloatAttr get(Type type, const APFloat &value);
static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
APFloat getValue() const;
/// This function is used to convert the value to a double, even if it loses
/// precision.
double getValueAsDouble() const;
static double getValueAsDouble(APFloat val);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Float;
}
/// Verify the construction invariants for a double value.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
Type type, double value);
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
Type type, const APFloat &value);
};
class IntegerAttr
: public Attribute::AttrBase<IntegerAttr, Attribute,
detail::IntegerAttributeStorage> {
public:
using Base::Base;
using ValueType = APInt;
static IntegerAttr get(Type type, int64_t value);
static IntegerAttr get(Type type, const APInt &value);
APInt getValue() const;
// TODO(jpienaar): Change callers to use getValue instead.
int64_t getInt() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Integer;
}
};
class IntegerSetAttr
: public Attribute::AttrBase<IntegerSetAttr, Attribute,
detail::IntegerSetAttributeStorage> {
public:
using Base::Base;
using ValueType = IntegerSet;
static IntegerSetAttr get(IntegerSet value);
IntegerSet getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::IntegerSet;
}
};
/// Opaque attributes represent attributes of non-registered dialects. These are
/// attribute represented in their raw string form, and can only usefully be
/// tested for attribute equality.
class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
detail::OpaqueAttributeStorage> {
public:
using Base::Base;
/// Get or create a new OpaqueAttr with the provided dialect and string data.
static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
MLIRContext *context);
/// Get or create a new OpaqueAttr with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a
/// null attribute is returned.
static OpaqueAttr getChecked(Identifier dialect, StringRef attrData,
Type type, Location location);
/// Returns the dialect namespace of the opaque attribute.
Identifier getDialectNamespace() const;
/// Returns the raw attribute data of the opaque attribute.
StringRef getAttrData() const;
/// Verify the construction of an opaque attribute.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Identifier dialect,
StringRef attrData, Type type);
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Opaque;
}
};
class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
detail::StringAttributeStorage> {
public:
using Base::Base;
using ValueType = StringRef;
/// Get an instance of a StringAttr with the given string.
static StringAttr get(StringRef bytes, MLIRContext *context);
/// Get an instance of a StringAttr with the given string and Type.
static StringAttr get(StringRef bytes, Type type);
StringRef getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::String;
}
};
/// A symbol reference attribute represents a symbolic reference to another
/// operation.
class SymbolRefAttr
: public Attribute::AttrBase<SymbolRefAttr, Attribute,
detail::StringAttributeStorage> {
public:
using Base::Base;
using ValueType = StringRef;
static SymbolRefAttr get(StringRef value, MLIRContext *ctx);
/// Returns the name of the held symbol reference.
StringRef getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::SymbolRef;
}
};
class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
detail::TypeAttributeStorage> {
public:
using Base::Base;
using ValueType = Type;
static TypeAttr get(Type value);
Type getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; }
};
/// Unit attributes are attributes that hold no specific value and are given
/// meaning by their existence.
class UnitAttr : public Attribute::AttrBase<UnitAttr> {
public:
using Base::Base;
static UnitAttr get(MLIRContext *context);
static bool kindof(unsigned kind) { return kind == StandardAttributes::Unit; }
};
//===----------------------------------------------------------------------===//
// Elements Attributes
//===----------------------------------------------------------------------===//
/// A base attribute that represents a reference to a static shaped tensor or
/// vector constant.
class ElementsAttr : public Attribute {
public:
using Attribute::Attribute;
/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
/// with static shape.
ShapedType getType() const;
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
/// Generates a new ElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This ElementsAttr should contain integers.
ElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const;
/// Generates a new ElementsAttr by mapping each float value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This ElementsAttr should contain floats.
ElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
}
};
/// An attribute that represents a reference to a dense vector or tensor object.
///
class DenseElementsAttr
: public Attribute::AttrBase<DenseElementsAttr, ElementsAttr,
detail::DenseElementsAttributeStorage> {
public:
using Base::Base;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
return attr.getKind() == StandardAttributes::DenseElements;
}
/// Constructs a dense elements attribute from an array of element values.
/// Each element attribute value is expected to be an element of 'type'.
/// 'type' must be a vector or tensor with static shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
/// Constructs a dense integer elements attribute from an array of integer
/// or floating-point values. Each value is expected to be the same bitwidth
/// of the element type of 'type'. 'type' must be a vector or tensor with
/// static shape.
template <typename T, typename = typename std::enable_if<
std::numeric_limits<T>::is_integer ||
llvm::is_one_of<T, float, double>::value>::type>
static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
const char *data = reinterpret_cast<const char *>(values.data());
return getRawIntOrFloat(
type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
/*isInt=*/std::numeric_limits<T>::is_integer);
}
/// Constructs a dense integer elements attribute from a single element.
template <typename T, typename = typename std::enable_if<
std::numeric_limits<T>::is_integer ||
llvm::is_one_of<T, float, double>::value>::type>
static DenseElementsAttr get(const ShapedType &type, T value) {
return get(type, llvm::makeArrayRef(value));
}
/// Overload of the above 'get' method that is specialized for boolean values.
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
/// Constructs a dense float elements attribute from an array of APFloat
/// values. Each APFloat value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
/// Construct a dense elements attribute for an initializer_list of values.
/// Each value is expected to be the same bitwidth of the element type of
/// 'type'. 'type' must be a vector or tensor with static shape.
template <typename T>
static DenseElementsAttr get(const ShapedType &type,
const std::initializer_list<T> &list) {
return get(type, ArrayRef<T>(list));
}
//===--------------------------------------------------------------------===//
// Iterators
//===--------------------------------------------------------------------===//
/// A utility iterator that allows walking over the internal Attribute values
/// of a DenseElementsAttr.
class AttributeElementIterator
: public indexed_accessor_iterator<AttributeElementIterator, const void *,
Attribute, Attribute, Attribute> {
public:
/// Accesses the Attribute value at this iterator position.
Attribute operator*() const;
private:
friend DenseElementsAttr;
/// Constructs a new iterator.
AttributeElementIterator(DenseElementsAttr attr, size_t index);
};
/// A utility iterator that allows walking over the internal raw APInt values.
class IntElementIterator
: public indexed_accessor_iterator<IntElementIterator, const char *,
APInt, APInt, APInt> {
public:
/// Accesses the raw APInt value at this iterator position.
APInt operator*() const;
private:
friend DenseElementsAttr;
/// Constructs a new iterator.
IntElementIterator(DenseElementsAttr attr, size_t index);
/// The bitwidth of the element type.
size_t bitWidth;
};
/// Iterator for walking over APFloat values.
class FloatElementIterator final
: public llvm::mapped_iterator<IntElementIterator,
std::function<APFloat(const APInt &)>> {
friend DenseElementsAttr;
/// Initializes the float element iterator to the specified iterator.
FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
public:
using reference = APFloat;
};
//===--------------------------------------------------------------------===//
// Value Querying
//===--------------------------------------------------------------------===//
/// Returns the number of raw elements held by this attribute.
size_t rawSize() const;
/// Returns if this attribute corresponds to a splat, i.e. if all element
/// values are the same.
bool isSplat() const;
/// If this attribute corresponds to a splat, then get the splat value.
/// Otherwise, return null.
Attribute getSplatValue() const;
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
/// Return the held element values as an array of integer or floating-point
/// values.
template <typename T, typename = typename std::enable_if<
(!std::is_same<T, bool>::value &&
std::numeric_limits<T>::is_integer) ||
llvm::is_one_of<T, float, double>::value>::type>
ArrayRef<T> getValues() const {
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer));
auto rawData = getRawData();
return ArrayRef<T>(reinterpret_cast<const T *>(rawData.data()),
rawData.size() / sizeof(T));
}
/// Return the held element values as a range of Attributes.
llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
template <typename T, typename = typename std::enable_if<
std::is_same<T, Attribute>::value>::type>
llvm::iterator_range<AttributeElementIterator> getValues() const {
return getAttributeValues();
}
AttributeElementIterator attr_value_begin() const;
AttributeElementIterator attr_value_end() const;
/// Return the held element values as a range of APInts. The element type of
/// this attribute must be of integer type.
llvm::iterator_range<IntElementIterator> getIntValues() const;
template <typename T, typename = typename std::enable_if<
std::is_same<T, APInt>::value>::type>
llvm::iterator_range<IntElementIterator> getValues() const {
return getIntValues();
}
IntElementIterator int_value_begin() const;
IntElementIterator int_value_end() const;
/// Return the held element values as a range of APFloat. The element type of
/// this attribute must be of float type.
llvm::iterator_range<FloatElementIterator> getFloatValues() const;
template <typename T, typename = typename std::enable_if<
std::is_same<T, APFloat>::value>::type>
llvm::iterator_range<FloatElementIterator> getValues() const {
return getFloatValues();
}
FloatElementIterator float_value_begin() const;
FloatElementIterator float_value_end() const;
//===--------------------------------------------------------------------===//
// Mutation Utilities
//===--------------------------------------------------------------------===//
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
DenseElementsAttr reshape(ShapedType newType);
/// Generates a new DenseElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This underlying type must be an DenseIntElementsAttr.
DenseElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const;
/// Generates a new DenseElementsAttr by mapping each float value to a new
/// underlying APInt. the new values can represent either a integer or float.
/// This underlying type must be an DenseFPElementsAttr.
DenseElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const;
protected:
/// Return the raw storage data held by this attribute.
ArrayRef<char> getRawData() const;
/// Get iterators to the raw APInt values for each element in this attribute.
IntElementIterator raw_int_begin() const {
return IntElementIterator(*this, 0);
}
IntElementIterator raw_int_end() const {
return IntElementIterator(*this, rawSize());
}
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'. 'type' must be a vector or tensor with static shape.
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<APInt> values);
/// Get or create a new dense elements attribute instance with the given raw
/// data buffer. 'type' must be a vector or tensor with static shape.
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
bool isSplat);
/// Overload of the raw 'get' method that asserts that the given type is of
/// integer or floating-point type. This method is used to verify type
/// invariants that the templatized 'get' method cannot.
static DenseElementsAttr getRawIntOrFloat(ShapedType type,
ArrayRef<char> data,
int64_t dataEltSize, bool isInt);
/// Check the information for a c++ data type, check if this type is valid for
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const;
};
/// An attribute that represents a reference to a dense float vector or tensor
/// object. Each element is stored as a double.
class DenseFPElementsAttr : public DenseElementsAttr {
public:
using iterator = DenseElementsAttr::FloatElementIterator;
using DenseElementsAttr::DenseElementsAttr;
/// Generates a new DenseElementsAttr by mapping each value attribute, and
/// constructing the DenseElementsAttr given the new element type.
DenseElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const;
/// Iterator access to the float element values.
iterator begin() const { return float_value_begin(); }
iterator end() const { return float_value_end(); }
/// Method for supporting type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr);
};
/// An attribute that represents a reference to a dense integer vector or tensor
/// object.
class DenseIntElementsAttr : public DenseElementsAttr {
public:
/// DenseIntElementsAttr iterates on APInt, so we can use the raw element
/// iterator directly.
using iterator = DenseElementsAttr::IntElementIterator;
using DenseElementsAttr::DenseElementsAttr;
/// Generates a new DenseElementsAttr by mapping each value attribute, and
/// constructing the DenseElementsAttr given the new element type.
DenseElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const;
/// Iterator access to the integer element values.
iterator begin() const { return raw_int_begin(); }
iterator end() const { return raw_int_end(); }
/// Method for supporting type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr);
};
/// An opaque attribute that represents a reference to a vector or tensor
/// constant with opaque content. This respresentation is for tensor constants
/// which the compiler may not need to interpret. This attribute is always
/// associated with a particular dialect, which provides a method to convert
/// tensor representation to a non-opaque format.
class OpaqueElementsAttr
: public Attribute::AttrBase<OpaqueElementsAttr, ElementsAttr,
detail::OpaqueElementsAttributeStorage> {
public:
using Base::Base;
using ValueType = StringRef;
static OpaqueElementsAttr get(Dialect *dialect, ShapedType type,
StringRef bytes);
StringRef getValue() const;
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
/// Decodes the attribute value using dialect-specific decoding hook.
/// Returns false if decoding is successful. If not, returns true and leaves
/// 'result' argument unspecified.
bool decode(ElementsAttr &result);
/// Returns dialect associated with this opaque constant.
Dialect *getDialect() const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::OpaqueElements;
}
};
/// An attribute that represents a reference to a sparse vector or tensor
/// object.
///
/// This class uses COO (coordinate list) encoding to represent the sparse
/// elements in an element attribute. Specifically, the sparse vector/tensor
/// stores the indices and values as two separate dense elements attributes of
/// tensor type (even if the sparse attribute is of vector type, in order to
/// support empty lists). The dense elements attribute indices is a 2-D tensor
/// of 64-bit integer elements with shape [N, ndims], which specifies the
/// indices of the elements in the sparse tensor that contains nonzero values.
/// The dense elements attribute values is a 1-D tensor with shape [N], and it
/// supplies the corresponding values for the indices.
///
/// For example,
/// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
/// [[1, 0, 0, 0],
/// [0, 0, 5, 0],
/// [0, 0, 0, 0]].
class SparseElementsAttr
: public Attribute::AttrBase<SparseElementsAttr, ElementsAttr,
detail::SparseElementsAttributeStorage> {
public:
using Base::Base;
/// 'type' must be a vector or tensor with static shape.
static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
DenseElementsAttr values);
DenseIntElementsAttr getIndices() const;
DenseElementsAttr getValues() const;
/// Return the value of the element at the given index.
Attribute getValue(ArrayRef<uint64_t> index) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::SparseElements;
}
};
/// An attribute that represents a reference to a splat vector or tensor
/// constant, meaning all of the elements have the same value.
class SplatElementsAttr : public DenseElementsAttr {
public:
using DenseElementsAttr::DenseElementsAttr;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
auto denseAttr = attr.dyn_cast<DenseElementsAttr>();
return denseAttr && denseAttr.isSplat();
}
};
template <typename U> bool Attribute::isa() const {
assert(impl && "isa<> used on a null attribute.");
return U::classof(*this);
}
template <typename U> U Attribute::dyn_cast() const {
return isa<U>() ? U(impl) : U(nullptr);
}
template <typename U> U Attribute::dyn_cast_or_null() const {
return (impl && isa<U>()) ? U(impl) : U(nullptr);
}
template <typename U> U Attribute::cast() const {
assert(isa<U>());
return U(impl);
}
// Make Attribute hashable.
inline ::llvm::hash_code hash_value(Attribute arg) {
return ::llvm::hash_value(arg.impl);
}
/// A NamedAttributeList is used to manage a list of named attributes. This
/// provides simple interfaces for adding/removing/finding attributes from
/// within a DictionaryAttr.
///
/// We assume there will be relatively few attributes on a given operation
/// (maybe a dozen or so, but not hundreds or thousands) so we use linear
/// searches for everything.
class NamedAttributeList {
public:
NamedAttributeList(DictionaryAttr attrs = nullptr)
: attrs((attrs && !attrs.empty()) ? attrs : nullptr) {}
NamedAttributeList(ArrayRef<NamedAttribute> attributes);
/// Return the underlying dictionary attribute. This may be null, if this list
/// has no attributes.
DictionaryAttr getDictionary() const { return attrs; }
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() const;
/// Replace the held attributes with ones provided in 'newAttrs'.
void setAttrs(ArrayRef<NamedAttribute> attributes);
/// Return the specified attribute if present, null otherwise.
Attribute get(StringRef name) const;
Attribute get(Identifier name) const;
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void set(Identifier name, Attribute value);
enum class RemoveResult { Removed, NotFound };
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
RemoveResult remove(Identifier name);
private:
DictionaryAttr attrs;
};
} // end namespace mlir.
namespace llvm {
// Attribute hash just like pointers.
template <> struct DenseMapInfo<mlir::Attribute> {
static mlir::Attribute getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
}
static mlir::Attribute getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::Attribute val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) {
return LHS == RHS;
}
};
/// Allow LLVM to steal the low bits of Attributes.
template <> struct PointerLikeTypeTraits<mlir::Attribute> {
public:
static inline void *getAsVoidPointer(mlir::Attribute attr) {
return const_cast<void *>(attr.getAsOpaquePointer());
}
static inline mlir::Attribute getFromVoidPointer(void *ptr) {
return mlir::Attribute::getFromOpaquePointer(ptr);
}
enum { NumLowBitsAvailable = 3 };
};
} // namespace llvm
#endif

457
third_party/mlir/include/mlir/IR/Block.h vendored Normal file
View File

@ -0,0 +1,457 @@
//===- Block.h - MLIR Block Class -------------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the Block class.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_BLOCK_H
#define MLIR_IR_BLOCK_H
#include "mlir/IR/Value.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/ilist.h"
#include "llvm/ADT/ilist_node.h"
//===----------------------------------------------------------------------===//
// ilist_traits for Operation
//===----------------------------------------------------------------------===//
namespace llvm {
namespace ilist_detail {
// Explicitly define the node access for the operation list so that we can
// break the dependence on the Operation class in this header. This allows for
// operations to have trailing Regions without a circular include
// dependence.
template <>
struct SpecificNodeAccess<
typename compute_node_options<::mlir::Operation>::type> : NodeAccess {
protected:
using OptionsT = typename compute_node_options<mlir::Operation>::type;
using pointer = typename OptionsT::pointer;
using const_pointer = typename OptionsT::const_pointer;
using node_type = ilist_node_impl<OptionsT>;
static node_type *getNodePtr(pointer N);
static const node_type *getNodePtr(const_pointer N);
static pointer getValuePtr(node_type *N);
static const_pointer getValuePtr(const node_type *N);
};
} // end namespace ilist_detail
template <> struct ilist_traits<::mlir::Operation> {
using Operation = ::mlir::Operation;
using op_iterator = simple_ilist<Operation>::iterator;
static void deleteNode(Operation *op);
void addNodeToList(Operation *op);
void removeNodeFromList(Operation *op);
void transferNodesFromList(ilist_traits<Operation> &otherList,
op_iterator first, op_iterator last);
private:
mlir::Block *getContainingBlock();
};
} // end namespace llvm
namespace mlir {
using BlockOperand = IROperandImpl<Block>;
class PredecessorIterator;
class SuccessorIterator;
/// `Block` represents an ordered list of `Operation`s.
class Block : public IRObjectWithUseList,
public llvm::ilist_node_with_parent<Block, Region> {
public:
explicit Block() {}
~Block();
void clear() {
// Drop all references from within this block.
dropAllReferences();
// Clear operations in the reverse order so that uses are destroyed
// before their defs.
while (!empty())
operations.pop_back();
}
/// Blocks are maintained in a Region.
Region *getParent();
/// Returns the closest surrounding operation that contains this block or
/// nullptr if this is a top-level block.
Operation *getContainingOp();
/// Return if this block is the entry block in the parent region.
bool isEntryBlock();
/// Insert this block (which must not already be in a function) right before
/// the specified block.
void insertBefore(Block *block);
/// Unlink this Block from its parent region and delete it.
void erase();
//===--------------------------------------------------------------------===//
// Block argument management
//===--------------------------------------------------------------------===//
// This is the list of arguments to the block.
using BlockArgListType = ArrayRef<BlockArgument *>;
BlockArgListType getArguments() { return arguments; }
using args_iterator = BlockArgListType::iterator;
using reverse_args_iterator = BlockArgListType::reverse_iterator;
args_iterator args_begin() { return getArguments().begin(); }
args_iterator args_end() { return getArguments().end(); }
reverse_args_iterator args_rbegin() { return getArguments().rbegin(); }
reverse_args_iterator args_rend() { return getArguments().rend(); }
bool args_empty() { return arguments.empty(); }
/// Add one value to the argument list.
BlockArgument *addArgument(Type type);
/// Add one argument to the argument list for each type specified in the list.
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
/// Erase the argument at 'index' and remove it from the argument list. If
/// 'updatePredTerms' is set to true, this argument is also removed from the
/// terminators of each predecessor to this block.
void eraseArgument(unsigned index, bool updatePredTerms = true);
unsigned getNumArguments() { return arguments.size(); }
BlockArgument *getArgument(unsigned i) { return arguments[i]; }
//===--------------------------------------------------------------------===//
// Operation list management
//===--------------------------------------------------------------------===//
/// This is the list of operations in the block.
using InstListType = llvm::iplist<Operation>;
InstListType &getOperations() { return operations; }
// Iteration over the operations in the block.
using iterator = InstListType::iterator;
using reverse_iterator = InstListType::reverse_iterator;
iterator begin() { return operations.begin(); }
iterator end() { return operations.end(); }
reverse_iterator rbegin() { return operations.rbegin(); }
reverse_iterator rend() { return operations.rend(); }
bool empty() { return operations.empty(); }
void push_back(Operation *op) { operations.push_back(op); }
void push_front(Operation *op) { operations.push_front(op); }
Operation &back() { return operations.back(); }
Operation &front() { return operations.front(); }
/// Returns 'op' if 'op' lies in this block, or otherwise finds the
/// ancestor operation of 'op' that lies in this block. Returns nullptr if
/// the latter fails.
/// TODO: This is very specific functionality that should live somewhere else,
/// probably in Dominance.cpp.
Operation *findAncestorInstInBlock(Operation &op);
/// This drops all operand uses from operations within this block, which is
/// an essential step in breaking cyclic dependences between references when
/// they are to be deleted.
void dropAllReferences();
/// This drops all uses of values defined in this block or in the blocks of
/// nested regions wherever the uses are located.
void dropAllDefinedValueUses();
/// Returns true if the ordering of the child operations is valid, false
/// otherwise.
bool isInstOrderValid();
/// Invalidates the current ordering of operations.
void invalidateInstOrder();
/// Verifies the current ordering of child operations matches the
/// validInstOrder flag. Returns false if the order is valid, true otherwise.
bool verifyInstOrder();
/// Recomputes the ordering of child operations within the block.
void recomputeInstOrder();
private:
/// A utility iterator that filters out operations that are not 'OpT'.
template <typename OpT>
class op_filter_iterator
: public llvm::filter_iterator<Block::iterator, bool (*)(Operation &)> {
static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
public:
op_filter_iterator(Block::iterator it, Block::iterator end)
: llvm::filter_iterator<Block::iterator, bool (*)(Operation &)>(
it, end, &filter) {}
/// Allow implict conversion to the underlying block iterator.
operator Block::iterator() const { return this->wrapped(); }
};
public:
/// This class provides iteration over the held instructions of a block for a
/// specific operation type.
template <typename OpT>
class op_iterator : public llvm::mapped_iterator<op_filter_iterator<OpT>,
OpT (*)(Operation &)> {
static OpT unwrap(Operation &op) { return llvm::cast<OpT>(op); }
public:
using reference = OpT;
/// Initializes the iterator to the specified filter iterator.
op_iterator(op_filter_iterator<OpT> it)
: llvm::mapped_iterator<op_filter_iterator<OpT>, OpT (*)(Operation &)>(
it, &unwrap) {}
/// Allow implict conversion to the underlying block iterator.
operator Block::iterator() const { return this->wrapped(); }
};
/// Return an iterator range over the operations within this block that are of
/// 'OpT'.
template <typename OpT> llvm::iterator_range<op_iterator<OpT>> getOps() {
auto endIt = end();
return {op_filter_iterator<OpT>(begin(), endIt),
op_filter_iterator<OpT>(endIt, endIt)};
}
template <typename OpT> op_iterator<OpT> op_begin() {
return op_filter_iterator<OpT>(begin(), end());
}
template <typename OpT> op_iterator<OpT> op_end() {
return op_filter_iterator<OpT>(end(), end());
}
//===--------------------------------------------------------------------===//
// Terminator management
//===--------------------------------------------------------------------===//
/// Get the terminator operation of this block. This function asserts that
/// the block has a valid terminator operation.
Operation *getTerminator();
//===--------------------------------------------------------------------===//
// Predecessors and successors.
//===--------------------------------------------------------------------===//
// Predecessor iteration.
using pred_iterator = PredecessorIterator;
pred_iterator pred_begin();
pred_iterator pred_end();
llvm::iterator_range<pred_iterator> getPredecessors();
/// Return true if this block has no predecessors.
bool hasNoPredecessors();
/// If this block has exactly one predecessor, return it. Otherwise, return
/// null.
///
/// Note that if a block has duplicate predecessors from a single block (e.g.
/// if you have a conditional branch with the same block as the true/false
/// destinations) is not considered to be a single predecessor.
Block *getSinglePredecessor();
// Indexed successor access.
unsigned getNumSuccessors();
Block *getSuccessor(unsigned i);
// Successor iteration.
using succ_iterator = SuccessorIterator;
succ_iterator succ_begin();
succ_iterator succ_end();
llvm::iterator_range<succ_iterator> getSuccessors();
//===--------------------------------------------------------------------===//
// Operation Walkers
//===--------------------------------------------------------------------===//
/// Walk the operations in this block in postorder, calling the callback for
/// each operation.
void walk(llvm::function_ref<void(Operation *)> callback);
/// Specialization of walk to only visit operations of 'OpTy'.
template <typename OpTy> void walk(llvm::function_ref<void(OpTy)> callback) {
walk([&](Operation *opInst) {
if (auto op = dyn_cast<OpTy>(opInst))
callback(op);
});
}
/// Walk the operations in the specified [begin, end) range of this block in
/// postorder, calling the callback for each operation.
void walk(Block::iterator begin, Block::iterator end,
llvm::function_ref<void(Operation *)> callback);
//===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
/// Split the block into two blocks before the specified operation or
/// iterator.
///
/// Note that all operations BEFORE the specified iterator stay as part of
/// the original basic block, and the rest of the operations in the original
/// block are moved to the new block, including the old terminator. The
/// original block is left without a terminator.
///
/// The newly formed Block is returned, and the specified iterator is
/// invalidated.
Block *splitBlock(iterator splitBefore);
Block *splitBlock(Operation *splitBeforeInst) {
return splitBlock(iterator(splitBeforeInst));
}
/// Returns pointer to member of operation list.
static InstListType Block::*getSublistAccess(Operation *) {
return &Block::operations;
}
void print(raw_ostream &os);
void dump();
/// Print out the name of the block without printing its body.
/// NOTE: The printType argument is ignored. We keep it for compatibility
/// with LLVM dominator machinery that expects it to exist.
void printAsOperand(raw_ostream &os, bool printType = true);
private:
/// Pair of the parent object that owns this block and a bit that signifies if
/// the operations within this block have a valid ordering.
llvm::PointerIntPair<Region *, /*IntBits=*/1, bool> parentValidInstOrderPair;
/// This is the list of operations in the block.
InstListType operations;
/// This is the list of arguments to the block.
std::vector<BlockArgument *> arguments;
Block(Block &) = delete;
void operator=(Block &) = delete;
friend struct llvm::ilist_traits<Block>;
};
} // end namespace mlir
//===----------------------------------------------------------------------===//
// ilist_traits for Block
//===----------------------------------------------------------------------===//
namespace llvm {
template <>
struct ilist_traits<::mlir::Block> : public ilist_alloc_traits<::mlir::Block> {
using Block = ::mlir::Block;
using block_iterator = simple_ilist<::mlir::Block>::iterator;
void addNodeToList(Block *block);
void removeNodeFromList(Block *block);
void transferNodesFromList(ilist_traits<Block> &otherList,
block_iterator first, block_iterator last);
private:
mlir::Region *getContainingRegion();
};
} // end namespace llvm
namespace mlir {
//===----------------------------------------------------------------------===//
// Predecessors
//===----------------------------------------------------------------------===//
/// Implement a predecessor iterator for blocks. This works by walking the use
/// lists of the blocks. The entries on this list are the BlockOperands that
/// are embedded into terminator operations. From the operand, we can get the
/// terminator that contains it, and its parent block is the predecessor.
class PredecessorIterator final
: public llvm::mapped_iterator<ValueUseIterator<BlockOperand>,
Block *(*)(BlockOperand &)> {
static Block *unwrap(BlockOperand &value);
public:
using reference = Block *;
/// Initializes the operand type iterator to the specified operand iterator.
PredecessorIterator(ValueUseIterator<BlockOperand> it)
: llvm::mapped_iterator<ValueUseIterator<BlockOperand>,
Block *(*)(BlockOperand &)>(it, &unwrap) {}
explicit PredecessorIterator(BlockOperand *operand)
: PredecessorIterator(ValueUseIterator<BlockOperand>(operand)) {}
/// Get the successor number in the predecessor terminator.
unsigned getSuccessorIndex() const;
};
inline auto Block::pred_begin() -> pred_iterator {
return pred_iterator((BlockOperand *)getFirstUse());
}
inline auto Block::pred_end() -> pred_iterator {
return pred_iterator(nullptr);
}
inline auto Block::getPredecessors() -> llvm::iterator_range<pred_iterator> {
return {pred_begin(), pred_end()};
}
//===----------------------------------------------------------------------===//
// Successors
//===----------------------------------------------------------------------===//
/// This template implements the successor iterators for Block.
class SuccessorIterator final
: public indexed_accessor_iterator<SuccessorIterator, Block *, Block *,
Block *, Block *> {
public:
/// Initializes the result iterator to the specified index.
SuccessorIterator(Block *object, unsigned index)
: indexed_accessor_iterator<SuccessorIterator, Block *, Block *, Block *,
Block *>(object, index) {}
SuccessorIterator(const SuccessorIterator &other)
: SuccessorIterator(other.object, other.index) {}
Block *operator*() const { return this->object->getSuccessor(this->index); }
/// Get the successor number in the terminator.
unsigned getSuccessorIndex() const { return this->index; }
};
inline auto Block::succ_begin() -> succ_iterator {
return succ_iterator(this, 0);
}
inline auto Block::succ_end() -> succ_iterator {
return succ_iterator(this, getNumSuccessors());
}
inline auto Block::getSuccessors() -> llvm::iterator_range<succ_iterator> {
return {succ_begin(), succ_end()};
}
} // end namespace mlir
#endif // MLIR_IR_BLOCK_H

View File

@ -0,0 +1,93 @@
//===- BlockAndValueMapping.h -----------------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines a utility class for maintaining a mapping for multiple
// value types.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_BLOCKANDVALUEMAPPING_H
#define MLIR_IR_BLOCKANDVALUEMAPPING_H
#include "mlir/IR/Block.h"
namespace mlir {
// This is a utility class for mapping one set of values to another. New
// mappings can be inserted via 'map'. Existing mappings can be
// found via the 'lookup*' functions. There are two variants that differ only in
// return value when an existing is not found for the provided key.
// 'lookupOrNull' returns nullptr where as 'lookupOrDefault' will return the
// lookup key.
class BlockAndValueMapping {
public:
/// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping,
/// it is overwritten.
void map(Block *from, Block *to) { valueMap[from] = to; }
void map(Value *from, Value *to) { valueMap[from] = to; }
/// Erases a mapping for 'from'.
void erase(IRObjectWithUseList *from) { valueMap.erase(from); }
/// Checks to see if a mapping for 'from' exists.
bool contains(IRObjectWithUseList *from) const {
return valueMap.count(from);
}
/// Lookup a mapped value within the map. If a mapping for the provided value
/// does not exist then return nullptr.
Block *lookupOrNull(Block *from) const {
return lookupOrValue(from, (Block *)nullptr);
}
Value *lookupOrNull(Value *from) const {
return lookupOrValue(from, (Value *)nullptr);
}
/// Lookup a mapped value within the map. If a mapping for the provided value
/// does not exist then return the provided value.
Block *lookupOrDefault(Block *from) const {
return lookupOrValue(from, from);
}
Value *lookupOrDefault(Value *from) const {
return lookupOrValue(from, from);
}
/// Lookup a mapped value within the map. This asserts the provided value
/// exists within the map.
template <typename T> T *lookup(T *from) const {
auto *result = lookupOrNull(from);
assert(result && "expected 'from' to be contained within the map");
return result;
}
/// Clears all mappings held by the mapper.
void clear() { valueMap.clear(); }
private:
/// Utility lookupOrValue that looks up an existing key or returns the
/// provided value. This function assumes that if a mapping does exist, then
/// it is of 'T' type.
template <typename T> T *lookupOrValue(T *from, T *value) const {
auto it = valueMap.find(from);
return it != valueMap.end() ? static_cast<T *>(it->second) : value;
}
llvm::DenseMap<IRObjectWithUseList *, IRObjectWithUseList *> valueMap;
};
} // end namespace mlir
#endif // MLIR_IR_BLOCKANDVALUEMAPPING_H

View File

@ -0,0 +1,384 @@
//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_IR_BUILDERS_H
#define MLIR_IR_BUILDERS_H
#include "mlir/IR/OpDefinition.h"
namespace mlir {
class AffineExpr;
class BlockAndValueMapping;
class ModuleOp;
class UnknownLoc;
class FileLineColLoc;
class Type;
class PrimitiveType;
class IntegerType;
class FunctionType;
class MemRefType;
class VectorType;
class RankedTensorType;
class UnrankedTensorType;
class TupleType;
class NoneType;
class BoolAttr;
class IntegerAttr;
class FloatAttr;
class StringAttr;
class TypeAttr;
class ArrayAttr;
class SymbolRefAttr;
class ElementsAttr;
class DenseElementsAttr;
class DenseIntElementsAttr;
class AffineMapAttr;
class AffineMap;
class UnitAttr;
/// This class is a general helper class for creating context-global objects
/// like types, attributes, and affine expressions.
class Builder {
public:
explicit Builder(MLIRContext *context) : context(context) {}
explicit Builder(ModuleOp module);
MLIRContext *getContext() const { return context; }
Identifier getIdentifier(StringRef str);
// Locations.
Location getUnknownLoc();
Location getFileLineColLoc(Identifier filename, unsigned line,
unsigned column);
Location getFusedLoc(ArrayRef<Location> locs,
Attribute metadata = Attribute());
// Types.
FloatType getBF16Type();
FloatType getF16Type();
FloatType getF32Type();
FloatType getF64Type();
IndexType getIndexType();
IntegerType getI1Type();
IntegerType getIntegerType(unsigned width);
FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
MemRefType getMemRefType(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition = {},
unsigned memorySpace = 0);
VectorType getVectorType(ArrayRef<int64_t> shape, Type elementType);
RankedTensorType getTensorType(ArrayRef<int64_t> shape, Type elementType);
UnrankedTensorType getTensorType(Type elementType);
TupleType getTupleType(ArrayRef<Type> elementTypes);
NoneType getNoneType();
/// Get or construct an instance of the type 'ty' with provided arguments.
template <typename Ty, typename... Args> Ty getType(Args... args) {
return Ty::get(context, args...);
}
// Attributes.
NamedAttribute getNamedAttr(StringRef name, Attribute val);
UnitAttr getUnitAttr();
BoolAttr getBoolAttr(bool value);
DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value);
IntegerAttr getIntegerAttr(Type type, int64_t value);
IntegerAttr getIntegerAttr(Type type, const APInt &value);
FloatAttr getFloatAttr(Type type, double value);
FloatAttr getFloatAttr(Type type, const APFloat &value);
StringAttr getStringAttr(StringRef bytes);
StringAttr getStringAttr(StringRef bytes, Type type);
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
AffineMapAttr getAffineMapAttr(AffineMap map);
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
TypeAttr getTypeAttr(Type type);
SymbolRefAttr getSymbolRefAttr(Operation *value);
SymbolRefAttr getSymbolRefAttr(StringRef value);
ElementsAttr getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values);
ElementsAttr getDenseIntElementsAttr(ShapedType type,
ArrayRef<int64_t> values);
ElementsAttr getSparseElementsAttr(ShapedType type,
DenseIntElementsAttr indices,
DenseElementsAttr values);
ElementsAttr getOpaqueElementsAttr(Dialect *dialect, ShapedType type,
StringRef bytes);
// Returns a 0-valued attribute of the given `type`. This function only
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
// ranked tensor of them. Returns null attribute otherwise.
Attribute getZeroAttr(Type type);
// Convenience methods for fixed types.
FloatAttr getF16FloatAttr(float value);
FloatAttr getF32FloatAttr(float value);
FloatAttr getF64FloatAttr(double value);
IntegerAttr getI32IntegerAttr(int32_t value);
IntegerAttr getI64IntegerAttr(int64_t value);
ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);
AffineExpr getAffineSymbolExpr(unsigned position);
AffineExpr getAffineConstantExpr(int64_t constant);
AffineMap getAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results);
// Special cases of affine maps and integer sets
/// Returns a single constant result affine map with 0 dimensions and 0
/// symbols. One constant result: () -> (val).
AffineMap getConstantAffineMap(int64_t val);
// One dimension id identity map: (i) -> (i).
AffineMap getDimIdentityMap();
// Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2).
AffineMap getMultiDimIdentityMap(unsigned rank);
// One symbol identity map: ()[s] -> (s).
AffineMap getSymbolIdentityMap();
/// Returns a map that shifts its (single) input dimension by 'shift'.
/// (d0) -> (d0 + shift)
AffineMap getSingleDimShiftAffineMap(int64_t shift);
/// Returns an affine map that is a translation (shift) of all result
/// expressions in 'map' by 'shift'.
/// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2
/// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2)
AffineMap getShiftedAffineMap(AffineMap map, int64_t shift);
// Integer set.
IntegerSet getIntegerSet(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> constraints,
ArrayRef<bool> isEq);
// TODO: Helpers for affine map/exprs, etc.
protected:
MLIRContext *context;
};
/// This class helps build Operations. Operations that are created are
/// automatically inserted at an insertion point. The builder is copyable.
class OpBuilder : public Builder {
public:
/// Create a builder with the given context.
explicit OpBuilder(MLIRContext *ctx) : Builder(ctx) {}
/// Create a builder and set the insertion point to the start of the region.
explicit OpBuilder(Region *region) : Builder(region->getContext()) {
if (!region->empty())
setInsertionPoint(&region->front(), region->front().begin());
}
explicit OpBuilder(Region &region) : OpBuilder(&region) {}
virtual ~OpBuilder();
/// Create a builder and set insertion point to the given operation, which
/// will cause subsequent insertions to go right before it.
explicit OpBuilder(Operation *op) : Builder(op->getContext()) {
setInsertionPoint(op);
}
explicit OpBuilder(Block *block) : OpBuilder(block, block->end()) {}
OpBuilder(Block *block, Block::iterator insertPoint)
: OpBuilder(block->getParent()) {
setInsertionPoint(block, insertPoint);
}
/// This class represents a saved insertion point.
class InsertPoint {
public:
/// Creates a new insertion point which doesn't point to anything.
InsertPoint() = default;
/// Creates a new insertion point at the given location.
InsertPoint(Block *insertBlock, Block::iterator insertPt)
: block(insertBlock), point(insertPt) {}
/// Returns true if this insert point is set.
bool isSet() const { return (block != nullptr); }
Block *getBlock() const { return block; }
Block::iterator getPoint() const { return point; }
private:
Block *block = nullptr;
Block::iterator point;
};
/// Reset the insertion point to no location. Creating an operation without a
/// set insertion point is an error, but this can still be useful when the
/// current insertion point a builder refers to is being removed.
void clearInsertionPoint() {
this->block = nullptr;
insertPoint = Block::iterator();
}
/// Return a saved insertion point.
InsertPoint saveInsertionPoint() const {
return InsertPoint(getInsertionBlock(), getInsertionPoint());
}
/// Restore the insert point to a previously saved point.
void restoreInsertionPoint(InsertPoint ip) {
if (ip.isSet())
setInsertionPoint(ip.getBlock(), ip.getPoint());
else
clearInsertionPoint();
}
/// Set the insertion point to the specified location.
void setInsertionPoint(Block *block, Block::iterator insertPoint) {
// TODO: check that insertPoint is in this rather than some other block.
this->block = block;
this->insertPoint = insertPoint;
}
/// Sets the insertion point to the specified operation, which will cause
/// subsequent insertions to go right before it.
void setInsertionPoint(Operation *op) {
setInsertionPoint(op->getBlock(), Block::iterator(op));
}
/// Sets the insertion point to the start of the specified block.
void setInsertionPointToStart(Block *block) {
setInsertionPoint(block, block->begin());
}
/// Sets the insertion point to the end of the specified block.
void setInsertionPointToEnd(Block *block) {
setInsertionPoint(block, block->end());
}
/// Return the block the current insertion point belongs to. Note that the
/// the insertion point is not necessarily the end of the block.
Block *getInsertionBlock() const { return block; }
/// Returns the current insertion point of the builder.
Block::iterator getInsertionPoint() const { return insertPoint; }
/// Add new block and set the insertion point to the end of it. The block is
/// inserted at the provided insertion point of 'parent'.
Block *createBlock(Region *parent, Region::iterator insertPt = {});
/// Add new block and set the insertion point to the end of it. The block is
/// placed before 'insertBefore'.
Block *createBlock(Block *insertBefore);
/// Returns the current block of the builder.
Block *getBlock() const { return block; }
/// Creates an operation given the fields represented as an OperationState.
virtual Operation *createOperation(const OperationState &state);
/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy create(Location location, Args&&... args) {
OperationState state(location, OpTy::getOperationName());
OpTy::build(this, &state, std::forward<Args>(args)...);
auto *op = createOperation(state);
auto result = dyn_cast<OpTy>(op);
assert(result && "Builder didn't return the right type");
return result;
}
/// Create an operation of specific op type at the current insertion point,
/// and immediately try to fold it. This functions populates 'results' with
/// the results after folding the operation.
template <typename OpTy, typename... Args>
void createOrFold(SmallVectorImpl<Value *> &results, Location location,
Args &&... args) {
auto op = create<OpTy>(location, std::forward<Args>(args)...);
tryFold(op.getOperation(), results);
}
/// Overload to create or fold a single result operation.
template <typename OpTy, typename... Args>
typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
Value *>::type
createOrFold(Location location, Args &&... args) {
SmallVector<Value *, 1> results;
createOrFold<OpTy>(results, location, std::forward<Args>(args)...);
return results.front();
}
/// Overload to create or fold a zero result operation.
template <typename OpTy, typename... Args>
typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
OpTy>::type
createOrFold(Location location, Args &&... args) {
auto op = create<OpTy>(location, std::forward<Args>(args)...);
SmallVector<Value *, 0> unused;
tryFold(op.getOperation(), unused);
// Folding cannot remove a zero-result operation, so for convenience we
// continue to return it.
return op;
}
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
/// ( leaving them alone if no entry is present). Replaces references to
/// cloned sub-operations to the corresponding operation that is copied,
/// and adds those mappings to the map.
Operation *clone(Operation &op, BlockAndValueMapping &mapper) {
Operation *cloneOp = op.clone(mapper);
insert(cloneOp);
return cloneOp;
}
Operation *clone(Operation &op) {
Operation *cloneOp = op.clone();
insert(cloneOp);
return cloneOp;
}
/// Creates a deep copy of this operation but keep the operation regions
/// empty. Operands are remapped using `mapper` (if present), and `mapper` is
/// updated to contain the results.
Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) {
Operation *cloneOp = op.cloneWithoutRegions(mapper);
insert(cloneOp);
return cloneOp;
}
Operation *cloneWithoutRegions(Operation &op) {
Operation *cloneOp = op.cloneWithoutRegions();
insert(cloneOp);
return cloneOp;
}
private:
/// Attempts to fold the given operation and places new results within
/// 'results'.
void tryFold(Operation *op, SmallVectorImpl<Value *> &results);
/// Insert the given operation at the current insertion point.
void insert(Operation *op);
Block *block = nullptr;
Block::iterator insertPoint;
};
} // namespace mlir
#endif

View File

@ -0,0 +1,604 @@
//===- Diagnostics.h - MLIR Diagnostics -------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines utilities for emitting diagnostics.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIAGNOSTICS_H
#define MLIR_IR_DIAGNOSTICS_H
#include "mlir/IR/Location.h"
#include "mlir/Support/STLExtras.h"
#include <functional>
namespace llvm {
class MemoryBuffer;
class SMLoc;
class SourceMgr;
} // end namespace llvm
namespace mlir {
class DiagnosticEngine;
class Identifier;
struct LogicalResult;
class MLIRContext;
class Operation;
class OperationName;
class Type;
namespace detail {
struct DiagnosticEngineImpl;
} // end namespace detail
/// Defines the different supported severity of a diagnostic.
enum class DiagnosticSeverity {
Note,
Warning,
Error,
Remark,
};
//===----------------------------------------------------------------------===//
// DiagnosticArgument
//===----------------------------------------------------------------------===//
/// A variant type that holds a single argument for a diagnostic.
class DiagnosticArgument {
public:
/// Enum that represents the different kinds of diagnostic arguments
/// supported.
enum class DiagnosticArgumentKind {
Attribute,
Double,
Integer,
Operation,
String,
Type,
Unsigned,
};
/// Outputs this argument to a stream.
void print(raw_ostream &os) const;
/// Returns the kind of this argument.
DiagnosticArgumentKind getKind() const { return kind; }
/// Returns this argument as an Attribute.
Attribute getAsAttribute() const;
/// Returns this argument as a double.
double getAsDouble() const {
assert(getKind() == DiagnosticArgumentKind::Double);
return doubleVal;
}
/// Returns this argument as a signed integer.
int64_t getAsInteger() const {
assert(getKind() == DiagnosticArgumentKind::Integer);
return static_cast<int64_t>(opaqueVal);
}
/// Returns this argument as an operation.
Operation &getAsOperation() const {
assert(getKind() == DiagnosticArgumentKind::Operation);
return *reinterpret_cast<Operation *>(opaqueVal);
}
/// Returns this argument as a string.
StringRef getAsString() const {
assert(getKind() == DiagnosticArgumentKind::String);
return stringVal;
}
/// Returns this argument as a Type.
Type getAsType() const;
/// Returns this argument as an unsigned integer.
uint64_t getAsUnsigned() const {
assert(getKind() == DiagnosticArgumentKind::Unsigned);
return static_cast<uint64_t>(opaqueVal);
}
private:
friend class Diagnostic;
// Construct from an Attribute.
explicit DiagnosticArgument(Attribute attr);
// Construct from a floating point number.
explicit DiagnosticArgument(double val)
: kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
// Construct from a signed integer.
template <typename T>
explicit DiagnosticArgument(
T val, typename std::enable_if<std::is_signed<T>::value &&
std::numeric_limits<T>::is_integer &&
sizeof(T) <= sizeof(int64_t)>::type * = 0)
: kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
// Construct from an unsigned integer.
template <typename T>
explicit DiagnosticArgument(
T val, typename std::enable_if<std::is_unsigned<T>::value &&
std::numeric_limits<T>::is_integer &&
sizeof(T) <= sizeof(uint64_t)>::type * = 0)
: kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
// Construct from an operation reference.
explicit DiagnosticArgument(Operation &val) : DiagnosticArgument(&val) {}
explicit DiagnosticArgument(Operation *val)
: kind(DiagnosticArgumentKind::Operation),
opaqueVal(reinterpret_cast<intptr_t>(val)) {
assert(val && "expected valid operation");
}
// Construct from a string reference.
explicit DiagnosticArgument(StringRef val)
: kind(DiagnosticArgumentKind::String), stringVal(val) {}
// Construct from a Type.
explicit DiagnosticArgument(Type val);
/// The kind of this argument.
DiagnosticArgumentKind kind;
/// The value of this argument.
union {
double doubleVal;
intptr_t opaqueVal;
StringRef stringVal;
};
};
inline raw_ostream &operator<<(raw_ostream &os, const DiagnosticArgument &arg) {
arg.print(os);
return os;
}
//===----------------------------------------------------------------------===//
// Diagnostic
//===----------------------------------------------------------------------===//
/// This class contains all of the information necessary to report a diagnostic
/// to the DiagnosticEngine. It should generally not be constructed directly,
/// and instead used transitively via InFlightDiagnostic.
class Diagnostic {
using NoteVector = std::vector<std::unique_ptr<Diagnostic>>;
/// This class implements a wrapper iterator around NoteVector::iterator to
/// implicitly dereference the unique_ptr.
template <typename IteratorTy, typename NotePtrTy = decltype(*IteratorTy()),
typename ResultTy = decltype(**IteratorTy())>
class NoteIteratorImpl
: public llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)> {
static ResultTy &unwrap(NotePtrTy note) { return *note; }
public:
NoteIteratorImpl(IteratorTy it)
: llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)>(it,
&unwrap) {}
};
public:
Diagnostic(Location loc, DiagnosticSeverity severity)
: loc(loc), severity(severity) {}
Diagnostic(Diagnostic &&) = default;
Diagnostic &operator=(Diagnostic &&) = default;
/// Returns the severity of this diagnostic.
DiagnosticSeverity getSeverity() const { return severity; }
/// Returns the source location for this diagnostic.
Location getLocation() const { return loc; }
/// Returns the current list of diagnostic arguments.
MutableArrayRef<DiagnosticArgument> getArguments() { return arguments; }
ArrayRef<DiagnosticArgument> getArguments() const { return arguments; }
/// Stream operator for inserting new diagnostic arguments.
template <typename Arg>
typename std::enable_if<!std::is_convertible<Arg, StringRef>::value,
Diagnostic &>::type
operator<<(Arg &&val) {
arguments.push_back(DiagnosticArgument(std::forward<Arg>(val)));
return *this;
}
/// Stream in a string literal.
Diagnostic &operator<<(const char *val) {
arguments.push_back(DiagnosticArgument(val));
return *this;
}
/// Stream in a Twine argument.
Diagnostic &operator<<(char val);
Diagnostic &operator<<(const Twine &val);
Diagnostic &operator<<(Twine &&val);
/// Stream in an Identifier.
Diagnostic &operator<<(Identifier val);
/// Stream in an OperationName.
Diagnostic &operator<<(OperationName val);
/// Stream in a range.
template <typename T> Diagnostic &operator<<(llvm::iterator_range<T> range) {
return appendRange(range);
}
template <typename T> Diagnostic &operator<<(llvm::ArrayRef<T> range) {
return appendRange(range);
}
/// Append a range to the diagnostic. The default delimiter between elements
/// is ','.
template <typename T, template <typename> class Container>
Diagnostic &appendRange(const Container<T> &c, const char *delim = ", ") {
interleave(
c, [&](const detail::ValueOfRange<Container<T>> &a) { *this << a; },
[&]() { *this << delim; });
return *this;
}
/// Append arguments to the diagnostic.
template <typename Arg1, typename Arg2, typename... Args>
Diagnostic &append(Arg1 &&arg1, Arg2 &&arg2, Args &&... args) {
append(std::forward<Arg1>(arg1));
return append(std::forward<Arg2>(arg2), std::forward<Args>(args)...);
}
/// Append one argument to the diagnostic.
template <typename Arg> Diagnostic &append(Arg &&arg) {
*this << std::forward<Arg>(arg);
return *this;
}
/// Outputs this diagnostic to a stream.
void print(raw_ostream &os) const;
/// Converts the diagnostic to a string.
std::string str() const;
/// Attaches a note to this diagnostic. A new location may be optionally
/// provided, if not, then the location defaults to the one specified for this
/// diagnostic. Notes may not be attached to other notes.
Diagnostic &attachNote(llvm::Optional<Location> noteLoc = llvm::None);
using note_iterator = NoteIteratorImpl<NoteVector::iterator>;
using const_note_iterator = NoteIteratorImpl<NoteVector::const_iterator>;
/// Returns the notes held by this diagnostic.
llvm::iterator_range<note_iterator> getNotes() {
return {notes.begin(), notes.end()};
}
llvm::iterator_range<const_note_iterator> getNotes() const {
return {notes.begin(), notes.end()};
}
/// Allow a diagnostic to be converted to 'failure'.
operator LogicalResult() const;
private:
Diagnostic(const Diagnostic &rhs) = delete;
Diagnostic &operator=(const Diagnostic &rhs) = delete;
/// The source location.
Location loc;
/// The severity of this diagnostic.
DiagnosticSeverity severity;
/// The current list of arguments.
SmallVector<DiagnosticArgument, 4> arguments;
/// A list of string values used as arguments. This is used to guarantee the
/// liveness of non-constant strings used in diagnostics.
std::vector<std::unique_ptr<char[]>> strings;
/// A list of attached notes.
NoteVector notes;
};
inline raw_ostream &operator<<(raw_ostream &os, const Diagnostic &diag) {
diag.print(os);
return os;
}
//===----------------------------------------------------------------------===//
// InFlightDiagnostic
//===----------------------------------------------------------------------===//
/// This class represents a diagnostic that is inflight and set to be reported.
/// This allows for last minute modifications of the diagnostic before it is
/// emitted by a DiagnosticEngine.
class InFlightDiagnostic {
public:
InFlightDiagnostic() = default;
InFlightDiagnostic(InFlightDiagnostic &&rhs)
: owner(rhs.owner), impl(std::move(rhs.impl)) {
// Reset the rhs diagnostic.
rhs.impl.reset();
rhs.abandon();
}
~InFlightDiagnostic() {
if (isInFlight())
report();
}
/// Stream operator for new diagnostic arguments.
template <typename Arg> InFlightDiagnostic &operator<<(Arg &&arg) & {
return append(std::forward<Arg>(arg));
}
template <typename Arg> InFlightDiagnostic &&operator<<(Arg &&arg) && {
return std::move(append(std::forward<Arg>(arg)));
}
/// Append arguments to the diagnostic.
template <typename... Args> InFlightDiagnostic &append(Args &&... args) & {
assert(isActive() && "diagnostic not active");
if (isInFlight())
impl->append(std::forward<Args>(args)...);
return *this;
}
template <typename... Args> InFlightDiagnostic &&append(Args &&... args) && {
return std::move(append(std::forward<Args>(args)...));
}
/// Attaches a note to this diagnostic.
Diagnostic &attachNote(llvm::Optional<Location> noteLoc = llvm::None) {
assert(isActive() && "diagnostic not active");
return impl->attachNote(noteLoc);
}
/// Reports the diagnostic to the engine.
void report();
/// Abandons this diagnostic so that it will no longer be reported.
void abandon();
/// Allow an inflight diagnostic to be converted to 'failure', otherwise
/// 'success' if this is an empty diagnostic.
operator LogicalResult() const;
private:
InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete;
InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete;
InFlightDiagnostic(DiagnosticEngine *owner, Diagnostic &&rhs)
: owner(owner), impl(std::move(rhs)) {}
/// Returns if the diagnostic is still active, i.e. it has a live diagnostic.
bool isActive() const { return impl.hasValue(); }
/// Returns if the diagnostic is still in flight to be reported.
bool isInFlight() const { return owner; }
// Allow access to the constructor.
friend DiagnosticEngine;
/// The engine that this diagnostic is to report to.
DiagnosticEngine *owner;
/// The raw diagnostic that is inflight to be reported.
llvm::Optional<Diagnostic> impl;
};
//===----------------------------------------------------------------------===//
// DiagnosticEngine
//===----------------------------------------------------------------------===//
/// This class is the main interface for diagnostics. The DiagnosticEngine
/// manages the registration of diagnostic handlers as well as the core API for
/// diagnostic emission. This class should not be constructed directly, but
/// instead interfaced with via an MLIRContext instance.
class DiagnosticEngine {
public:
~DiagnosticEngine();
// Diagnostic handler registration and use. MLIR supports the ability for the
// IR to carry arbitrary metadata about operation location information. If a
// problem is detected by the compiler, it can invoke the emitError /
// emitWarning / emitRemark method on an Operation and have it get reported
// through this interface.
//
// Tools using MLIR are encouraged to register error handlers and define a
// schema for their location information. If they don't, then warnings and
// notes will be dropped and errors will be emitted to errs.
using HandlerTy = std::function<void(Diagnostic)>;
/// Set the diagnostic handler for this engine. Note that this replaces any
/// existing handler.
void setHandler(const HandlerTy &handler);
/// Return the current diagnostic handler, or null if none is present.
HandlerTy getHandler();
/// Create a new inflight diagnostic with the given location and severity.
InFlightDiagnostic emit(Location loc, DiagnosticSeverity severity) {
assert(severity != DiagnosticSeverity::Note &&
"notes should not be emitted directly");
return InFlightDiagnostic(this, Diagnostic(loc, severity));
}
/// Emit a diagnostic using the registered issue handler if present, or with
/// the default behavior if not.
void emit(Diagnostic diag);
private:
friend class MLIRContextImpl;
DiagnosticEngine();
/// The internal implementation of the DiagnosticEngine.
std::unique_ptr<detail::DiagnosticEngineImpl> impl;
};
//===----------------------------------------------------------------------===//
// ScopedDiagnosticHandler
//===----------------------------------------------------------------------===//
/// This diagnostic handler is a simple RAII class that saves and restores the
/// current diagnostic handler registered to a given context. This class can
/// be either be used directly, or in conjunction with a derived diagnostic
/// handler.
class ScopedDiagnosticHandler {
public:
ScopedDiagnosticHandler(MLIRContext *ctx);
ScopedDiagnosticHandler(MLIRContext *ctx,
const DiagnosticEngine::HandlerTy &handler);
~ScopedDiagnosticHandler();
/// Propagate a diagnostic to the existing diagnostic handler.
void propagateDiagnostic(Diagnostic diag) {
if (existingHandler)
existingHandler(std::move(diag));
}
private:
/// The existing diagnostic handler registered with the context at the time of
/// construction.
DiagnosticEngine::HandlerTy existingHandler;
/// The context to register the handler back to.
MLIRContext *ctx;
};
/// Utility method to emit an error message using this location.
InFlightDiagnostic emitError(Location loc);
InFlightDiagnostic emitError(Location loc, const Twine &message);
/// Utility method to emit a warning message using this location.
InFlightDiagnostic emitWarning(Location loc);
InFlightDiagnostic emitWarning(Location loc, const Twine &message);
/// Utility method to emit a remark message using this location.
InFlightDiagnostic emitRemark(Location loc);
InFlightDiagnostic emitRemark(Location loc, const Twine &message);
//===----------------------------------------------------------------------===//
// SourceMgrDiagnosticHandler
//===----------------------------------------------------------------------===//
namespace detail {
struct SourceMgrDiagnosticHandlerImpl;
} // end namespace detail
/// This class is a utility diagnostic handler for use with llvm::SourceMgr.
class SourceMgrDiagnosticHandler : public ScopedDiagnosticHandler {
public:
SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx,
llvm::raw_ostream &os);
SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx);
~SourceMgrDiagnosticHandler();
/// Emit the given diagnostic information with the held source manager.
void emitDiagnostic(Location loc, Twine message, DiagnosticSeverity kind);
protected:
/// Emit the given diagnostic with the held source manager.
void emitDiagnostic(Diagnostic &diag);
/// Get a memory buffer for the given file, or nullptr if no file is
/// available.
const llvm::MemoryBuffer *getBufferForFile(StringRef filename);
/// The source manager that we are wrapping.
llvm::SourceMgr &mgr;
/// The output stream to use when printing diagnostics.
llvm::raw_ostream &os;
private:
/// Convert a location into the given memory buffer into an SMLoc.
llvm::SMLoc convertLocToSMLoc(FileLineColLoc loc);
/// The maximum depth that a call stack will be printed.
/// TODO(riverriddle) This should be a tunable flag.
unsigned callStackLimit = 10;
std::unique_ptr<detail::SourceMgrDiagnosticHandlerImpl> impl;
};
//===----------------------------------------------------------------------===//
// SourceMgrDiagnosticVerifierHandler
//===----------------------------------------------------------------------===//
namespace detail {
struct SourceMgrDiagnosticVerifierHandlerImpl;
} // end namespace detail
/// This class is a utility diagnostic handler for use with llvm::SourceMgr that
/// verifies that emitted diagnostics match 'expected-*' lines on the
/// corresponding line of the source file.
class SourceMgrDiagnosticVerifierHandler : public SourceMgrDiagnosticHandler {
public:
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx,
llvm::raw_ostream &out);
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx);
~SourceMgrDiagnosticVerifierHandler();
/// Returns the status of the handler and verifies that all expected
/// diagnostics were emitted. This return success if all diagnostics were
/// verified correctly, failure otherwise.
LogicalResult verify();
private:
/// Process a single diagnostic.
void process(Diagnostic &diag);
/// Process a FileLineColLoc diagnostic.
void process(FileLineColLoc loc, StringRef msg, DiagnosticSeverity kind);
std::unique_ptr<detail::SourceMgrDiagnosticVerifierHandlerImpl> impl;
};
//===----------------------------------------------------------------------===//
// ParallelDiagnosticHandler
//===----------------------------------------------------------------------===//
namespace detail {
struct ParallelDiagnosticHandlerImpl;
} // end namespace detail
/// This class is a utility diagnostic handler for use when multi-threading some
/// part of the compiler where diagnostics may be emitted. This handler ensures
/// a deterministic ordering to the emitted diagnostics that mirrors that of a
/// single-threaded compilation.
class ParallelDiagnosticHandler {
public:
ParallelDiagnosticHandler(MLIRContext *ctx);
~ParallelDiagnosticHandler();
/// Set the order id for the current thread. This is required to be set by
/// each thread that will be emitting diagnostics to this handler. The orderID
/// corresponds to the order in which diagnostics would be emitted when
/// executing synchronously. For example, if we were processing a list
/// of operations [a, b, c] on a single-thread. Diagnostics emitted while
/// processing operation 'a' would be emitted before those for 'b' or 'c'.
/// This corresponds 1-1 with the 'orderID'. The thread that is processing 'a'
/// should set the orderID to '0'; the thread processing 'b' should set it to
/// '1'; and so on and so forth. This provides a way for the handler to
/// deterministically order the diagnostics that it receives given the thread
/// that it is receiving on.
void setOrderIDForThread(size_t orderID);
private:
std::unique_ptr<detail::ParallelDiagnosticHandlerImpl> impl;
};
} // namespace mlir
#endif

View File

@ -0,0 +1,286 @@
//===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the 'dialect' abstraction.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECT_H
#define MLIR_IR_DIALECT_H
#include "mlir/IR/OperationSupport.h"
namespace mlir {
class OpBuilder;
class Type;
using DialectConstantDecodeHook =
std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
using DialectConstantFoldHook = std::function<LogicalResult(
Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
using DialectExtractElementHook =
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
/// default named types for asm printing, etc.
///
/// Instances of the dialect object are global across all MLIRContext's that may
/// be active in the process.
///
class Dialect {
public:
virtual ~Dialect();
/// Utility function that returns if the given string is a valid dialect
/// namespace.
static bool isValidNamespace(StringRef str);
MLIRContext *getContext() const { return context; }
StringRef getNamespace() const { return name; }
/// Returns true if this dialect allows for unregistered operations, i.e.
/// operations prefixed with the dialect namespace but not registered with
/// addOperation.
bool allowsUnknownOperations() const { return allowUnknownOps; }
//===--------------------------------------------------------------------===//
// Constant Hooks
//===--------------------------------------------------------------------===//
/// Registered fallback constant fold hook for the dialect. Like the constant
/// fold hook of each operation, it attempts to constant fold the operation
/// with the specified constant operand values - the elements in "operands"
/// will correspond directly to the operands of the operation, but may be null
/// if non-constant. If constant folding is successful, this fills in the
/// `results` vector. If not, this returns failure and `results` is
/// unspecified.
DialectConstantFoldHook constantFoldHook =
[](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) { return failure(); };
/// Registered hook to decode opaque constants associated with this
/// dialect. The hook function attempts to decode an opaque constant tensor
/// into a tensor with non-opaque content. If decoding is successful, this
/// method returns false and sets 'output' attribute. If not, it returns true
/// and leaves 'output' unspecified. The default hook fails to decode.
DialectConstantDecodeHook decodeHook =
[](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
/// Registered hook to extract an element from an opaque constant associated
/// with this dialect. If element has been successfully extracted, this
/// method returns that element. If not, it returns an empty attribute.
/// The default hook fails to extract an element.
DialectExtractElementHook extractElementHook =
[](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
return Attribute();
};
/// Registered hook to materialize a single constant operation from a given
/// attribute value with the desired resultant type. This method should use
/// the provided builder to create the operation without changing the
/// insertion position. The generated operation is expected to be constant
/// like, i.e. single result, zero operands, non side-effecting, etc. On
/// success, this hook should return the value generated to represent the
/// constant value. Otherwise, it should return null on failure.
virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
return nullptr;
}
//===--------------------------------------------------------------------===//
// Parsing Hooks
//===--------------------------------------------------------------------===//
/// Parse an attribute registered to this dialect. If 'type' is nonnull, it
/// refers to the expected type of the attribute.
virtual Attribute parseAttribute(StringRef attrData, Type type,
Location loc) const;
/// Print an attribute registered to this dialect. Note: The type of the
/// attribute need not be printed by this method as it is always printed by
/// the caller.
virtual void printAttribute(Attribute, raw_ostream &) const {
llvm_unreachable("dialect has no registered attribute printing hook");
}
/// Parse a type registered to this dialect.
virtual Type parseType(StringRef tyData, Location loc) const;
/// Print a type registered to this dialect.
virtual void printType(Type, raw_ostream &) const {
llvm_unreachable("dialect has no registered type printing hook");
}
/// Registered hooks for getting identifier aliases for symbols. The
/// identifier is used in place of the symbol when printing textual IR.
///
/// Hook for defining Attribute kind aliases. This will generate an alias for
/// all attributes of the given kind in the form : <alias>[0-9]+. These
/// aliases must not contain `.`.
virtual void getAttributeKindAliases(
SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) {}
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
/// end with a numeric digit([0-9]+).
virtual void getAttributeAliases(
SmallVectorImpl<std::pair<Attribute, StringRef>> &aliases) {}
/// Hook for defining Type aliases.
virtual void
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) {}
//===--------------------------------------------------------------------===//
// Verification Hooks
//===--------------------------------------------------------------------===//
/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
/// invoked from any operation containing a region.
virtual LogicalResult verifyRegionArgAttribute(Operation *,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute);
/// Verify an attribute from this dialect on the given operation. Returns
/// failure if the verification failed, success otherwise.
virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
return success();
}
protected:
/// The constructor takes a unique namespace for this dialect as well as the
/// context to bind to.
/// Note: The namespace must not contain '.' characters.
/// Note: All operations belonging to this dialect must have names starting
/// with the namespace followed by '.'.
/// Example:
/// - "tf" for the TensorFlow ops like "tf.add".
Dialect(StringRef name, MLIRContext *context);
/// This method is used by derived classes to add their operations to the set.
///
template <typename... Args> void addOperations() {
VariadicOperationAdder<Args...>::addToSet(*this);
}
// It would be nice to define this as variadic functions instead of a nested
// variadic type, but we can't do that: function template partial
// specialization is not allowed, and we can't define an overload set because
// we don't have any arguments of the types we are pushing around.
template <typename First, typename... Rest> class VariadicOperationAdder {
public:
static void addToSet(Dialect &dialect) {
dialect.addOperation(AbstractOperation::get<First>(dialect));
VariadicOperationAdder<Rest...>::addToSet(dialect);
}
};
template <typename First> class VariadicOperationAdder<First> {
public:
static void addToSet(Dialect &dialect) {
dialect.addOperation(AbstractOperation::get<First>(dialect));
}
};
void addOperation(AbstractOperation opInfo);
/// This method is used by derived classes to add their types to the set.
template <typename... Args> void addTypes() {
VariadicSymbolAdder<Args...>::addToSet(*this);
}
/// This method is used by derived classes to add their attributes to the set.
template <typename... Args> void addAttributes() {
VariadicSymbolAdder<Args...>::addToSet(*this);
}
// It would be nice to define this as variadic functions instead of a nested
// variadic type, but we can't do that: function template partial
// specialization is not allowed, and we can't define an overload set
// because we don't have any arguments of the types we are pushing around.
template <typename First, typename... Rest> struct VariadicSymbolAdder {
static void addToSet(Dialect &dialect) {
VariadicSymbolAdder<First>::addToSet(dialect);
VariadicSymbolAdder<Rest...>::addToSet(dialect);
}
};
template <typename First> struct VariadicSymbolAdder<First> {
static void addToSet(Dialect &dialect) {
dialect.addSymbol(First::getClassID());
}
};
// Enable support for unregistered operations.
void allowUnknownOperations(bool allow = true) { allowUnknownOps = allow; }
private:
// Register a symbol(e.g. type) with its given unique class identifier.
void addSymbol(const ClassID *const classID);
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;
/// Register this dialect object with the specified context. The context
/// takes ownership of the heap allocated dialect.
void registerDialect(MLIRContext *context);
/// The namespace of this dialect.
StringRef name;
/// This is the context that owns this Dialect object.
MLIRContext *context;
/// Flag that toggles if this dialect supports unregistered operations, i.e.
/// operations prefixed with the dialect namespace but not registered with
/// addOperation.
bool allowUnknownOps;
};
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
/// Registers a specific dialect creation function with the system, typically
/// used through the DialectRegistration template.
void registerDialectAllocator(const DialectAllocatorFunction &function);
/// Registers all dialects with the specified MLIRContext.
void registerAllDialects(MLIRContext *context);
/// Utility to register a dialect. Client can register their dialect with the
/// global registry by calling registerDialect<MyDialect>();
template <typename ConcreteDialect> void registerDialect() {
registerDialectAllocator([](MLIRContext *ctx) {
// Just allocate the dialect, the context takes ownership of it.
new ConcreteDialect(ctx);
});
}
/// DialectRegistration provides a global initialiser that registers a Dialect
/// allocation routine.
///
/// Usage:
///
/// // At namespace scope.
/// static DialectRegistration<MyDialect> Unused;
template <typename ConcreteDialect> struct DialectRegistration {
DialectRegistration() { registerDialect<ConcreteDialect>(); }
};
} // namespace mlir
#endif

View File

@ -0,0 +1,82 @@
//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines abstraction and registration mechanism for dialect hooks.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECT_HOOKS_H
#define MLIR_IR_DIALECT_HOOKS_H
#include "mlir/IR/Dialect.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
using DialectHooksSetter = std::function<void(MLIRContext *)>;
/// Dialect hooks allow external components to register their functions to
/// be called for specific tasks specialized per dialect, such as decoding
/// of opaque constants. To register concrete dialect hooks, one should
/// define a DialectHooks subclass and use it as a template
/// argument to DialectHooksRegistration. For example,
/// class MyHooks : public DialectHooks {...};
/// static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
/// The subclass should override DialectHook methods for supported hooks.
class DialectHooks {
public:
// Returns hook to constant fold an operation.
DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
// Returns hook to decode opaque constant tensor.
DialectConstantDecodeHook getDecodeHook() { return nullptr; }
// Returns hook to extract an element of an opaque constant tensor.
DialectExtractElementHook getExtractElementHook() { return nullptr; }
};
/// Registers a function that will set hooks in the registered dialects
/// based on information coming from DialectHooksRegistration.
void registerDialectHooksSetter(const DialectHooksSetter &function);
/// DialectHooksRegistration provides a global initialiser that registers
/// a dialect hooks setter routine.
/// Usage:
///
/// // At namespace scope.
/// static DialectHooksRegistration<MyHooks, MyDialect> unused;
template <typename ConcreteHooks> struct DialectHooksRegistration {
DialectHooksRegistration(StringRef dialectName) {
registerDialectHooksSetter([dialectName](MLIRContext *ctx) {
Dialect *dialect = ctx->getRegisteredDialect(dialectName);
if (!dialect) {
llvm::errs() << "error: cannot register hooks for unknown dialect '"
<< dialectName << "'\n";
abort();
}
// Set hooks.
ConcreteHooks hooks;
if (auto h = hooks.getConstantFoldHook())
dialect->constantFoldHook = h;
if (auto h = hooks.getDecodeHook())
dialect->decodeHook = h;
if (auto h = hooks.getExtractElementHook())
dialect->extractElementHook = h;
});
}
};
} // namespace mlir
#endif

View File

@ -0,0 +1,47 @@
//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file enumerates the different dialects that define custom classes
// within the attribute or type system.
//
//===----------------------------------------------------------------------===//
DEFINE_SYM_KIND_RANGE(STANDARD)
DEFINE_SYM_KIND_RANGE(TENSORFLOW_CONTROL)
DEFINE_SYM_KIND_RANGE(TENSORFLOW_EXECUTOR)
DEFINE_SYM_KIND_RANGE(TENSORFLOW)
DEFINE_SYM_KIND_RANGE(LLVM)
DEFINE_SYM_KIND_RANGE(QUANTIZATION)
DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect
// The following ranges are reserved for experimenting with MLIR dialects in a
// private context without having to register them here.
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_1)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_2)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_3)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_4)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_5)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_6)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_7)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_8)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_9)
#undef DEFINE_SYM_KIND_RANGE

View File

@ -0,0 +1,159 @@
//===- Function.h - MLIR Function Class -------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Functions are the basic unit of composition in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_FUNCTION_H
#define MLIR_IR_FUNCTION_H
#include "mlir/IR/Block.h"
#include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
//===--------------------------------------------------------------------===//
// Function Operation.
//===--------------------------------------------------------------------===//
/// FuncOp represents a function, or an operation containing one region that
/// forms a CFG(Control Flow Graph). The region of a function is not allowed to
/// implicitly capture global values, and all external references must use
/// Function arguments or attributes that establish a symbolic connection(e.g.
/// symbols referenced by name via a string attribute).
class FuncOp : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
OpTrait::IsIsolatedFromAbove, OpTrait::FunctionLike> {
public:
using Op::Op;
using Op::print;
static StringRef getOperationName() { return "func"; }
static FuncOp create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
static FuncOp create(Location location, StringRef name, FunctionType type,
llvm::iterator_range<dialect_attr_iterator> attrs);
static FuncOp create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs);
static void build(Builder *builder, OperationState *result, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs);
static void build(Builder *builder, OperationState *result, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs);
/// Operation hooks.
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
/// Returns the type of this function.
FunctionType getType() {
return getAttrOfType<TypeAttr>(getTypeAttrName())
.getValue()
.cast<FunctionType>();
}
/// Change the type of this function in place. This is an extremely dangerous
/// operation and it is up to the caller to ensure that this is legal for this
/// function, and to restore invariants:
/// - the entry block args must be updated to match the function params.
/// - the arguments attributes may need an update: if the new type has less
/// parameters we drop the extra attributes, if there are more parameters
/// they won't have any attributes.
void setType(FunctionType newType) {
setAttr(getTypeAttrName(), TypeAttr::get(newType));
}
/// Create a deep copy of this function and all of its blocks, remapping
/// any operands that use values outside of the function using the map that is
/// provided (leaving them alone if no entry is present). If the mapper
/// contains entries for function arguments, these arguments are not included
/// in the new function. Replaces references to cloned sub-values with the
/// corresponding value that is copied, and adds those mappings to the mapper.
FuncOp clone(BlockAndValueMapping &mapper);
FuncOp clone();
/// Clone the internal blocks and attributes from this function into dest. Any
/// cloned blocks are appended to the back of dest. This function asserts that
/// the attributes of the current function and dest are compatible.
void cloneInto(FuncOp dest, BlockAndValueMapping &mapper);
//===--------------------------------------------------------------------===//
// Body Handling
//===--------------------------------------------------------------------===//
/// Add an entry block to an empty function, and set up the block arguments
/// to match the signature of the function.
void addEntryBlock();
private:
// This trait needs access to `getNumFuncArguments` and `verifyType` hooks
// defined below.
friend class OpTrait::FunctionLike<FuncOp>;
/// Returns the number of arguments. This is a hook for OpTrait::FunctionLike.
unsigned getNumFuncArguments() { return getType().getInputs().size(); }
/// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
/// attribute is present and checks if it holds a function type. Ensures
/// getType and getNumFuncArguments can be called safely.
LogicalResult verifyType() {
auto type = getTypeAttr().getValue();
if (!type.isa<FunctionType>())
return emitOpError("requires '" + getTypeAttrName() +
"' attribute of function type");
return success();
}
};
} // end namespace mlir
namespace llvm {
// Functions hash just like pointers.
template <> struct DenseMapInfo<mlir::FuncOp> {
static mlir::FuncOp getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::FuncOp::getFromOpaquePointer(pointer);
}
static mlir::FuncOp getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::FuncOp::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(mlir::FuncOp val) {
return hash_value(val.getAsOpaquePointer());
}
static bool isEqual(mlir::FuncOp LHS, mlir::FuncOp RHS) { return LHS == RHS; }
};
/// Allow stealing the low bits of FuncOp.
template <> struct PointerLikeTypeTraits<mlir::FuncOp> {
public:
static inline void *getAsVoidPointer(mlir::FuncOp I) {
return const_cast<void *>(I.getAsOpaquePointer());
}
static inline mlir::FuncOp getFromVoidPointer(void *P) {
return mlir::FuncOp::getFromOpaquePointer(P);
}
enum { NumLowBitsAvailable = 3 };
};
} // namespace llvm
#endif // MLIR_IR_FUNCTION_H

View File

@ -0,0 +1,390 @@
//===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines support types for Operations that represent function-like
// constructs to use.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_FUNCTIONSUPPORT_H
#define MLIR_IR_FUNCTIONSUPPORT_H
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/SmallString.h"
namespace mlir {
namespace impl {
/// Return the name of the attribute used for function types.
inline StringRef getTypeAttrName() { return "type"; }
/// Return the name of the attribute used for function arguments.
inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
out.clear();
return ("arg" + Twine(arg)).toStringRef(out);
}
/// Returns the dictionary attribute corresponding to the argument at 'index'.
/// If there are no argument attributes at 'index', a null attribute is
/// returned.
inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) {
SmallString<8> nameOut;
return op->getAttrOfType<DictionaryAttr>(getArgAttrName(index, nameOut));
}
/// Return all of the attributes for the argument at 'index'.
inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
auto argDict = getArgAttrDict(op, index);
return argDict ? argDict.getValue() : llvm::None;
}
/// Callback type for `parseFunctionLikeOp`, the callback should produce the
/// type that will be associated with a function-like operation from lists of
/// function arguments and results.
using FuncTypeBuilder =
llvm::function_ref<Type(Builder &, ArrayRef<Type>, ArrayRef<Type>)>;
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
/// input and output types. If the builder returns a null type, `result` will
/// not contain the `type` attribute. The caller can then either add the type
/// or use op's verifier to report errors.
ParseResult parseFunctionLikeOp(OpAsmParser *parser, OperationState *result,
FuncTypeBuilder funcTypeBuilder);
/// Printer implementation for function-like operations. Accepts lists of
/// argument and result types to use while printing.
void printFunctionLikeOp(OpAsmPrinter *p, Operation *op,
ArrayRef<Type> argTypes, ArrayRef<Type> results);
} // namespace impl
namespace OpTrait {
/// This trait provides APIs for Ops that behave like functions. In particular:
/// - Ops can be used with SymbolTable in the parent Op and have names;
/// - Ops have a single region with multiple blocks that corresponds to the body
/// of the function;
/// - the absence of a region corresonds to an external function;
/// - arguments of the first block of the region are treated as function
/// arguments;
/// - they can have argument attributes that are stored in a dictionary
/// attribute on the Op itself.
/// This trait does *NOT* provide type support for the functions, meaning that
/// concrete Ops must handle the type of the declared or defined function.
/// `getTypeAttrName()` is a convenience function that returns the name of the
/// attribute that can be used to store the function type, but the trait makes
/// no assumption based on it.
///
/// - Concrete ops *must* define a member function `getNumFuncArguments()` that
/// returns the number of function arguments based exclusively on type (so that
/// it can be called on function declarations).
/// - To verify that the type respects op-specific invariants, concrete ops may
/// redefine the `verifyType()` hook that will be called after verifying the
/// presence of the `type` attribute and before any call to
/// `getNumFuncArguments` from the verifier.
template <typename ConcreteType>
class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
public:
/// Verify that all of the argument attributes are dialect attributes.
static LogicalResult verifyTrait(Operation *op);
//===--------------------------------------------------------------------===//
// Name Handling.
//===--------------------------------------------------------------------===//
/// Returns the name of this function.
StringRef getName() {
return this->getOperation()
->template getAttrOfType<StringAttr>(
mlir::SymbolTable::getSymbolAttrName())
.getValue();
}
/// Set the name of this function.
void setName(StringRef name) {
this->getOperation()->setAttr(
mlir::SymbolTable::getSymbolAttrName(),
StringAttr::get(name, this->getOperation()->getContext()));
}
//===--------------------------------------------------------------------===//
// Body Handling
//===--------------------------------------------------------------------===//
/// Returns true if this function is external, i.e. it has no body.
bool isExternal() { return empty(); }
Region &getBody() { return this->getOperation()->getRegion(0); }
/// Delete all blocks from this function.
void eraseBody() {
getBody().dropAllReferences();
getBody().getBlocks().clear();
}
/// This is the list of blocks in the function.
using RegionType = Region::RegionType;
RegionType &getBlocks() { return getBody().getBlocks(); }
// Iteration over the block in the function.
using iterator = RegionType::iterator;
using reverse_iterator = RegionType::reverse_iterator;
iterator begin() { return getBody().begin(); }
iterator end() { return getBody().end(); }
reverse_iterator rbegin() { return getBody().rbegin(); }
reverse_iterator rend() { return getBody().rend(); }
bool empty() { return getBody().empty(); }
void push_back(Block *block) { getBody().push_back(block); }
void push_front(Block *block) { getBody().push_front(block); }
Block &back() { return getBody().back(); }
Block &front() { return getBody().front(); }
//===--------------------------------------------------------------------===//
// Type Attribute Handling
//===--------------------------------------------------------------------===//
/// Return the name of the attribute used for function types.
static StringRef getTypeAttrName() { return ::mlir::impl::getTypeAttrName(); }
TypeAttr getTypeAttr() {
return this->getOperation()->template getAttrOfType<TypeAttr>(
getTypeAttrName());
}
bool isTypeAttrValid() {
auto typeAttr = getTypeAttr();
if (!typeAttr)
return false;
return typeAttr.getValue() != Type{};
}
//===--------------------------------------------------------------------===//
// Argument Handling
//===--------------------------------------------------------------------===//
unsigned getNumArguments() {
return static_cast<ConcreteType *>(this)->getNumFuncArguments();
}
/// Gets argument.
BlockArgument *getArgument(unsigned idx) {
return getBlocks().front().getArgument(idx);
}
// Supports non-const operand iteration.
using args_iterator = Block::args_iterator;
args_iterator args_begin() { return front().args_begin(); }
args_iterator args_end() { return front().args_end(); }
llvm::iterator_range<args_iterator> getArguments() {
return {args_begin(), args_end()};
}
//===--------------------------------------------------------------------===//
// Argument Attributes
//===--------------------------------------------------------------------===//
/// FunctionLike operations allow for attaching attributes to each of the
/// respective function arguments. These argument attributes are stored as
/// DictionaryAttrs in the main operation attribute dictionary. The name of
/// these entries is `arg` followed by the index of the argument. These
/// argument attribute dictionaries are optional, and will generally only
/// exist if they are non-empty.
/// Return all of the attributes for the argument at 'index'.
ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
return ::mlir::impl::getArgAttrs(this->getOperation(), index);
}
/// Return all argument attributes of this function.
void getAllArgAttrs(SmallVectorImpl<NamedAttributeList> &result) {
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
result.emplace_back(getArgAttrDict(i));
}
/// Return the specified attribute, if present, for the argument at 'index',
/// null otherwise.
Attribute getArgAttr(unsigned index, Identifier name) {
auto argDict = getArgAttrDict(index);
return argDict ? argDict.get(name) : nullptr;
}
Attribute getArgAttr(unsigned index, StringRef name) {
auto argDict = getArgAttrDict(index);
return argDict ? argDict.get(name) : nullptr;
}
template <typename AttrClass>
AttrClass getArgAttrOfType(unsigned index, Identifier name) {
return getArgAttr(index, name).template dyn_cast_or_null<AttrClass>();
}
template <typename AttrClass>
AttrClass getArgAttrOfType(unsigned index, StringRef name) {
return getArgAttr(index, name).template dyn_cast_or_null<AttrClass>();
}
/// Set the attributes held by the argument at 'index'.
void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes);
void setArgAttrs(unsigned index, NamedAttributeList attributes);
void setAllArgAttrs(ArrayRef<NamedAttributeList> attributes) {
assert(attributes.size() == getNumArguments());
for (unsigned i = 0, e = attributes.size(); i != e; ++i)
setArgAttrs(i, attributes[i]);
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setArgAttr(unsigned index, Identifier name, Attribute value);
void setArgAttr(unsigned index, StringRef name, Attribute value) {
setArgAttr(index, Identifier::get(name, this->getOperation()->getContext()),
value);
}
/// Remove the attribute 'name' from the argument at 'index'.
NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
Identifier name);
protected:
/// Returns the attribute entry name for the set of argument attributes at
/// index 'arg'.
static StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
return ::mlir::impl::getArgAttrName(arg, out);
}
/// Returns the dictionary attribute corresponding to the argument at 'index'.
/// If there are no argument attributes at 'index', a null attribute is
/// returned.
DictionaryAttr getArgAttrDict(unsigned index) {
assert(index < getNumArguments() && "invalid argument number");
return ::mlir::impl::getArgAttrDict(this->getOperation(), index);
}
/// Hook for concrete classes to verify that the type attribute respects
/// op-specific invariants. Default implementation always succeeds.
LogicalResult verifyType() { return success(); }
};
template <typename ConcreteType>
LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
MLIRContext *ctx = op->getContext();
auto funcOp = cast<ConcreteType>(op);
if (!funcOp.isTypeAttrValid())
return funcOp.emitOpError("requires a type attribute '")
<< getTypeAttrName() << '\'';
if (failed(funcOp.verifyType()))
return failure();
for (unsigned i = 0, e = funcOp.getNumArguments(); i != e; ++i) {
// Verify that all of the argument attributes are dialect attributes, i.e.
// that they contain a dialect prefix in their name. Call the dialect, if
// registered, to verify the attributes themselves.
for (auto attr : funcOp.getArgAttrs(i)) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError("arguments may only have dialect attributes");
auto dialectNamePair = attr.first.strref().split('.');
if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
/*argIndex=*/i, attr)))
return failure();
}
}
}
// Check that the op has exactly one region for the body.
if (op->getNumRegions() != 1)
return funcOp.emitOpError("expects one region");
// Check that if the entry block exists, it has the same number of arguments
// as the function-like operation.
if (funcOp.isExternal())
return success();
unsigned numArguments = funcOp.getNumArguments();
if (funcOp.front().getNumArguments() != numArguments)
return funcOp.emitOpError("entry block must have ")
<< numArguments << " arguments to match function signature";
return success();
}
//===----------------------------------------------------------------------===//
// Function Argument Attribute.
//===----------------------------------------------------------------------===//
/// Set the attributes held by the argument at 'index'.
template <typename ConcreteType>
void FunctionLike<ConcreteType>::setArgAttrs(
unsigned index, ArrayRef<NamedAttribute> attributes) {
assert(index < getNumArguments() && "invalid argument number");
SmallString<8> nameOut;
getArgAttrName(index, nameOut);
Operation *op = this->getOperation();
if (attributes.empty())
return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
}
template <typename ConcreteType>
void FunctionLike<ConcreteType>::setArgAttrs(unsigned index,
NamedAttributeList attributes) {
assert(index < getNumArguments() && "invalid argument number");
SmallString<8> nameOut;
if (auto newAttr = attributes.getDictionary())
return this->getOperation()->setAttr(getArgAttrName(index, nameOut),
newAttr);
static_cast<ConcreteType *>(this)->removeAttr(getArgAttrName(index, nameOut));
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
template <typename ConcreteType>
void FunctionLike<ConcreteType>::setArgAttr(unsigned index, Identifier name,
Attribute value) {
auto curAttr = getArgAttrDict(index);
NamedAttributeList attrList(curAttr);
attrList.set(name, value);
// If the attribute changed, then set the new arg attribute list.
if (curAttr != attrList.getDictionary())
setArgAttrs(index, attrList);
}
/// Remove the attribute 'name' from the argument at 'index'.
template <typename ConcreteType>
NamedAttributeList::RemoveResult
FunctionLike<ConcreteType>::removeArgAttr(unsigned index, Identifier name) {
// Build an attribute list and remove the attribute at 'name'.
NamedAttributeList attrList(getArgAttrDict(index));
auto result = attrList.remove(name);
// If the attribute was removed, then update the argument dictionary.
if (result == NamedAttributeList::RemoveResult::Removed)
setArgAttrs(index, attrList);
return result;
}
} // end namespace OpTrait
} // end namespace mlir
#endif // MLIR_IR_FUNCTIONSUPPORT_H

View File

@ -0,0 +1,143 @@
//===- Identifier.h - MLIR Identifier Class ---------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_IR_IDENTIFIER_H
#define MLIR_IR_IDENTIFIER_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
class MLIRContext;
/// This class represents a uniqued string owned by an MLIRContext. Strings
/// represented by this type cannot contain nul characters, and may not have a
/// zero length.
///
/// This is a POD type with pointer size, so it should be passed around by
/// value. The underlying data is owned by MLIRContext and is thus immortal for
/// almost all clients.
class Identifier {
public:
/// Return an identifier for the specified string.
static Identifier get(StringRef str, MLIRContext *context);
Identifier(const Identifier &) = default;
Identifier &operator=(const Identifier &other) = default;
/// Return a StringRef for the string.
StringRef strref() const { return StringRef(pointer, size()); }
/// Identifiers implicitly convert to StringRefs.
operator StringRef() const { return strref(); }
/// Return an std::string.
std::string str() const { return strref().str(); }
/// Return a null terminated C string.
const char *c_str() const { return pointer; }
/// Return a pointer to the start of the string data.
const char *data() const { return pointer; }
/// Return the number of bytes in this string.
unsigned size() const { return ::strlen(pointer); }
/// Return true if this identifier is the specified string.
bool is(StringRef string) const { return strref().equals(string); }
const char *begin() const { return pointer; }
const char *end() const { return pointer + size(); }
void print(raw_ostream &os) const;
void dump() const;
const void *getAsOpaquePointer() const {
return static_cast<const void *>(pointer);
}
static Identifier getFromOpaquePointer(const void *pointer) {
return Identifier((const char *)pointer);
}
private:
/// These are the bytes of the string, which is a nul terminated string.
const char *pointer;
explicit Identifier(const char *pointer) : pointer(pointer) {}
};
inline raw_ostream &operator<<(raw_ostream &os, Identifier identifier) {
identifier.print(os);
return os;
}
inline bool operator==(Identifier lhs, Identifier rhs) {
return lhs.data() == rhs.data();
}
inline bool operator!=(Identifier lhs, Identifier rhs) {
return lhs.data() != rhs.data();
}
inline bool operator==(Identifier lhs, StringRef rhs) { return lhs.is(rhs); }
inline bool operator!=(Identifier lhs, StringRef rhs) { return !lhs.is(rhs); }
inline bool operator==(StringRef lhs, Identifier rhs) { return rhs.is(lhs); }
inline bool operator!=(StringRef lhs, Identifier rhs) { return !rhs.is(lhs); }
// Make identifiers hashable.
inline llvm::hash_code hash_value(Identifier arg) {
return llvm::hash_value(arg.strref());
}
} // end namespace mlir
namespace llvm {
// Identifiers hash just like pointers, there is no need to hash the bytes.
template <>
struct DenseMapInfo<mlir::Identifier> {
static mlir::Identifier getEmptyKey() {
auto pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
return mlir::Identifier::getFromOpaquePointer(pointer);
}
static mlir::Identifier getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
return mlir::Identifier::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(mlir::Identifier Val) {
return DenseMapInfo<const void *>::getHashValue(Val.data());
}
static bool isEqual(mlir::Identifier LHS, mlir::Identifier RHS) {
return LHS == RHS;
}
};
/// The pointer inside of an identifier comes from a StringMap, so its alignment
/// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to
/// steal the low bits.
template <>
struct PointerLikeTypeTraits<mlir::Identifier> {
public:
static inline void *getAsVoidPointer(mlir::Identifier I) {
return const_cast<void *>(I.getAsOpaquePointer());
}
static inline mlir::Identifier getFromVoidPointer(void *P) {
return mlir::Identifier::getFromOpaquePointer(P);
}
enum { NumLowBitsAvailable = 2 };
};
} // end namespace llvm
#endif

View File

@ -0,0 +1,137 @@
//===- IntegerSet.h - MLIR Integer Set Class --------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Integer sets are sets of points from the integer lattice constrained by
// affine equality/inequality constraints. This class is meant to represent
// integer sets in the IR - for 'affine.if' operations and as attributes of
// other operations. It is typically expected to contain only a handful of
// affine constraints, and is immutable like an affine map. Integer sets are not
// unique'd - although affine expressions that make up its equalities and
// inequalites are themselves unique.
// This class is not meant for affine analysis and operations like set
// operations, emptiness checks, or other math operations for analysis and
// transformation. For the latter, use FlatAffineConstraints.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_INTEGER_SET_H
#define MLIR_IR_INTEGER_SET_H
#include "mlir/IR/AffineExpr.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
namespace detail {
struct IntegerSetStorage;
}
class MLIRContext;
/// An integer set representing a conjunction of one or more affine equalities
/// and inequalities. An integer set in the IR is immutable like the affine map,
/// but integer sets are not unique'd. The affine expressions that make up the
/// equalities and inequalities of an integer set are themselves unique and are
/// allocated by the bump pointer allocator.
class IntegerSet {
public:
using ImplType = detail::IntegerSetStorage;
IntegerSet() : set(nullptr) {}
explicit IntegerSet(ImplType *set) : set(set) {}
IntegerSet(const IntegerSet &other) : set(other.set) {}
IntegerSet &operator=(const IntegerSet &other) = default;
static IntegerSet get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> constraints,
ArrayRef<bool> eqFlags);
// Returns the canonical empty IntegerSet (i.e. a set with no integer points).
static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols,
MLIRContext *context) {
auto one = getAffineConstantExpr(1, context);
/* 1 == 0 */
return get(numDims, numSymbols, one, true);
}
/// Returns true if this is the canonical integer set.
bool isEmptyIntegerSet() const;
explicit operator bool() { return set; }
bool operator==(IntegerSet other) const { return set == other.set; }
unsigned getNumDims() const;
unsigned getNumSymbols() const;
unsigned getNumOperands() const;
unsigned getNumConstraints() const;
unsigned getNumEqualities() const;
unsigned getNumInequalities() const;
ArrayRef<AffineExpr> getConstraints() const;
AffineExpr getConstraint(unsigned idx) const;
/// Returns the equality bits, which specify whether each of the constraints
/// is an equality or inequality.
ArrayRef<bool> getEqFlags() const;
/// Returns true if the idx^th constraint is an equality, false if it is an
/// inequality.
bool isEq(unsigned idx) const;
MLIRContext *getContext() const;
void print(raw_ostream &os) const;
void dump() const;
friend ::llvm::hash_code hash_value(IntegerSet arg);
private:
ImplType *set;
/// Sets with constraints fewer than kUniquingThreshold are uniqued.
constexpr static unsigned kUniquingThreshold = 4;
};
// Make AffineExpr hashable.
inline ::llvm::hash_code hash_value(IntegerSet arg) {
return ::llvm::hash_value(arg.set);
}
} // end namespace mlir
namespace llvm {
// IntegerSet hash just like pointers
template <> struct DenseMapInfo<mlir::IntegerSet> {
static mlir::IntegerSet getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::IntegerSet(static_cast<mlir::IntegerSet::ImplType *>(pointer));
}
static mlir::IntegerSet getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::IntegerSet(static_cast<mlir::IntegerSet::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::IntegerSet val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::IntegerSet LHS, mlir::IntegerSet RHS) {
return LHS == RHS;
}
};
} // namespace llvm
#endif // MLIR_IR_INTEGER_SET_H

View File

@ -0,0 +1,270 @@
//===- Location.h - MLIR Location Classes -----------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// These classes provide the ability to relate MLIR objects back to source
// location position information.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_LOCATION_H
#define MLIR_IR_LOCATION_H
#include "mlir/IR/Attributes.h"
namespace mlir {
class Attribute;
class MLIRContext;
class Identifier;
namespace detail {
struct LocationStorage;
struct UnknownLocationStorage;
struct FileLineColLocationStorage;
struct NameLocationStorage;
struct CallSiteLocationStorage;
struct FusedLocationStorage;
} // namespace detail
/// Location objects represent source locations information in MLIR.
/// LocationAttr acts as the anchor for all Location based attributes.
class LocationAttr : public Attribute {
public:
using Attribute::Attribute;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Attribute attr) {
return attr.getKind() >= StandardAttributes::FIRST_LOCATION_ATTR &&
attr.getKind() <= StandardAttributes::LAST_LOCATION_ATTR;
}
};
/// This class defines the main interface for locations in MLIR and acts as a
/// non-nullable wrapper around a LocationAttr.
class Location {
public:
Location(LocationAttr loc) : impl(loc) {
assert(loc && "location should never be null.");
}
/// Access the impl location attribute.
operator LocationAttr() const { return impl; }
LocationAttr *operator->() const { return const_cast<LocationAttr *>(&impl); }
/// Type casting utilities on the underlying location.
template <typename U> bool isa() const { return impl.isa<U>(); }
template <typename U> U dyn_cast() const { return impl.dyn_cast<U>(); }
template <typename U> U cast() const { return impl.cast<U>(); }
/// Comparison operators.
bool operator==(Location rhs) const { return impl == rhs.impl; }
bool operator!=(Location rhs) const { return !(*this == rhs); }
/// Print the location.
void print(raw_ostream &os) const { impl.print(os); }
void dump() const { impl.dump(); }
friend ::llvm::hash_code hash_value(Location arg);
/// Methods for supporting PointerLikeTypeTraits.
const void *getAsOpaquePointer() const { return impl.getAsOpaquePointer(); }
static Location getFromOpaquePointer(const void *pointer) {
return LocationAttr(reinterpret_cast<const AttributeStorage *>(pointer));
}
protected:
/// The internal backing location attribute.
LocationAttr impl;
};
inline raw_ostream &operator<<(raw_ostream &os, const Location &loc) {
loc.print(os);
return os;
}
/// Represents a location as call site. "callee" is the concrete location
/// (Unknown/NameLocation/FileLineColLoc) and "caller" points to the caller's
/// location (another CallLocation or a concrete location). Multiple
/// CallSiteLocs can be chained to form a call stack.
class CallSiteLoc
: public Attribute::AttrBase<CallSiteLoc, LocationAttr,
detail::CallSiteLocationStorage> {
public:
using Base::Base;
/// Return a uniqued call location object.
static Location get(Location callee, Location caller, MLIRContext *context);
/// Return a call site location which represents a name reference in one line
/// or a stack of frames. The input frames are ordered from innermost to
/// outermost.
static Location get(Location name, ArrayRef<Location> frames,
MLIRContext *context);
/// The concrete location information this object presents.
Location getCallee() const;
/// The caller's location.
Location getCaller() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::CallSiteLocation;
}
};
/// Represents a location derived from a file/line/column location. The column
/// and line may be zero to represent unknown column and/or unknown line/column
/// information.
class FileLineColLoc
: public Attribute::AttrBase<FileLineColLoc, LocationAttr,
detail::FileLineColLocationStorage> {
public:
using Base::Base;
/// Return a uniqued FileLineCol location object.
static Location get(Identifier filename, unsigned line, unsigned column,
MLIRContext *context);
static Location get(StringRef filename, unsigned line, unsigned column,
MLIRContext *context);
StringRef getFilename() const;
unsigned getLine() const;
unsigned getColumn() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::FileLineColLocation;
}
};
/// Represents a value composed of multiple source constructs, with an optional
/// metadata attribute.
class FusedLoc : public Attribute::AttrBase<FusedLoc, LocationAttr,
detail::FusedLocationStorage> {
public:
using Base::Base;
/// Return a uniqued Fused Location object. The first location in the list
/// will get precedence during diagnostic emission, with the rest being
/// displayed as supplementary "fused from here" style notes.
static Location get(ArrayRef<Location> locs, Attribute metadata,
MLIRContext *context);
static Location get(ArrayRef<Location> locs, MLIRContext *context) {
return get(locs, Attribute(), context);
}
ArrayRef<Location> getLocations() const;
/// Returns the optional metadata attached to this fused location. Given that
/// it is optional, the return value may be a null node.
Attribute getMetadata() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::FusedLocation;
}
};
/// Represents an identity name attached to a child location.
class NameLoc : public Attribute::AttrBase<NameLoc, LocationAttr,
detail::NameLocationStorage> {
public:
using Base::Base;
/// Return a uniqued name location object. The child location must not be
/// another NameLoc.
static Location get(Identifier name, Location child, MLIRContext *context);
/// Return a uniqued name location object with an unknown child.
static Location get(Identifier name, MLIRContext *context);
/// Return the name identifier.
Identifier getName() const;
/// Return the child location.
Location getChildLoc() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::NameLocation;
}
};
/// Represents an unknown location. This is always a singleton for a given
/// MLIRContext.
class UnknownLoc : public Attribute::AttrBase<UnknownLoc, LocationAttr> {
public:
using Base::Base;
/// Get an instance of the UnknownLoc.
static Location get(MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::UnknownLocation;
}
};
// Make Location hashable.
inline ::llvm::hash_code hash_value(Location arg) {
return hash_value(arg.impl);
}
} // end namespace mlir
namespace llvm {
// Type hash just like pointers.
template <> struct DenseMapInfo<mlir::Location> {
static mlir::Location getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::Location::getFromOpaquePointer(pointer);
}
static mlir::Location getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::Location::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(mlir::Location val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::Location LHS, mlir::Location RHS) {
return LHS == RHS;
}
};
/// We align LocationStorage by 8, so allow LLVM to steal the low bits.
template <> struct PointerLikeTypeTraits<mlir::Location> {
public:
static inline void *getAsVoidPointer(mlir::Location I) {
return const_cast<void *>(I.getAsOpaquePointer());
}
static inline mlir::Location getFromVoidPointer(void *P) {
return mlir::Location::getFromOpaquePointer(P);
}
enum {
NumLowBitsAvailable =
PointerLikeTypeTraits<mlir::Attribute>::NumLowBitsAvailable
};
};
} // namespace llvm
#endif

View File

@ -0,0 +1,92 @@
//===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_IR_MLIRCONTEXT_H
#define MLIR_IR_MLIRCONTEXT_H
#include "mlir/Support/LLVM.h"
#include <functional>
#include <memory>
#include <vector>
namespace mlir {
class AbstractOperation;
class DiagnosticEngine;
class Dialect;
class InFlightDiagnostic;
class Location;
class MLIRContextImpl;
class StorageUniquer;
/// MLIRContext is the top-level object for a collection of MLIR modules. It
/// holds immortal uniqued objects like types, and the tables used to unique
/// them.
///
/// MLIRContext gets a redundant "MLIR" prefix because otherwise it ends up with
/// a very generic name ("Context") and because it is uncommon for clients to
/// interact with it.
///
class MLIRContext {
public:
explicit MLIRContext();
~MLIRContext();
/// Return information about all registered IR dialects.
std::vector<Dialect *> getRegisteredDialects();
/// Get a registered IR dialect with the given namespace. If an exact match is
/// not found, then return nullptr.
Dialect *getRegisteredDialect(StringRef name);
/// Get a registered IR dialect for the given derived dialect type. The
/// derived type must provide a static 'getDialectNamespace' method.
template <typename T> T *getRegisteredDialect() {
return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
}
/// Return information about all registered operations. This isn't very
/// efficient: typically you should ask the operations about their properties
/// directly.
std::vector<AbstractOperation *> getRegisteredOperations();
// This is effectively private given that only MLIRContext.cpp can see the
// MLIRContextImpl type.
MLIRContextImpl &getImpl() { return *impl; }
/// Returns the diagnostic engine for this context.
DiagnosticEngine &getDiagEngine();
/// Returns the storage uniquer used for creating affine constructs.
StorageUniquer &getAffineUniquer();
/// Returns the storage uniquer used for constructing type storage instances.
/// This should not be used directly.
StorageUniquer &getTypeUniquer();
/// Returns the storage uniquer used for constructing attribute storage
/// instances. This should not be used directly.
StorageUniquer &getAttributeUniquer();
private:
const std::unique_ptr<MLIRContextImpl> impl;
MLIRContext(const MLIRContext &) = delete;
void operator=(const MLIRContext &) = delete;
};
} // end namespace mlir
#endif // MLIR_IR_MLIRCONTEXT_H

View File

@ -0,0 +1,177 @@
//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file provides a simple and efficient mechanism for performing general
// tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
// include/llvm/IR/PatternMatch.h.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_MATCHERS_H
#define MLIR_MATCHERS_H
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include <type_traits>
namespace mlir {
namespace detail {
/// The matcher that matches a certain kind of Attribute and binds the value
/// inside the Attribute.
template <
typename AttrClass,
// Require AttrClass to be a derived class from Atribute and get its
// value type
typename ValueType =
typename std::enable_if<std::is_base_of<Attribute, AttrClass>::value,
AttrClass>::type::ValueType,
// Require the ValueType is not void
typename = typename std::enable_if<!std::is_void<ValueType>::value>::type>
struct attr_value_binder {
ValueType *bind_value;
/// Creates a matcher instance that binds the value to bv if match succeeds.
attr_value_binder(ValueType *bv) : bind_value(bv) {}
bool match(const Attribute &attr) {
if (auto intAttr = attr.dyn_cast<AttrClass>()) {
*bind_value = intAttr.getValue();
return true;
}
return false;
}
};
/// The matcher that matches a constant foldable operation that has no side
/// effect, no operands and produces a single result.
template <typename AttrT> struct constant_op_binder {
AttrT *bind_value;
/// Creates a matcher instance that binds the constant attribute value to
/// bind_value if match succeeds.
constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
bool match(Operation *op) {
if (op->getNumOperands() > 0 || op->getNumResults() != 1)
return false;
if (!op->hasNoSideEffect())
return false;
SmallVector<OpFoldResult, 1> foldedOp;
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
if ((*bind_value = attr.dyn_cast<AttrT>()))
return true;
}
}
return false;
}
};
/// The matcher that matches a constant scalar / vector splat / tensor splat
/// integer operation and binds the constant integer value.
struct constant_int_op_binder {
IntegerAttr::ValueType *bind_value;
/// Creates a matcher instance that binds the value to bv if match succeeds.
constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
bool match(Operation *op) {
Attribute attr;
if (!constant_op_binder<Attribute>(&attr).match(op))
return false;
auto type = op->getResult(0)->getType();
if (type.isa<IntegerType>()) {
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
}
if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
.match(splatAttr.getSplatValue());
}
}
return false;
}
};
// The matcher that matches a given target constant scalar / vector splat /
// tensor splat integer value.
template <int64_t TargetValue> struct constant_int_value_matcher {
bool match(Operation *op) {
APInt value;
return constant_int_op_binder(&value).match(op) && TargetValue == value;
}
};
/// The matcher that matches a certain kind of op.
template <typename OpClass> struct op_matcher {
bool match(Operation *op) { return isa<OpClass>(op); }
};
} // end namespace detail
/// Entry point for matching a pattern over a Value.
template <typename Pattern>
inline bool matchPattern(Value *value, const Pattern &pattern) {
// TODO: handle other cases
if (auto *op = value->getDefiningOp())
return const_cast<Pattern &>(pattern).match(op);
return false;
}
/// Entry point for matching a pattern over an Operation.
template <typename Pattern>
inline bool matchPattern(Operation *op, const Pattern &pattern) {
return const_cast<Pattern &>(pattern).match(op);
}
/// Matches a constant holding a scalar/vector/tensor integer (splat) and
/// writes the integer value to bind_value.
inline detail::constant_int_op_binder
m_ConstantInt(IntegerAttr::ValueType *bind_value) {
return detail::constant_int_op_binder(bind_value);
}
/// Matches a value from a constant foldable operation and writes the value to
/// bind_value.
template <typename AttrT>
inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
return detail::constant_op_binder<AttrT>(bind_value);
}
/// Matches a constant scalar / vector splat / tensor splat integer one.
inline detail::constant_int_value_matcher<1> m_One() {
return detail::constant_int_value_matcher<1>();
}
/// Matches the given OpClass.
template <typename OpClass> inline detail::op_matcher<OpClass> m_Op() {
return detail::op_matcher<OpClass>();
}
/// Matches a constant scalar / vector splat / tensor splat integer zero.
inline detail::constant_int_value_matcher<0> m_Zero() {
return detail::constant_int_value_matcher<0>();
}
} // end namespace mlir
#endif // MLIR_MATCHERS_H

View File

@ -0,0 +1,213 @@
//===- Module.h - MLIR Module Class -----------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Module is the top-level container for code in an MLIR program.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_MODULE_H
#define MLIR_IR_MODULE_H
#include "mlir/IR/SymbolTable.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// Module Operation.
//===----------------------------------------------------------------------===//
/// ModuleOp represents a module, or an operation containing one region with a
/// single block containing opaque operations. The region of a module is not
/// allowed to implicitly capture global values, and all external references
/// must use symbolic references via attributes(e.g. via a string name).
class ModuleOp : public Op<ModuleOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
OpTrait::IsIsolatedFromAbove, OpTrait::SymbolTable> {
public:
using Op::Op;
using Op::print;
static StringRef getOperationName() { return "module"; }
static void build(Builder *builder, OperationState *result);
/// Construct a module from the given location.
static ModuleOp create(Location loc);
/// Operation hooks.
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
/// Return body of this module.
Region &getBodyRegion();
Block *getBody();
/// Print the this module in the custom top-level form.
void print(raw_ostream &os);
void dump();
//===--------------------------------------------------------------------===//
// Body Management.
//===--------------------------------------------------------------------===//
/// Iteration over the operations in the module.
using iterator = Block::iterator;
iterator begin() { return getBody()->begin(); }
iterator end() { return getBody()->end(); }
Operation &front() { return *begin(); }
/// This returns a range of operations of the given type 'T' held within the
/// module.
template <typename T> llvm::iterator_range<Block::op_iterator<T>> getOps() {
return getBody()->getOps<T>();
}
/// Insert the operation into the back of the body, before the terminator.
void push_back(Operation *op) {
insert(Block::iterator(getBody()->getTerminator()), op);
}
/// Insert the operation at the given insertion point. Note: The operation is
/// never inserted after the terminator, even if the insertion point is end().
void insert(Operation *insertPt, Operation *op) {
insert(Block::iterator(insertPt), op);
}
void insert(Block::iterator insertPt, Operation *op) {
auto *body = getBody();
if (insertPt == body->end())
insertPt = Block::iterator(body->getTerminator());
body->getOperations().insert(insertPt, op);
}
};
/// The ModuleTerminatorOp is a special terminator operation for the body of a
/// ModuleOp, it has no semantic meaning beyond keeping the body of a ModuleOp
/// well-formed.
///
/// This operation does _not_ have a custom syntax. However, ModuleOp will omit
/// the terminator in their custom syntax for brevity.
class ModuleTerminatorOp
: public Op<ModuleTerminatorOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
OpTrait::IsTerminator> {
public:
using Op::Op;
static StringRef getOperationName() { return "module_terminator"; }
static void build(Builder *, OperationState *) {}
LogicalResult verify();
};
//===----------------------------------------------------------------------===//
// Module Manager.
//===----------------------------------------------------------------------===//
/// A class used to manage the symbols held by a module. This class handles
/// ensures that symbols inserted into a module have a unique name, and provides
/// efficent named lookup to held symbols.
class ModuleManager {
public:
ModuleManager(ModuleOp module) : module(module), symbolTable(module) {}
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names must never include the @ on them.
template <typename T, typename NameTy> T lookupSymbol(NameTy &&name) const {
return symbolTable.lookup<T>(name);
}
/// Insert a new symbol into the module, auto-renaming it as necessary.
void insert(Operation *op) {
symbolTable.insert(op);
module.push_back(op);
}
void insert(Block::iterator insertPt, Operation *op) {
symbolTable.insert(op);
module.insert(insertPt, op);
}
/// Remove the given symbol from the module symbol table and then erase it.
void erase(Operation *op) {
symbolTable.erase(op);
op->erase();
}
/// Return the internally held module.
ModuleOp getModule() const { return module; }
/// Return the context of the internal module.
MLIRContext *getContext() { return module.getContext(); }
private:
ModuleOp module;
SymbolTable symbolTable;
};
/// This class acts as an owning reference to a module, and will automatically
/// destroy the held module if valid.
class OwningModuleRef {
public:
OwningModuleRef(std::nullptr_t = nullptr) {}
OwningModuleRef(ModuleOp module) : module(module) {}
OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
~OwningModuleRef() {
if (module)
module.erase();
}
// Assign from another module reference.
OwningModuleRef &operator=(OwningModuleRef &&other) {
if (module)
module.erase();
module = other.release();
return *this;
}
/// Allow accessing the internal module.
ModuleOp get() const { return module; }
ModuleOp operator*() const { return module; }
ModuleOp *operator->() { return &module; }
explicit operator bool() const { return module; }
/// Release the referenced module.
ModuleOp release() {
ModuleOp released;
std::swap(released, module);
return released;
}
private:
ModuleOp module;
};
} // end namespace mlir
namespace llvm {
/// Allow stealing the low bits of ModuleOp.
template <> struct PointerLikeTypeTraits<mlir::ModuleOp> {
public:
static inline void *getAsVoidPointer(mlir::ModuleOp I) {
return const_cast<void *>(I.getAsOpaquePointer());
}
static inline mlir::ModuleOp getFromVoidPointer(void *P) {
return mlir::ModuleOp::getFromOpaquePointer(P);
}
enum { NumLowBitsAvailable = 3 };
};
} // end namespace llvm
#endif // MLIR_IR_MODULE_H

1416
third_party/mlir/include/mlir/IR/OpBase.td vendored Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,516 @@
//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This classes used by the implementation details of Op types.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_OPIMPLEMENTATION_H
#define MLIR_IR_OPIMPLEMENTATION_H
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
class Builder;
//===----------------------------------------------------------------------===//
// OpAsmPrinter
//===----------------------------------------------------------------------===//
/// This is a pure-virtual base class that exposes the asmprinter hooks
/// necessary to implement a custom print() method.
class OpAsmPrinter {
public:
OpAsmPrinter() {}
virtual ~OpAsmPrinter();
virtual raw_ostream &getStream() const = 0;
/// Print implementations for various things an operation contains.
virtual void printOperand(Value *value) = 0;
/// Print a comma separated list of operands.
template <typename ContainerType>
void printOperands(const ContainerType &container) {
printOperands(container.begin(), container.end());
}
/// Print a comma separated list of operands.
template <typename IteratorType>
void printOperands(IteratorType it, IteratorType end) {
if (it == end)
return;
printOperand(*it);
for (++it; it != end; ++it) {
getStream() << ", ";
printOperand(*it);
}
}
virtual void printType(Type type) = 0;
virtual void printAttribute(Attribute attr) = 0;
/// Print a successor, and use list, of a terminator operation given the
/// terminator and the successor index.
virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0;
/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
/// specific well known attributes, commonly used if the attribute value is
/// printed some other way (like as a fixed operand).
virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) = 0;
/// Print the entire operation with the default generic assembly form.
virtual void printGenericOp(Operation *op) = 0;
/// Prints a region.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
bool printBlockTerminators = true) = 0;
/// Prints an affine map of SSA ids, where SSA id names are used in place
/// of dims/symbols.
/// Operand values must come from single-result sources, and be valid
/// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
ArrayRef<Value *> operands) = 0;
/// Print an optional arrow followed by a type list.
void printOptionalArrowTypeList(ArrayRef<Type> types) {
if (types.empty())
return;
auto &os = getStream() << " -> ";
bool wrapped = types.size() != 1 || types[0].isa<FunctionType>();
if (wrapped)
os << '(';
interleaveComma(types, *this);
if (wrapped)
os << ')';
}
/// Print the complete type of an operation in functional form.
void printFunctionalType(Operation *op) {
auto &os = getStream();
os << "(";
interleaveComma(op->getNonSuccessorOperands(), os,
[&](Value *operand) { printType(operand->getType()); });
os << ") -> ";
if (op->getNumResults() == 1 &&
!op->getResult(0)->getType().isa<FunctionType>()) {
printType(op->getResult(0)->getType());
} else {
os << '(';
interleaveComma(op->getResultTypes(), os);
os << ')';
}
}
private:
OpAsmPrinter(const OpAsmPrinter &) = delete;
void operator=(const OpAsmPrinter &) = delete;
};
// Make the implementations convenient to use.
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) {
p.printOperand(&value);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
p.printType(type);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
p.printAttribute(attr);
return p;
}
// Support printing anything that isn't convertible to one of the above types,
// even if it isn't exactly one of them. For example, we want to print
// FunctionType with the Type version above, not have it match this.
template <typename T, typename std::enable_if<
!std::is_convertible<T &, Value &>::value &&
!std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, Attribute &>::value,
T>::type * = nullptr>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
p.getStream() << other;
return p;
}
//===----------------------------------------------------------------------===//
// OpAsmParser
//===----------------------------------------------------------------------===//
/// The OpAsmParser has methods for interacting with the asm parser: parsing
/// things from it, emitting errors etc. It has an intentionally high-level API
/// that is designed to reduce/constrain syntax innovation in individual
/// operations.
///
/// For example, consider an op like this:
///
/// %x = load %p[%1, %2] : memref<...>
///
/// The "%x = load" tokens are already parsed and therefore invisible to the
/// custom op parser. This can be supported by calling `parseOperandList` to
/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
/// parse the indices, then calling `parseColonTypeList` to parse the result
/// type.
///
class OpAsmParser {
public:
virtual ~OpAsmParser();
/// Emit a diagnostic at the specified location and return failure.
virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
const Twine &message = {}) = 0;
/// Return a builder which provides useful access to MLIRContext, global
/// objects like types and attributes.
virtual Builder &getBuilder() const = 0;
/// Get the location of the next token and store it into the argument. This
/// always succeeds.
virtual llvm::SMLoc getCurrentLocation() = 0;
ParseResult getCurrentLocation(llvm::SMLoc *loc) {
*loc = getCurrentLocation();
return success();
}
/// Return the location of the original name token.
virtual llvm::SMLoc getNameLoc() const = 0;
// These methods emit an error and return failure or success. This allows
// these to be chained together into a linear sequence of || expressions in
// many cases.
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
/// Parse a '->' token.
virtual ParseResult parseArrow() = 0;
/// Parse a '->' token if present
virtual ParseResult parseOptionalArrow() = 0;
/// Parse a `:` token.
virtual ParseResult parseColon() = 0;
/// Parse a `:` token if present.
virtual ParseResult parseOptionalColon() = 0;
/// Parse a `,` token.
virtual ParseResult parseComma() = 0;
/// Parse a `,` token if present.
virtual ParseResult parseOptionalComma() = 0;
/// Parse a `=` token.
virtual ParseResult parseEqual() = 0;
/// Parse a keyword.
ParseResult parseKeyword(const char *keyword, const Twine &msg = "") {
if (parseOptionalKeyword(keyword))
return emitError(getNameLoc(), "expected '") << keyword << "'" << msg;
return success();
}
/// Parse a keyword if present.
virtual ParseResult parseOptionalKeyword(const char *keyword) = 0;
/// Parse a `(` token.
virtual ParseResult parseLParen() = 0;
/// Parse a `(` token if present.
virtual ParseResult parseOptionalLParen() = 0;
/// Parse a `)` token.
virtual ParseResult parseRParen() = 0;
/// Parse a `)` token if present.
virtual ParseResult parseOptionalRParen() = 0;
/// Parse a `[` token.
virtual ParseResult parseLSquare() = 0;
/// Parse a `[` token if present.
virtual ParseResult parseOptionalLSquare() = 0;
/// Parse a `]` token.
virtual ParseResult parseRSquare() = 0;
/// Parse a `]` token if present.
virtual ParseResult parseOptionalRSquare() = 0;
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name.
ParseResult parseAttribute(Attribute &result, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) {
return parseAttribute(result, Type(), attrName, attrs);
}
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
/// name.
virtual ParseResult
parseAttribute(Attribute &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of attribute.
Attribute attr;
if (parseAttribute(attr, type, attrName, attrs))
return failure();
// Check for the right kind of attribute.
result = attr.dyn_cast<AttrType>();
if (!result)
return emitError(loc, "invalid kind of constant specified");
return success();
}
/// Parse a named dictionary into 'result' if it is present.
virtual ParseResult
parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) = 0;
//===--------------------------------------------------------------------===//
// Operand Parsing
//===--------------------------------------------------------------------===//
/// This is the representation of an operand reference.
struct OperandType {
llvm::SMLoc location; // Location of the token.
StringRef name; // Value name, e.g. %42 or %abc
unsigned number; // Number, e.g. 12 for an operand like %xyz#12
};
/// Parse a single operand.
virtual ParseResult parseOperand(OperandType &result) = 0;
/// These are the supported delimiters around operand lists and region
/// argument lists, used by parseOperandList and parseRegionArgumentList.
enum class Delimiter {
/// Zero or more operands with no delimiters.
None,
/// Parens surrounding zero or more operands.
Paren,
/// Square brackets surrounding zero or more operands.
Square,
/// Parens supporting zero or more operands, or nothing.
OptionalParen,
/// Square brackets supporting zero or more ops, or nothing.
OptionalSquare,
};
/// Parse zero or more SSA comma-separated operand references with a specified
/// surrounding delimiter, and an optional required operand count.
virtual ParseResult
parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
Delimiter delimiter) {
return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
}
/// Parse zero or more trailing SSA comma-separated trailing operand
/// references with a specified surrounding delimiter, and an optional
/// required operand count. A leading comma is expected before the operands.
virtual ParseResult
parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
Delimiter delimiter) {
return parseTrailingOperandList(result, /*requiredOperandCount=*/-1,
delimiter);
}
/// Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<Value *> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error on failure, or
/// appending the results to the list on success. This method should be used
/// when all operands have the same type.
ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type,
SmallVectorImpl<Value *> &result) {
for (auto elt : operands)
if (resolveOperand(elt, type, result))
return failure();
return success();
}
/// Resolve a list of operands and a list of operand types to SSA values,
/// emitting an error and returning failure, or appending the results
/// to the list on success.
ParseResult resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type> types, llvm::SMLoc loc,
SmallVectorImpl<Value *> &result) {
if (operands.size() != types.size())
return emitError(loc)
<< operands.size() << " operands present, but expected "
<< types.size();
for (unsigned i = 0, e = operands.size(); i != e; ++i)
if (resolveOperand(operands[i], types[i], result))
return failure();
return success();
}
/// Parses an affine map attribute where dims and symbols are SSA operands.
/// Operand values must come from single-result sources, and be valid
/// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
virtual ParseResult
parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
/// Parses a region. Any parsed blocks are appended to "region" and must be
/// moved to the op regions after the op is created. The first block of the
/// region takes "arguments" of types "argTypes".
virtual ParseResult parseRegion(Region &region,
ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes) = 0;
/// Parses a region if present.
virtual ParseResult parseOptionalRegion(Region &region,
ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes) = 0;
/// Parse a region argument. Region arguments define new values; so this also
/// checks if values with the same name have not been defined yet.
virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
/// Parse zero or more region arguments with a specified surrounding
/// delimiter, and an optional required argument count. Region arguments
/// define new values; so this also checks if values with the same names have
/// not been defined yet.
virtual ParseResult
parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
virtual ParseResult
parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
Delimiter delimiter) {
return parseRegionArgumentList(result, /*requiredOperandCount=*/-1,
delimiter);
}
/// Parse a region argument if present.
virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0;
//===--------------------------------------------------------------------===//
// Successor Parsing
//===--------------------------------------------------------------------===//
/// Parse a single operation successor and its operand list.
virtual ParseResult
parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands) = 0;
//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;
/// Parse an optional arrow followed by a type list.
virtual ParseResult
parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a colon followed by a type.
virtual ParseResult parseColonType(Type &result) = 0;
/// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
template <typename TypeType> ParseResult parseColonType(TypeType &result) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of type.
Type type;
if (parseColonType(type))
return failure();
// Check for the right kind of attribute.
result = type.dyn_cast<TypeType>();
if (!result)
return emitError(loc, "invalid kind of type specified");
return success();
}
/// Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse an optional colon followed by a type list, which if present must
/// have at least one type.
virtual ParseResult
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a keyword followed by a type.
ParseResult parseKeywordType(const char *keyword, Type &result) {
return failure(parseKeyword(keyword) || parseType(result));
}
/// Add the specified type to the end of the specified type list and return
/// success. This is a helper designed to allow parse methods to be simple
/// and chain through || operators.
ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
result.push_back(type);
return success();
}
/// Add the specified types to the end of the specified type list and return
/// success. This is a helper designed to allow parse methods to be simple
/// and chain through || operators.
ParseResult addTypesToList(ArrayRef<Type> types,
SmallVectorImpl<Type> &result) {
result.append(types.begin(), types.end());
return success();
}
private:
/// Parse either an operand list or a region argument list depending on
/// whether isOperandList is true.
ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
bool isOperandList,
int requiredOperandCount,
Delimiter delimiter);
};
} // end namespace mlir
#endif

View File

@ -0,0 +1,710 @@
//===- Operation.h - MLIR Operation Class -----------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the Operation class.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_OPERATION_H
#define MLIR_IR_OPERATION_H
#include "mlir/IR/Block.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Region.h"
#include "llvm/ADT/Twine.h"
namespace mlir {
class BlockAndValueMapping;
class Location;
class MLIRContext;
class OperandIterator;
class OperandTypeIterator;
struct OperationState;
class ResultIterator;
class ResultTypeIterator;
/// Terminator operations can have Block operands to represent successors.
using BlockOperand = IROperandImpl<Block>;
/// Operation is a basic unit of execution within a function. Operations can
/// be nested within other operations effectively forming a tree. Child
/// operations are organized into operation blocks represented by a 'Block'
/// class.
class Operation final
: public llvm::ilist_node_with_parent<Operation, Block>,
private llvm::TrailingObjects<Operation, OpResult, BlockOperand, unsigned,
Region, detail::OperandStorage> {
public:
/// Create a new Operation with the specific fields.
static Operation *create(Location location, OperationName name,
ArrayRef<Value *> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
ArrayRef<Block *> successors, unsigned numRegions,
bool resizableOperandList, MLIRContext *context);
/// Overload of create that takes an existing NamedAttributeList to avoid
/// unnecessarily uniquing a list of attributes.
static Operation *create(Location location, OperationName name,
ArrayRef<Value *> operands,
ArrayRef<Type> resultTypes,
const NamedAttributeList &attributes,
ArrayRef<Block *> successors, unsigned numRegions,
bool resizableOperandList, MLIRContext *context);
/// Create a new Operation from the fields stored in `state`.
static Operation *create(const OperationState &state);
/// The name of an operation is the key identifier for it.
OperationName getName() { return name; }
/// If this operation has a registered operation description, return it.
/// Otherwise return null.
const AbstractOperation *getAbstractOperation() {
return getName().getAbstractOperation();
}
/// Returns true if this operation has a registered operation description,
/// otherwise false.
bool isRegistered() { return getAbstractOperation(); }
/// Remove this operation from its parent block and delete it.
void erase();
/// Create a deep copy of this operation, remapping any operands that use
/// values outside of the operation using the map that is provided (leaving
/// them alone if no entry is present). Replaces references to cloned
/// sub-operations to the corresponding operation that is copied, and adds
/// those mappings to the map.
Operation *clone(BlockAndValueMapping &mapper);
Operation *clone();
/// Create a deep copy of this operation but keep the operation regions empty.
/// Operands are remapped using `mapper` (if present), and `mapper` is updated
/// to contain the results.
Operation *cloneWithoutRegions(BlockAndValueMapping &mapper);
Operation *cloneWithoutRegions();
/// Returns the operation block that contains this operation.
Block *getBlock() { return block; }
/// Return the context this operation is associated with.
MLIRContext *getContext();
/// Return the dialact this operation is associated with, or nullptr if the
/// associated dialect is not registered.
Dialect *getDialect();
/// The source location the operation was defined or derived from.
Location getLoc() { return location; }
/// Set the source location the operation was defined or derived from.
void setLoc(Location loc) { location = loc; }
/// Returns the region to which the instruction belongs, which can be a
/// function body region or a region that belongs to another operation.
/// Returns nullptr if the instruction is unlinked.
Region *getContainingRegion() const;
/// Returns the closest surrounding operation that contains this operation
/// or nullptr if this is a top-level operation.
Operation *getParentOp();
/// Return the closest surrounding parent operation that is of type 'OpTy'.
template <typename OpTy> OpTy getParentOfType() {
auto *op = this;
while ((op = op->getParentOp()))
if (auto parentOp = llvm::dyn_cast<OpTy>(op))
return parentOp;
return OpTy();
}
/// Replace any uses of 'from' with 'to' within this operation.
void replaceUsesOfWith(Value *from, Value *to);
/// Destroys this operation and its subclass data.
void destroy();
/// This drops all operand uses from this operation, which is an essential
/// step in breaking cyclic dependences between references when they are to
/// be deleted.
void dropAllReferences();
/// Drop uses of all values defined by this operation or its nested regions.
void dropAllDefinedValueUses();
/// Unlink this operation from its current block and insert it right before
/// `existingInst` which may be in the same or another block in the same
/// function.
void moveBefore(Operation *existingInst);
/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
void moveBefore(Block *block, llvm::iplist<Operation>::iterator iterator);
/// Given an operation 'other' that is within the same parent block, return
/// whether the current operation is before 'other' in the operation list
/// of the parent block.
/// Note: This function has an average complexity of O(1), but worst case may
/// take O(N) where N is the number of operations within the parent block.
bool isBeforeInBlock(Operation *other);
void print(raw_ostream &os);
void dump();
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
/// Returns if the operation has a resizable operation list, i.e. operands can
/// be added.
bool hasResizableOperandsList() { return getOperandStorage().isResizable(); }
/// Replace the current operands of this operation with the ones provided in
/// 'operands'. If the operands list is not resizable, the size of 'operands'
/// must be less than or equal to the current number of operands.
void setOperands(ArrayRef<Value *> operands) {
getOperandStorage().setOperands(this, operands);
}
unsigned getNumOperands() { return getOperandStorage().size(); }
Value *getOperand(unsigned idx) { return getOpOperand(idx).get(); }
void setOperand(unsigned idx, Value *value) {
return getOpOperand(idx).set(value);
}
// Support operand iteration.
using operand_iterator = OperandIterator;
using operand_range = llvm::iterator_range<operand_iterator>;
operand_iterator operand_begin();
operand_iterator operand_end();
/// Returns an iterator on the underlying Value's (Value *).
operand_range getOperands();
/// Erase the operand at position `idx`.
void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); }
MutableArrayRef<OpOperand> getOpOperands() {
return getOperandStorage().getOperands();
}
OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; }
// Support operand type iteration.
using operand_type_iterator = OperandTypeIterator;
using operand_type_range = llvm::iterator_range<operand_type_iterator>;
operand_type_iterator operand_type_begin();
operand_type_iterator operand_type_end();
operand_type_range getOperandTypes();
//===--------------------------------------------------------------------===//
// Results
//===--------------------------------------------------------------------===//
/// Return true if there are no users of any results of this operation.
bool use_empty();
unsigned getNumResults() { return numResults; }
Value *getResult(unsigned idx) { return &getOpResult(idx); }
// Support result iteration.
using result_iterator = ResultIterator;
using result_range = llvm::iterator_range<result_iterator>;
result_iterator result_begin();
result_iterator result_end();
result_range getResults();
MutableArrayRef<OpResult> getOpResults() {
return {getTrailingObjects<OpResult>(), numResults};
}
OpResult &getOpResult(unsigned idx) { return getOpResults()[idx]; }
// Support result type iteration.
using result_type_iterator = ResultTypeIterator;
using result_type_range = llvm::iterator_range<result_type_iterator>;
result_type_iterator result_type_begin();
result_type_iterator result_type_end();
result_type_range getResultTypes();
//===--------------------------------------------------------------------===//
// Attributes
//===--------------------------------------------------------------------===//
// Operations may optionally carry a list of attributes that associate
// constants to names. Attributes may be dynamically added and removed over
// the lifetime of an operation.
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
/// Return the internal attribute list on this operation.
NamedAttributeList &getAttrList() { return attrs; }
/// Set the attribute list on this operation.
/// Using a NamedAttributeList is more efficient as it does not require new
/// uniquing in the MLIRContext.
void setAttrs(NamedAttributeList newAttrs) { attrs = newAttrs; }
/// Return the specified attribute if present, null otherwise.
Attribute getAttr(Identifier name) { return attrs.get(name); }
Attribute getAttr(StringRef name) { return attrs.get(name); }
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) {
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
void setAttr(StringRef name, Attribute value) {
setAttr(Identifier::get(name, getContext()), value);
}
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
return attrs.remove(name);
}
/// A utility iterator that filters out non-dialect attributes.
class dialect_attr_iterator
: public llvm::filter_iterator<ArrayRef<NamedAttribute>::iterator,
bool (*)(NamedAttribute)> {
static bool filter(NamedAttribute attr) {
// Dialect attributes are prefixed by the dialect name, like operations.
return attr.first.strref().count('.');
}
explicit dialect_attr_iterator(ArrayRef<NamedAttribute>::iterator it,
ArrayRef<NamedAttribute>::iterator end)
: llvm::filter_iterator<ArrayRef<NamedAttribute>::iterator,
bool (*)(NamedAttribute)>(it, end, &filter) {}
// Allow access to the constructor.
friend Operation;
};
using dialect_attr_range = llvm::iterator_range<dialect_attr_iterator>;
/// Return a range corresponding to the dialect attributes for this operation.
dialect_attr_range getDialectAttrs() {
auto attrs = getAttrs();
return {dialect_attr_iterator(attrs.begin(), attrs.end()),
dialect_attr_iterator(attrs.end(), attrs.end())};
}
dialect_attr_iterator dialect_attr_begin() {
auto attrs = getAttrs();
return dialect_attr_iterator(attrs.begin(), attrs.end());
}
dialect_attr_iterator dialect_attr_end() {
auto attrs = getAttrs();
return dialect_attr_iterator(attrs.end(), attrs.end());
}
/// Set the dialect attributes for this operation, and preserve all dependent.
template <typename DialectAttrT>
void setDialectAttrs(DialectAttrT &&dialectAttrs) {
SmallVector<NamedAttribute, 16> attrs;
attrs.assign(std::begin(dialectAttrs), std::end(dialectAttrs));
for (auto attr : getAttrs())
if (!attr.first.strref().count('.'))
attrs.push_back(attr);
setAttrs(llvm::makeArrayRef(attrs));
}
//===--------------------------------------------------------------------===//
// Blocks
//===--------------------------------------------------------------------===//
/// Returns the number of regions held by this operation.
unsigned getNumRegions() { return numRegions; }
/// Returns the regions held by this operation.
MutableArrayRef<Region> getRegions() {
auto *regions = getTrailingObjects<Region>();
return {regions, numRegions};
}
/// Returns the region held by this operation at position 'index'.
Region &getRegion(unsigned index) {
assert(index < numRegions && "invalid region index");
return getRegions()[index];
}
//===--------------------------------------------------------------------===//
// Terminators
//===--------------------------------------------------------------------===//
MutableArrayRef<BlockOperand> getBlockOperands() {
return {getTrailingObjects<BlockOperand>(), numSuccs};
}
/// Return the operands of this operation that are *not* successor arguments.
operand_range getNonSuccessorOperands();
operand_range getSuccessorOperands(unsigned index);
Value *getSuccessorOperand(unsigned succIndex, unsigned opIndex) {
assert(!isKnownNonTerminator() && "only terminators may have successors");
assert(opIndex < getNumSuccessorOperands(succIndex));
return getOperand(getSuccessorOperandIndex(succIndex) + opIndex);
}
bool hasSuccessors() { return numSuccs != 0; }
unsigned getNumSuccessors() { return numSuccs; }
unsigned getNumSuccessorOperands(unsigned index) {
assert(!isKnownNonTerminator() && "only terminators may have successors");
assert(index < getNumSuccessors());
return getTrailingObjects<unsigned>()[index];
}
Block *getSuccessor(unsigned index) {
assert(index < getNumSuccessors());
return getBlockOperands()[index].get();
}
void setSuccessor(Block *block, unsigned index);
/// Erase a specific operand from the operand list of the successor at
/// 'index'.
void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
assert(succIndex < getNumSuccessors());
assert(opIndex < getNumSuccessorOperands(succIndex));
getOperandStorage().eraseOperand(getSuccessorOperandIndex(succIndex) +
opIndex);
--getTrailingObjects<unsigned>()[succIndex];
}
/// Get the index of the first operand of the successor at the provided
/// index.
unsigned getSuccessorOperandIndex(unsigned index);
//===--------------------------------------------------------------------===//
// Accessors for various properties of operations
//===--------------------------------------------------------------------===//
/// Returns whether the operation is commutative.
bool isCommutative() {
if (auto *absOp = getAbstractOperation())
return absOp->hasProperty(OperationProperty::Commutative);
return false;
}
/// Returns whether the operation has side-effects.
bool hasNoSideEffect() {
if (auto *absOp = getAbstractOperation())
return absOp->hasProperty(OperationProperty::NoSideEffect);
return false;
}
/// Represents the status of whether an operation is a terminator. We
/// represent an 'unknown' status because we want to support unregistered
/// terminators.
enum class TerminatorStatus { Terminator, NonTerminator, Unknown };
/// Returns the status of whether this operation is a terminator or not.
TerminatorStatus getTerminatorStatus() {
if (auto *absOp = getAbstractOperation()) {
return absOp->hasProperty(OperationProperty::Terminator)
? TerminatorStatus::Terminator
: TerminatorStatus::NonTerminator;
}
return TerminatorStatus::Unknown;
}
/// Returns if the operation is known to be a terminator.
bool isKnownTerminator() {
return getTerminatorStatus() == TerminatorStatus::Terminator;
}
/// Returns if the operation is known to *not* be a terminator.
bool isKnownNonTerminator() {
return getTerminatorStatus() == TerminatorStatus::NonTerminator;
}
/// Returns if the operation is known to be completely isolated from enclosing
/// regions, i.e. no internal regions reference values defined above this
/// operation.
bool isKnownIsolatedFromAbove() {
if (auto *absOp = getAbstractOperation())
return absOp->hasProperty(OperationProperty::IsolatedFromAbove);
return false;
}
/// Attempt to fold this operation with the specified constant operand values
/// - the elements in "operands" will correspond directly to the operands of
/// the operation, but may be null if non-constant. If folding is successful,
/// this fills in the `results` vector. If not, `results` is unspecified.
LogicalResult fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results);
/// Returns if the operation was registered with a particular trait, e.g.
/// hasTrait<OperandsAreIntegerLike>().
template <template <typename T> class Trait> bool hasTrait() {
auto *absOp = getAbstractOperation();
return absOp ? absOp->hasTrait<Trait>() : false;
}
//===--------------------------------------------------------------------===//
// Operation Walkers
//===--------------------------------------------------------------------===//
/// Walk this operation in postorder, calling the callback for each operation
/// including this one.
void walk(llvm::function_ref<void(Operation *)> callback);
/// Specialization of walk to only visit operations of 'T'.
template <typename T> void walk(llvm::function_ref<void(T)> callback) {
walk([&](Operation *op) {
if (auto derivedOp = dyn_cast<T>(op))
callback(derivedOp);
});
}
//===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
/// Emit an error with the op name prefixed, like "'dim' op " which is
/// convenient for verifiers.
InFlightDiagnostic emitOpError(const Twine &message = {});
/// Emit an error about fatal conditions with this operation, reporting up to
/// any diagnostic handlers that may be listening.
InFlightDiagnostic emitError(const Twine &message = {});
/// Emit a warning about this operation, reporting up to any diagnostic
/// handlers that may be listening.
InFlightDiagnostic emitWarning(const Twine &message = {});
/// Emit a remark about this operation, reporting up to any diagnostic
/// handlers that may be listening.
InFlightDiagnostic emitRemark(const Twine &message = {});
private:
Operation(Location location, OperationName name, unsigned numResults,
unsigned numSuccessors, unsigned numRegions,
const NamedAttributeList &attributes, MLIRContext *context);
// Operations are deleted through the destroy() member because they are
// allocated with malloc.
~Operation();
/// Returns the operand storage object.
detail::OperandStorage &getOperandStorage() {
return *getTrailingObjects<detail::OperandStorage>();
}
/// Provide a 'getParent' method for ilist_node_with_parent methods.
/// We mark it as const function because ilist_node_with_parent specifically
/// requires a 'getParent() const' method. Once ilist_node removes this
/// constraint, we should drop the const to fit the rest of the MLIR const
/// model.
Block *getParent() const { return block; }
/// The operation block that containts this operation.
Block *block = nullptr;
/// This holds information about the source location the operation was defined
/// or derived from.
Location location;
/// Relative order of this operation in its parent block. Used for
/// O(1) local dominance checks between operations.
mutable unsigned orderIndex = 0;
const unsigned numResults, numSuccs, numRegions;
/// This holds the name of the operation.
OperationName name;
/// This holds general named attributes for the operation.
NamedAttributeList attrs;
// allow ilist_traits access to 'block' field.
friend struct llvm::ilist_traits<Operation>;
// allow block to access the 'orderIndex' field.
friend class Block;
// allow ilist_node_with_parent to access the 'getParent' method.
friend class llvm::ilist_node_with_parent<Operation, Block>;
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<Operation, OpResult, BlockOperand, unsigned,
Region, detail::OperandStorage>;
size_t numTrailingObjects(OverloadToken<OpResult>) const {
return numResults;
}
size_t numTrailingObjects(OverloadToken<BlockOperand>) const {
return numSuccs;
}
size_t numTrailingObjects(OverloadToken<Region>) const { return numRegions; }
size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; }
};
inline raw_ostream &operator<<(raw_ostream &os, Operation &op) {
op.print(os);
return os;
}
/// This class implements the const/non-const operand iterators for the
/// Operation class in terms of getOperand(idx).
class OperandIterator final
: public indexed_accessor_iterator<OperandIterator, Operation *, Value *,
Value *, Value *> {
public:
/// Initializes the operand iterator to the specified operand index.
OperandIterator(Operation *object, unsigned index)
: indexed_accessor_iterator<OperandIterator, Operation *, Value *,
Value *, Value *>(object, index) {}
Value *operator*() const { return this->object->getOperand(this->index); }
};
/// This class implements the operand type iterators for the Operation
/// class in terms of operand_iterator->getType().
class OperandTypeIterator final
: public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
static Type unwrap(Value *value) { return value->getType(); }
public:
using reference = Type;
/// Initializes the operand type iterator to the specified operand iterator.
OperandTypeIterator(OperandIterator it)
: llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {
}
};
// Implement the inline operand iterator methods.
inline auto Operation::operand_begin() -> operand_iterator {
return operand_iterator(this, 0);
}
inline auto Operation::operand_end() -> operand_iterator {
return operand_iterator(this, getNumOperands());
}
inline auto Operation::getOperands() -> operand_range {
return {operand_begin(), operand_end()};
}
inline auto Operation::operand_type_begin() -> operand_type_iterator {
return operand_type_iterator(operand_begin());
}
inline auto Operation::operand_type_end() -> operand_type_iterator {
return operand_type_iterator(operand_end());
}
inline auto Operation::getOperandTypes() -> operand_type_range {
return {operand_type_begin(), operand_type_end()};
}
/// This class implements the result iterators for the Operation class
/// in terms of getResult(idx).
class ResultIterator final
: public indexed_accessor_iterator<ResultIterator, Operation *, Value *,
Value *, Value *> {
public:
/// Initializes the result iterator to the specified index.
ResultIterator(Operation *object, unsigned index)
: indexed_accessor_iterator<ResultIterator, Operation *, Value *, Value *,
Value *>(object, index) {}
Value *operator*() const { return this->object->getResult(this->index); }
};
/// This class implements the result type iterators for the Operation
/// class in terms of result_iterator->getType().
class ResultTypeIterator final
: public llvm::mapped_iterator<ResultIterator, Type (*)(Value *)> {
static Type unwrap(Value *value) { return value->getType(); }
public:
using reference = Type;
/// Initializes the result type iterator to the specified result iterator.
ResultTypeIterator(ResultIterator it)
: llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
};
// Implement the inline result iterator methods.
inline auto Operation::result_begin() -> result_iterator {
return result_iterator(this, 0);
}
inline auto Operation::result_end() -> result_iterator {
return result_iterator(this, getNumResults());
}
inline auto Operation::getResults() -> llvm::iterator_range<result_iterator> {
return {result_begin(), result_end()};
}
inline auto Operation::result_type_begin() -> result_type_iterator {
return result_type_iterator(result_begin());
}
inline auto Operation::result_type_end() -> result_type_iterator {
return result_type_iterator(result_end());
}
inline auto Operation::getResultTypes() -> result_type_range {
return {result_type_begin(), result_type_end()};
}
} // end namespace mlir
namespace llvm {
/// Provide isa functionality for operation casts.
template <typename T> struct isa_impl<T, ::mlir::Operation> {
static inline bool doit(const ::mlir::Operation &op) {
return T::classof(const_cast<::mlir::Operation *>(&op));
}
};
/// Provide specializations for operation casts as the resulting T is value
/// typed.
template <typename T> struct cast_retty_impl<T, ::mlir::Operation *> {
using ret_type = T;
};
template <typename T> struct cast_retty_impl<T, ::mlir::Operation> {
using ret_type = T;
};
template <class T>
struct cast_convert_val<T, ::mlir::Operation, ::mlir::Operation> {
static T doit(::mlir::Operation &val) { return T(&val); }
};
template <class T>
struct cast_convert_val<T, ::mlir::Operation *, ::mlir::Operation *> {
static T doit(::mlir::Operation *val) { return T(val); }
};
} // end namespace llvm
#endif // MLIR_IR_OPERATION_H

View File

@ -0,0 +1,483 @@
//===- OperationSupport.h ---------------------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines a number of support types that Operation and related
// classes build on top of.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_OPERATION_SUPPORT_H
#define MLIR_IR_OPERATION_SUPPORT_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/Support/TrailingObjects.h"
#include <memory>
namespace mlir {
class Block;
class Dialect;
class Operation;
struct OperationState;
class OpAsmParser;
class OpAsmParserResult;
class OpAsmPrinter;
class OpFoldResult;
class ParseResult;
class Pattern;
class Region;
class RewritePattern;
class Type;
class Value;
/// This is an adaptor from a list of values to named operands of OpTy. In a
/// generic operation context, e.g., in dialect conversions, an ordered array of
/// `Value`s is treated as operands of `OpTy`. This adaptor takes a reference
/// to the array and provides accessors with the same names as `OpTy` for
/// operands. This makes possible to create function templates that operate on
/// either OpTy or OperandAdaptor<OpTy> seamlessly.
template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
/// This is a vector that owns the patterns inside of it.
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
enum class OperationProperty {
/// This bit is set for an operation if it is a commutative operation: that
/// is a binary operator (two inputs) where "a op b" and "b op a" produce the
/// same results.
Commutative = 0x1,
/// This bit is set for operations that have no side effects: that means that
/// they do not read or write memory, or access any hidden state.
NoSideEffect = 0x2,
/// This bit is set for an operation if it is a terminator: that means
/// an operation at the end of a block.
Terminator = 0x4,
/// This bit is set for operations that are completely isolated from above.
/// This is used for operations whose regions are explicit capture only, i.e.
/// they are never allowed to implicitly reference values defined above the
/// parent operation.
IsolatedFromAbove = 0x8,
};
/// This is a "type erased" representation of a registered operation. This
/// should only be used by things like the AsmPrinter and other things that need
/// to be parameterized by generic operation hooks. Most user code should use
/// the concrete operation types.
class AbstractOperation {
public:
using OperationProperties = uint32_t;
/// This is the name of the operation.
const StringRef name;
/// This is the dialect that this operation belongs to.
Dialect &dialect;
/// Return true if this "op class" can match against the specified operation.
bool (&classof)(Operation *op);
/// Use the specified object to parse this ops custom assembly format.
ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result);
/// This hook implements the AsmPrinter for this operation.
void (&printAssembly)(Operation *op, OpAsmPrinter *p);
/// This hook implements the verifier for this operation. It should emits an
/// error message and returns failure if a problem is detected, or returns
/// success if everything is ok.
LogicalResult (&verifyInvariants)(Operation *op);
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the Builder::createOrFold API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
///
/// 1. They can leave the operation alone and without changing the IR, and
/// return failure.
/// 2. They can mutate the operation in place, without changing anything else
/// in the IR. In this case, return success.
/// 3. They can return a list of existing values that can be used instead of
/// the operation. In this case, fill in the results list and return
/// success. The caller will remove the operation and use those results
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
/// generalized constant folding.
LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results);
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context);
/// Returns whether the operation has a particular property.
bool hasProperty(OperationProperty property) const {
return opProperties & static_cast<OperationProperties>(property);
}
/// Returns if the operation has a particular trait.
template <template <typename T> class Trait> bool hasTrait() const {
return hasRawTrait(ClassID::getID<Trait>());
}
/// Look up the specified operation in the specified MLIRContext and return a
/// pointer to it if present. Otherwise, return a null pointer.
static const AbstractOperation *lookup(StringRef opName,
MLIRContext *context);
/// This constructor is used by Dialect objects when they register the list of
/// operations they contain.
template <typename T> static AbstractOperation get(Dialect &dialect) {
return AbstractOperation(
T::getOperationName(), dialect, T::getOperationProperties(), T::classof,
T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook,
T::getCanonicalizationPatterns, T::hasTrait);
}
private:
AbstractOperation(
StringRef name, Dialect &dialect, OperationProperties opProperties,
bool (&classof)(Operation *op),
ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result),
void (&printAssembly)(Operation *op, OpAsmPrinter *p),
LogicalResult (&verifyInvariants)(Operation *op),
LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results),
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context),
bool (&hasTrait)(ClassID *traitID))
: name(name), dialect(dialect), classof(classof),
parseAssembly(parseAssembly), printAssembly(printAssembly),
verifyInvariants(verifyInvariants), foldHook(foldHook),
getCanonicalizationPatterns(getCanonicalizationPatterns),
opProperties(opProperties), hasRawTrait(hasTrait) {}
/// The properties of the operation.
const OperationProperties opProperties;
/// This hook returns if the operation contains the trait corresponding
/// to the given ClassID.
bool (&hasRawTrait)(ClassID *traitID);
};
class OperationName {
public:
using RepresentationUnion =
llvm::PointerUnion<Identifier, const AbstractOperation *>;
OperationName(AbstractOperation *op) : representation(op) {}
OperationName(StringRef name, MLIRContext *context);
/// Return the name of the dialect this operation is registered to.
StringRef getDialect() const;
/// Return the name of this operation. This always succeeds.
StringRef getStringRef() const;
/// If this operation has a registered operation description, return it.
/// Otherwise return null.
const AbstractOperation *getAbstractOperation() const;
void print(raw_ostream &os) const;
void dump() const;
void *getAsOpaquePointer() const {
return static_cast<void *>(representation.getOpaqueValue());
}
static OperationName getFromOpaquePointer(void *pointer);
private:
RepresentationUnion representation;
OperationName(RepresentationUnion representation)
: representation(representation) {}
};
inline raw_ostream &operator<<(raw_ostream &os, OperationName identifier) {
identifier.print(os);
return os;
}
inline bool operator==(OperationName lhs, OperationName rhs) {
return lhs.getAsOpaquePointer() == rhs.getAsOpaquePointer();
}
inline bool operator!=(OperationName lhs, OperationName rhs) {
return lhs.getAsOpaquePointer() != rhs.getAsOpaquePointer();
}
// Make operation names hashable.
inline llvm::hash_code hash_value(OperationName arg) {
return llvm::hash_value(arg.getAsOpaquePointer());
}
/// This represents an operation in an abstracted form, suitable for use with
/// the builder APIs. This object is a large and heavy weight object meant to
/// be used as a temporary object on the stack. It is generally unwise to put
/// this in a collection.
struct OperationState {
MLIRContext *const context;
Location location;
OperationName name;
SmallVector<Value *, 4> operands;
/// Types of the results of this operation.
SmallVector<Type, 4> types;
SmallVector<NamedAttribute, 4> attributes;
/// Successors of this operation and their respective operands.
SmallVector<Block *, 1> successors;
/// Regions that the op will hold.
SmallVector<std::unique_ptr<Region>, 1> regions;
/// If the operation has a resizable operand list.
bool resizableOperandList = false;
public:
OperationState(Location location, StringRef name);
OperationState(Location location, OperationName name);
OperationState(Location location, StringRef name, ArrayRef<Value *> operands,
ArrayRef<Type> types, ArrayRef<NamedAttribute> attributes,
ArrayRef<Block *> successors = {},
MutableArrayRef<std::unique_ptr<Region>> regions = {},
bool resizableOperandList = false);
void addOperands(ArrayRef<Value *> newOperands) {
assert(successors.empty() &&
"Non successor operands should be added first.");
operands.append(newOperands.begin(), newOperands.end());
}
void addTypes(ArrayRef<Type> newTypes) {
types.append(newTypes.begin(), newTypes.end());
}
/// Add an attribute with the specified name.
void addAttribute(StringRef name, Attribute attr) {
addAttribute(Identifier::get(name, getContext()), attr);
}
/// Add an attribute with the specified name.
void addAttribute(Identifier name, Attribute attr) {
attributes.push_back({name, attr});
}
/// Add an array of named attributes.
void addAttributes(ArrayRef<NamedAttribute> newAttributes) {
attributes.append(newAttributes.begin(), newAttributes.end());
}
void addSuccessor(Block *successor, ArrayRef<Value *> succOperands) {
successors.push_back(successor);
// Insert a sentinal operand to mark a barrier between successor operands.
operands.push_back(nullptr);
operands.append(succOperands.begin(), succOperands.end());
}
/// Create a region that should be attached to the operation. These regions
/// can be filled in immediately without waiting for Operation to be
/// created. When it is, the region bodies will be transferred.
Region *addRegion();
/// Take a region that should be attached to the Operation. The body of the
/// region will be transferred when the Operation is constructed. If the
/// region is null, a new empty region will be attached to the Operation.
void addRegion(std::unique_ptr<Region> &&region);
/// Sets the operand list of the operation as resizable.
void setOperandListToResizable(bool isResizable = true) {
resizableOperandList = isResizable;
}
/// Get the context held by this operation state.
MLIRContext *getContext() { return location->getContext(); }
};
namespace detail {
/// A utility class holding the information necessary to dynamically resize
/// operands.
struct ResizableStorage {
ResizableStorage(OpOperand *opBegin, unsigned numOperands)
: firstOpAndIsDynamic(opBegin, false), capacity(numOperands) {}
~ResizableStorage() { cleanupStorage(); }
/// Cleanup any allocated storage.
void cleanupStorage() {
// If the storage is dynamic, then we need to free the storage.
if (isStorageDynamic())
free(firstOpAndIsDynamic.getPointer());
}
/// Sets the storage pointer to a new dynamically allocated block.
void setDynamicStorage(OpOperand *opBegin) {
/// Cleanup the old storage if necessary.
cleanupStorage();
firstOpAndIsDynamic.setPointerAndInt(opBegin, true);
}
/// Returns the current storage pointer.
OpOperand *getPointer() { return firstOpAndIsDynamic.getPointer(); }
/// Returns if the current storage of operands is in the trailing objects is
/// in a dynamically allocated memory block.
bool isStorageDynamic() const { return firstOpAndIsDynamic.getInt(); }
/// A pointer to the first operand element. This is either to the trailing
/// objects storage, or a dynamically allocated block of memory.
llvm::PointerIntPair<OpOperand *, 1, bool> firstOpAndIsDynamic;
// The maximum number of operands that can be currently held by the storage.
unsigned capacity;
};
/// This class handles the management of operation operands. Operands are
/// stored similarly to the elements of a SmallVector except for two key
/// differences. The first is the inline storage, which is a trailing objects
/// array. The second is that being able to dynamically resize the operand list
/// is optional.
class OperandStorage final
: private llvm::TrailingObjects<OperandStorage, ResizableStorage,
OpOperand> {
public:
OperandStorage(unsigned numOperands, bool resizable)
: numOperands(numOperands), resizable(resizable) {
// Initialize the resizable storage.
if (resizable) {
new (&getResizableStorage())
ResizableStorage(getTrailingObjects<OpOperand>(), numOperands);
}
}
~OperandStorage() {
// Manually destruct the operands.
for (auto &operand : getOperands())
operand.~OpOperand();
// If the storage is resizable then destruct the utility.
if (resizable)
getResizableStorage().~ResizableStorage();
}
/// Replace the operands contained in the storage with the ones provided in
/// 'operands'.
void setOperands(Operation *owner, ArrayRef<Value *> operands);
/// Erase an operand held by the storage.
void eraseOperand(unsigned index);
/// Get the operation operands held by the storage.
MutableArrayRef<OpOperand> getOperands() {
return {getRawOperands(), size()};
}
/// Return the number of operands held in the storage.
unsigned size() const { return numOperands; }
/// Returns the additional size necessary for allocating this object.
static size_t additionalAllocSize(unsigned numOperands, bool resizable) {
return additionalSizeToAlloc<ResizableStorage, OpOperand>(resizable ? 1 : 0,
numOperands);
}
/// Returns if this storage is resizable.
bool isResizable() const { return resizable; }
private:
/// Clear the storage and destroy the current operands held by the storage.
void clear() { numOperands = 0; }
/// Returns the current pointer for the raw operands array.
OpOperand *getRawOperands() {
return resizable ? getResizableStorage().getPointer()
: getTrailingObjects<OpOperand>();
}
/// Returns the resizable operand utility class.
ResizableStorage &getResizableStorage() {
assert(resizable);
return *getTrailingObjects<ResizableStorage>();
}
/// Grow the internal resizable operand storage.
void grow(ResizableStorage &resizeUtil, size_t minSize);
/// The current number of operands, and the current max operand capacity.
unsigned numOperands : 31;
/// Whether this storage is resizable or not.
bool resizable : 1;
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<OperandStorage, ResizableStorage, OpOperand>;
size_t numTrailingObjects(OverloadToken<ResizableStorage>) const {
return resizable ? 1 : 0;
}
};
} // end namespace detail
} // end namespace mlir
namespace llvm {
// Identifiers hash just like pointers, there is no need to hash the bytes.
template <> struct DenseMapInfo<mlir::OperationName> {
static mlir::OperationName getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::OperationName::getFromOpaquePointer(pointer);
}
static mlir::OperationName getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::OperationName::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(mlir::OperationName Val) {
return DenseMapInfo<void *>::getHashValue(Val.getAsOpaquePointer());
}
static bool isEqual(mlir::OperationName LHS, mlir::OperationName RHS) {
return LHS == RHS;
}
};
/// The pointer inside of an identifier comes from a StringMap, so its alignment
/// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to
/// steal the low bits.
template <> struct PointerLikeTypeTraits<mlir::OperationName> {
public:
static inline void *getAsVoidPointer(mlir::OperationName I) {
return const_cast<void *>(I.getAsOpaquePointer());
}
static inline mlir::OperationName getFromVoidPointer(void *P) {
return mlir::OperationName::getFromOpaquePointer(P);
}
enum {
NumLowBitsAvailable = PointerLikeTypeTraits<
mlir::OperationName::RepresentationUnion>::NumLowBitsAvailable
};
};
} // end namespace llvm
#endif

View File

@ -0,0 +1,455 @@
//===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_PATTERNMATCHER_H
#define MLIR_PATTERNMATCHER_H
#include "mlir/IR/Builders.h"
namespace mlir {
class PatternRewriter;
//===----------------------------------------------------------------------===//
// PatternBenefit class
//===----------------------------------------------------------------------===//
/// This class represents the benefit of a pattern match in a unitless scheme
/// that ranges from 0 (very little benefit) to 65K. The most common unit to
/// use here is the "number of operations matched" by the pattern.
///
/// This also has a sentinel representation that can be used for patterns that
/// fail to match.
///
class PatternBenefit {
enum { ImpossibleToMatchSentinel = 65535 };
public:
/*implicit*/ PatternBenefit(unsigned benefit);
PatternBenefit(const PatternBenefit &) = default;
PatternBenefit &operator=(const PatternBenefit &) = default;
static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
/// If the corresponding pattern can match, return its benefit. If the
// corresponding pattern isImpossibleToMatch() then this aborts.
unsigned short getBenefit() const;
bool operator==(const PatternBenefit &rhs) const {
return representation == rhs.representation;
}
bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
bool operator<(const PatternBenefit &rhs) const {
return representation < rhs.representation;
}
private:
PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
unsigned short representation;
};
/// Pattern state is used by patterns that want to maintain state between their
/// match and rewrite phases. Patterns can define a pattern-specific subclass
/// of this.
class PatternState {
public:
virtual ~PatternState() {}
protected:
// Must be subclassed.
PatternState() {}
};
/// This is the type returned by a pattern match. A match failure returns a
/// None value. A match success returns a Some value with any state the pattern
/// may need to maintain (but may also be null).
using PatternMatchResult = Optional<std::unique_ptr<PatternState>>;
//===----------------------------------------------------------------------===//
// Pattern class
//===----------------------------------------------------------------------===//
/// Instances of Pattern can be matched against SSA IR. These matches get used
/// in ways dependent on their subclasses and the driver doing the matching.
/// For example, RewritePatterns implement a rewrite from one matched pattern
/// to a replacement DAG tile.
class Pattern {
public:
/// Return the benefit (the inverse of "cost") of matching this pattern. The
/// benefit of a Pattern is always static - rewrites that may have dynamic
/// benefit can be instantiated multiple times (different Pattern instances)
/// for each benefit that they may return, and be guarded by different match
/// condition predicates.
PatternBenefit getBenefit() const { return benefit; }
/// Return the root node that this pattern matches. Patterns that can
/// match multiple root types are instantiated once per root.
OperationName getRootKind() const { return rootKind; }
//===--------------------------------------------------------------------===//
// Implementation hooks for patterns to implement.
//===--------------------------------------------------------------------===//
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). On failure, this
/// returns a None value. On success it returns a (possibly null)
/// pattern-specific state wrapped in an Optional.
virtual PatternMatchResult match(Operation *op) const = 0;
virtual ~Pattern() {}
//===--------------------------------------------------------------------===//
// Helper methods to simplify pattern implementations
//===--------------------------------------------------------------------===//
/// This method indicates that no match was found.
static PatternMatchResult matchFailure() { return None; }
/// This method indicates that a match was found and has the specified cost.
PatternMatchResult
matchSuccess(std::unique_ptr<PatternState> state = {}) const {
return PatternMatchResult(std::move(state));
}
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
private:
const OperationName rootKind;
const PatternBenefit benefit;
virtual void anchor();
};
/// RewritePattern is the common base class for all DAG to DAG replacements.
/// There are two possible usages of this class:
/// * Multi-step RewritePattern with "match" and "rewrite"
/// - By overloading the "match" and "rewrite" functions, the user can
/// separate the concerns of matching and rewriting.
/// * Single-step RewritePattern with "matchAndRewrite"
/// - By overloading the "matchAndRewrite" function, the user can perform
/// the rewrite in the same call as the match. This removes the need for
/// any PatternState.
///
class RewritePattern : public Pattern {
public:
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// rewriter. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const;
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). On failure, this
/// returns a None value. On success, it returns a (possibly null)
/// pattern-specific state wrapped in an Optional. This state is passed back
/// into the rewrite function if this match is selected.
PatternMatchResult match(Operation *op) const override;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
/// function will automatically perform the rewrite.
virtual PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
if (auto matchResult = match(op)) {
rewrite(op, std::move(*matchResult), rewriter);
return matchSuccess();
}
return matchFailure();
}
/// Return a list of operations that may be generated when rewriting an
/// operation instance with this pattern.
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
RewritePattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context)
: Pattern(rootName, benefit, context) {}
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching. They can also specify
/// the names of operations that may be generated during a successful rewrite.
RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context);
/// A list of the potential operations that may be generated when rewriting
/// an op with this pattern.
llvm::SmallVector<OperationName, 2> generatedOps;
};
/// OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(SourceOp::getOperationName(), benefit, context) {}
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const final {
rewrite(llvm::cast<SourceOp>(op), std::move(state), rewriter);
}
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(llvm::cast<SourceOp>(op), rewriter);
}
PatternMatchResult match(Operation *op) const final {
return match(llvm::cast<SourceOp>(op));
}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(llvm::cast<SourceOp>(op), rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const {
rewrite(op, rewriter);
}
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual PatternMatchResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual PatternMatchResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const {
if (auto matchResult = match(op)) {
rewrite(op, std::move(*matchResult), rewriter);
return matchSuccess();
}
return matchFailure();
}
};
//===----------------------------------------------------------------------===//
// PatternRewriter class
//===----------------------------------------------------------------------===//
/// This class coordinates the application of a pattern to the current function,
/// providing a way to create operations and keep track of what gets deleted.
///
/// These class serves two purposes:
/// 1) it is the interface that patterns interact with to make mutations to the
/// IR they are being applied to.
/// 2) It is a base class that clients of the PatternMatcher use when they want
/// to apply patterns and observe their effects (e.g. to keep worklists or
/// other data structures up to date).
///
class PatternRewriter : public OpBuilder {
public:
/// Create operation of specific op type at the current insertion point
/// without verifying to see if it is valid.
template <typename OpTy, typename... Args>
OpTy create(Location location, Args... args) {
OperationState state(location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *op = createOperation(state);
auto result = dyn_cast<OpTy>(op);
assert(result && "Builder didn't return the right type");
return result;
}
/// Creates an operation of specific op type at the current insertion point.
/// If the result is an invalid op (the verifier hook fails), emit an error
/// and return null.
template <typename OpTy, typename... Args>
OpTy createChecked(Location location, Args... args) {
OperationState state(location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *op = createOperation(state);
// If the Operation we produce is valid, return it.
if (!OpTy::verifyInvariants(op)) {
auto result = dyn_cast<OpTy>(op);
assert(result && "Builder didn't return the right type");
return result;
}
// Otherwise, the error message got emitted. Just remove the operation
// we made.
op->erase();
return OpTy();
}
/// This is implemented to create the specified operations and serves as a
/// notification hook for rewriters that want to know about new operations.
virtual Operation *createOperation(const OperationState &state) = 0;
/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be different. The caller
/// is responsible for creating or updating the operation transferring flow
// of control to the region and pass it the correct block arguments.
virtual void inlineRegionBefore(Region &region, Region &parent,
Region::iterator before);
void inlineRegionBefore(Region &region, Block *before);
/// This method performs the final replacement for a pattern, where the
/// results of the operation are updated to use the specified list of SSA
/// values. In addition to replacing and removing the specified operation,
/// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those values are dead, this will
/// remove them as well.
virtual void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead);
void replaceOp(Operation *op, ArrayRef<Value *> newValues) {
replaceOp(op, newValues, llvm::None);
}
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
template <typename OpTy, typename... Args>
void replaceOpWithNewOp(Operation *op, Args &&... args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {});
}
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types. This allows
/// specifying a list of ops that may be removed if dead.
template <typename OpTy, typename... Args>
void replaceOpWithNewOp(ArrayRef<Value *> valuesToRemoveIfDead, Operation *op,
Args &&... args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(),
valuesToRemoveIfDead);
}
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before) {
return block->splitBlock(before);
}
/// This method is used as the final notification hook for patterns that end
/// up modifying the pattern root in place, by changing its operands. This is
/// a minor efficiency win (it avoids creating a new operation and removing
/// the old one) but also often allows simpler code in the client.
///
/// The valuesToRemoveIfDead list is an optional list of values that the
/// rewriter should remove if they are dead at this point.
///
void updatedRootInPlace(Operation *op,
ArrayRef<Value *> valuesToRemoveIfDead = {});
protected:
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
virtual ~PatternRewriter();
// These are the callback methods that subclasses can choose to implement if
// they would like to be notified about certain types of mutations.
/// Notify the pattern rewriter that the specified operation has been mutated
/// in place. This is called after the mutation is done.
virtual void notifyRootUpdated(Operation *op) {}
/// Notify the pattern rewriter that the specified operation is about to be
/// replaced with another set of operations. This is called before the uses
/// of the operation have been changed.
virtual void notifyRootReplaced(Operation *op) {}
/// This is called on an operation that a pattern match is removing, right
/// before the operation is deleted. At this point, the operation has zero
/// uses.
virtual void notifyOperationRemoved(Operation *op) {}
private:
/// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
ArrayRef<Value *> valuesToRemoveIfDead);
};
//===----------------------------------------------------------------------===//
// Pattern-driven rewriters
//===----------------------------------------------------------------------===//
/// This is a vector that owns the patterns inside of it.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
/// This class manages optimization and execution of a group of rewrite
/// patterns, providing an API for finding and applying, the best match against
/// a given node.
///
class RewritePatternMatcher {
public:
/// Create a RewritePatternMatcher with the specified set of patterns.
explicit RewritePatternMatcher(OwningRewritePatternList &&patterns);
/// Try to match the given operation to a pattern and rewrite it. Return
/// true if any pattern matches.
bool matchAndRewrite(Operation *op, PatternRewriter &rewriter);
private:
RewritePatternMatcher(const RewritePatternMatcher &) = delete;
void operator=(const RewritePatternMatcher &) = delete;
/// The group of patterns that are matched for optimization through this
/// matcher.
OwningRewritePatternList patterns;
};
/// Rewrite the regions of the specified operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
/// work-list driven manner. Return true if no more patterns can be matched in
/// the result operation regions.
/// Note: This does not apply patterns to the top-level operation itself.
///
bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
/// Helper class to create a list of rewrite patterns given a list of their
/// types and a list of attributes perfect-forwarded to each of the conversion
/// constructors.
template <typename Arg, typename... Args> struct RewriteListBuilder {
template <typename... ConstructorArgs>
static void build(OwningRewritePatternList &patterns,
ConstructorArgs &&... constructorArgs) {
RewriteListBuilder<Args...>::build(
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
RewriteListBuilder<Arg>::build(
patterns, std::forward<ConstructorArgs>(constructorArgs)...);
}
};
// Template specialization to stop recursion.
template <typename Arg> struct RewriteListBuilder<Arg> {
template <typename... ConstructorArgs>
static void build(OwningRewritePatternList &patterns,
ConstructorArgs &&... constructorArgs) {
patterns.emplace_back(llvm::make_unique<Arg>(
std::forward<ConstructorArgs>(constructorArgs)...));
}
};
} // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H

Some files were not shown because too many files have changed in this diff Show More