mlir-translate: support -verify-diagnostics

MLIR translation tools can emit diagnostics and we want to be able to check if
it is indeed the case in tests. Reuse the source manager error handlers
provided for mlir-opt to support the verification in mlir-translate. This
requires us to change the signature of the functions that are registered to
translate sources to MLIR: it now takes a source manager instead of a memory
buffer.

PiperOrigin-RevId: 279132972
Change-Id: I9750d8f199766c391cda142aad8957d52beece3b
This commit is contained in:
A. Unique TensorFlower 2019-11-07 11:42:11 -08:00 committed by TensorFlower Gardener
parent 8c0b8fa3c4
commit 990149b809
14 changed files with 179 additions and 129 deletions

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "llvm/Support/Endian.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
@ -861,8 +862,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
return OwningModuleRef(module);
}
static OwningModuleRef FlatBufferFileToMlirTrans(
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) {
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
MLIRContext* context) {
const llvm::MemoryBuffer* input =
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
std::string error;
auto loc =
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
@ -884,4 +887,7 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
}
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
"tflite-flatbuffer-to-mlir", FlatBufferFileToMlirTrans);
"tflite-flatbuffer-to-mlir",
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
return FlatBufferFileToMlirTrans(&source_mgr, context);
});

View File

@ -86,14 +86,14 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
if (use_splatted_constant) {
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
std::move(file), debug_info_file, input_arrays, input_dtypes,
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes,
/*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
add_pseudo_input_nodes, context);
}
return tensorflow::GraphdefToMlirTranslateFunction(
std::move(file), debug_info_file, input_arrays, input_dtypes,
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes,
/*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false,
/*upgrade_legacy=*/true, add_pseudo_input_nodes, context);

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "absl/memory/memory.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
@ -43,15 +42,15 @@ using stream_executor::port::Status;
using stream_executor::port::StatusOr;
static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
std::unique_ptr<llvm::MemoryBuffer> input,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
bool add_pseudo_input_nodes, mlir::MLIRContext* context) {
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
mlir::MLIRContext* context) {
GraphDef graphdef;
TF_RETURN_IF_ERROR(tensorflow::LoadProtoFromBuffer(
{input->getBufferStart(), input->getBufferSize()}, &graphdef));
TF_RETURN_IF_ERROR(
tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
GraphDebugInfo debug_info;
if (!debug_info_file.empty()) {
@ -91,17 +90,16 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
}
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
bool add_pseudo_input_nodes, mlir::MLIRContext* context) {
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
mlir::MLIRContext* context) {
auto module_or = GraphdefToMlirImport(
std::move(input), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
add_pseudo_input_nodes, context);
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
graph_as_function, upgrade_legacy, add_pseudo_input_nodes, context);
if (!module_or.status().ok()) {
LOG(ERROR) << "Graph import failed: " << module_or.status();
return nullptr;
@ -136,17 +134,16 @@ mlir::OwningModuleRef SavedModelToMlirImport(
}
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
bool add_pseudo_input_nodes, mlir::MLIRContext* context) {
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
mlir::MLIRContext* context) {
auto module_or = GraphdefToMlirImport(
std::move(input), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
add_pseudo_input_nodes, context);
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
graph_as_function, upgrade_legacy, add_pseudo_input_nodes, context);
if (!module_or.status().ok()) {
LOG(ERROR) << "Graph import failed: " << module_or.status();
return nullptr;

View File

@ -32,22 +32,22 @@ namespace tensorflow {
// `input_filename` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`.
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
bool add_pseudo_input_nodes, mlir::MLIRContext* context);
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
mlir::MLIRContext* context);
// Similar as the above function, but replaces all constant tensors
// with randomly generated splat values.
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
bool add_pseudo_input_nodes, mlir::MLIRContext* context);
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
mlir::MLIRContext* context);
// Converts a TensorFlow SavedModel stored in the directory with the given
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the

View File

