Add missing dialect registration to MLIR TF Lite lstm_utils_test

This is caught by new assertions added in MLIR upstream to detect such misconfiguration.

PiperOrigin-RevId: 313539845
Change-Id: I90735dca1c7e417b3e3f42fa144177522bc242a4
This commit is contained in:
Mehdi Amini 2020-05-28 00:33:52 -07:00 committed by TensorFlower Gardener
parent 83a67afb59
commit e651638ee4

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/test.h"
namespace mlir {
@ -92,7 +93,9 @@ class LstmUtilsTest : public ::testing::Test {
LstmUtilsTest() {}
void SetUp() override {
builder_ = std::unique_ptr<mlir::Builder>(new Builder(&context_));
RegisterDialects();
context_ = std::make_unique<mlir::MLIRContext>();
builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
fused_lstm_func_cifg_ =
createLstmCompositeFunc(builder_.get(), false, true);
@ -105,10 +108,17 @@ class LstmUtilsTest : public ::testing::Test {
fused_ln_lstm_func_.erase();
builder_.reset();
}
void RegisterDialects() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<TensorFlowLiteDialect>();
}
FuncOp fused_lstm_func_;
FuncOp fused_lstm_func_cifg_;
FuncOp fused_ln_lstm_func_;
mlir::MLIRContext context_;
std::unique_ptr<mlir::MLIRContext> context_;
std::unique_ptr<mlir::Builder> builder_;
};