Updates LLVM usage to match
[f9dc2b707935](https://github.com/llvm/llvm-project/commit/f9dc2b707935)

PiperOrigin-RevId: 327538369
Change-Id: I199bf5d4f7f311229949d6174bea84c833b21074
This commit is contained in:
A. Unique TensorFlower 2020-08-19 17:17:34 -07:00 committed by TensorFlower Gardener
parent 3d7a7556c5
commit e2ff54f938
45 changed files with 96 additions and 123 deletions

View File

@ -43,7 +43,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",

View File

@ -13,112 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/MlirOptMain.h"
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> verifyPasses(
"verify-each",
llvm::cl::desc("Run the verifier after each transformation pass"),
llvm::cl::init(true));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> allowUnregisteredDialects(
"allow-unregistered-dialect",
llvm::cl::desc("Allow operation with no registered dialects"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> showDialects(
"show-dialects", llvm::cl::desc("Print the list of registered dialects"),
llvm::cl::init(false));
int main(int argc, char **argv) {
mlir::registerAllDialects();
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
mlir::registerAllPasses();
mlir::mhlo::registerAllDialects();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
llvm::InitLLVM y(argc, argv);
// Register any pass manager command line options.
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
mlir::registerPassManagerCLOptions();
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
// Parse pass names in main to ensure static initialization completed.
llvm::cl::ParseCommandLineOptions(argc, argv,
"MLIR modular optimizer driver\n");
if (showDialects) {
mlir::MLIRContext context;
llvm::outs() << "Registered Dialects:\n";
for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
llvm::outs() << dialect->getNamespace() << "\n";
}
return 0;
}
// Set up the input file.
std::string errorMessage;
auto file = mlir::openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return 1;
}
auto output = mlir::openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
exit(1);
}
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects))) {
return 1;
}
// Keep the output file if the invocation of MlirOptMain was successful.
output->keep();
return 0;
return failed(
mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry));
}

View File

@ -673,6 +673,7 @@ cc_library(
":flatbuffer_tflite_operator_lib",
":tensorflow_lite",
":tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",

View File

@ -61,6 +61,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
@ -354,8 +355,13 @@ class Translator {
if (emit_custom_ops) {
enabled_op_types_.emplace(OpType::kCustomOp);
}
tf_dialect_ = module.getContext()->getRegisteredDialect("tf");
tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl");
tf_dialect_ =
module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>();
tfl_dialect_ = module.getContext()
->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>();
// Right now the TF executor dialect is still needed to build NodeDef.
module.getContext()
->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
}
Optional<std::string> TranslateInternal();

View File

@ -65,6 +65,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -479,7 +480,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
auto dialect = elem_type.getContext()->getLoadedDialect("tf");
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
@ -1072,6 +1073,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
const std::vector<std::string>& ordered_input_arrays,
const std::vector<std::string>& ordered_output_arrays,
bool experimental_prune_unreachable_nodes_unconditionally) {
context->loadDialect<
mlir::StandardOpsDialect, mlir::quant::QuantizationDialect,
mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>();
auto model_ptr =
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
if (nullptr == model_ptr) {

View File

@ -249,7 +249,7 @@ Status mlir::CustomOptionsToAttributes(
{static_cast<int64_t>(custom_options.size())}, builder.getIntegerType(8));
attributes->emplace_back(builder.getNamedAttr(
"custom_option",
OpaqueElementsAttr::get(builder.getContext()->getRegisteredDialect("tfl"),
OpaqueElementsAttr::get(builder.getContext()->getLoadedDialect("tfl"),
type, content)));
return Status::OK();

View File

@ -98,6 +98,7 @@ int main(int argc, char** argv) {
// Load the MLIR module.
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));

View File

@ -49,6 +49,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const GraphDef& input,
string* result) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
GraphImportConfig specs;
mlir::TFL::QuantizationSpecs quant_specs;

View File

@ -122,6 +122,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
string* result) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::TFL::QuantizationSpecs quant_specs;
// Parse input arrays.

View File

@ -52,6 +52,7 @@ TfLiteStatus QuantizeModel(
}
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true);

View File

@ -37,6 +37,7 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true);

View File

@ -46,7 +46,7 @@ stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
} else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
auto etype = complex_type.getElementType();
if (etype.isF32()) {
auto dialect = etype.getContext()->getRegisteredDialect("tf");
auto dialect = etype.getContext()->getLoadedDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);

View File

