Integrate LLVM at llvm/llvm-project@f9dc2b7079
Updates LLVM usage to match [f9dc2b707935](https://github.com/llvm/llvm-project/commit/f9dc2b707935) PiperOrigin-RevId: 327538369 Change-Id: I199bf5d4f7f311229949d6174bea84c833b21074
This commit is contained in:
parent
3d7a7556c5
commit
e2ff54f938
@ -43,7 +43,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:MlirOptLib",
|
"@llvm-project//mlir:MlirOptLib",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -13,112 +13,18 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "llvm/Support/CommandLine.h"
|
|
||||||
#include "llvm/Support/InitLLVM.h"
|
|
||||||
#include "llvm/Support/SourceMgr.h"
|
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||||
#include "mlir/IR/AsmState.h"
|
|
||||||
#include "mlir/IR/Dialect.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/InitAllDialects.h"
|
#include "mlir/InitAllDialects.h"
|
||||||
#include "mlir/InitAllPasses.h"
|
#include "mlir/InitAllPasses.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Pass/PassManager.h"
|
|
||||||
#include "mlir/Support/FileUtilities.h"
|
|
||||||
#include "mlir/Support/MlirOptMain.h"
|
#include "mlir/Support/MlirOptMain.h"
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
|
||||||
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
|
||||||
llvm::cl::desc("<input file>"),
|
|
||||||
llvm::cl::init("-"));
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
|
||||||
static llvm::cl::opt<std::string> outputFilename(
|
|
||||||
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
|
||||||
llvm::cl::init("-"));
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
|
||||||
static llvm::cl::opt<bool> splitInputFile(
|
|
||||||
"split-input-file",
|
|
||||||
llvm::cl::desc("Split the input file into pieces and process each "
|
|
||||||
"chunk independently"),
|
|
||||||
llvm::cl::init(false));
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
|
||||||
static llvm::cl::opt<bool> verifyDiagnostics(
|
|
||||||
"verify-diagnostics",
|
|
||||||
llvm::cl::desc("Check that emitted diagnostics match "
|
|
||||||
"expected-* lines on the corresponding line"),
|
|
||||||
llvm::cl::init(false));
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
|
||||||
static llvm::cl::opt<bool> verifyPasses(
|
|
||||||
"verify-each",
|
|
||||||
llvm::cl::desc("Run the verifier after each transformation pass"),
|
|
||||||
llvm::cl::init(true));
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
|
||||||
static llvm::cl::opt<bool> allowUnregisteredDialects(
|
|
||||||
"allow-unregistered-dialect",
|
|
||||||
llvm::cl::desc("Allow operation with no registered dialects"),
|
|
||||||
llvm::cl::init(false));
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
|
||||||
static llvm::cl::opt<bool> showDialects(
|
|
||||||
"show-dialects", llvm::cl::desc("Print the list of registered dialects"),
|
|
||||||
llvm::cl::init(false));
|
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
mlir::registerAllDialects();
|
mlir::DialectRegistry registry;
|
||||||
|
mlir::registerAllDialects(registry);
|
||||||
mlir::registerAllPasses();
|
mlir::registerAllPasses();
|
||||||
|
|
||||||
mlir::mhlo::registerAllDialects();
|
|
||||||
mlir::mhlo::registerAllMhloPasses();
|
mlir::mhlo::registerAllMhloPasses();
|
||||||
mlir::lmhlo::registerAllLmhloPasses();
|
mlir::lmhlo::registerAllLmhloPasses();
|
||||||
|
return failed(
|
||||||
llvm::InitLLVM y(argc, argv);
|
mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry));
|
||||||
|
|
||||||
// Register any pass manager command line options.
|
|
||||||
mlir::registerAsmPrinterCLOptions();
|
|
||||||
mlir::registerMLIRContextCLOptions();
|
|
||||||
mlir::registerPassManagerCLOptions();
|
|
||||||
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
|
|
||||||
|
|
||||||
// Parse pass names in main to ensure static initialization completed.
|
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
|
||||||
"MLIR modular optimizer driver\n");
|
|
||||||
|
|
||||||
if (showDialects) {
|
|
||||||
mlir::MLIRContext context;
|
|
||||||
llvm::outs() << "Registered Dialects:\n";
|
|
||||||
for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
|
|
||||||
llvm::outs() << dialect->getNamespace() << "\n";
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up the input file.
|
|
||||||
std::string errorMessage;
|
|
||||||
auto file = mlir::openInputFile(inputFilename, &errorMessage);
|
|
||||||
if (!file) {
|
|
||||||
llvm::errs() << errorMessage << "\n";
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto output = mlir::openOutputFile(outputFilename, &errorMessage);
|
|
||||||
if (!output) {
|
|
||||||
llvm::errs() << errorMessage << "\n";
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
|
|
||||||
splitInputFile, verifyDiagnostics, verifyPasses,
|
|
||||||
allowUnregisteredDialects))) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
// Keep the output file if the invocation of MlirOptMain was successful.
|
|
||||||
output->keep();
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
@ -673,6 +673,7 @@ cc_library(
|
|||||||
":flatbuffer_tflite_operator_lib",
|
":flatbuffer_tflite_operator_lib",
|
||||||
":tensorflow_lite",
|
":tensorflow_lite",
|
||||||
":tensorflow_lite_dialect_registration",
|
":tensorflow_lite_dialect_registration",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
@ -61,6 +61,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
|
#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
||||||
@ -354,8 +355,13 @@ class Translator {
|
|||||||
if (emit_custom_ops) {
|
if (emit_custom_ops) {
|
||||||
enabled_op_types_.emplace(OpType::kCustomOp);
|
enabled_op_types_.emplace(OpType::kCustomOp);
|
||||||
}
|
}
|
||||||
tf_dialect_ = module.getContext()->getRegisteredDialect("tf");
|
tf_dialect_ =
|
||||||
tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl");
|
module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>();
|
||||||
|
tfl_dialect_ = module.getContext()
|
||||||
|
->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>();
|
||||||
|
// Right now the TF executor dialect is still needed to build NodeDef.
|
||||||
|
module.getContext()
|
||||||
|
->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<std::string> TranslateInternal();
|
Optional<std::string> TranslateInternal();
|
||||||
|
@ -65,6 +65,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
@ -479,7 +480,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
|||||||
|
|
||||||
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
|
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
|
||||||
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
|
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
|
||||||
auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
|
auto dialect = elem_type.getContext()->getLoadedDialect("tf");
|
||||||
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
|
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
|
||||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||||
|
|
||||||
@ -1072,6 +1073,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
|||||||
const std::vector<std::string>& ordered_input_arrays,
|
const std::vector<std::string>& ordered_input_arrays,
|
||||||
const std::vector<std::string>& ordered_output_arrays,
|
const std::vector<std::string>& ordered_output_arrays,
|
||||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
|
context->loadDialect<
|
||||||
|
mlir::StandardOpsDialect, mlir::quant::QuantizationDialect,
|
||||||
|
mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>();
|
||||||
|
|
||||||
auto model_ptr =
|
auto model_ptr =
|
||||||
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
||||||
if (nullptr == model_ptr) {
|
if (nullptr == model_ptr) {
|
||||||
|
@ -249,7 +249,7 @@ Status mlir::CustomOptionsToAttributes(
|
|||||||
{static_cast<int64_t>(custom_options.size())}, builder.getIntegerType(8));
|
{static_cast<int64_t>(custom_options.size())}, builder.getIntegerType(8));
|
||||||
attributes->emplace_back(builder.getNamedAttr(
|
attributes->emplace_back(builder.getNamedAttr(
|
||||||
"custom_option",
|
"custom_option",
|
||||||
OpaqueElementsAttr::get(builder.getContext()->getRegisteredDialect("tfl"),
|
OpaqueElementsAttr::get(builder.getContext()->getLoadedDialect("tfl"),
|
||||||
type, content)));
|
type, content)));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -98,6 +98,7 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
// Load the MLIR module.
|
// Load the MLIR module.
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
llvm::SourceMgr source_mgr;
|
llvm::SourceMgr source_mgr;
|
||||||
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
|
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
|
||||||
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));
|
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));
|
||||||
|
@ -49,6 +49,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
const GraphDef& input,
|
const GraphDef& input,
|
||||||
string* result) {
|
string* result) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
GraphImportConfig specs;
|
GraphImportConfig specs;
|
||||||
mlir::TFL::QuantizationSpecs quant_specs;
|
mlir::TFL::QuantizationSpecs quant_specs;
|
||||||
|
|
||||||
|
@ -122,6 +122,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
|||||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||||
string* result) {
|
string* result) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::TFL::QuantizationSpecs quant_specs;
|
mlir::TFL::QuantizationSpecs quant_specs;
|
||||||
|
|
||||||
// Parse input arrays.
|
// Parse input arrays.
|
||||||
|
@ -52,6 +52,7 @@ TfLiteStatus QuantizeModel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
StatusScopedDiagnosticHandler statusHandler(&context,
|
StatusScopedDiagnosticHandler statusHandler(&context,
|
||||||
/*propagate=*/true);
|
/*propagate=*/true);
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
|||||||
flatbuffers::FlatBufferBuilder* builder,
|
flatbuffers::FlatBufferBuilder* builder,
|
||||||
tflite::ErrorReporter* error_reporter) {
|
tflite::ErrorReporter* error_reporter) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
StatusScopedDiagnosticHandler statusHandler(&context,
|
StatusScopedDiagnosticHandler statusHandler(&context,
|
||||||
/*propagate=*/true);
|
/*propagate=*/true);
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
|
|||||||
} else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
|
} else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
|
||||||
auto etype = complex_type.getElementType();
|
auto etype = complex_type.getElementType();
|
||||||
if (etype.isF32()) {
|
if (etype.isF32()) {
|
||||||
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
auto dialect = etype.getContext()->getLoadedDialect("tf");
|
||||||
tensorflow::TensorProto repr;
|
tensorflow::TensorProto repr;
|
||||||
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
||||||
|
|
||||||
|
@ -56,9 +56,9 @@ inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
|
|||||||
const std::string& content) {
|
const std::string& content) {
|
||||||
ShapedType type = RankedTensorType::get(
|
ShapedType type = RankedTensorType::get(
|
||||||
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
|
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
|
||||||
return OpaqueElementsAttr::get(
|
return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
|
||||||
builder->getContext()->getRegisteredDialect("tfl"), type,
|
type,
|
||||||
StringRef(content.data(), content.size()));
|
StringRef(content.data(), content.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline TensorType GetInputType(FuncOp func, int idx) {
|
inline TensorType GetInputType(FuncOp func, int idx) {
|
||||||
|
@ -128,6 +128,7 @@ Status MlirFunctionOptimizationPass::Run(
|
|||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
GraphImportConfig import_config;
|
GraphImportConfig import_config;
|
||||||
import_config.graph_as_function = true;
|
import_config.graph_as_function = true;
|
||||||
import_config.control_outputs = *control_ret_node_names;
|
import_config.control_outputs = *control_ret_node_names;
|
||||||
@ -208,6 +209,7 @@ Status MlirV1CompatGraphOptimizationPass::Run(
|
|||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
GraphImportConfig import_config;
|
GraphImportConfig import_config;
|
||||||
import_config.upgrade_legacy = true;
|
import_config.upgrade_legacy = true;
|
||||||
// Restrict functionalization to TPU nodes to avoid problems in v1 session
|
// Restrict functionalization to TPU nodes to avoid problems in v1 session
|
||||||
|
@ -41,6 +41,7 @@ std::string ImportGraphDef(const std::string &proto,
|
|||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
GraphImportConfig specs;
|
GraphImportConfig specs;
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
|
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
|
||||||
if (!module.ok()) {
|
if (!module.ok()) {
|
||||||
Set_TF_Status_from_Status(status, module.status());
|
Set_TF_Status_from_Status(status, module.status());
|
||||||
@ -85,6 +86,7 @@ std::string ExperimentalConvertSavedModelToMlir(
|
|||||||
std::vector<string> exported_names =
|
std::vector<string> exported_names =
|
||||||
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
auto module_or = ConvertSavedModelToMlir(
|
auto module_or = ConvertSavedModelToMlir(
|
||||||
&bundle, &context, absl::Span<std::string>(exported_names));
|
&bundle, &context, absl::Span<std::string>(exported_names));
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
@ -115,6 +117,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
|
|||||||
// Convert the SavedModelBundle to an MLIR module.
|
// Convert the SavedModelBundle to an MLIR module.
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
auto module_or =
|
auto module_or =
|
||||||
ConvertSavedModelV1ToMlir(bundle, {}, &context, upgrade_legacy);
|
ConvertSavedModelV1ToMlir(bundle, {}, &context, upgrade_legacy);
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
|
@ -38,6 +38,7 @@ PYBIND11_MODULE(mlir_wrapper, m) {
|
|||||||
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
||||||
llvm::SMLoc());
|
llvm::SMLoc());
|
||||||
mlir::MLIRContext ctx;
|
mlir::MLIRContext ctx;
|
||||||
|
ctx.loadAllGloballyRegisteredDialects();
|
||||||
auto module = mlir::parseSourceFile(SM, &ctx);
|
auto module = mlir::parseSourceFile(SM, &ctx);
|
||||||
if (!module) {
|
if (!module) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -240,7 +240,7 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
|
|||||||
auto def_op = val.getDefiningOp();
|
auto def_op = val.getDefiningOp();
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
auto exec_dialect =
|
auto exec_dialect =
|
||||||
function.getContext()->getRegisteredDialect("tf_executor");
|
function.getContext()->getLoadedDialect("tf_executor");
|
||||||
assert(def_op->getDialect() == exec_dialect &&
|
assert(def_op->getDialect() == exec_dialect &&
|
||||||
"unable to forward control dependencies");
|
"unable to forward control dependencies");
|
||||||
#endif
|
#endif
|
||||||
|
@ -104,7 +104,7 @@ LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LaunchToDeviceAttributePass::runOnFunction() {
|
void LaunchToDeviceAttributePass::runOnFunction() {
|
||||||
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
|
const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
|
||||||
if (!tf_dialect) {
|
if (!tf_dialect) {
|
||||||
getFunction().emitError() << "'tf' dialect is not registered";
|
getFunction().emitError() << "'tf' dialect is not registered";
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
@ -152,7 +152,7 @@ void UnmarkChildren(Block* block) {
|
|||||||
|
|
||||||
void MarkOpsForOutsideCompilation::runOnOperation() {
|
void MarkOpsForOutsideCompilation::runOnOperation() {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
|
const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
|
||||||
if (!tf_dialect) {
|
if (!tf_dialect) {
|
||||||
getOperation().emitError() << "'tf' dialect is not registered";
|
getOperation().emitError() << "'tf' dialect is not registered";
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
@ -438,7 +438,7 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
|
|||||||
|
|
||||||
void ReplicateToIslandPass::runOnOperation() {
|
void ReplicateToIslandPass::runOnOperation() {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
|
const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
|
||||||
if (!tf_dialect) {
|
if (!tf_dialect) {
|
||||||
module.emitError() << "'tf' dialect is not registered";
|
module.emitError() << "'tf' dialect is not registered";
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
@ -597,7 +597,7 @@ ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
|
|||||||
bool propagate_caller_callee_constants)
|
bool propagate_caller_callee_constants)
|
||||||
: graph_version_(graph_version),
|
: graph_version_(graph_version),
|
||||||
propagate_caller_callee_constants_(propagate_caller_callee_constants) {
|
propagate_caller_callee_constants_(propagate_caller_callee_constants) {
|
||||||
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
|
tf_dialect_ = context->getLoadedDialect<TensorFlowDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
|
ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
|
||||||
|
@ -34,7 +34,7 @@ class SimpleTFDeviceAssignmentPass
|
|||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
Builder builder(&getContext());
|
Builder builder(&getContext());
|
||||||
Dialect* tf = getContext().getRegisteredDialect<TensorFlowDialect>();
|
Dialect* tf = getContext().getLoadedDialect<TensorFlowDialect>();
|
||||||
getFunction().walk([&](Operation* op) {
|
getFunction().walk([&](Operation* op) {
|
||||||
if (auto device_attr = op->getAttrOfType<StringAttr>("device")) {
|
if (auto device_attr = op->getAttrOfType<StringAttr>("device")) {
|
||||||
// We assign default device to ops with device attribute that is empty.
|
// We assign default device to ops with device attribute that is empty.
|
||||||
|
@ -726,7 +726,7 @@ Status Exporter::Convert(mlir::ModuleOp module,
|
|||||||
mlir::Identifier::get("main", module.getContext());
|
mlir::Identifier::get("main", module.getContext());
|
||||||
absl::optional<mlir::FuncOp> entry_func;
|
absl::optional<mlir::FuncOp> entry_func;
|
||||||
FunctionDefLibrary flib;
|
FunctionDefLibrary flib;
|
||||||
auto tf_dialect = module.getContext()->getRegisteredDialect("tf");
|
auto tf_dialect = module.getContext()->getLoadedDialect("tf");
|
||||||
for (auto function : module.getOps<mlir::FuncOp>()) {
|
for (auto function : module.getOps<mlir::FuncOp>()) {
|
||||||
if (function.isExternal())
|
if (function.isExternal())
|
||||||
return errors::FailedPrecondition("External functions not supported");
|
return errors::FailedPrecondition("External functions not supported");
|
||||||
@ -799,7 +799,7 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
|||||||
stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
|
stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
|
||||||
mlir::FuncOp func, const GraphExportConfig& configs,
|
mlir::FuncOp func, const GraphExportConfig& configs,
|
||||||
FunctionDef* function_def) {
|
FunctionDef* function_def) {
|
||||||
Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf");
|
Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
|
||||||
FunctionDefLibrary flib;
|
FunctionDefLibrary flib;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));
|
Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));
|
||||||
|
@ -420,6 +420,7 @@ Status CompileSerializedMlirToXlaHlo(
|
|||||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
mlir::MLIRContext mlir_context;
|
mlir::MLIRContext mlir_context;
|
||||||
|
mlir_context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef mlir_module;
|
mlir::OwningModuleRef mlir_module;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -509,6 +510,7 @@ Status CompileGraphToXlaHlo(
|
|||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
GraphImportConfig config;
|
GraphImportConfig config;
|
||||||
config.graph_as_function = true;
|
config.graph_as_function = true;
|
||||||
// Disable shape inference during import as some TensorFlow op fails during
|
// Disable shape inference during import as some TensorFlow op fails during
|
||||||
|
@ -161,7 +161,7 @@ StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
|
|||||||
default:
|
default:
|
||||||
// TODO(shpeisman): restructure code to reuse dialect pointer across
|
// TODO(shpeisman): restructure code to reuse dialect pointer across
|
||||||
// calls.
|
// calls.
|
||||||
auto* dialect = builder->getContext()->getRegisteredDialect("tf");
|
auto* dialect = builder->getContext()->getLoadedDialect("tf");
|
||||||
return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
|
return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +43,7 @@ static void RegisterDialects() {
|
|||||||
|
|
||||||
TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
|
TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::Builder b(&context);
|
mlir::Builder b(&context);
|
||||||
|
|
||||||
PartialTensorShape output_shape =
|
PartialTensorShape output_shape =
|
||||||
@ -52,6 +53,7 @@ TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
|
|||||||
|
|
||||||
TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
|
TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::Builder b(&context);
|
mlir::Builder b(&context);
|
||||||
|
|
||||||
PartialTensorShape output_shape = ConvertTypeToTensorShape(
|
PartialTensorShape output_shape = ConvertTypeToTensorShape(
|
||||||
@ -61,6 +63,7 @@ TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
|
|||||||
|
|
||||||
TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) {
|
TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::Builder b(&context);
|
mlir::Builder b(&context);
|
||||||
|
|
||||||
PartialTensorShape output_shape = ConvertTypeToTensorShape(
|
PartialTensorShape output_shape = ConvertTypeToTensorShape(
|
||||||
|
@ -36,6 +36,7 @@ std::string ConvertToMlirString(const std::vector<int64_t>& dims,
|
|||||||
}
|
}
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
mlir::Builder b(&context);
|
mlir::Builder b(&context);
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
auto status_or = ConvertToMlirTensorType(shape, dtype, &b);
|
auto status_or = ConvertToMlirTensorType(shape, dtype, &b);
|
||||||
std::string buf;
|
std::string buf;
|
||||||
llvm::raw_string_ostream os(buf);
|
llvm::raw_string_ostream os(buf);
|
||||||
|
@ -60,6 +60,7 @@ class FakeDevice : public Device {
|
|||||||
|
|
||||||
TEST(DeviceUtilTest, AddDeviceToOp) {
|
TEST(DeviceUtilTest, AddDeviceToOp) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef module_ref =
|
mlir::OwningModuleRef module_ref =
|
||||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
|
|
||||||
@ -101,6 +102,7 @@ TEST(DeviceUtilTest, AddDeviceToOp) {
|
|||||||
|
|
||||||
TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
|
TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef module_ref =
|
mlir::OwningModuleRef module_ref =
|
||||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
|
|
||||||
@ -110,6 +112,7 @@ TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
|
|||||||
|
|
||||||
TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) {
|
TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef module_ref =
|
mlir::OwningModuleRef module_ref =
|
||||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
|
|
||||||
|
@ -66,6 +66,7 @@ Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph,
|
|||||||
WritableFile* file) {
|
WritableFile* file) {
|
||||||
WritableFileRawStream os(std::move(file));
|
WritableFileRawStream os(std::move(file));
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef module;
|
mlir::OwningModuleRef module;
|
||||||
if (flib_def) {
|
if (flib_def) {
|
||||||
flib_def = &graph.flib_def();
|
flib_def = &graph.flib_def();
|
||||||
|
@ -28,6 +28,7 @@ namespace {
|
|||||||
|
|
||||||
TEST(DumpMlirModuleTest, NoEnvPrefix) {
|
TEST(DumpMlirModuleTest, NoEnvPrefix) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef module_ref =
|
mlir::OwningModuleRef module_ref =
|
||||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
unsetenv("TF_DUMP_GRAPH_PREFIX");
|
unsetenv("TF_DUMP_GRAPH_PREFIX");
|
||||||
@ -38,6 +39,7 @@ TEST(DumpMlirModuleTest, NoEnvPrefix) {
|
|||||||
|
|
||||||
TEST(DumpMlirModuleTest, LogInfo) {
|
TEST(DumpMlirModuleTest, LogInfo) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef module_ref =
|
mlir::OwningModuleRef module_ref =
|
||||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
setenv("TF_DUMP_GRAPH_PREFIX", "-", 1);
|
setenv("TF_DUMP_GRAPH_PREFIX", "-", 1);
|
||||||
@ -48,6 +50,7 @@ TEST(DumpMlirModuleTest, LogInfo) {
|
|||||||
|
|
||||||
TEST(DumpMlirModuleTest, Valid) {
|
TEST(DumpMlirModuleTest, Valid) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef module_ref =
|
mlir::OwningModuleRef module_ref =
|
||||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1);
|
setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1);
|
||||||
|
@ -29,6 +29,7 @@ using testing::HasSubstr;
|
|||||||
|
|
||||||
TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) {
|
TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
auto id = Identifier::get("test.cc", &context);
|
auto id = Identifier::get("test.cc", &context);
|
||||||
auto loc = FileLineColLoc::get(id, 0, 0, &context);
|
auto loc = FileLineColLoc::get(id, 0, 0, &context);
|
||||||
|
|
||||||
|
@ -602,6 +602,7 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
|
|||||||
|
|
||||||
TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
|
TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::Builder builder(&context);
|
mlir::Builder builder(&context);
|
||||||
auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
|
auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
|
||||||
auto status_or_device_coodinates =
|
auto status_or_device_coodinates =
|
||||||
@ -615,6 +616,7 @@ TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
|
|||||||
|
|
||||||
TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
|
TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::Builder builder(&context);
|
mlir::Builder builder(&context);
|
||||||
auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
|
auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
|
||||||
auto status_or_device_coodinates =
|
auto status_or_device_coodinates =
|
||||||
@ -627,6 +629,7 @@ TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
|
|||||||
TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
|
TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
|
||||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef module_ref =
|
mlir::OwningModuleRef module_ref =
|
||||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
|||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
#include "mlir/IR/AsmState.h" // from @llvm-project
|
#include "mlir/IR/AsmState.h" // from @llvm-project
|
||||||
|
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||||
|
#include "mlir/InitAllPasses.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||||
@ -63,6 +65,8 @@ static llvm::cl::opt<bool> allowUnregisteredDialects(
|
|||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
|
mlir::registerAllPasses();
|
||||||
|
|
||||||
tensorflow::InitMlir y(&argc, &argv);
|
tensorflow::InitMlir y(&argc, &argv);
|
||||||
|
|
||||||
// Register various MLIR command line options.
|
// Register various MLIR command line options.
|
||||||
@ -84,9 +88,12 @@ int main(int argc, char **argv) {
|
|||||||
auto output = mlir::openOutputFile(output_filename, &error_message);
|
auto output = mlir::openOutputFile(output_filename, &error_message);
|
||||||
QCHECK(output) << error_message;
|
QCHECK(output) << error_message;
|
||||||
|
|
||||||
|
mlir::DialectRegistry registry;
|
||||||
|
mlir::registerAllDialects(registry);
|
||||||
if (failed(mlir::MlirOptMain(output->os(), std::move(file), pass_pipeline,
|
if (failed(mlir::MlirOptMain(output->os(), std::move(file), pass_pipeline,
|
||||||
split_input_file, verify_diagnostics,
|
registry, split_input_file, verify_diagnostics,
|
||||||
verify_passes, allowUnregisteredDialects)))
|
verify_passes, allowUnregisteredDialects,
|
||||||
|
/*preloadDialectsInContext=*/true)))
|
||||||
return 1;
|
return 1;
|
||||||
output->keep();
|
output->keep();
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -111,6 +111,7 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
if (import_saved_model_object_graph) {
|
if (import_saved_model_object_graph) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
|
|
||||||
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
|
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||||
input_filename, tags, exported_names, &context);
|
input_filename, tags, exported_names, &context);
|
||||||
@ -119,6 +120,7 @@ int main(int argc, char** argv) {
|
|||||||
module_or.ConsumeValueOrDie()->print(output->os());
|
module_or.ConsumeValueOrDie()->print(output->os());
|
||||||
} else if (import_saved_model_signature_defs) {
|
} else if (import_saved_model_signature_defs) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
|
|
||||||
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
|
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||||
input_filename, tags, exported_names, &context, upgrade_legacy);
|
input_filename, tags, exported_names, &context, upgrade_legacy);
|
||||||
@ -139,6 +141,7 @@ int main(int argc, char** argv) {
|
|||||||
llvm::SourceMgr sourceMgr;
|
llvm::SourceMgr sourceMgr;
|
||||||
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
|
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context);
|
mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context);
|
||||||
return (*requested_translation)(sourceMgr, os, &context);
|
return (*requested_translation)(sourceMgr, os, &context);
|
||||||
};
|
};
|
||||||
|
@ -125,6 +125,7 @@ int main(int argc, char** argv) {
|
|||||||
"TF GraphDef to TFJS JSON converter\n");
|
"TF GraphDef to TFJS JSON converter\n");
|
||||||
|
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
llvm::SourceMgr source_mgr;
|
llvm::SourceMgr source_mgr;
|
||||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
||||||
|
|
||||||
|
@ -261,6 +261,7 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
|||||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
|
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
|
||||||
|
@ -90,8 +90,9 @@ int main(int argc, char **argv) {
|
|||||||
|
|
||||||
if (showDialects) {
|
if (showDialects) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
llvm::outs() << "Registered Dialects:\n";
|
llvm::outs() << "Registered Dialects:\n";
|
||||||
for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
|
for (mlir::Dialect *dialect : context.getLoadedDialects()) {
|
||||||
llvm::outs() << dialect->getNamespace() << "\n";
|
llvm::outs() << dialect->getNamespace() << "\n";
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
@ -111,9 +112,12 @@ int main(int argc, char **argv) {
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
|
mlir::DialectRegistry registry;
|
||||||
|
registerAllDialects(registry);
|
||||||
|
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
|
||||||
splitInputFile, verifyDiagnostics, verifyPasses,
|
splitInputFile, verifyDiagnostics, verifyPasses,
|
||||||
allowUnregisteredDialects))) {
|
allowUnregisteredDialects,
|
||||||
|
/*preloadDialectsInContext=*/true))) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
// Keep the output file if the invocation of MlirOptMain was successful.
|
// Keep the output file if the invocation of MlirOptMain was successful.
|
||||||
|
@ -64,6 +64,7 @@ inline ::testing::PolymorphicMatcher<ProtoStringMatcher> EqualsProto(
|
|||||||
|
|
||||||
TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
|
TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
Builder b(&context);
|
Builder b(&context);
|
||||||
|
|
||||||
EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32);
|
EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32);
|
||||||
@ -74,6 +75,7 @@ TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
|
|||||||
|
|
||||||
TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
|
TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
Builder b(&context);
|
Builder b(&context);
|
||||||
|
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
@ -95,6 +97,7 @@ TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
|
|||||||
|
|
||||||
TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) {
|
TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
Builder b(&context);
|
Builder b(&context);
|
||||||
|
|
||||||
// Memref without any affine map. Note: memory space is ignored for shape.
|
// Memref without any affine map. Note: memory space is ignored for shape.
|
||||||
|
@ -152,6 +152,7 @@ Status ConvertGraphDefToXlaViaMlir(
|
|||||||
|
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
mlir::OwningModuleRef module,
|
mlir::OwningModuleRef module,
|
||||||
ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));
|
ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));
|
||||||
|
@ -622,6 +622,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
|
|
||||||
// Compile must be thread-safe so create a new LLVM context for the module.
|
// Compile must be thread-safe so create a new LLVM context for the module.
|
||||||
mlir::MLIRContext mlir_context;
|
mlir::MLIRContext mlir_context;
|
||||||
|
mlir_context.loadAllGloballyRegisteredDialects();
|
||||||
llvm::LLVMContext llvm_context;
|
llvm::LLVMContext llvm_context;
|
||||||
auto llvm_module =
|
auto llvm_module =
|
||||||
absl::make_unique<llvm::Module>("__compute_module", llvm_context);
|
absl::make_unique<llvm::Module>("__compute_module", llvm_context);
|
||||||
@ -833,6 +834,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
|||||||
|
|
||||||
// Compile must be thread-safe so create a new LLVM context for the module.
|
// Compile must be thread-safe so create a new LLVM context for the module.
|
||||||
mlir::MLIRContext mlir_context;
|
mlir::MLIRContext mlir_context;
|
||||||
|
mlir_context.loadAllGloballyRegisteredDialects();
|
||||||
llvm::LLVMContext llvm_context;
|
llvm::LLVMContext llvm_context;
|
||||||
llvm::Module llvm_module("__compute_module", llvm_context);
|
llvm::Module llvm_module("__compute_module", llvm_context);
|
||||||
llvm_module.setDataLayout(target_machine->createDataLayout());
|
llvm_module.setDataLayout(target_machine->createDataLayout());
|
||||||
|
@ -25,6 +25,7 @@ namespace mlir_gpu {
|
|||||||
|
|
||||||
EmissionContext::EmissionContext(std::unique_ptr<HloModule> module)
|
EmissionContext::EmissionContext(std::unique_ptr<HloModule> module)
|
||||||
: module_(std::move(module)), context_() {
|
: module_(std::move(module)), context_() {
|
||||||
|
context_.loadAllGloballyRegisteredDialects();
|
||||||
error_handler_ = [](const ErrorMap& instructions_with_error,
|
error_handler_ = [](const ErrorMap& instructions_with_error,
|
||||||
HloModule* module) {
|
HloModule* module) {
|
||||||
std::set<const HloComputation*> computations_with_error;
|
std::set<const HloComputation*> computations_with_error;
|
||||||
|
@ -46,6 +46,7 @@ std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) {
|
|||||||
hlo_module.entry_computation()->root_instruction();
|
hlo_module.entry_computation()->root_instruction();
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
context.loadAllGloballyRegisteredDialects();
|
||||||
mlir::OwningModuleRef mlir_module(
|
mlir::OwningModuleRef mlir_module(
|
||||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)));
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)));
|
||||||
|
|
||||||
|
@ -699,8 +699,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check out LLVM and MLIR from llvm-project.
|
# Check out LLVM and MLIR from llvm-project.
|
||||||
LLVM_COMMIT = "e75bc5c791e0e8dbe79f7453e55af9e8d03c9cc0"
|
LLVM_COMMIT = "f9dc2b7079350d0fed3bb3775f496b90483c9e42"
|
||||||
LLVM_SHA256 = "9c22f59d50853329cd0105ecb95256ad345313372ddda593030cd81b7c72e657"
|
LLVM_SHA256 = "59866525042c3165c4fcb4c855bc315a390b4ec8eb76846bbd3e5ac3d8f50c1d"
|
||||||
LLVM_URLS = [
|
LLVM_URLS = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||||
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||||
|
2
third_party/mlir/BUILD
vendored
2
third_party/mlir/BUILD
vendored
@ -1124,6 +1124,7 @@ cc_library(
|
|||||||
":ControlFlowInterfaces",
|
":ControlFlowInterfaces",
|
||||||
":IR",
|
":IR",
|
||||||
":LLVMOpsIncGen",
|
":LLVMOpsIncGen",
|
||||||
|
":OpenMPDialect",
|
||||||
":SideEffectInterfaces",
|
":SideEffectInterfaces",
|
||||||
":Support",
|
":Support",
|
||||||
"@llvm-project//llvm:AsmParser",
|
"@llvm-project//llvm:AsmParser",
|
||||||
@ -3542,6 +3543,7 @@ cc_library(
|
|||||||
":LinalgOps",
|
":LinalgOps",
|
||||||
":LinalgTransforms",
|
":LinalgTransforms",
|
||||||
":Pass",
|
":Pass",
|
||||||
|
":SCFDialect",
|
||||||
":SCFToStandard",
|
":SCFToStandard",
|
||||||
":StandardOps",
|
":StandardOps",
|
||||||
":StandardToLLVM",
|
":StandardToLLVM",
|
||||||
|
1
third_party/mlir/test.BUILD
vendored
1
third_party/mlir/test.BUILD
vendored
@ -186,6 +186,7 @@ cc_library(
|
|||||||
"@llvm-project//mlir:LinalgTransforms",
|
"@llvm-project//mlir:LinalgTransforms",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
|
"@llvm-project//mlir:SPIRVDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:StandardOpsTransforms",
|
"@llvm-project//mlir:StandardOpsTransforms",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user