mlir-translate: support -verify-diagnostics
MLIR translation tools can emit diagnostics and we want to be able to check if it is indeed the case in tests. Reuse the source manager error handlers provided for mlir-opt to support the verification in mlir-translate. This requires us to change the signature of the functions that are registered to translate sources to MLIR: it now takes a source manager instead of a memory buffer. PiperOrigin-RevId: 279132972 Change-Id: I9750d8f199766c391cda142aad8957d52beece3b
This commit is contained in:
parent
8c0b8fa3c4
commit
990149b809
tensorflow/compiler/mlir
third_party/mlir
include/mlir
lib
Dialect/SPIRV/Serialization
Support
Target/LLVMIR
Translation
tools/mlir-translate
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/Endian.h"
|
#include "llvm/Support/Endian.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "llvm/Support/MemoryBuffer.h"
|
#include "llvm/Support/MemoryBuffer.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||||
@ -861,8 +862,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
|||||||
return OwningModuleRef(module);
|
return OwningModuleRef(module);
|
||||||
}
|
}
|
||||||
|
|
||||||
static OwningModuleRef FlatBufferFileToMlirTrans(
|
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) {
|
MLIRContext* context) {
|
||||||
|
const llvm::MemoryBuffer* input =
|
||||||
|
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
|
||||||
std::string error;
|
std::string error;
|
||||||
auto loc =
|
auto loc =
|
||||||
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
||||||
@ -884,4 +887,7 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
|
|||||||
}
|
}
|
||||||
|
|
||||||
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
||||||
"tflite-flatbuffer-to-mlir", FlatBufferFileToMlirTrans);
|
"tflite-flatbuffer-to-mlir",
|
||||||
|
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
|
||||||
|
return FlatBufferFileToMlirTrans(&source_mgr, context);
|
||||||
|
});
|
||||||
|
@ -86,14 +86,14 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
|||||||
|
|
||||||
if (use_splatted_constant) {
|
if (use_splatted_constant) {
|
||||||
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||||
std::move(file), debug_info_file, input_arrays, input_dtypes,
|
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||||
input_shapes, output_arrays, prune_unused_nodes,
|
input_shapes, output_arrays, prune_unused_nodes,
|
||||||
/*convert_legacy_fed_inputs=*/true,
|
/*convert_legacy_fed_inputs=*/true,
|
||||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
||||||
add_pseudo_input_nodes, context);
|
add_pseudo_input_nodes, context);
|
||||||
}
|
}
|
||||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||||
std::move(file), debug_info_file, input_arrays, input_dtypes,
|
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||||
input_shapes, output_arrays, prune_unused_nodes,
|
input_shapes, output_arrays, prune_unused_nodes,
|
||||||
/*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false,
|
/*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false,
|
||||||
/*upgrade_legacy=*/true, add_pseudo_input_nodes, context);
|
/*upgrade_legacy=*/true, add_pseudo_input_nodes, context);
|
||||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||||
@ -43,15 +42,15 @@ using stream_executor::port::Status;
|
|||||||
using stream_executor::port::StatusOr;
|
using stream_executor::port::StatusOr;
|
||||||
|
|
||||||
static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
llvm::StringRef input, absl::string_view debug_info_file,
|
||||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context) {
|
mlir::MLIRContext* context) {
|
||||||
GraphDef graphdef;
|
GraphDef graphdef;
|
||||||
TF_RETURN_IF_ERROR(tensorflow::LoadProtoFromBuffer(
|
TF_RETURN_IF_ERROR(
|
||||||
{input->getBufferStart(), input->getBufferSize()}, &graphdef));
|
tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
|
||||||
|
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
if (!debug_info_file.empty()) {
|
if (!debug_info_file.empty()) {
|
||||||
@ -91,17 +90,16 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
llvm::StringRef input, absl::string_view debug_info_file,
|
||||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context) {
|
mlir::MLIRContext* context) {
|
||||||
auto module_or = GraphdefToMlirImport(
|
auto module_or = GraphdefToMlirImport(
|
||||||
std::move(input), debug_info_file, input_arrays, input_dtypes,
|
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||||
input_shapes, output_arrays, prune_unused_nodes,
|
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
|
||||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
graph_as_function, upgrade_legacy, add_pseudo_input_nodes, context);
|
||||||
add_pseudo_input_nodes, context);
|
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -136,17 +134,16 @@ mlir::OwningModuleRef SavedModelToMlirImport(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
llvm::StringRef input, absl::string_view debug_info_file,
|
||||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context) {
|
mlir::MLIRContext* context) {
|
||||||
auto module_or = GraphdefToMlirImport(
|
auto module_or = GraphdefToMlirImport(
|
||||||
std::move(input), debug_info_file, input_arrays, input_dtypes,
|
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||||
input_shapes, output_arrays, prune_unused_nodes,
|
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
|
||||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
graph_as_function, upgrade_legacy, add_pseudo_input_nodes, context);
|
||||||
add_pseudo_input_nodes, context);
|
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -32,22 +32,22 @@ namespace tensorflow {
|
|||||||
// `input_filename` into a MLIR module. Creates MLIR entities into the
|
// `input_filename` into a MLIR module. Creates MLIR entities into the
|
||||||
// given MLIR `context`.
|
// given MLIR `context`.
|
||||||
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
llvm::StringRef input, absl::string_view debug_info_file,
|
||||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context);
|
mlir::MLIRContext* context);
|
||||||
|
|
||||||
// Similar as the above function, but replaces all constant tensors
|
// Similar as the above function, but replaces all constant tensors
|
||||||
// with randomly generated splat values.
|
// with randomly generated splat values.
|
||||||
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input,
|
llvm::StringRef input, absl::string_view debug_info_file,
|
||||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||||
bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
|
bool graph_as_function, bool upgrade_legacy, bool add_pseudo_input_nodes,
|
||||||
bool add_pseudo_input_nodes, mlir::MLIRContext* context);
|
mlir::MLIRContext* context);
|
||||||
|
|
||||||
// Converts a TensorFlow SavedModel stored in the directory with the given
|
// Converts a TensorFlow SavedModel stored in the directory with the given
|
||||||
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
||||||
|
@ -40,12 +40,12 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static OwningModuleRef GraphdefToMlirTranslateFunction(
|
static OwningModuleRef GraphdefToMlirTranslateFunction(llvm::StringRef input,
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) {
|
MLIRContext* context) {
|
||||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||||
std::move(input), debug_info_file, input_arrays, input_dtypes,
|
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||||
input_shapes, output_arrays, prune_unused_nodes,
|
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
|
||||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
graph_as_function, upgrade_legacy,
|
||||||
/*add_pseudo_input_nodes=*/false, context);
|
/*add_pseudo_input_nodes=*/false, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,11 +53,11 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate(
|
|||||||
"graphdef-to-mlir", GraphdefToMlirTranslateFunction);
|
"graphdef-to-mlir", GraphdefToMlirTranslateFunction);
|
||||||
|
|
||||||
static OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
static OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext* context) {
|
llvm::StringRef input, MLIRContext* context) {
|
||||||
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||||
std::move(input), debug_info_file, input_arrays, input_dtypes,
|
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||||
input_shapes, output_arrays, prune_unused_nodes,
|
output_arrays, prune_unused_nodes, convert_legacy_fed_inputs,
|
||||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
graph_as_function, upgrade_legacy,
|
||||||
/*add_pseudo_input_nodes=*/false, context);
|
/*add_pseudo_input_nodes=*/false, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "llvm/Support/InitLLVM.h"
|
#include "llvm/Support/InitLLVM.h"
|
||||||
#include "llvm/Support/MemoryBuffer.h"
|
#include "llvm/Support/MemoryBuffer.h"
|
||||||
|
#include "llvm/Support/SMLoc.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||||
@ -105,8 +107,11 @@ int main(int argc, char** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(
|
llvm::SourceMgr source_mgr;
|
||||||
(*requested_translation)(std::move(input), output->os(), &context)))
|
source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
|
||||||
|
mlir::SourceMgrDiagnosticHandler diagnostic_handler(source_mgr, &context);
|
||||||
|
|
||||||
|
if (failed((*requested_translation)(source_mgr, output->os(), &context)))
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,12 +64,11 @@ bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
|
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context) {
|
llvm::StringRef input, mlir::MLIRContext* context) {
|
||||||
HloProto hlo_proto;
|
HloProto hlo_proto;
|
||||||
string content(input->getBufferStart(), input->getBufferSize());
|
string content(input.data(), input.size());
|
||||||
if (!LoadHloProto(content, &hlo_proto)) {
|
if (!LoadHloProto(content, &hlo_proto)) {
|
||||||
LOG(ERROR) << "Failed to load proto: "
|
LOG(ERROR) << "Failed to load proto";
|
||||||
<< input->getBufferIdentifier().str();
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,9 +85,9 @@ mlir::OwningModuleRef HloToMlirHloTranslateFunction(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
|
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context) {
|
llvm::StringRef input, mlir::MLIRContext* context) {
|
||||||
HloProto hlo_proto;
|
HloProto hlo_proto;
|
||||||
string content(input->getBufferStart(), input->getBufferSize());
|
string content(input.data(), input.size());
|
||||||
|
|
||||||
auto hlo_module_error = ParseAndReturnUnverifiedModule(content);
|
auto hlo_module_error = ParseAndReturnUnverifiedModule(content);
|
||||||
if (!hlo_module_error.ok()) {
|
if (!hlo_module_error.ok()) {
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
class MemoryBuffer;
|
|
||||||
class StringRef;
|
class StringRef;
|
||||||
} // namespace llvm
|
} // namespace llvm
|
||||||
|
|
||||||
@ -34,14 +33,14 @@ namespace xla {
|
|||||||
|
|
||||||
// Converts a HloModuleProto stored in the file with the given `input_filename`
|
// Converts a HloModuleProto stored in the file with the given `input_filename`
|
||||||
// into a MLIR module. Creates MLIR entities into the given MLIR `context`.
|
// into a MLIR module. Creates MLIR entities into the given MLIR `context`.
|
||||||
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
|
mlir::OwningModuleRef HloToMlirHloTranslateFunction(llvm::StringRef input,
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context);
|
mlir::MLIRContext* context);
|
||||||
|
|
||||||
// Converts a HloModule stored in text form for a file with the given
|
// Converts a HloModule stored in text form for a file with the given
|
||||||
// `input_filename` into a MLIR module. Creates MLIR entities into the given
|
// `input_filename` into a MLIR module. Creates MLIR entities into the given
|
||||||
// MLIR `context`.
|
// MLIR `context`.
|
||||||
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
|
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input, mlir::MLIRContext* context);
|
llvm::StringRef input, mlir::MLIRContext* context);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
34
third_party/mlir/include/mlir/Translation.h
vendored
34
third_party/mlir/include/mlir/Translation.h
vendored
@ -27,6 +27,7 @@
|
|||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
class MemoryBuffer;
|
class MemoryBuffer;
|
||||||
|
class SourceMgr;
|
||||||
class StringRef;
|
class StringRef;
|
||||||
} // namespace llvm
|
} // namespace llvm
|
||||||
|
|
||||||
@ -36,12 +37,19 @@ class MLIRContext;
|
|||||||
class ModuleOp;
|
class ModuleOp;
|
||||||
class OwningModuleRef;
|
class OwningModuleRef;
|
||||||
|
|
||||||
/// Interface of the function that translates a source file held by the given
|
/// Interface of the function that translates the sources managed by `sourceMgr`
|
||||||
/// MemoryBuffer to MLIR. The implementation should create a new MLIR ModuleOp
|
/// to MLIR. The source manager has at least one buffer. The implementation
|
||||||
/// in the given context and return a pointer to it, or a nullptr in case of any
|
/// should create a new MLIR ModuleOp in the given context and return a pointer
|
||||||
/// error.
|
/// to it, or a nullptr in case of any error.
|
||||||
using TranslateToMLIRFunction = std::function<OwningModuleRef(
|
using TranslateSourceMgrToMLIRFunction =
|
||||||
std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *)>;
|
std::function<OwningModuleRef(llvm::SourceMgr &sourceMgr, MLIRContext *)>;
|
||||||
|
|
||||||
|
/// Interface of the function that translates the given string to MLIR. The
|
||||||
|
/// implementation should create a new MLIR ModuleOp in the given context. If
|
||||||
|
/// source-related error reporting is required from within the function, use
|
||||||
|
/// TranslateSourceMgrToMLIRFunction instead.
|
||||||
|
using TranslateStringRefToMLIRFunction =
|
||||||
|
std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
|
||||||
|
|
||||||
/// Interface of the function that translates MLIR to a different format and
|
/// Interface of the function that translates MLIR to a different format and
|
||||||
/// outputs the result to a stream. It is allowed to modify the module.
|
/// outputs the result to a stream. It is allowed to modify the module.
|
||||||
@ -53,11 +61,10 @@ using TranslateFromMLIRFunction =
|
|||||||
/// should be written to the given raw_ostream. The implementation should create
|
/// should be written to the given raw_ostream. The implementation should create
|
||||||
/// all MLIR constructs needed during the process inside the given context. This
|
/// all MLIR constructs needed during the process inside the given context. This
|
||||||
/// can be used for round-tripping external formats through the MLIR system.
|
/// can be used for round-tripping external formats through the MLIR system.
|
||||||
using TranslateFunction =
|
using TranslateFunction = std::function<LogicalResult(
|
||||||
std::function<LogicalResult(std::unique_ptr<llvm::MemoryBuffer> input,
|
llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>;
|
||||||
llvm::raw_ostream &output, MLIRContext *)>;
|
|
||||||
|
|
||||||
/// Use Translate[ToMLIR|FromMLIR|]Registration as a global initialiser that
|
/// Use Translate[ToMLIR|FromMLIR]Registration as a global initialiser that
|
||||||
/// registers a function and associates it with name. This requires that a
|
/// registers a function and associates it with name. This requires that a
|
||||||
/// translation has not been registered to a given name.
|
/// translation has not been registered to a given name.
|
||||||
///
|
///
|
||||||
@ -69,7 +76,9 @@ using TranslateFunction =
|
|||||||
/// \{
|
/// \{
|
||||||
struct TranslateToMLIRRegistration {
|
struct TranslateToMLIRRegistration {
|
||||||
TranslateToMLIRRegistration(llvm::StringRef name,
|
TranslateToMLIRRegistration(llvm::StringRef name,
|
||||||
const TranslateToMLIRFunction &function);
|
const TranslateSourceMgrToMLIRFunction &function);
|
||||||
|
TranslateToMLIRRegistration(llvm::StringRef name,
|
||||||
|
const TranslateStringRefToMLIRFunction &function);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TranslateFromMLIRRegistration {
|
struct TranslateFromMLIRRegistration {
|
||||||
@ -83,7 +92,8 @@ struct TranslateRegistration {
|
|||||||
/// \}
|
/// \}
|
||||||
|
|
||||||
/// Get a read-only reference to the translator registry.
|
/// Get a read-only reference to the translator registry.
|
||||||
const llvm::StringMap<TranslateToMLIRFunction> &getTranslationToMLIRRegistry();
|
const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
|
||||||
|
getTranslationToMLIRRegistry();
|
||||||
const llvm::StringMap<TranslateFromMLIRFunction> &
|
const llvm::StringMap<TranslateFromMLIRFunction> &
|
||||||
getTranslationFromMLIRRegistry();
|
getTranslationFromMLIRRegistry();
|
||||||
const llvm::StringMap<TranslateFunction> &getTranslationRegistry();
|
const llvm::StringMap<TranslateFunction> &getTranslationRegistry();
|
||||||
|
@ -42,7 +42,7 @@ using namespace mlir;
|
|||||||
|
|
||||||
// Deserializes the SPIR-V binary module stored in the file named as
|
// Deserializes the SPIR-V binary module stored in the file named as
|
||||||
// `inputFilename` and returns a module containing the SPIR-V module.
|
// `inputFilename` and returns a module containing the SPIR-V module.
|
||||||
OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
Builder builder(context);
|
Builder builder(context);
|
||||||
|
|
||||||
@ -70,9 +70,10 @@ OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static TranslateToMLIRRegistration fromBinary(
|
static TranslateToMLIRRegistration fromBinary(
|
||||||
"deserialize-spirv",
|
"deserialize-spirv", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
|
||||||
[](std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *context) {
|
assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
|
||||||
return deserializeModule(std::move(input), context);
|
return deserializeModule(
|
||||||
|
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
|
||||||
});
|
});
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -111,13 +112,9 @@ static TranslateFromMLIRRegistration
|
|||||||
// Round-trip registration
|
// Round-trip registration
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
|
||||||
llvm::raw_ostream &output, MLIRContext *context) {
|
llvm::raw_ostream &output, MLIRContext *context) {
|
||||||
llvm::SourceMgr sourceMgr;
|
// Parse an MLIR module from the source manager.
|
||||||
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
|
|
||||||
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
|
|
||||||
|
|
||||||
// Parse the memory buffer as a MLIR module.
|
|
||||||
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||||
if (!srcModule)
|
if (!srcModule)
|
||||||
return failure();
|
return failure();
|
||||||
@ -151,7 +148,7 @@ LogicalResult roundTripModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
|||||||
|
|
||||||
static TranslateRegistration
|
static TranslateRegistration
|
||||||
roundtrip("test-spirv-roundtrip",
|
roundtrip("test-spirv-roundtrip",
|
||||||
[](std::unique_ptr<llvm::MemoryBuffer> input,
|
[](llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output,
|
||||||
llvm::raw_ostream &output, MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
return roundTripModule(std::move(input), output, context);
|
return roundTripModule(sourceMgr, output, context);
|
||||||
});
|
});
|
||||||
|
@ -55,15 +55,15 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
|
|||||||
wrapperStorage.reserve(toMLIRRegistry.size() + fromMLIRRegistry.size() +
|
wrapperStorage.reserve(toMLIRRegistry.size() + fromMLIRRegistry.size() +
|
||||||
fileToFileRegistry.size());
|
fileToFileRegistry.size());
|
||||||
for (const auto &kv : toMLIRRegistry) {
|
for (const auto &kv : toMLIRRegistry) {
|
||||||
TranslateToMLIRFunction function = kv.second;
|
TranslateSourceMgrToMLIRFunction function = kv.second;
|
||||||
TranslateFunction wrapper =
|
TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
|
||||||
[function](std::unique_ptr<llvm::MemoryBuffer> input,
|
llvm::raw_ostream &output,
|
||||||
llvm::raw_ostream &output, MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
OwningModuleRef module = function(std::move(input), context);
|
OwningModuleRef module = function(sourceMgr, context);
|
||||||
if (!module)
|
if (!module)
|
||||||
return failure();
|
return failure();
|
||||||
return printMLIROutput(*module, output);
|
return printMLIROutput(*module, output);
|
||||||
};
|
};
|
||||||
wrapperStorage.emplace_back(std::move(wrapper));
|
wrapperStorage.emplace_back(std::move(wrapper));
|
||||||
|
|
||||||
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
|
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
|
||||||
@ -71,18 +71,14 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
|
|||||||
|
|
||||||
for (const auto &kv : fromMLIRRegistry) {
|
for (const auto &kv : fromMLIRRegistry) {
|
||||||
TranslateFromMLIRFunction function = kv.second;
|
TranslateFromMLIRFunction function = kv.second;
|
||||||
TranslateFunction wrapper =
|
TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr,
|
||||||
[function](std::unique_ptr<llvm::MemoryBuffer> input,
|
llvm::raw_ostream &output,
|
||||||
llvm::raw_ostream &output, MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
llvm::SourceMgr sourceMgr;
|
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||||
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
|
if (!module)
|
||||||
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
|
return failure();
|
||||||
|
return function(module.get(), output);
|
||||||
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
};
|
||||||
if (!module)
|
|
||||||
return failure();
|
|
||||||
return function(module.get(), output);
|
|
||||||
};
|
|
||||||
wrapperStorage.emplace_back(std::move(wrapper));
|
wrapperStorage.emplace_back(std::move(wrapper));
|
||||||
|
|
||||||
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
|
addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
|
||||||
|
@ -571,15 +571,15 @@ mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
|
|||||||
|
|
||||||
// Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
|
// Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
|
||||||
// LLVM dialect.
|
// LLVM dialect.
|
||||||
OwningModuleRef
|
OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
|
||||||
translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
MLIRContext *context) {
|
||||||
MLIRContext *context) {
|
|
||||||
LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
|
LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
|
||||||
assert(dialect && "Could not find LLVMDialect?");
|
assert(dialect && "Could not find LLVMDialect?");
|
||||||
|
|
||||||
llvm::SMDiagnostic err;
|
llvm::SMDiagnostic err;
|
||||||
std::unique_ptr<llvm::Module> llvmModule =
|
std::unique_ptr<llvm::Module> llvmModule =
|
||||||
llvm::parseIR(*input, err, dialect->getLLVMContext(),
|
llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
|
||||||
|
dialect->getLLVMContext(),
|
||||||
/*UpgradeDebugInfo=*/true,
|
/*UpgradeDebugInfo=*/true,
|
||||||
/*DataLayoutString=*/"");
|
/*DataLayoutString=*/"");
|
||||||
if (!llvmModule) {
|
if (!llvmModule) {
|
||||||
@ -593,7 +593,7 @@ translateLLVMIRToModule(std::unique_ptr<llvm::MemoryBuffer> input,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static TranslateToMLIRRegistration
|
static TranslateToMLIRRegistration
|
||||||
fromLLVM("import-llvm", [](std::unique_ptr<llvm::MemoryBuffer> input,
|
fromLLVM("import-llvm",
|
||||||
MLIRContext *context) {
|
[](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
|
||||||
return translateLLVMIRToModule(std::move(input), context);
|
return translateLLVMIRToModule(sourceMgr, context);
|
||||||
});
|
});
|
||||||
|
33
third_party/mlir/lib/Translation/Translation.cpp
vendored
33
third_party/mlir/lib/Translation/Translation.cpp
vendored
@ -23,15 +23,16 @@
|
|||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "llvm/Support/ManagedStatic.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
// Get the mutable static map between registered "to MLIR" translations and the
|
// Get the mutable static map between registered "to MLIR" translations and the
|
||||||
// TranslateToMLIRFunctions that perform those translations.
|
// TranslateToMLIRFunctions that perform those translations.
|
||||||
static llvm::StringMap<TranslateToMLIRFunction> &
|
static llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
|
||||||
getMutableTranslationToMLIRRegistry() {
|
getMutableTranslationToMLIRRegistry() {
|
||||||
static llvm::StringMap<TranslateToMLIRFunction> translationToMLIRRegistry;
|
static llvm::StringMap<TranslateSourceMgrToMLIRFunction>
|
||||||
|
translationToMLIRRegistry;
|
||||||
return translationToMLIRRegistry;
|
return translationToMLIRRegistry;
|
||||||
}
|
}
|
||||||
// Get the mutable static map between registered "from MLIR" translations and
|
// Get the mutable static map between registered "from MLIR" translations and
|
||||||
@ -49,8 +50,10 @@ static llvm::StringMap<TranslateFunction> &getMutableTranslationRegistry() {
|
|||||||
return translationRegistry;
|
return translationRegistry;
|
||||||
}
|
}
|
||||||
|
|
||||||
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
|
// Puts `function` into the to-MLIR translation registry unless there is already
|
||||||
StringRef name, const TranslateToMLIRFunction &function) {
|
// a function registered for the same name.
|
||||||
|
static void registerTranslateToMLIRFunction(
|
||||||
|
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
|
||||||
auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry();
|
auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry();
|
||||||
if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end())
|
if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end())
|
||||||
llvm::report_fatal_error(
|
llvm::report_fatal_error(
|
||||||
@ -59,6 +62,24 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
|
|||||||
translationToMLIRRegistry[name] = function;
|
translationToMLIRRegistry[name] = function;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
|
||||||
|
StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
|
||||||
|
registerTranslateToMLIRFunction(name, function);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wraps `function` with a lambda that extracts a StringRef from a source
|
||||||
|
// manager and registers the wrapper lambda as a to-MLIR conversion.
|
||||||
|
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
|
||||||
|
StringRef name, const TranslateStringRefToMLIRFunction &function) {
|
||||||
|
auto translationFunction = [function](llvm::SourceMgr &sourceMgr,
|
||||||
|
MLIRContext *ctx) {
|
||||||
|
const llvm::MemoryBuffer *buffer =
|
||||||
|
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
|
||||||
|
return function(buffer->getBuffer(), ctx);
|
||||||
|
};
|
||||||
|
registerTranslateToMLIRFunction(name, translationFunction);
|
||||||
|
}
|
||||||
|
|
||||||
TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
|
TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
|
||||||
StringRef name, const TranslateFromMLIRFunction &function) {
|
StringRef name, const TranslateFromMLIRFunction &function) {
|
||||||
auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry();
|
auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry();
|
||||||
@ -84,7 +105,7 @@ TranslateRegistration::TranslateRegistration(
|
|||||||
|
|
||||||
// Merely add the const qualifier to the mutable registry so that external users
|
// Merely add the const qualifier to the mutable registry so that external users
|
||||||
// cannot modify it.
|
// cannot modify it.
|
||||||
const llvm::StringMap<TranslateToMLIRFunction> &
|
const llvm::StringMap<TranslateSourceMgrToMLIRFunction> &
|
||||||
mlir::getTranslationToMLIRRegistry() {
|
mlir::getTranslationToMLIRRegistry() {
|
||||||
return getMutableTranslationToMLIRRegistry();
|
return getMutableTranslationToMLIRRegistry();
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/Diagnostics.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/Support/FileUtilities.h"
|
#include "mlir/Support/FileUtilities.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
@ -46,6 +47,12 @@ static llvm::cl::opt<bool>
|
|||||||
"process each chunk independently"),
|
"process each chunk independently"),
|
||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
static llvm::cl::opt<bool> verifyDiagnostics(
|
||||||
|
"verify-diagnostics",
|
||||||
|
llvm::cl::desc("Check that emitted diagnostics match "
|
||||||
|
"expected-* lines on the corresponding line"),
|
||||||
|
llvm::cl::init(false));
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
|
||||||
@ -69,11 +76,24 @@ int main(int argc, char **argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Processes the memory buffer with a new MLIRContext.
|
// Processes the memory buffer with a new MLIRContext.
|
||||||
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
|
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
return (*translationRequested)(std::move(ownedBuffer), os, &context);
|
llvm::SourceMgr sourceMgr;
|
||||||
|
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
|
||||||
|
|
||||||
|
if (!verifyDiagnostics) {
|
||||||
|
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
|
||||||
|
return (*translationRequested)(sourceMgr, os, &context);
|
||||||
|
}
|
||||||
|
|
||||||
|
// In the diagnostic verification flow, we ignore whether the translation
|
||||||
|
// failed (in most cases, it is expected to fail). Instead, we check if the
|
||||||
|
// diagnostics were produced as expected.
|
||||||
|
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
|
||||||
|
(*translationRequested)(sourceMgr, os, &context);
|
||||||
|
return sourceMgrHandler.verify();
|
||||||
};
|
};
|
||||||
|
|
||||||
if (splitInputFile) {
|
if (splitInputFile) {
|
||||||
|
Loading…
Reference in New Issue
Block a user