mlir-translate: support -verify-diagnostics
MLIR translation tools can emit diagnostics and we want to be able to check if it is indeed the case in tests. Reuse the source manager error handlers provided for mlir-opt to support the verification in mlir-translate. This requires us to change the signature of the functions that are registered to translate sources to MLIR: it now takes a source manager instead of a memory buffer. PiperOrigin-RevId: 279132972 Change-Id: I9750d8f199766c391cda142aad8957d52beece3b
This commit is contained in:
parent
8c0b8fa3c4
commit
990149b809
tensorflow/compiler/mlir
third_party/mlir
include/mlir
lib
Dialect/SPIRV/Serialization
Support
Target/LLVMIR
Translation
tools/mlir-translate
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Endian.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
@ -861,8 +862,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
return OwningModuleRef(module);
|
||||
}
|
||||
|
||||
static OwningModuleRef FlatBufferFileToMlirTrans(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) {
|
||||
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||
MLIRContext* context) {
|
||||
const llvm::MemoryBuffer* input =
|
||||
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
|
||||
std::string error;
|
||||
auto loc =
|
||||
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
||||
@ -884,4 +887,7 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
|
||||
}
|
||||
|
||||
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
||||
"tflite-flatbuffer-to-mlir", FlatBufferFileToMlirTrans);
|
||||
"tflite-flatbuffer-to-mlir",
|
||||
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
|
||||
return FlatBufferFileToMlirTrans(&source_mgr, context);
|
||||
});
|
||||
|
@ -86,14 +86,14 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
|
||||
if (use_splatted_constant) {
|
||||
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||
std::move(file), debug_info_file, input_arrays, input_dtypes,
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, prune_unused_nodes,
|
||||
/*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
||||
add_pseudo_input_nodes, context);
|
||||
}
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
std::move(file), debug_info_file, input_arrays, input_dtypes,
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, prune_unused_nodes,
|
||||
/*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false,
|
||||
/*upgrade_legacy=*/true, add_pseudo_input_nodes, context);
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
@ -43,15 +42,15 @@ using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context) {
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||
mlir::MLIRContext* context) {
|
||||
GraphDef graphdef;
|
||||
TF_RETURN_IF_ERROR(tensorflow::LoadProtoFromBuffer(
|
||||
{input->getBufferStart(), input->getBufferSize()}, &graphdef));
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
if (!debug_info_file.empty()) {
|
||||
@ -91,17 +90,16 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context) {
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||
mlir::MLIRContext* context) {
|
||||
auto module_or = GraphdefToMlirImport(
|
||||
std::move(input), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, prune_unused_nodes,
|
||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
||||
add_pseudo_input_nodes, context);
|
||||
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
|
||||
graph_as_function, upgrade_legacy, add_pseudo_input_nodes, context);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
@ -136,17 +134,16 @@ mlir::OwningModuleRef SavedModelToMlirImport(
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context) {
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||
mlir::MLIRContext* context) {
|
||||
auto module_or = GraphdefToMlirImport(
|
||||
std::move(input), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, prune_unused_nodes,
|
||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
||||
add_pseudo_input_nodes, context);
|
||||
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
|
||||
graph_as_function, upgrade_legacy, add_pseudo_input_nodes, context);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
|
@ -32,22 +32,22 @@ namespace tensorflow {
|
||||
// `input_filename` into a MLIR module. Creates MLIR entities into the
|
||||
// given MLIR `context`.
|
||||
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context);
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Similar as the above function, but replaces all constant tensors
|
||||
// with randomly generated splat values.
|
||||
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context);
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Converts a TensorFlow SavedModel stored in the directory with the given
|
||||
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
||||
|
@ -40,12 +40,12 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) {
|
||||
static OwningModuleRef GraphdefToMlirTranslateFunction(llvm::StringRef input,
|
||||
MLIRContext* context) {
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
std::move(input), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, prune_unused_nodes,
|
||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
||||
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
|
||||
graph_as_function, upgrade_legacy,
|
||||
/*add_pseudo_input_nodes=*/false, context);
|
||||
}
|
||||
|
||||
@ -53,11 +53,11 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate(
|
||||
"graphdef-to-mlir", GraphdefToMlirTranslateFunction);
|
||||
|
||||
static OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) {
|
||||
llvm::StringRef input, MLIRContext* context) {
|
||||
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||
std::move(input), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, prune_unused_nodes,
|
||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
||||
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
|
||||
graph_as_function, upgrade_legacy,
|
||||
/*add_pseudo_input_nodes=*/false, context);
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
@ -105,8 +107,11 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (failed(
|
||||
(*requested_translation)(std::move(input), output->os(), &context)))
|
||||
llvm::SourceMgr source_mgr;
|
||||
source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
|
||||
mlir::SourceMgrDiagnosticHandler diagnostic_handler(source_mgr, &context);
|
||||
|
||||
if (failed((*requested_translation)(source_mgr, output->os(), &context)))
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -64,12 +64,11 @@ bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) {
|
||||
} // namespace
|
||||
|
||||
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context) {
|
||||
llvm::StringRef input, mlir::MLIRContext* context) {
|
||||
HloProto hlo_proto;
|
||||
string content(input->getBufferStart(), input->getBufferSize());
|
||||
string content(input.data(), input.size());
|
||||
if (!LoadHloProto(content, &hlo_proto)) {
|
||||
LOG(ERROR) << "Failed to load proto: "
|
||||
<< input->getBufferIdentifier().str();
|
||||
LOG(ERROR) << "Failed to load proto";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -86,9 +85,9 @@ mlir::OwningModuleRef HloToMlirHloTranslateFunction(
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context) {
|
||||
llvm::StringRef input, mlir::MLIRContext* context) {
|
||||
HloProto hlo_proto;
|
||||
string content(input->getBufferStart(), input->getBufferSize());
|
||||
string content(input.data(), input.size());
|
||||
|
||||
auto hlo_module_error = ParseAndReturnUnverifiedModule(content);
|
||||
if (!hlo_module_error.ok()) {
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace llvm {
|
||||
class MemoryBuffer;
|
||||
class StringRef;
|
||||
} // namespace llvm
|
||||
|
||||
@ -34,14 +33,14 @@ namespace xla {
|
||||
|
||||
// Converts a HloModuleProto stored in the file with the given `input_filename`
|
||||
// into a MLIR module. Creates MLIR entities into the given MLIR `context`.
|
||||
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context);
|
||||
mlir::OwningModuleRef HloToMlirHloTranslateFunction(llvm::StringRef input,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Converts a HloModule stored in text form for a file with the given
|
||||
// `input_filename` into a MLIR module. Creates MLIR entities into the given
|
||||
// MLIR `context`.
|
||||
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context);
|
||||
llvm::StringRef input, mlir::MLIRContext* context);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
34
third_party/mlir/include/mlir/Translation.h
vendored
34
third_party/mlir/include/mlir/Translation.h
vendored
@ -27,6 +27,7 @@
|
||||
|
||||
namespace llvm {
|
||||
class MemoryBuffer;
|
||||
class SourceMgr;
|
||||
class StringRef;
|
||||
} // namespace llvm
|
||||
|
||||
@ -36,12 +37,19 @@ class MLIRContext;
|
||||
class ModuleOp;
|
||||
class OwningModuleRef;
|
||||
|
||||
/// Interface of the function that translates a source file held by the given
|
||||
/// MemoryBuffer to MLIR. The implementation should create a new MLIR ModuleOp
|
||||
/// in the given context and return a pointer to it, or a nullptr in case of any
|
||||
/// error.
|
||||
using TranslateToMLIRFunction = std::function<OwningModuleRef(
|
||||
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *)>;
|
||||
/// Interface of the function that translates the sources managed by `sourceMgr`
|
||||
/// to MLIR. The source manager has at least one buffer. The implementation
|
||||
/// should create a new MLIR ModuleOp in the given context and return a pointer
|
||||
/// to it, or a nullptr in case of any error.
|
||||
using TranslateSourceMgrToMLIRFunction =
|
||||
std::function<OwningModuleRef(llvm::SourceMgr &sourceMgr, MLIRContext *)>;
|
||||
|
||||
/// Interface of the function that translates the given string to MLIR. The
|
||||
/// implementation should create a new MLIR ModuleOp in the given context. If
|
||||
/// source-related error reporting is required from within the function, use
|
||||
/// TranslateSourceMgrToMLIRFunction instead.
|
||||
using TranslateStringRefToMLIRFunction =
|
||||
std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
|
||||
|
||||
/// Interface of the function that translates MLIR to a different format and
|
||||
/// outputs the result to a stream. It is allowed to modify the module.
|
||||
@ -53,11 +61,10 @@ using TranslateFromMLIRFunction =
|
||||
/// should be written to the given raw_ostream. The implementation should create
|
||||
/// all MLIR constructs needed during the process inside the given context. This
|
||||
/// can be used for round-tripping external formats through the MLIR system.
|
||||
using TranslateFunction =
|
||||
std::function<LogicalResult(std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
llvm::raw_ostream &output, MLIRContext *)>;
|
||||
using TranslateFunction = std::function<LogicalResult(
|
||||
llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>;
|
||||
|
||||
/// Use Translate[ToMLIR|FromMLIR|]Registration as a global initialiser that
|
||||
/// Use Translate[ToMLIR|FromMLIR]Registration as a global initialiser that
|
||||
/// registers a function and associates it with name. This requires that a
|
||||
/// translation has not been registered to a given name.
|
||||
///
|
||||
@ -69,7 +76,9 @@ using TranslateFunction =
|
||||
/// \{
|
||||
struct TranslateToMLIRRegistration {
|
||||
TranslateToMLIRRegistration(llvm::StringRef name,
|
||||
const TranslateToMLIRFunction &function);
|
||||
const TranslateSourceMgrToMLIRFunction &function);
|
||||
TranslateToMLIRRegistration(llvm::StringRef name,
|
||||
const TranslateStringRefToMLIRFunction &function);
|
||||
};
|
||||
|
||||
struct TranslateFromMLIRRegistration {
|
||||
@ -83,7 +92,8 @@ struct TranslateRegistration {
|
||||
/// \}
|
||||
|
||||
/// Get a read-only reference to the translator registry.
|
||||
const llvm::StringMap<TranslateToMLIRFunction> &getTranslationToMLIRRegistry();
|
||||
const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
|
||||
getTranslationToMLIRRegistry();
|
||||
const llvm::StringMap<TranslateFromMLIRFunction> &
|
||||
getTranslationFromMLIRRegistry();
|
||||
const llvm::StringMap<TranslateFunction> &getTranslationRegistry();
|
||||
|
@ -42,7 +42,7 @@ using namespace mlir;
|
||||
|
||||
// Deserializes the SPIR-V binary module stored in the file named as
|
||||
// `inputFilename` and returns a module containing the SPIR-V module.
|
||||
OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
|
||||
MLIRContext *context) {
|
||||
Builder builder(context);
|
||||
|
||||
@ -70,9 +70,10 @@ OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
}
|
||||
|
||||
static TranslateToMLIRRegistration fromBinary(
|
||||
"deserialize-spirv",
|
||||
[](std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *context) {
|
||||
return deserializeModule(std::move(input), context);
|
||||
"deserialize-spirv", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
|
||||
assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
|
||||
return deserializeModule(
|
||||
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
|
||||
});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -111,13 +112,9 @@ static TranslateFromMLIRRegistration
|
||||
// Round-trip registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
|
||||
llvm::raw_ostream &output, MLIRContext *context) {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
|
||||
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
|
||||
|
||||
// Parse the memory buffer as a MLIR module.
|
||||
// Parse an MLIR module from the source manager.
|
||||
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||
if (!srcModule)
|
||||
return failure();
|
||||
@ -151,7 +148,7 @@ LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
|
||||
static TranslateRegistration
|
||||
roundtrip("test-spirv-roundtrip",
|
||||
[](std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
llvm::raw_ostream &output, MLIRContext *context) {
|
||||
return roundTripModule(std::move(input), output, context);
|
||||
[](llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output,
|
||||
MLIRContext *context) {
|
||||
return roundTripModule(sourceMgr, output, context);
|
||||
});
|
||||
|
@ -55,15 +55,15 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
|
||||
wrapperStorage.reserve(toMLIRRegistry.size() + fromMLIRRegistry.size() +
|
||||
fileToFileRegistry.size());
|
||||
for (const auto &kv : toMLIRRegistry) {
|
||||
TranslateToMLIRFunction function = kv.second;
|
||||
TranslateFunction wrapper =
|
||||
[function](std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
llvm::raw_ostream &output, MLIRContext *context) {
|
||||
OwningModuleRef module = function(std::move(input), context);
|
||||
if (!module)
|
||||
return failure();
|
||||
return printMLIROutput(*module, output);
|
||||
};
|
||||
TranslateSourceMgrToMLIRFunction function = kv.second;
|
||||
TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
|
||||
llvm::raw_ostream &output,
|
||||
MLIRContext *context) {
|
||||
OwningModuleRef module = function(sourceMgr, context);
|
||||
if (!module)
|
||||
return failure();
|
||||
return printMLIROutput(*module, output);
|
||||
};
|
||||
wrapperStorage.emplace_back(std::move(wrapper));
|
||||
|
||||
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
|
||||
@ -71,18 +71,14 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
|
||||
|
||||
for (const auto &kv : fromMLIRRegistry) {
|
||||
TranslateFromMLIRFunction function = kv.second;
|
||||
TranslateFunction wrapper =
|
||||
[function](std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
llvm::raw_ostream &output, MLIRContext *context) {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
|
||||
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
|
||||
|
||||
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||
if (!module)
|
||||
return failure();
|
||||
return function(module.get(), output);
|
||||
};
|
||||
TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
|
||||
llvm::raw_ostream &output,
|
||||
MLIRContext *context) {
|
||||
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||
if (!module)
|
||||
return failure();
|
||||
return function(module.get(), output);
|
||||
};
|
||||
wrapperStorage.emplace_back(std::move(wrapper));
|
||||
|
||||
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
|
||||
|
@ -571,15 +571,15 @@ mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
|
||||
|
||||
// Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
|
||||
// LLVM dialect.
|
||||
OwningModuleRef
|
||||
translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
MLIRContext *context) {
|
||||
OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
|
||||
MLIRContext *context) {
|
||||
LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
|
||||
assert(dialect && "Could not find LLVMDialect?");
|
||||
|
||||
llvm::SMDiagnostic err;
|
||||
std::unique_ptr<llvm::Module> llvmModule =
|
||||
llvm::parseIR(*input, err, dialect->getLLVMContext(),
|
||||
llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
|
||||
dialect->getLLVMContext(),
|
||||
/*UpgradeDebugInfo=*/true,
|
||||
/*DataLayoutString=*/"");
|
||||
if (!llvmModule) {
|
||||
@ -593,7 +593,7 @@ translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
}
|
||||
|
||||
static TranslateToMLIRRegistration
|
||||
fromLLVM("import-llvm", [](std::unique_ptr<llvm::MemoryBuffer> input,
|
||||
MLIRContext *context) {
|
||||
return translateLLVMIRToModule(std::move(input), context);
|
||||
});
|
||||
fromLLVM("import-llvm",
|
||||
[](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
|
||||
return translateLLVMIRToModule(sourceMgr, context);
|
||||
});
|
||||
|
33
third_party/mlir/lib/Translation/Translation.cpp
vendored
33
third_party/mlir/lib/Translation/Translation.cpp
vendored
@ -23,15 +23,16 @@
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Get the mutable static map between registered "to MLIR" translations and the
|
||||
// TranslateToMLIRFunctions that perform those translations.
|
||||
static llvm::StringMap<TranslateToMLIRFunction> &
|
||||
static llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
|
||||
getMutableTranslationToMLIRRegistry() {
|
||||
static llvm::StringMap<TranslateToMLIRFunction> translationToMLIRRegistry;
|
||||
static llvm::StringMap<TranslateSourceMgrToMLIRFunction>
|
||||
translationToMLIRRegistry;
|
||||
return translationToMLIRRegistry;
|
||||
}
|
||||
// Get the mutable static map between registered "from MLIR" translations and
|
||||
@ -49,8 +50,10 @@ static llvm::StringMap<TranslateFunction> &getMutableTranslationRegistry() {
|
||||
return translationRegistry;
|
||||
}
|
||||
|
||||
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
|
||||
StringRef name, const TranslateToMLIRFunction &function) {
|
||||
// Puts `function` into the to-MLIR translation registry unless there is already
|
||||
// a function registered for the same name.
|
||||
static void registerTranslateToMLIRFunction(
|
||||
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
|
||||
auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry();
|
||||
if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end())
|
||||
llvm::report_fatal_error(
|
||||
@ -59,6 +62,24 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
|
||||
translationToMLIRRegistry[name] = function;
|
||||
}
|
||||
|
||||
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
|
||||
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
|
||||
registerTranslateToMLIRFunction(name, function);
|
||||
}
|
||||
|
||||
// Wraps `function` with a lambda that extracts a StringRef from a source
|
||||
// manager and registers the wrapper lambda as a to-MLIR conversion.
|
||||
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
|
||||
StringRef name, const TranslateStringRefToMLIRFunction &function) {
|
||||
auto translationFunction = [function](llvm::SourceMgr &sourceMgr,
|
||||
MLIRContext *ctx) {
|
||||
const llvm::MemoryBuffer *buffer =
|
||||
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
|
||||
return function(buffer->getBuffer(), ctx);
|
||||
};
|
||||
registerTranslateToMLIRFunction(name, translationFunction);
|
||||
}
|
||||
|
||||
TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
|
||||
StringRef name, const TranslateFromMLIRFunction &function) {
|
||||
auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry();
|
||||
@ -84,7 +105,7 @@ TranslateRegistration::TranslateRegistration(
|
||||
|
||||
// Merely add the const qualifier to the mutable registry so that external users
|
||||
// cannot modify it.
|
||||
const llvm::StringMap<TranslateToMLIRFunction> &
|
||||
const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
|
||||
mlir::getTranslationToMLIRRegistry() {
|
||||
return getMutableTranslationToMLIRRegistry();
|
||||
}
|
||||
|
@ -20,6 +20,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
@ -46,6 +47,12 @@ static llvm::cl::opt<bool>
|
||||
"process each chunk independently"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
static llvm::cl::opt<bool> verifyDiagnostics(
|
||||
"verify-diagnostics",
|
||||
llvm::cl::desc("Check that emitted diagnostics match "
|
||||
"expected-* lines on the corresponding line"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
@ -69,11 +76,24 @@ int main(int argc, char **argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
/// Processes the memory buffer with a new MLIRContext.
|
||||
// Processes the memory buffer with a new MLIRContext.
|
||||
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
|
||||
raw_ostream &os) {
|
||||
MLIRContext context;
|
||||
return (*translationRequested)(std::move(ownedBuffer), os, &context);
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
|
||||
|
||||
if (!verifyDiagnostics) {
|
||||
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
|
||||
return (*translationRequested)(sourceMgr, os, &context);
|
||||
}
|
||||
|
||||
// In the diagnostic verification flow, we ignore whether the translation
|
||||
// failed (in most cases, it is expected to fail). Instead, we check if the
|
||||
// diagnostics were produced as expected.
|
||||
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
|
||||
(*translationRequested)(sourceMgr, os, &context);
|
||||
return sourceMgrHandler.verify();
|
||||
};
|
||||
|
||||
if (splitInputFile) {
|
||||
|
Loading…
Reference in New Issue
Block a user