@ -56,9 +56,9 @@ inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
const std::string& content) {
ShapedType type = RankedTensorType::get(
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
return OpaqueElementsAttr::get(
builder->getContext()->getRegisteredDialect("tfl"), type,
StringRef(content.data(), content.size()));
return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
type,
StringRef(content.data(), content.size()));
}
inline TensorType GetInputType(FuncOp func, int idx) {

View File

@ -128,6 +128,7 @@ Status MlirFunctionOptimizationPass::Run(
GraphDebugInfo debug_info;
RegisterDialects();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
@ -208,6 +209,7 @@ Status MlirV1CompatGraphOptimizationPass::Run(
GraphDebugInfo debug_info;
RegisterDialects();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
GraphImportConfig import_config;
import_config.upgrade_legacy = true;
// Restrict functionalization to TPU nodes to avoid problems in v1 session

View File

@ -41,6 +41,7 @@ std::string ImportGraphDef(const std::string &proto,
GraphDebugInfo debug_info;
GraphImportConfig specs;
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
@ -85,6 +86,7 @@ std::string ExperimentalConvertSavedModelToMlir(
std::vector<string> exported_names =
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto module_or = ConvertSavedModelToMlir(
&bundle, &context, absl::Span<std::string>(exported_names));
if (!module_or.status().ok()) {
@ -115,6 +117,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
// Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto module_or =
ConvertSavedModelV1ToMlir(bundle, {}, &context, upgrade_legacy);
if (!module_or.status().ok()) {

View File

@ -38,6 +38,7 @@ PYBIND11_MODULE(mlir_wrapper, m) {
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
llvm::SMLoc());
mlir::MLIRContext ctx;
ctx.loadAllGloballyRegisteredDialects();
auto module = mlir::parseSourceFile(SM, &ctx);
if (!module) {
return false;

View File

@ -240,7 +240,7 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
auto def_op = val.getDefiningOp();
#ifndef NDEBUG
auto exec_dialect =
function.getContext()->getRegisteredDialect("tf_executor");
function.getContext()->getLoadedDialect("tf_executor");
assert(def_op->getDialect() == exec_dialect &&
"unable to forward control dependencies");
#endif

View File

@ -104,7 +104,7 @@ LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect,
}
void LaunchToDeviceAttributePass::runOnFunction() {
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
if (!tf_dialect) {
getFunction().emitError() << "'tf' dialect is not registered";
return signalPassFailure();

View File

@ -152,7 +152,7 @@ void UnmarkChildren(Block* block) {
void MarkOpsForOutsideCompilation::runOnOperation() {
auto module = getOperation();
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
if (!tf_dialect) {
getOperation().emitError() << "'tf' dialect is not registered";
return signalPassFailure();

View File

@ -438,7 +438,7 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
void ReplicateToIslandPass::runOnOperation() {
auto module = getOperation();
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
if (!tf_dialect) {
module.emitError() << "'tf' dialect is not registered";
return signalPassFailure();

View File

@ -597,7 +597,7 @@ ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
bool propagate_caller_callee_constants)
: graph_version_(graph_version),
propagate_caller_callee_constants_(propagate_caller_callee_constants) {
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
tf_dialect_ = context->getLoadedDialect<TensorFlowDialect>();
}
ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,

View File

@ -34,7 +34,7 @@ class SimpleTFDeviceAssignmentPass
void runOnFunction() override {
Builder builder(&getContext());
Dialect* tf = getContext().getRegisteredDialect<TensorFlowDialect>();
Dialect* tf = getContext().getLoadedDialect<TensorFlowDialect>();
getFunction().walk([&](Operation* op) {
if (auto device_attr = op->getAttrOfType<StringAttr>("device")) {
// We assign default device to ops with device attribute that is empty.

View File

@ -726,7 +726,7 @@ Status Exporter::Convert(mlir::ModuleOp module,
mlir::Identifier::get("main", module.getContext());
absl::optional<mlir::FuncOp> entry_func;
FunctionDefLibrary flib;
auto tf_dialect = module.getContext()->getRegisteredDialect("tf");
auto tf_dialect = module.getContext()->getLoadedDialect("tf");
for (auto function : module.getOps<mlir::FuncOp>()) {
if (function.isExternal())
return errors::FailedPrecondition("External functions not supported");
@ -799,7 +799,7 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
mlir::FuncOp func, const GraphExportConfig& configs,
FunctionDef* function_def) {
Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf");
Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
FunctionDefLibrary flib;
TF_RETURN_IF_ERROR(
Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));

View File

@ -420,6 +420,7 @@ Status CompileSerializedMlirToXlaHlo(
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
RegisterDialects();
mlir::MLIRContext mlir_context;
mlir_context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef mlir_module;
TF_RETURN_IF_ERROR(
@ -509,6 +510,7 @@ Status CompileGraphToXlaHlo(
RegisterDialects();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
GraphImportConfig config;
config.graph_as_function = true;
// Disable shape inference during import as some TensorFlow op fails during

View File

@ -161,7 +161,7 @@ StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
default:
// TODO(shpeisman): restructure code to reuse dialect pointer across
// calls.
auto* dialect = builder->getContext()->getRegisteredDialect("tf");
auto* dialect = builder->getContext()->getLoadedDialect("tf");
return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
}

View File

@ -43,6 +43,7 @@ static void RegisterDialects() {
TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::Builder b(&context);
PartialTensorShape output_shape =
@ -52,6 +53,7 @@ TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::Builder b(&context);
PartialTensorShape output_shape = ConvertTypeToTensorShape(
@ -61,6 +63,7 @@ TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::Builder b(&context);
PartialTensorShape output_shape = ConvertTypeToTensorShape(

View File

@ -36,6 +36,7 @@ std::string ConvertToMlirString(const std::vector<int64_t>& dims,
}
mlir::MLIRContext context;
mlir::Builder b(&context);
context.loadAllGloballyRegisteredDialects();
auto status_or = ConvertToMlirTensorType(shape, dtype, &b);
std::string buf;
llvm::raw_string_ostream os(buf);

View File

@ -60,6 +60,7 @@ class FakeDevice : public Device {
TEST(DeviceUtilTest, AddDeviceToOp) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
@ -101,6 +102,7 @@ TEST(DeviceUtilTest, AddDeviceToOp) {
TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
@ -110,6 +112,7 @@ TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));

View File

@ -66,6 +66,7 @@ Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph,
WritableFile* file) {
WritableFileRawStream os(std::move(file));
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module;
if (flib_def) {
flib_def = &graph.flib_def();

View File

@ -28,6 +28,7 @@ namespace {
TEST(DumpMlirModuleTest, NoEnvPrefix) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
unsetenv("TF_DUMP_GRAPH_PREFIX");
@ -38,6 +39,7 @@ TEST(DumpMlirModuleTest, NoEnvPrefix) {
TEST(DumpMlirModuleTest, LogInfo) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
setenv("TF_DUMP_GRAPH_PREFIX", "-", 1);
@ -48,6 +50,7 @@ TEST(DumpMlirModuleTest, LogInfo) {
TEST(DumpMlirModuleTest, Valid) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1);

View File

@ -29,6 +29,7 @@ using testing::HasSubstr;
TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto id = Identifier::get("test.cc", &context);
auto loc = FileLineColLoc::get(id, 0, 0, &context);

View File

@ -602,6 +602,7 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::Builder builder(&context);
auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
auto status_or_device_coodinates =
@ -615,6 +616,7 @@ TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::Builder builder(&context);
auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
auto status_or_device_coodinates =
@ -627,6 +629,7 @@ TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::OpBuilder builder(module_ref->getBodyRegion());

View File

@ -18,6 +18,8 @@ limitations under the License.
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/AsmState.h" // from @llvm-project
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "mlir/InitAllPasses.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
@ -63,6 +65,8 @@ static llvm::cl::opt<bool> allowUnregisteredDialects(
llvm::cl::init(false));
int main(int argc, char **argv) {
mlir::registerAllPasses();
tensorflow::InitMlir y(&argc, &argv);
// Register various MLIR command line options.
@ -84,9 +88,12 @@ int main(int argc, char **argv) {
auto output = mlir::openOutputFile(output_filename, &error_message);
QCHECK(output) << error_message;
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
if (failed(mlir::MlirOptMain(output->os(), std::move(file), pass_pipeline,
split_input_file, verify_diagnostics,
verify_passes, allowUnregisteredDialects)))
registry, split_input_file, verify_diagnostics,
verify_passes, allowUnregisteredDialects,
/*preloadDialectsInContext=*/true)))
return 1;
output->keep();
return 0;

View File

@ -111,6 +111,7 @@ int main(int argc, char** argv) {
if (import_saved_model_object_graph) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, exported_names, &context);
@ -119,6 +120,7 @@ int main(int argc, char** argv) {
module_or.ConsumeValueOrDie()->print(output->os());
} else if (import_saved_model_signature_defs) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, exported_names, &context, upgrade_legacy);
@ -139,6 +141,7 @@ int main(int argc, char** argv) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context);
return (*requested_translation)(sourceMgr, os, &context);
};

View File

@ -125,6 +125,7 @@ int main(int argc, char** argv) {
"TF GraphDef to TFJS JSON converter\n");
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
llvm::SourceMgr source_mgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);

View File

@ -261,6 +261,7 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
llvm::ArrayRef<uint32_t> unroll_factors) {
RegisterDialects();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));

View File

@ -90,8 +90,9 @@ int main(int argc, char **argv) {
if (showDialects) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
llvm::outs() << "Registered Dialects:\n";
for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
for (mlir::Dialect *dialect : context.getLoadedDialects()) {
llvm::outs() << dialect->getNamespace() << "\n";
}
return 0;
@ -111,9 +112,12 @@ int main(int argc, char **argv) {
exit(1);
}
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
mlir::DialectRegistry registry;
registerAllDialects(registry);
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects))) {
allowUnregisteredDialects,
/*preloadDialectsInContext=*/true))) {
return 1;
}
// Keep the output file if the invocation of MlirOptMain was successful.

View File

@ -64,6 +64,7 @@ inline ::testing::PolymorphicMatcher<ProtoStringMatcher> EqualsProto(
TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
Builder b(&context);
EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32);
@ -74,6 +75,7 @@ TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
Builder b(&context);
EXPECT_TRUE(
@ -95,6 +97,7 @@ TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
Builder b(&context);
// Memref without any affine map. Note: memory space is ignored for shape.

View File

@ -152,6 +152,7 @@ Status ConvertGraphDefToXlaViaMlir(
RegisterDialects();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
TF_ASSIGN_OR_RETURN(
mlir::OwningModuleRef module,
ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));

View File

@ -622,6 +622,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// Compile must be thread-safe so create a new LLVM context for the module.
mlir::MLIRContext mlir_context;
mlir_context.loadAllGloballyRegisteredDialects();
llvm::LLVMContext llvm_context;
auto llvm_module =
absl::make_unique<llvm::Module>("__compute_module", llvm_context);
@ -833,6 +834,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
// Compile must be thread-safe so create a new LLVM context for the module.
mlir::MLIRContext mlir_context;
mlir_context.loadAllGloballyRegisteredDialects();
llvm::LLVMContext llvm_context;
llvm::Module llvm_module("__compute_module", llvm_context);
llvm_module.setDataLayout(target_machine->createDataLayout());

View File

@ -25,6 +25,7 @@ namespace mlir_gpu {
EmissionContext::EmissionContext(std::unique_ptr<HloModule> module)
: module_(std::move(module)), context_() {
context_.loadAllGloballyRegisteredDialects();
error_handler_ = [](const ErrorMap& instructions_with_error,
HloModule* module) {
std::set<const HloComputation*> computations_with_error;

View File

@ -46,6 +46,7 @@ std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) {
hlo_module.entry_computation()->root_instruction();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef mlir_module(
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)));

View File

@ -699,8 +699,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
)
# Check out LLVM and MLIR from llvm-project.
LLVM_COMMIT = "e75bc5c791e0e8dbe79f7453e55af9e8d03c9cc0"
LLVM_SHA256 = "9c22f59d50853329cd0105ecb95256ad345313372ddda593030cd81b7c72e657"
LLVM_COMMIT = "f9dc2b7079350d0fed3bb3775f496b90483c9e42"
LLVM_SHA256 = "59866525042c3165c4fcb4c855bc315a390b4ec8eb76846bbd3e5ac3d8f50c1d"
LLVM_URLS = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),

View File

@ -1124,6 +1124,7 @@ cc_library(
":ControlFlowInterfaces",
":IR",
":LLVMOpsIncGen",
":OpenMPDialect",
":SideEffectInterfaces",
":Support",
"@llvm-project//llvm:AsmParser",
@ -3542,6 +3543,7 @@ cc_library(
":LinalgOps",
":LinalgTransforms",
":Pass",
":SCFDialect",
":SCFToStandard",
":StandardOps",
":StandardToLLVM",

View File

@ -186,6 +186,7 @@ cc_library(
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SPIRVDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Support",