@ -40,12 +40,12 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) {
}
} // namespace
static OwningModuleRef GraphdefToMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) {
static OwningModuleRef GraphdefToMlirTranslateFunction(llvm::StringRef input,
MLIRContext* context) {
return tensorflow::GraphdefToMlirTranslateFunction(
std::move(input), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
graph_as_function, upgrade_legacy,
/*add_pseudo_input_nodes=*/false, context);
}
@ -53,11 +53,11 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate(
"graphdef-to-mlir", GraphdefToMlirTranslateFunction);
static OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) {
llvm::StringRef input, MLIRContext* context) {
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
std::move(input), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
graph_as_function, upgrade_legacy,
/*add_pseudo_input_nodes=*/false, context);
}

View File

@ -18,6 +18,8 @@ limitations under the License.
#include "absl/strings/str_split.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
@ -105,8 +107,11 @@ int main(int argc, char** argv) {
return 1;
}
if (failed(
(*requested_translation)(std::move(input), output->os(), &context)))
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
mlir::SourceMgrDiagnosticHandler diagnostic_handler(source_mgr, &context);
if (failed((*requested_translation)(source_mgr, output->os(), &context)))
return 1;
}

View File

@ -64,12 +64,11 @@ bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) {
} // namespace
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context) {
llvm::StringRef input, mlir::MLIRContext* context) {
HloProto hlo_proto;
string content(input->getBufferStart(), input->getBufferSize());
string content(input.data(), input.size());
if (!LoadHloProto(content, &hlo_proto)) {
LOG(ERROR) << "Failed to load proto: "
<< input->getBufferIdentifier().str();
LOG(ERROR) << "Failed to load proto";
return nullptr;
}
@ -86,9 +85,9 @@ mlir::OwningModuleRef HloToMlirHloTranslateFunction(
}
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context) {
llvm::StringRef input, mlir::MLIRContext* context) {
HloProto hlo_proto;
string content(input->getBufferStart(), input->getBufferSize());
string content(input.data(), input.size());
auto hlo_module_error = ParseAndReturnUnverifiedModule(content);
if (!hlo_module_error.ok()) {

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
namespace llvm {
class MemoryBuffer;
class StringRef;
} // namespace llvm
@ -34,14 +33,14 @@ namespace xla {
// Converts a HloModuleProto stored in the file with the given `input_filename`
// into a MLIR module. Creates MLIR entities into the given MLIR `context`.
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context);
mlir::OwningModuleRef HloToMlirHloTranslateFunction(llvm::StringRef input,
mlir::MLIRContext* context);
// Converts a HloModule stored in text form for a file with the given
// `input_filename` into a MLIR module. Creates MLIR entities into the given
// MLIR `context`.
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context);
llvm::StringRef input, mlir::MLIRContext* context);
} // namespace xla

View File

