Updates LLVM usage to match
[418c218efa95](https://github.com/llvm/llvm-project/commit/418c218efa95)

PiperOrigin-RevId: 357010454
Change-Id: I3454c9b6a75e3794398a02c152eecfb4ef3eedff
This commit is contained in:
A. Unique TensorFlower 2021-02-11 10:48:15 -08:00 committed by TensorFlower Gardener
parent 0fcaae3843
commit f9b33b85b7
17 changed files with 90 additions and 40 deletions

View File

@ -100,10 +100,10 @@ int main(int argc, char** argv) {
}
// Load the MLIR module.
mlir::MLIRContext context;
context.getDialectRegistry()
.insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
mlir::StandardOpsDialect>();
mlir::DialectRegistry registry;
registry.insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
mlir::StandardOpsDialect>();
mlir::MLIRContext context(registry);
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());

View File

@ -53,8 +53,9 @@ TfLiteStatus QuantizeModel(
return kTfLiteError;
}
MLIRContext context;
context.getDialectRegistry().insert<mlir::TFL::TensorFlowLiteDialect>();
DialectRegistry registry;
registry.insert<mlir::TFL::TensorFlowLiteDialect>();
MLIRContext context(registry);
StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true);

View File

@ -182,8 +182,9 @@ Status MlirFunctionOptimizationPass::Run(
<< ", Total: " << registry_->passes().size();
GraphDebugInfo debug_info;
mlir::MLIRContext context;
RegisterDialects(context.getDialectRegistry());
mlir::DialectRegistry registry;
RegisterDialects(registry);
mlir::MLIRContext context(registry);
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
@ -342,8 +343,9 @@ Status MlirV1CompatGraphOptimizationPass::Run(
<< " passes)";
GraphDebugInfo debug_info;
mlir::MLIRContext context;
RegisterDialects(context.getDialectRegistry());
mlir::DialectRegistry registry;
RegisterDialects(registry);
mlir::MLIRContext context(registry);
GraphImportConfig import_config;
import_config.upgrade_legacy = true;
// Restrict functionalization to TPU nodes to avoid problems in v1 session

View File

@ -238,8 +238,9 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
const std::string &pass_pipeline,
bool show_debug_info,
TF_Status *status) {
mlir::MLIRContext context;
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
mlir::DialectRegistry registry;
mlir::RegisterAllTensorFlowDialects(registry);
mlir::MLIRContext context(registry);
mlir::OwningModuleRef module;
{
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);

View File

@ -30,17 +30,20 @@ limitations under the License.
PYBIND11_MODULE(mlir_wrapper, m) {
m.def("preloadTensorFlowDialects", [](mlir::MLIRContext &context) {
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
context.getDialectRegistry().loadAll(&context);
mlir::DialectRegistry registry;
mlir::RegisterAllTensorFlowDialects(registry);
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
});
m.def("verify", [](std::string input) {
llvm::SourceMgr SM = llvm::SourceMgr();
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
llvm::SMLoc());
mlir::MLIRContext ctx;
mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
ctx.getDialectRegistry().loadAll(&ctx);
mlir::DialectRegistry registry;
mlir::RegisterAllTensorFlowDialects(registry);
mlir::MLIRContext ctx(registry);
ctx.loadAllAvailableDialects();
auto module = mlir::parseSourceFile(SM, &ctx);
if (!module) {
return false;

View File

@ -77,8 +77,10 @@ using tensorflow::tracing::TracingTensorHandle;
namespace {
void RegisterDialects(mlir::MLIRContext& ctx) {
mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
ctx.getDialectRegistry().loadAll(&ctx);
mlir::DialectRegistry registry;
mlir::RegisterAllTensorFlowDialects(registry);
ctx.appendDialectRegistry(registry);
ctx.loadAllAvailableDialects();
}
Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,

View File

@ -146,7 +146,8 @@ void LoadImporterDialects(mlir::MLIRContext& context) {
// Load dialects involved in the conversion
mlir::DialectRegistry registry;
mlir::RegisterAllTensorFlowDialects(registry);
registry.loadAll(&context);
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
}
// This class is used to generate new MLIR function name strings that are both

View File

@ -491,8 +491,9 @@ Status CompileSerializedMlirToXlaHlo(
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
mlir::MLIRContext mlir_context;
RegisterDialects(mlir_context.getDialectRegistry());
mlir::DialectRegistry mlir_registry;
RegisterDialects(mlir_registry);
mlir::MLIRContext mlir_context(mlir_registry);
mlir::OwningModuleRef mlir_module;
TF_RETURN_IF_ERROR(
@ -646,7 +647,9 @@ Status GraphToModule(const Graph& graph,
const GraphDebugInfo& debug_info,
mlir::MLIRContext* context,
mlir::OwningModuleRef* module) {
RegisterDialects(context->getDialectRegistry());
mlir::DialectRegistry registry;
RegisterDialects(registry);
context->appendDialectRegistry(registry);
GraphImportConfig config;
config.graph_as_function = true;
config.control_outputs = control_rets;

View File

@ -371,7 +371,9 @@ static mlir::OwningModuleRef SerializedMlirStringAttrToMlirModuleTranslate(
}
auto str_attr = attr.cast<mlir::StringAttr>();
RegisterMlirInputDialects(context->getDialectRegistry());
mlir::DialectRegistry registry;
RegisterMlirInputDialects(registry);
context->appendDialectRegistry(registry);
mlir::OwningModuleRef module_ref;
auto status =
DeserializeMlirModule(str_attr.getValue().str(), context, &module_ref);

View File

@ -92,7 +92,7 @@ std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::GetFromText(
StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx) {
mlir_ctx->allowUnregisteredDialects(/*allow=*/true);
// Load dialects involved in the conversion
mlir::DialectRegistry& registry = mlir_ctx->getDialectRegistry();
mlir::DialectRegistry registry;
// clang-format off
registry.insert<mlir::StandardOpsDialect,
mlir::scf::SCFDialect,
@ -102,7 +102,8 @@ std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::GetFromText(
mlir::tf_executor::TensorFlowExecutorDialect,
mlir::TFR::TFRDialect>();
// clang-format on
registry.loadAll(mlir_ctx);
mlir_ctx->appendDialectRegistry(registry);
mlir_ctx->loadAllAvailableDialects();
// Load the TFR functions in a mlir::ModuleOp
auto memory_buffer = llvm::MemoryBuffer::getMemBuffer(

View File

@ -33,12 +33,12 @@ limitations under the License.
PYBIND11_MODULE(tfr_wrapper, m) {
m.def("verify", [](std::string input) {
mlir::MLIRContext ctx;
auto& registry = ctx.getDialectRegistry();
mlir::DialectRegistry registry;
registry.insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
mlir::TFR::TFRDialect>();
ctx.getDialectRegistry().loadAll(&ctx);
mlir::MLIRContext ctx(registry);
ctx.loadAllAvailableDialects();
llvm::SourceMgr source_mgr = llvm::SourceMgr();
source_mgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),

View File

@ -388,9 +388,10 @@ StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
bool embed_memref_prints, bool generate_fatbin, bool print_ptx,
bool enable_ftz, bool cpu_codegen) {
auto& registry = context.getDialectRegistry();
mlir::DialectRegistry registry;
mlir::RegisterAllTensorFlowDialects(registry);
registry.insert<mlir::chlo::HloClientDialect, mlir::mhlo::MhloDialect>();
context.appendDialectRegistry(registry);
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
TF_RETURN_IF_ERROR(

View File

@ -9,7 +9,7 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) ->
// CHECK-SAME: ([[LHS:%.*]]: tensor<1x4x2xf32>, [[RHS:%.*]]: tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: [[LHSSHAPE:%.*]] = shape.shape_of [[LHS]] : tensor<1x4x2xf32>
// CHECK: [[RHSSHAPE:%.*]] = shape.shape_of [[RHS]] : tensor<3x2x4xf32>
// CHECK: [[CM2:%.*]] = constant -2 : i32
// CHECK: [[CM2:%.*]] = constant -2 : index
// CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]])
// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]])
// CHECK: [[BCASTHEAD:%.*]] = shape.broadcast [[LHSHEAD]], [[RHSHEAD]]

View File

@ -2774,7 +2774,7 @@ static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
Value const_neg2 =
rewriter->create<ConstantOp>(loc, rewriter->getI32IntegerAttr(-2));
rewriter->create<ConstantOp>(loc, rewriter->getIndexAttr(-2));
auto lhs_splitted =
rewriter->create<shape::SplitAtOp>(loc, lhs_shape, const_neg2);
auto rhs_splitted =

View File

@ -27,8 +27,9 @@ namespace xla {
namespace gpu {
TEST(IrEmissionUtilsTest, TestOperandPartitionNoAlias) {
mlir::MLIRContext context;
mlir::mhlo::registerAllMhloDialects(context.getDialectRegistry());
mlir::DialectRegistry registry;
mlir::mhlo::registerAllMhloDialects(registry);
mlir::MLIRContext context(registry);
auto module = mlir::parseSourceString(R"(
func @foo(%arg0 : memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>) {
@ -43,8 +44,9 @@ TEST(IrEmissionUtilsTest, TestOperandPartitionNoAlias) {
}
TEST(IrEmissionUtilsTest, TestOperandPartitionWithAlias0) {
mlir::MLIRContext context;
mlir::mhlo::registerAllMhloDialects(context.getDialectRegistry());
mlir::DialectRegistry registry;
mlir::mhlo::registerAllMhloDialects(registry);
mlir::MLIRContext context(registry);
auto module = mlir::parseSourceString(R"(
func @foo(%arg0 : memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>) {
@ -59,8 +61,9 @@ TEST(IrEmissionUtilsTest, TestOperandPartitionWithAlias0) {
}
TEST(IrEmissionUtilsTest, TestOperandPartitionWithAlias1) {
mlir::MLIRContext context;
mlir::mhlo::registerAllMhloDialects(context.getDialectRegistry());
mlir::DialectRegistry registry;
mlir::mhlo::registerAllMhloDialects(registry);
mlir::MLIRContext context(registry);
auto module = mlir::parseSourceString(R"(
func @foo(%arg0 : memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>) {

View File

@ -685,8 +685,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
)
# Check out LLVM and MLIR from llvm-project.
LLVM_COMMIT = "9db6e97a8605f6a447ed171e59d5fa46fdfdc432"
LLVM_SHA256 = "f30fe9eb9a342187d25babccd85c3af4f09ee7340108a9f3a259af1dc0c76484"
LLVM_COMMIT = "418c218efa950245ba075b9bb3a53505b807c5df"
LLVM_SHA256 = "9b16312cb14ee38866f96e937b7af1efb6ab57da157ad48aa126a57db07301a1"
LLVM_URLS = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),

View File

@ -4498,6 +4498,34 @@ gentbl(
],
)
filegroup(
name = "LinalgSparseOpsTdFiles",
srcs = [
"include/mlir/Dialect/Linalg/IR/LinalgBase.td",
"include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td",
"include/mlir/Interfaces/ViewLikeInterface.td",
":OpBaseTdFiles",
],
)
gentbl(
name = "LinalgSparseOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
"-gen-op-decls",
"include/mlir/Dialect/Linalg/IR/LinalgSparseOps.h.inc",
),
(
"-gen-op-defs",
"include/mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td",
td_srcs = [":LinalgSparseOpsTdFiles"],
)
gentbl(
name = "LinalgInterfacesIncGen",
strip_include_prefix = "include",
@ -4643,6 +4671,7 @@ cc_library(
":LinalgInterfacesIncGen",
":LinalgNamedStructuredOpsIncGen",
":LinalgOpsIncGen",
":LinalgSparseOpsIncGen",
":LinalgStructuredOpsIncGen",
":Parser",
":SideEffectInterfaces",
@ -4700,6 +4729,7 @@ cc_library(
":LLVMDialect",
":LinalgOps",
":LinalgPassIncGen",
":LinalgSparseOpsIncGen",
":LinalgStructuredOpsIncGen",
":Pass",
":SCFDialect",