mlir-translate: support -verify-diagnostics

MLIR translation tools can emit diagnostics and we want to be able to check if
it is indeed the case in tests. Reuse the source manager error handlers
provided for mlir-opt to support the verification in mlir-translate. This
requires us to change the signature of the functions that are registered to
translate sources to MLIR: it now takes a source manager instead of a memory
buffer.

PiperOrigin-RevId: 279132972
Change-Id: I9750d8f199766c391cda142aad8957d52beece3b
This commit is contained in:
A. Unique TensorFlower 2019-11-07 11:42:11 -08:00 committed by TensorFlower Gardener
parent 8c0b8fa3c4
commit 990149b809
14 changed files with 179 additions and 129 deletions

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "llvm/Support/Endian.h" #include "llvm/Support/Endian.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir #include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
@ -861,8 +862,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
return OwningModuleRef(module); return OwningModuleRef(module);
} }
static OwningModuleRef FlatBufferFileToMlirTrans( static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) { MLIRContext* context) {
const llvm::MemoryBuffer* input =
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
std::string error; std::string error;
auto loc = auto loc =
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
@ -884,4 +887,7 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
} }
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
"tflite-flatbuffer-to-mlir", FlatBufferFileToMlirTrans); "tflite-flatbuffer-to-mlir",
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
return FlatBufferFileToMlirTrans(&source_mgr, context);
});

View File