@ -27,6 +27,7 @@
namespace llvm {
class MemoryBuffer;
class SourceMgr;
class StringRef;
} // namespace llvm
@ -36,12 +37,19 @@ class MLIRContext;
class ModuleOp;
class OwningModuleRef;
/// Interface of the function that translates a source file held by the given
/// MemoryBuffer to MLIR. The implementation should create a new MLIR ModuleOp
/// in the given context and return a pointer to it, or a nullptr in case of any
/// error.
using TranslateToMLIRFunction = std::function<OwningModuleRef(
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *)>;
/// Interface of the function that translates the sources managed by `sourceMgr`
/// to MLIR. The source manager has at least one buffer. The implementation
/// should create a new MLIR ModuleOp in the given context and return a pointer
/// to it, or a nullptr in case of any error.
using TranslateSourceMgrToMLIRFunction =
std::function<OwningModuleRef(llvm::SourceMgr &sourceMgr, MLIRContext *)>;
/// Interface of the function that translates the given string to MLIR. The
/// implementation should create a new MLIR ModuleOp in the given context. If
/// source-related error reporting is required from within the function, use
/// TranslateSourceMgrToMLIRFunction instead.
using TranslateStringRefToMLIRFunction =
std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
/// Interface of the function that translates MLIR to a different format and
/// outputs the result to a stream. It is allowed to modify the module.
@ -53,11 +61,10 @@ using TranslateFromMLIRFunction =
/// should be written to the given raw_ostream. The implementation should create
/// all MLIR constructs needed during the process inside the given context. This
/// can be used for round-tripping external formats through the MLIR system.
using TranslateFunction =
std::function<LogicalResult(std::unique_ptr<llvm::MemoryBuffer> input,
llvm::raw_ostream &output, MLIRContext *)>;
using TranslateFunction = std::function<LogicalResult(
llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>;
/// Use Translate[ToMLIR|FromMLIR|]Registration as a global initialiser that
/// Use Translate[ToMLIR|FromMLIR]Registration as a global initialiser that
/// registers a function and associates it with name. This requires that a
/// translation has not been registered to a given name.
///
@ -69,7 +76,9 @@ using TranslateFunction =
/// \{
struct TranslateToMLIRRegistration {
TranslateToMLIRRegistration(llvm::StringRef name,
const TranslateToMLIRFunction &function);
const TranslateSourceMgrToMLIRFunction &function);
TranslateToMLIRRegistration(llvm::StringRef name,
const TranslateStringRefToMLIRFunction &function);
};
struct TranslateFromMLIRRegistration {
@ -83,7 +92,8 @@ struct TranslateRegistration {
/// \}
/// Get a read-only reference to the translator registry.
const llvm::StringMap<TranslateToMLIRFunction> &getTranslationToMLIRRegistry();
const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
getTranslationToMLIRRegistry();
const llvm::StringMap<TranslateFromMLIRFunction> &
getTranslationFromMLIRRegistry();
const llvm::StringMap<TranslateFunction> &getTranslationRegistry();

View File

@ -42,7 +42,7 @@ using namespace mlir;
// Deserializes the SPIR-V binary module stored in the file named as
// `inputFilename` and returns a module containing the SPIR-V module.
OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
MLIRContext *context) {
Builder builder(context);
@ -70,9 +70,10 @@ OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
}
static TranslateToMLIRRegistration fromBinary(
"deserialize-spirv",
[](std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *context) {
return deserializeModule(std::move(input), context);
"deserialize-spirv", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
return deserializeModule(
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
});
//===----------------------------------------------------------------------===//
@ -111,13 +112,9 @@ static TranslateFromMLIRRegistration
// Round-trip registration
//===----------------------------------------------------------------------===//
LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input,
LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
llvm::raw_ostream &output, MLIRContext *context) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
// Parse the memory buffer as a MLIR module.
// Parse an MLIR module from the source manager.
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
if (!srcModule)
return failure();
@ -151,7 +148,7 @@ LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input,
static TranslateRegistration
roundtrip("test-spirv-roundtrip",
[](std::unique_ptr<llvm::MemoryBuffer> input,
llvm::raw_ostream &output, MLIRContext *context) {
return roundTripModule(std::move(input), output, context);
[](llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output,
MLIRContext *context) {
return roundTripModule(sourceMgr, output, context);
});

View File

@ -55,15 +55,15 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
wrapperStorage.reserve(toMLIRRegistry.size() + fromMLIRRegistry.size() +
fileToFileRegistry.size());
for (const auto &kv : toMLIRRegistry) {
TranslateToMLIRFunction function = kv.second;
TranslateFunction wrapper =
[function](std::unique_ptr<llvm::MemoryBuffer> input,
llvm::raw_ostream &output, MLIRContext *context) {
OwningModuleRef module = function(std::move(input), context);
if (!module)
return failure();
return printMLIROutput(*module, output);
};
TranslateSourceMgrToMLIRFunction function = kv.second;
TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
llvm::raw_ostream &output,
MLIRContext *context) {
OwningModuleRef module = function(sourceMgr, context);
if (!module)
return failure();
return printMLIROutput(*module, output);
};
wrapperStorage.emplace_back(std::move(wrapper));
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
@ -71,18 +71,14 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
for (const auto &kv : fromMLIRRegistry) {
TranslateFromMLIRFunction function = kv.second;
TranslateFunction wrapper =
[function](std::unique_ptr<llvm::MemoryBuffer> input,
llvm::raw_ostream &output, MLIRContext *context) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
if (!module)
return failure();
return function(module.get(), output);
};
TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
llvm::raw_ostream &output,
MLIRContext *context) {
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
if (!module)
return failure();
return function(module.get(), output);
};
wrapperStorage.emplace_back(std::move(wrapper));
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());

