Update third_party/tensorflow/compiler/mlir/tensorflow/utils/... to not depend on the global Dialect Registry (NFC)
PiperOrigin-RevId: 328171679 Change-Id: I3a4e40a04cb14b9c3d53239b31d3d642bb97daac
This commit is contained in:
parent
116792db45
commit
d41d51215b
@ -17,10 +17,13 @@ limitations under the License.
|
||||
#define MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
namespace mhlo {
|
||||
|
||||
void registerAllDialects();
|
||||
|
||||
// Add chlo, mhlo, lmhlo dialects to the provided registry.
|
||||
void registerAllMhloDialects(DialectRegistry ®istry);
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -31,3 +31,11 @@ void mlir::mhlo::registerAllDialects() {
|
||||
|
||||
// Dependent dialects
|
||||
}
|
||||
|
||||
void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry ®istry) {
|
||||
// clang-format off
|
||||
registry.insert<mlir::chlo::HloClientDialect,
|
||||
mlir::lmhlo::LmhloDialect,
|
||||
mlir::mhlo::MhloDialect>();
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -1512,6 +1512,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/hlo:sink_constants_to_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
|
@ -36,7 +36,9 @@ limitations under the License.
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
@ -276,16 +278,9 @@ Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::shape::ShapeDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::mhlo::MhloDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
static void RegisterDialects(mlir::DialectRegistry& registry) {
|
||||
mlir::RegisterAllTensorFlowDialects(registry);
|
||||
mlir::mhlo::registerAllMhloDialects(registry);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -418,9 +413,8 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext mlir_context;
|
||||
mlir_context.loadAllGloballyRegisteredDialects();
|
||||
RegisterDialects(mlir_context.getDialectRegistry());
|
||||
mlir::OwningModuleRef mlir_module;
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -507,10 +501,8 @@ Status CompileGraphToXlaHlo(
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
RegisterDialects();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
RegisterDialects(context.getDialectRegistry());
|
||||
GraphImportConfig config;
|
||||
config.graph_as_function = true;
|
||||
// Disable shape inference during import as some TensorFlow op fails during
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
@ -33,17 +34,13 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
static void RegisterDialects(mlir::MLIRContext &context) {
|
||||
context.loadDialect<mlir::TF::TensorFlowDialect>();
|
||||
}
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
RegisterDialects(context);
|
||||
mlir::Builder b(&context);
|
||||
|
||||
PartialTensorShape output_shape =
|
||||
@ -53,7 +50,7 @@ TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
RegisterDialects(context);
|
||||
mlir::Builder b(&context);
|
||||
|
||||
PartialTensorShape output_shape = ConvertTypeToTensorShape(
|
||||
@ -63,7 +60,7 @@ TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
RegisterDialects(context);
|
||||
mlir::Builder b(&context);
|
||||
|
||||
PartialTensorShape output_shape = ConvertTypeToTensorShape(
|
||||
@ -80,8 +77,8 @@ TEST(ConvertTypeToTensorTypeTest, ScalarTensorType) {
|
||||
}
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) {
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext context;
|
||||
RegisterDialects(context);
|
||||
mlir::Builder b(&context);
|
||||
|
||||
// Create the sample tensor to convert.
|
||||
@ -126,9 +123,8 @@ class ConvertTensorTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(ConvertTensorTest, Simple) {
|
||||
RegisterDialects();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
RegisterDialects(context);
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<Eigen::half>(
|
||||
{Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context)));
|
||||
ASSERT_NO_FATAL_FAILURE(
|
||||
|
@ -36,7 +36,6 @@ std::string ConvertToMlirString(const std::vector<int64_t>& dims,
|
||||
}
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder b(&context);
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
auto status_or = ConvertToMlirTensorType(shape, dtype, &b);
|
||||
std::string buf;
|
||||
llvm::raw_string_ostream os(buf);
|
||||
|
@ -60,7 +60,6 @@ class FakeDevice : public Device {
|
||||
|
||||
TEST(DeviceUtilTest, AddDeviceToOp) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
|
||||
@ -102,7 +101,6 @@ TEST(DeviceUtilTest, AddDeviceToOp) {
|
||||
|
||||
TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
|
||||
@ -112,7 +110,6 @@ TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
|
||||
|
||||
TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
|
||||
|
@ -66,7 +66,6 @@ Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph,
|
||||
WritableFile* file) {
|
||||
WritableFileRawStream os(std::move(file));
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::OwningModuleRef module;
|
||||
if (flib_def) {
|
||||
flib_def = &graph.flib_def();
|
||||
|
@ -28,7 +28,6 @@ namespace {
|
||||
|
||||
TEST(DumpMlirModuleTest, NoEnvPrefix) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
unsetenv("TF_DUMP_GRAPH_PREFIX");
|
||||
@ -39,7 +38,6 @@ TEST(DumpMlirModuleTest, NoEnvPrefix) {
|
||||
|
||||
TEST(DumpMlirModuleTest, LogInfo) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
setenv("TF_DUMP_GRAPH_PREFIX", "-", 1);
|
||||
@ -50,7 +48,6 @@ TEST(DumpMlirModuleTest, LogInfo) {
|
||||
|
||||
TEST(DumpMlirModuleTest, Valid) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1);
|
||||
|
@ -29,7 +29,6 @@ using testing::HasSubstr;
|
||||
|
||||
TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) {
|
||||
MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
auto id = Identifier::get("test.cc", &context);
|
||||
auto loc = FileLineColLoc::get(id, 0, 0, &context);
|
||||
|
||||
|
@ -602,7 +602,6 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::Builder builder(&context);
|
||||
auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
|
||||
auto status_or_device_coodinates =
|
||||
@ -616,7 +615,6 @@ TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
mlir::Builder builder(&context);
|
||||
auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
|
||||
auto status_or_device_coodinates =
|
||||
@ -627,9 +625,8 @@ TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
@ -644,8 +641,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
@ -665,8 +662,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) {
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
@ -685,8 +682,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
@ -705,8 +702,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
@ -728,8 +725,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
@ -753,8 +750,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
@ -780,8 +777,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
|
Loading…
Reference in New Issue
Block a user