@ -86,14 +86,14 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
if (use_splatted_constant) { if (use_splatted_constant) {
return tensorflow::GraphdefToSplattedMlirTranslateFunction( return tensorflow::GraphdefToSplattedMlirTranslateFunction(
std::move(file), debug_info_file, input_arrays, input_dtypes, file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes, input_shapes, output_arrays, prune_unused_nodes,
/*convert_legacy_fed_inputs=*/true, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true, /*graph_as_function=*/false, /*upgrade_legacy=*/true,
add_pseudo_input_nodes, context); add_pseudo_input_nodes, context);
} }
return tensorflow::GraphdefToMlirTranslateFunction( return tensorflow::GraphdefToMlirTranslateFunction(
std::move(file), debug_info_file, input_arrays, input_dtypes, file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes, input_shapes, output_arrays, prune_unused_nodes,
/*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false, /*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false,
/*upgrade_legacy=*/true, add_pseudo_input_nodes, context); /*upgrade_legacy=*/true, add_pseudo_input_nodes, context);

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir
@ -43,15 +42,15 @@ using stream_executor::port::Status;
using stream_executor::port::StatusOr; using stream_executor::port::StatusOr;
static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport( static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
std::unique_ptr<llvm::MemoryBuffer> input, llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view debug_info_file, absl::string_view input_arrays, absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view input_shapes, absl::string_view output_arrays,
absl::string_view output_arrays, bool prune_unused_nodes, bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
bool add_pseudo_input_nodes, mlir::MLIRContext* context) { mlir::MLIRContext* context) {
GraphDef graphdef; GraphDef graphdef;
TF_RETURN_IF_ERROR(tensorflow::LoadProtoFromBuffer( TF_RETURN_IF_ERROR(
{input->getBufferStart(), input->getBufferSize()}, &graphdef)); tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
GraphDebugInfo debug_info; GraphDebugInfo debug_info;
if (!debug_info_file.empty()) { if (!debug_info_file.empty()) {
@ -91,17 +90,16 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
} }
mlir::OwningModuleRef GraphdefToMlirTranslateFunction( mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view debug_info_file, absl::string_view input_arrays, absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view input_shapes, absl::string_view output_arrays,
absl::string_view output_arrays, bool prune_unused_nodes, bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
bool add_pseudo_input_nodes, mlir::MLIRContext* context) { mlir::MLIRContext* context) {
auto module_or = GraphdefToMlirImport( auto module_or = GraphdefToMlirImport(
std::move(input), debug_info_file, input_arrays, input_dtypes, input, debug_info_file, input_arrays, input_dtypes, input_shapes,
input_shapes, output_arrays, prune_unused_nodes, output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, graph_as_function, upgrade_legacy, add_pseudo_input_nodes, context);
add_pseudo_input_nodes, context);
if (!module_or.status().ok()) { if (!module_or.status().ok()) {
LOG(ERROR) << "Graph import failed: " << module_or.status(); LOG(ERROR) << "Graph import failed: " << module_or.status();
return nullptr; return nullptr;
@ -136,17 +134,16 @@ mlir::OwningModuleRef SavedModelToMlirImport(
} }
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view debug_info_file, absl::string_view input_arrays, absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view input_shapes, absl::string_view output_arrays,
absl::string_view output_arrays, bool prune_unused_nodes, bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
bool add_pseudo_input_nodes, mlir::MLIRContext* context) { mlir::MLIRContext* context) {
auto module_or = GraphdefToMlirImport( auto module_or = GraphdefToMlirImport(
std::move(input), debug_info_file, input_arrays, input_dtypes, input, debug_info_file, input_arrays, input_dtypes, input_shapes,
input_shapes, output_arrays, prune_unused_nodes, output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, graph_as_function, upgrade_legacy, add_pseudo_input_nodes, context);
add_pseudo_input_nodes, context);
if (!module_or.status().ok()) { if (!module_or.status().ok()) {
LOG(ERROR) << "Graph import failed: " << module_or.status(); LOG(ERROR) << "Graph import failed: " << module_or.status();
return nullptr; return nullptr;

View File

@ -32,22 +32,22 @@ namespace tensorflow {
// `input_filename` into a MLIR module. Creates MLIR entities into the // `input_filename` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`. // given MLIR `context`.
mlir::OwningModuleRef GraphdefToMlirTranslateFunction( mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view debug_info_file, absl::string_view input_arrays, absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view input_shapes, absl::string_view output_arrays,
absl::string_view output_arrays, bool prune_unused_nodes, bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
bool add_pseudo_input_nodes, mlir::MLIRContext* context); mlir::MLIRContext* context);
// Similar as the above function, but replaces all constant tensors // Similar as the above function, but replaces all constant tensors
// with randomly generated splat values. // with randomly generated splat values.
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view debug_info_file, absl::string_view input_arrays, absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view input_shapes, absl::string_view output_arrays,
absl::string_view output_arrays, bool prune_unused_nodes, bool prune_unused_nodes, bool convert_legacy_fed_inputs,
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
bool add_pseudo_input_nodes, mlir::MLIRContext* context); mlir::MLIRContext* context);
// Converts a TensorFlow SavedModel stored in the directory with the given // Converts a TensorFlow SavedModel stored in the directory with the given
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the // `saved_model_dir` into a MLIR module. Creates MLIR entities into the

View File

@ -40,12 +40,12 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) {
} }
} // namespace } // namespace
static OwningModuleRef GraphdefToMlirTranslateFunction( static OwningModuleRef GraphdefToMlirTranslateFunction(llvm::StringRef input,
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) { MLIRContext* context) {
return tensorflow::GraphdefToMlirTranslateFunction( return tensorflow::GraphdefToMlirTranslateFunction(
std::move(input), debug_info_file, input_arrays, input_dtypes, input, debug_info_file, input_arrays, input_dtypes, input_shapes,
input_shapes, output_arrays, prune_unused_nodes, output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, graph_as_function, upgrade_legacy,
/*add_pseudo_input_nodes=*/false, context); /*add_pseudo_input_nodes=*/false, context);
} }
@ -53,11 +53,11 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate(
"graphdef-to-mlir", GraphdefToMlirTranslateFunction); "graphdef-to-mlir", GraphdefToMlirTranslateFunction);
static OwningModuleRef GraphdefToSplattedMlirTranslateFunction( static OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) { llvm::StringRef input, MLIRContext* context) {
return tensorflow::GraphdefToSplattedMlirTranslateFunction( return tensorflow::GraphdefToSplattedMlirTranslateFunction(
std::move(input), debug_info_file, input_arrays, input_dtypes, input, debug_info_file, input_arrays, input_dtypes, input_shapes,
input_shapes, output_arrays, prune_unused_nodes, output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, graph_as_function, upgrade_legacy,
/*add_pseudo_input_nodes=*/false, context); /*add_pseudo_input_nodes=*/false, context);
} }