View File

@ -571,15 +571,15 @@ mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
// Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
// LLVM dialect.
OwningModuleRef
translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
MLIRContext *context) {
OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
MLIRContext *context) {
LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
assert(dialect && "Could not find LLVMDialect?");
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> llvmModule =
llvm::parseIR(*input, err, dialect->getLLVMContext(),
llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
dialect->getLLVMContext(),
/*UpgradeDebugInfo=*/true,
/*DataLayoutString=*/"");
if (!llvmModule) {
@ -593,7 +593,7 @@ translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
}
static TranslateToMLIRRegistration
fromLLVM("import-llvm", [](std::unique_ptr<llvm::MemoryBuffer> input,
MLIRContext *context) {
return translateLLVMIRToModule(std::move(input), context);
});
fromLLVM("import-llvm",
[](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
return translateLLVMIRToModule(sourceMgr, context);
});

View File

@ -23,15 +23,16 @@
#include "mlir/IR/Module.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
// Get the mutable static map between registered "to MLIR" translations and the
// TranslateToMLIRFunctions that perform those translations.
static llvm::StringMap<TranslateToMLIRFunction> &
static llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
getMutableTranslationToMLIRRegistry() {
static llvm::StringMap<TranslateToMLIRFunction> translationToMLIRRegistry;
static llvm::StringMap<TranslateSourceMgrToMLIRFunction>
translationToMLIRRegistry;
return translationToMLIRRegistry;
}
// Get the mutable static map between registered "from MLIR" translations and
@ -49,8 +50,10 @@ static llvm::StringMap<TranslateFunction> &getMutableTranslationRegistry() {
return translationRegistry;
}
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, const TranslateToMLIRFunction &function) {
// Puts `function` into the to-MLIR translation registry unless there is already
// a function registered for the same name.
static void registerTranslateToMLIRFunction(
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry();
if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end())
llvm::report_fatal_error(
@ -59,6 +62,24 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
translationToMLIRRegistry[name] = function;
}
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
registerTranslateToMLIRFunction(name, function);
}
// Wraps `function` with a lambda that extracts a StringRef from a source
// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, const TranslateStringRefToMLIRFunction &function) {
auto translationFunction = [function](llvm::SourceMgr &sourceMgr,
MLIRContext *ctx) {
const llvm::MemoryBuffer *buffer =
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
return function(buffer->getBuffer(), ctx);
};
registerTranslateToMLIRFunction(name, translationFunction);
}
TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
StringRef name, const TranslateFromMLIRFunction &function) {
auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry();
@ -84,7 +105,7 @@ TranslateRegistration::TranslateRegistration(
// Merely add the const qualifier to the mutable registry so that external users
// cannot modify it.
const llvm::StringMap<TranslateToMLIRFunction> &
const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
mlir::getTranslationToMLIRRegistry() {
return getMutableTranslationToMLIRRegistry();
}

View File

@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
@ -46,6 +47,12 @@ static llvm::cl::opt<bool>
"process each chunk independently"),
llvm::cl::init(false));
static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
@ -69,11 +76,24 @@ int main(int argc, char **argv) {
return 1;
}
/// Processes the memory buffer with a new MLIRContext.
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
MLIRContext context;
return (*translationRequested)(std::move(ownedBuffer), os, &context);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return (*translationRequested)(sourceMgr, os, &context);
}
// In the diagnostic verification flow, we ignore whether the translation
// failed (in most cases, it is expected to fail). Instead, we check if the
// diagnostics were produced as expected.
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
(*translationRequested)(sourceMgr, os, &context);
return sourceMgrHandler.verify();
};
if (splitInputFile) {