View File

@ -18,6 +18,8 @@ limitations under the License.
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "llvm/Support/InitLLVM.h" #include "llvm/Support/InitLLVM.h"
#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir #include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
@ -105,8 +107,11 @@ int main(int argc, char** argv) {
return 1; return 1;
} }
if (failed( llvm::SourceMgr source_mgr;
(*requested_translation)(std::move(input), output->os(), &context))) source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
mlir::SourceMgrDiagnosticHandler diagnostic_handler(source_mgr, &context);
if (failed((*requested_translation)(source_mgr, output->os(), &context)))
return 1; return 1;
} }

View File

@ -64,12 +64,11 @@ bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) {
} // namespace } // namespace
mlir::OwningModuleRef HloToMlirHloTranslateFunction( mlir::OwningModuleRef HloToMlirHloTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context) { llvm::StringRef input, mlir::MLIRContext* context) {
HloProto hlo_proto; HloProto hlo_proto;
string content(input->getBufferStart(), input->getBufferSize()); string content(input.data(), input.size());
if (!LoadHloProto(content, &hlo_proto)) { if (!LoadHloProto(content, &hlo_proto)) {
LOG(ERROR) << "Failed to load proto: " LOG(ERROR) << "Failed to load proto";
<< input->getBufferIdentifier().str();
return nullptr; return nullptr;
} }
@ -86,9 +85,9 @@ mlir::OwningModuleRef HloToMlirHloTranslateFunction(
} }
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction( mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context) { llvm::StringRef input, mlir::MLIRContext* context) {
HloProto hlo_proto; HloProto hlo_proto;
string content(input->getBufferStart(), input->getBufferSize()); string content(input.data(), input.size());
auto hlo_module_error = ParseAndReturnUnverifiedModule(content); auto hlo_module_error = ParseAndReturnUnverifiedModule(content);
if (!hlo_module_error.ok()) { if (!hlo_module_error.ok()) {

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
namespace llvm { namespace llvm {
class MemoryBuffer;
class StringRef; class StringRef;
} // namespace llvm } // namespace llvm
@ -34,14 +33,14 @@ namespace xla {
// Converts a HloModuleProto stored in the file with the given `input_filename` // Converts a HloModuleProto stored in the file with the given `input_filename`
// into a MLIR module. Creates MLIR entities into the given MLIR `context`. // into a MLIR module. Creates MLIR entities into the given MLIR `context`.
mlir::OwningModuleRef HloToMlirHloTranslateFunction( mlir::OwningModuleRef HloToMlirHloTranslateFunction(llvm::StringRef input,
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context); mlir::MLIRContext* context);
// Converts a HloModule stored in text form for a file with the given // Converts a HloModule stored in text form for a file with the given
// `input_filename` into a MLIR module. Creates MLIR entities into the given // `input_filename` into a MLIR module. Creates MLIR entities into the given
// MLIR `context`. // MLIR `context`.
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction( mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context); llvm::StringRef input, mlir::MLIRContext* context);
} // namespace xla } // namespace xla

View File

@ -27,6 +27,7 @@
namespace llvm { namespace llvm {
class MemoryBuffer; class MemoryBuffer;
class SourceMgr;
class StringRef; class StringRef;
} // namespace llvm } // namespace llvm
@ -36,12 +37,19 @@ class MLIRContext;
class ModuleOp; class ModuleOp;
class OwningModuleRef; class OwningModuleRef;
/// Interface of the function that translates a source file held by the given /// Interface of the function that translates the sources managed by `sourceMgr`
/// MemoryBuffer to MLIR. The implementation should create a new MLIR ModuleOp /// to MLIR. The source manager has at least one buffer. The implementation
/// in the given context and return a pointer to it, or a nullptr in case of any /// should create a new MLIR ModuleOp in the given context and return a pointer
/// error. /// to it, or a nullptr in case of any error.
using TranslateToMLIRFunction = std::function<OwningModuleRef( using TranslateSourceMgrToMLIRFunction =
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *)>; std::function<OwningModuleRef(llvm::SourceMgr &sourceMgr, MLIRContext *)>;
/// Interface of the function that translates the given string to MLIR. The
/// implementation should create a new MLIR ModuleOp in the given context. If
/// source-related error reporting is required from within the function, use
/// TranslateSourceMgrToMLIRFunction instead.
using TranslateStringRefToMLIRFunction =
std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
/// Interface of the function that translates MLIR to a different format and /// Interface of the function that translates MLIR to a different format and
/// outputs the result to a stream. It is allowed to modify the module. /// outputs the result to a stream. It is allowed to modify the module.
@ -53,11 +61,10 @@ using TranslateFromMLIRFunction =
/// should be written to the given raw_ostream. The implementation should create /// should be written to the given raw_ostream. The implementation should create
/// all MLIR constructs needed during the process inside the given context. This /// all MLIR constructs needed during the process inside the given context. This
/// can be used for round-tripping external formats through the MLIR system. /// can be used for round-tripping external formats through the MLIR system.
using TranslateFunction = using TranslateFunction = std::function<LogicalResult(
std::function<LogicalResult(std::unique_ptr<llvm::MemoryBuffer> input, llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>;
llvm::raw_ostream &output, MLIRContext *)>;
/// Use Translate[ToMLIR|FromMLIR|]Registration as a global initialiser that /// Use Translate[ToMLIR|FromMLIR]Registration as a global initialiser that
/// registers a function and associates it with name. This requires that a /// registers a function and associates it with name. This requires that a
/// translation has not been registered to a given name. /// translation has not been registered to a given name.
/// ///
@ -69,7 +76,9 @@ using TranslateFunction =
/// \{ /// \{
struct TranslateToMLIRRegistration { struct TranslateToMLIRRegistration {
TranslateToMLIRRegistration(llvm::StringRef name, TranslateToMLIRRegistration(llvm::StringRef name,
const TranslateToMLIRFunction &function); const TranslateSourceMgrToMLIRFunction &function);
TranslateToMLIRRegistration(llvm::StringRef name,
const TranslateStringRefToMLIRFunction &function);
}; };
struct TranslateFromMLIRRegistration { struct TranslateFromMLIRRegistration {
@ -83,7 +92,8 @@ struct TranslateRegistration {
/// \} /// \}
/// Get a read-only reference to the translator registry. /// Get a read-only reference to the translator registry.
const llvm::StringMap<TranslateToMLIRFunction> &getTranslationToMLIRRegistry(); const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
getTranslationToMLIRRegistry();
const llvm::StringMap<TranslateFromMLIRFunction> & const llvm::StringMap<TranslateFromMLIRFunction> &
getTranslationFromMLIRRegistry(); getTranslationFromMLIRRegistry();
const llvm::StringMap<TranslateFunction> &getTranslationRegistry(); const llvm::StringMap<TranslateFunction> &getTranslationRegistry();

View File

@ -42,7 +42,7 @@ using namespace mlir;
// Deserializes the SPIR-V binary module stored in the file named as // Deserializes the SPIR-V binary module stored in the file named as
// `inputFilename` and returns a module containing the SPIR-V module. // `inputFilename` and returns a module containing the SPIR-V module.
OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input, OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
MLIRContext *context) { MLIRContext *context) {
Builder builder(context); Builder builder(context);
@ -70,9 +70,10 @@ OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
} }
static TranslateToMLIRRegistration fromBinary( static TranslateToMLIRRegistration fromBinary(
"deserialize-spirv", "deserialize-spirv", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
[](std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *context) { assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
return deserializeModule(std::move(input), context); return deserializeModule(
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
}); });
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -111,13 +112,9 @@ static TranslateFromMLIRRegistration
// Round-trip registration // Round-trip registration
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input, LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
llvm::raw_ostream &output, MLIRContext *context) { llvm::raw_ostream &output, MLIRContext *context) {
llvm::SourceMgr sourceMgr; // Parse an MLIR module from the source manager.
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
// Parse the memory buffer as a MLIR module.
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context)); auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
if (!srcModule) if (!srcModule)
return failure(); return failure();
@ -151,7 +148,7 @@ LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input,
static TranslateRegistration static TranslateRegistration
roundtrip("test-spirv-roundtrip", roundtrip("test-spirv-roundtrip",
[](std::unique_ptr<llvm::MemoryBuffer> input, [](llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output,
llvm::raw_ostream &output, MLIRContext *context) { MLIRContext *context) {
return roundTripModule(std::move(input), output, context); return roundTripModule(sourceMgr, output, context);
}); });

View File

@ -55,11 +55,11 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
wrapperStorage.reserve(toMLIRRegistry.size() + fromMLIRRegistry.size() + wrapperStorage.reserve(toMLIRRegistry.size() + fromMLIRRegistry.size() +
fileToFileRegistry.size()); fileToFileRegistry.size());
for (const auto &kv : toMLIRRegistry) { for (const auto &kv : toMLIRRegistry) {
TranslateToMLIRFunction function = kv.second; TranslateSourceMgrToMLIRFunction function = kv.second;
TranslateFunction wrapper = TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
[function](std::unique_ptr<llvm::MemoryBuffer> input, llvm::raw_ostream &output,
llvm::raw_ostream &output, MLIRContext *context) { MLIRContext *context) {
OwningModuleRef module = function(std::move(input), context); OwningModuleRef module = function(sourceMgr, context);
if (!module) if (!module)
return failure(); return failure();
return printMLIROutput(*module, output); return printMLIROutput(*module, output);
@ -71,13 +71,9 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
for (const auto &kv : fromMLIRRegistry) { for (const auto &kv : fromMLIRRegistry) {
TranslateFromMLIRFunction function = kv.second; TranslateFromMLIRFunction function = kv.second;
TranslateFunction wrapper = TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
[function](std::unique_ptr<llvm::MemoryBuffer> input, llvm::raw_ostream &output,
llvm::raw_ostream &output, MLIRContext *context) { MLIRContext *context) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context)); auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
if (!module) if (!module)
return failure(); return failure();

View File

@ -571,15 +571,15 @@ mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
// Deserializes the LLVM bitcode stored in `input` into an MLIR module in the // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
// LLVM dialect. // LLVM dialect.
OwningModuleRef OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
MLIRContext *context) { MLIRContext *context) {
LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>(); LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
assert(dialect && "Could not find LLVMDialect?"); assert(dialect && "Could not find LLVMDialect?");
llvm::SMDiagnostic err; llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> llvmModule = std::unique_ptr<llvm::Module> llvmModule =
llvm::parseIR(*input, err, dialect->getLLVMContext(), llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
dialect->getLLVMContext(),
/*UpgradeDebugInfo=*/true, /*UpgradeDebugInfo=*/true,
/*DataLayoutString=*/""); /*DataLayoutString=*/"");
if (!llvmModule) { if (!llvmModule) {
@ -593,7 +593,7 @@ translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
} }
static TranslateToMLIRRegistration static TranslateToMLIRRegistration
fromLLVM("import-llvm", [](std::unique_ptr<llvm::MemoryBuffer> input, fromLLVM("import-llvm",
MLIRContext *context) { [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
return translateLLVMIRToModule(std::move(input), context); return translateLLVMIRToModule(sourceMgr, context);
}); });

View File

@ -23,15 +23,16 @@
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/SourceMgr.h"
using namespace mlir; using namespace mlir;
// Get the mutable static map between registered "to MLIR" translations and the // Get the mutable static map between registered "to MLIR" translations and the
// TranslateToMLIRFunctions that perform those translations. // TranslateToMLIRFunctions that perform those translations.
static llvm::StringMap<TranslateToMLIRFunction> & static llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
getMutableTranslationToMLIRRegistry() { getMutableTranslationToMLIRRegistry() {
static llvm::StringMap<TranslateToMLIRFunction> translationToMLIRRegistry; static llvm::StringMap<TranslateSourceMgrToMLIRFunction>
translationToMLIRRegistry;
return translationToMLIRRegistry; return translationToMLIRRegistry;
} }
// Get the mutable static map between registered "from MLIR" translations and // Get the mutable static map between registered "from MLIR" translations and
@ -49,8 +50,10 @@ static llvm::StringMap<TranslateFunction> &getMutableTranslationRegistry() {
return translationRegistry; return translationRegistry;
} }
TranslateToMLIRRegistration::TranslateToMLIRRegistration( // Puts `function` into the to-MLIR translation registry unless there is already
StringRef name, const TranslateToMLIRFunction &function) { // a function registered for the same name.
static void registerTranslateToMLIRFunction(
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry(); auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry();
if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end()) if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end())
llvm::report_fatal_error( llvm::report_fatal_error(
@ -59,6 +62,24 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
translationToMLIRRegistry[name] = function; translationToMLIRRegistry[name] = function;
} }
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
registerTranslateToMLIRFunction(name, function);
}
// Wraps `function` with a lambda that extracts a StringRef from a source
// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, const TranslateStringRefToMLIRFunction &function) {
auto translationFunction = [function](llvm::SourceMgr &sourceMgr,
MLIRContext *ctx) {
const llvm::MemoryBuffer *buffer =
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
return function(buffer->getBuffer(), ctx);
};
registerTranslateToMLIRFunction(name, translationFunction);
}
TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
StringRef name, const TranslateFromMLIRFunction &function) { StringRef name, const TranslateFromMLIRFunction &function) {
auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry(); auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry();
@ -84,7 +105,7 @@ TranslateRegistration::TranslateRegistration(
// Merely add the const qualifier to the mutable registry so that external users // Merely add the const qualifier to the mutable registry so that external users
// cannot modify it. // cannot modify it.
const llvm::StringMap<TranslateToMLIRFunction> & const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
mlir::getTranslationToMLIRRegistry() { mlir::getTranslationToMLIRRegistry() {
return getMutableTranslationToMLIRRegistry(); return getMutableTranslationToMLIRRegistry();
} }

View File

@ -20,6 +20,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/Support/FileUtilities.h" #include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
@ -46,6 +47,12 @@ static llvm::cl::opt<bool>
"process each chunk independently"), "process each chunk independently"),
llvm::cl::init(false)); llvm::cl::init(false));
static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
int main(int argc, char **argv) { int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);
@ -69,11 +76,24 @@ int main(int argc, char **argv) {
return 1; return 1;
} }
/// Processes the memory buffer with a new MLIRContext. // Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer, auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) { raw_ostream &os) {
MLIRContext context; MLIRContext context;
return (*translationRequested)(std::move(ownedBuffer), os, &context); llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return (*translationRequested)(sourceMgr, os, &context);
}
// In the diagnostic verification flow, we ignore whether the translation
// failed (in most cases, it is expected to fail). Instead, we check if the
// diagnostics were produced as expected.
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
(*translationRequested)(sourceMgr, os, &context);
return sourceMgrHandler.verify();
}; };
if (splitInputFile) { if (splitInputFile) {