Merge pull request #39330 from shraiysh:master
PiperOrigin-RevId: 311248675 Change-Id: Idf82ebbb155efec1624565eb13bd67573b68037a
This commit is contained in:
commit
7bffd6c498
tensorflow
compiler/mlir/python/mlir_wrapper
BUILDattrs.ccbasic_classes.ccbuilders.ccfilecheck_wrapper.ccmlir_wrapper.ccmlir_wrapper.hops.cctypes.cc
python/tf_program
41
tensorflow/compiler/mlir/python/mlir_wrapper/BUILD
Normal file
41
tensorflow/compiler/mlir/python/mlir_wrapper/BUILD
Normal file
@ -0,0 +1,41 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "mlir_wrapper",
|
||||
srcs = [
|
||||
"attrs.cc",
|
||||
"basic_classes.cc",
|
||||
"builders.cc",
|
||||
"mlir_wrapper.cc",
|
||||
"mlir_wrapper.h",
|
||||
"ops.cc",
|
||||
"types.cc",
|
||||
],
|
||||
module_name = "mlir_wrapper",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "filecheck_wrapper",
|
||||
srcs = ["filecheck_wrapper.cc"],
|
||||
module_name = "filecheck_wrapper",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:support",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
25
tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc
Normal file
25
tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc
Normal file
@ -0,0 +1,25 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
void init_attrs(py::module& m) {
|
||||
py::class_<mlir::Attribute>(m, "Attribute");
|
||||
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "IntegerAttr")
|
||||
.def("get",
|
||||
py::overload_cast<mlir::Type, int64_t>(&mlir::IntegerAttr::get));
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/FileCheck.h"
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Region.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
void init_basic_classes(py::module& m) {
|
||||
py::class_<mlir::MLIRContext>(m, "MLIRContext").def(py::init<>());
|
||||
|
||||
py::class_<mlir::Location>(m, "Location");
|
||||
|
||||
py::class_<mlir::UnknownLoc>(m, "UnknownLoc")
|
||||
.def("get", &mlir::UnknownLoc::get);
|
||||
|
||||
py::class_<mlir::Region>(m, "Region")
|
||||
.def("back", &mlir::Region::back, py::return_value_policy::reference)
|
||||
.def("front", &mlir::Region::front, py::return_value_policy::reference)
|
||||
.def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); })
|
||||
.def("push_back", &mlir::Region::push_back)
|
||||
.def("size", [](mlir::Region& r) { return r.getBlocks().size(); })
|
||||
.def("front", &mlir::Region::front, py::return_value_policy::reference);
|
||||
py::class_<mlir::Block::iterator>(m, "Block_Iterator");
|
||||
py::class_<mlir::Block>(m, "Block")
|
||||
.def("new", ([]() { return new mlir::Block; }),
|
||||
py::return_value_policy::reference)
|
||||
.def("end", &mlir::Block::end)
|
||||
.def("addArgument", &mlir::Block::addArgument);
|
||||
|
||||
py::class_<mlir::Value>(m, "Value").def("getType", &mlir::Value::getType);
|
||||
py::class_<mlir::OpResult, mlir::Value>(m, "OpResult");
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "BlockArgument");
|
||||
}
|
51
tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc
Normal file
51
tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
void init_builders(py::module& m) {
|
||||
py::class_<mlir::Builder>(m, "Builder")
|
||||
.def(py::init<mlir::MLIRContext*>())
|
||||
.def("getFunctionType",
|
||||
[](mlir::Builder& b, std::vector<mlir::Type> inputs,
|
||||
std::vector<mlir::Type> outputs) {
|
||||
return b.getFunctionType(llvm::ArrayRef<mlir::Type>(inputs),
|
||||
llvm::ArrayRef<mlir::Type>(outputs));
|
||||
});
|
||||
py::class_<mlir::OpBuilder>(m, "OpBuilder")
|
||||
.def(py::init<mlir::MLIRContext*>())
|
||||
.def(py::init<mlir::Region&>())
|
||||
.def(py::init<mlir::Operation*>())
|
||||
.def(py::init<mlir::Block*, mlir::Block::iterator>())
|
||||
.def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc)
|
||||
.def("setInsertionPoint",
|
||||
py::overload_cast<mlir::Block*, mlir::Block::iterator>(
|
||||
&mlir::OpBuilder::setInsertionPoint))
|
||||
.def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint)
|
||||
.def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint)
|
||||
.def(
|
||||
"createOperation",
|
||||
[](mlir::OpBuilder& opb, mlir::OperationState& state) {
|
||||
return opb.createOperation(state);
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def("getContext", &mlir::OpBuilder::getContext,
|
||||
py::return_value_policy::reference);
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "OpBuilder_InsertionPoint")
|
||||
.def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock);
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/FileCheck.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
PYBIND11_MODULE(filecheck_wrapper, m) {
|
||||
m.def("check", [](std::string input, std::string check) {
|
||||
llvm::FileCheckRequest fcr;
|
||||
llvm::FileCheck fc(fcr);
|
||||
llvm::SourceMgr SM = llvm::SourceMgr();
|
||||
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
||||
llvm::SMLoc());
|
||||
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check),
|
||||
llvm::SMLoc());
|
||||
llvm::Regex regex = fc.buildCheckPrefixRegex();
|
||||
fc.readCheckFile(SM, llvm::StringRef(check), regex);
|
||||
return fc.checkInput(SM, llvm::StringRef(input));
|
||||
});
|
||||
}
|
38
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
Normal file
38
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
PYBIND11_MODULE(mlir_wrapper, m) {
|
||||
m.def("registerDialects", []() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
});
|
||||
|
||||
init_basic_classes(m);
|
||||
init_types(m);
|
||||
init_builders(m);
|
||||
init_ops(m);
|
||||
init_attrs(m);
|
||||
}
|
30
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h
Normal file
30
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
||||
#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void init_basic_classes(py::module& m);
|
||||
void init_types(py::module& m);
|
||||
void init_builders(py::module& m);
|
||||
void init_ops(py::module& m);
|
||||
void init_attrs(py::module& m);
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
194
tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc
Normal file
194
tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc
Normal file
@ -0,0 +1,194 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
void init_ops(py::module& m) {
|
||||
py::class_<mlir::Operation, std::unique_ptr<mlir::Operation, py::nodelete>>(
|
||||
m, "Operation")
|
||||
.def("getRegion", &mlir::Operation::getRegion,
|
||||
py::return_value_policy::reference)
|
||||
.def("getResult", &mlir::Operation::getResult)
|
||||
.def("dump", &mlir::Operation::dump)
|
||||
.def("getNumResults", &mlir::Operation::getNumResults);
|
||||
|
||||
py::class_<mlir::OperationState>(m, "OperationState")
|
||||
.def(py::init([](mlir::Location loc, std::string name) {
|
||||
return mlir::OperationState(loc, llvm::StringRef(name));
|
||||
}))
|
||||
.def("addTypes",
|
||||
[](mlir::OperationState& state, std::vector<mlir::Type> tys) {
|
||||
state.addTypes(mlir::ArrayRef<mlir::Type>(tys));
|
||||
})
|
||||
.def("addOperands",
|
||||
[](mlir::OperationState& os, std::vector<mlir::Value> ops) {
|
||||
os.addOperands(mlir::ArrayRef<mlir::Value>(ops));
|
||||
})
|
||||
.def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion),
|
||||
py::return_value_policy::reference);
|
||||
|
||||
py::class_<mlir::ModuleOp>(m, "ModuleOp")
|
||||
.def("create",
|
||||
[](mlir::Location loc) { return mlir::ModuleOp::create(loc); })
|
||||
.def("push_back",
|
||||
[](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); })
|
||||
.def("dump", &mlir::ModuleOp::dump)
|
||||
.def("getAsStr", [](mlir::ModuleOp& m) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
m.print(os);
|
||||
return os.str();
|
||||
});
|
||||
|
||||
py::class_<mlir::FuncOp>(m, "FuncOp")
|
||||
.def("create",
|
||||
[](mlir::Location location, std::string name,
|
||||
mlir::FunctionType type) {
|
||||
auto func = mlir::FuncOp::create(location, name, type);
|
||||
func.addEntryBlock();
|
||||
return func;
|
||||
})
|
||||
.def(
|
||||
"getBody",
|
||||
[](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); },
|
||||
py::return_value_policy::reference)
|
||||
.def("getArguments",
|
||||
[](mlir::FuncOp& f) { return f.getArguments().vec(); })
|
||||
.def("getName", [](mlir::FuncOp& f) { return f.getName().str(); })
|
||||
.def("getType", &mlir::FuncOp::getType);
|
||||
|
||||
py::class_<mlir::ReturnOp>(m, "ReturnOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
std::vector<mlir::Value> values) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::ReturnOp>(loc,
|
||||
mlir::ArrayRef<mlir::Value>(values))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::AddOp
|
||||
py::class_<mlir::TF::AddV2Op>(m, "Tf_AddV2Op")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::AddV2Op>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
py::class_<mlir::TF::AnyOp>(m, "Tf_AnyOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input,
|
||||
mlir::Value reduction_indices,
|
||||
bool keep_dims = false) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::TF::AnyOp>(loc, opb.getI1Type(), input,
|
||||
reduction_indices, keep_dims)
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::ConstOp
|
||||
py::class_<mlir::TF::ConstOp>(m, "Tf_ConstOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
mlir::Attribute value) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::ConstOp>(loc, value).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::EqualOp
|
||||
py::class_<mlir::TF::EqualOp>(m, "Tf_EqualOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::TF::EqualOp>(loc, x, y, opb.getBoolAttr(true))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::GreaterEqualOp
|
||||
py::class_<mlir::TF::GreaterEqualOp>(m, "Tf_GreaterEqualOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::GreaterEqualOp>(loc, x, y)
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::GreaterOp
|
||||
py::class_<mlir::TF::GreaterOp>(m, "Tf_GreaterOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::GreaterOp>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::LegacyCallOp
|
||||
py::class_<mlir::TF::LegacyCallOp>(m, "Tf_LegacyCallOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
std::vector<mlir::Type> output, std::vector<mlir::Value> args,
|
||||
std::string f) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::TF::LegacyCallOp>(
|
||||
loc, mlir::ArrayRef<mlir::Type>(output),
|
||||
mlir::ArrayRef<mlir::Value>(args), mlir::StringRef(f))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::LessEqualOp
|
||||
py::class_<mlir::TF::LessEqualOp>(m, "Tf_LessEqualOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::LessEqualOp>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::LessOp
|
||||
py::class_<mlir::TF::LessOp>(m, "Tf_LessOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::LessOp>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::NegOp
|
||||
py::class_<mlir::TF::NegOp>(m, "Tf_NegOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
mlir::Value x) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::NegOp>(loc, x).getOperation();
|
||||
});
|
||||
|
||||
py::class_<mlir::TF::NotEqualOp>(m, "Tf_NotEqualOp")
|
||||
.def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) {
|
||||
return opb
|
||||
.create<mlir::TF::NotEqualOp>(
|
||||
loc, x, y, mlir::BoolAttr::get(true, opb.getContext()))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::SubOp
|
||||
py::class_<mlir::TF::SubOp>(m, "Tf_SubOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::SubOp>(loc, x, y).getOperation();
|
||||
});
|
||||
}
|
48
tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
Normal file
48
tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
void init_types(py::module& m) {
|
||||
// Type
|
||||
py::class_<mlir::Type> Type(m, "Type");
|
||||
Type.def("getKind", &mlir::Type::getKind);
|
||||
|
||||
// Type Enums
|
||||
py::enum_<mlir::StandardTypes::Kind>(Type, "StandardTypes_Kind")
|
||||
.value("BF16", mlir::StandardTypes::BF16);
|
||||
|
||||
// Type Sub-classes
|
||||
py::class_<mlir::FunctionType, mlir::Type>(m, "FunctionType")
|
||||
.def("getResults",
|
||||
[](mlir::FunctionType& ft) { return ft.getResults().vec(); });
|
||||
|
||||
py::class_<mlir::FloatType, mlir::Type>(m, "FloatType")
|
||||
.def("get", &mlir::FloatType::get);
|
||||
|
||||
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
|
||||
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
|
||||
&mlir::IntegerType::get));
|
||||
|
||||
py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
|
||||
.def("get", &mlir::UnrankedTensorType::get);
|
||||
|
||||
py::class_<mlir::RankedTensorType, mlir::Type>(m, "RankedTensorType")
|
||||
.def("get", [](std::vector<int64_t> shape, mlir::Type ty) {
|
||||
return mlir::RankedTensorType::get(mlir::ArrayRef<int64_t>(shape), ty);
|
||||
});
|
||||
}
|
22
tensorflow/python/tf_program/BUILD
Normal file
22
tensorflow/python/tf_program/BUILD
Normal file
@ -0,0 +1,22 @@
|
||||
package(licenses = ["notice"])
|
||||
|
||||
py_library(
|
||||
name = "pywrap_tfd",
|
||||
srcs = ["pywrap_tfd.py"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/python/mlir_wrapper",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mlir_gen",
|
||||
srcs = ["mlir_gen.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":pywrap_tfd",
|
||||
"//tensorflow/python/autograph/pyct",
|
||||
"//tensorflow/python/autograph/pyct/static_analysis",
|
||||
"//tensorflow/python/types",
|
||||
"@gast_archive//:gast",
|
||||
],
|
||||
)
|
456
tensorflow/python/tf_program/mlir_gen.py
Normal file
456
tensorflow/python/tf_program/mlir_gen.py
Normal file
@ -0,0 +1,456 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""mlir_gen: Generate mlir code from python code."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# pylint: disable=missing-function-docstring
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast as ast
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||
from tensorflow.python.autograph.pyct.static_analysis import liveness
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
|
||||
import tensorflow.python.tf_program.pywrap_tfd as tfp
|
||||
from tensorflow.python.types import core
|
||||
|
||||
|
||||
class SymbolTable(object):
|
||||
"""Symbol Table for python code."""
|
||||
|
||||
def __init__(self):
|
||||
self.symbols = []
|
||||
self.enter_scope()
|
||||
|
||||
def enter_scope(self):
|
||||
"""Enter a new scope - at function level."""
|
||||
self.symbols.append({'types': {}, 'symbols': {}})
|
||||
self.curr_table = self.symbols[len(self.symbols) - 1]
|
||||
|
||||
def insert_symbol(self, name, value):
|
||||
self.curr_table['symbols'][name] = value
|
||||
self.curr_table['types'][name] = value.getType()
|
||||
return value
|
||||
|
||||
def insert_type(self, name, type_):
|
||||
self.curr_table['types'][name] = type_
|
||||
|
||||
def exit_scope(self):
|
||||
self.symbols.pop()
|
||||
self.curr_table = self.symbols[len(self.symbols) - 1]
|
||||
|
||||
def lookup(self, name):
|
||||
curr_idx = len(self.symbols) - 1
|
||||
while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']):
|
||||
curr_idx -= 1
|
||||
if curr_idx < 0:
|
||||
return None
|
||||
return self.symbols[curr_idx]['symbols'][name]
|
||||
|
||||
def lookup_type(self, name):
|
||||
curr_idx = len(self.symbols) - 1
|
||||
while curr_idx >= 0 and (name not in self.symbols[curr_idx]['types']):
|
||||
curr_idx -= 1
|
||||
if curr_idx < 0:
|
||||
return None
|
||||
return self.symbols[curr_idx]['types'][name]
|
||||
|
||||
def __repr__(self):
|
||||
s = '\n'.join(
|
||||
' ' * idx * 2 + str(table) for idx, table in enumerate(self.symbols))
|
||||
return s
|
||||
|
||||
|
||||
class ProcessType(ast.NodeVisitor):
|
||||
"""Visit a node and return processed type Currently only visits annotations and gives their type.
|
||||
"""
|
||||
|
||||
def __init__(self, prog, ctx):
|
||||
self.prog = prog
|
||||
self.ctx = ctx
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
# Supported: core.Tensor
|
||||
value = self.visit(node.value)
|
||||
if value is None or not hasattr(value, node.attr):
|
||||
raise AttributeError(str(type(value)) + ' has no attribute ' + node.attr)
|
||||
attr = getattr(value, node.attr)
|
||||
|
||||
if attr == core.Tensor:
|
||||
return tfp.UnrankedTensorType.get(tfp.IntegerType.get(32, self.prog.ctx))
|
||||
return attr
|
||||
|
||||
def visit_Name(self, node):
|
||||
if node.id == 'int':
|
||||
return tfp.IntegerType.get(32, self.prog.ctx)
|
||||
if node.id == 'bool':
|
||||
return tfp.IntegerType.get(1, self.prog.ctx)
|
||||
if node.id in self.ctx.info.namespace:
|
||||
return self.ctx.info.namespace[node.id]
|
||||
|
||||
|
||||
class MLIRGen(ast.NodeVisitor):
|
||||
"""Visit the AST and generate MLIR code Requires liveness, reading_definitions.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx):
|
||||
self.ctx = ctx
|
||||
self.symbol_table = SymbolTable()
|
||||
self.prog = tfp.TFProgram()
|
||||
self.opbuilder = None
|
||||
|
||||
def visit_block(self, block):
|
||||
return [self.visit(item) for item in block]
|
||||
|
||||
def process_type(self, node):
|
||||
return ProcessType(self.prog, self.ctx).visit(node)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
value = self.visit(node.value)
|
||||
if isinstance(value, tuple):
|
||||
# If it is a tuple of values, assign one to each in targets
|
||||
# TODO: This currently is assuming that all elts in targets[0] are Name
|
||||
# objects. This might not be always True.
|
||||
for key, val in zip(node.targets[0].elts, value):
|
||||
self.symbol_table.insert_symbol(key.id, val)
|
||||
else:
|
||||
self.symbol_table.insert_symbol(node.targets[0].id, value)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
left = self.visit(node.left)
|
||||
right = self.visit(node.right)
|
||||
if isinstance(node.op, ast.Sub):
|
||||
return tfp.Tf_SubOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
|
||||
left, right).getResult(0)
|
||||
if isinstance(node.op, ast.Add):
|
||||
return tfp.Tf_AddV2Op.create(self.opbuilder,
|
||||
self.opbuilder.getUnknownLoc(), left,
|
||||
right).getResult(0)
|
||||
|
||||
def visit_BoolOp(self, node):
|
||||
values = [self.visit(value) for value in node.values]
|
||||
if isinstance(node.op, ast.Or):
|
||||
return tfp.OrOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
|
||||
values).getResult(0)
|
||||
if isinstance(node.op, ast.And):
|
||||
return tfp.AndOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
|
||||
values).getResult(0)
|
||||
|
||||
def visit_Call(self, node):
|
||||
func = self.visit(node.func)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
callop = tfp.Tf_LegacyCallOp.create(self.opbuilder,
|
||||
self.opbuilder.getUnknownLoc(),
|
||||
func.getType().getResults(), args,
|
||||
func.getName())
|
||||
if callop.getNumResults() == 1:
|
||||
return callop[0]
|
||||
return tuple(callop.getResult(idx) for idx in range(callop.getNumResults()))
|
||||
|
||||
def visit_Compare(self, node):
|
||||
left = self.visit(node.left)
|
||||
opb = self.opbuilder
|
||||
for op, right in zip(node.ops, node.comparators):
|
||||
if isinstance(op, ast.Eq):
|
||||
left = tfp.Tf_EqualOp.create(opb, opb.getUnknownLoc(), left,
|
||||
self.visit(right)).getResult(0)
|
||||
elif isinstance(op, ast.Lt):
|
||||
left = tfp.Tf_LessOp.create(opb, opb.getUnknownLoc(), left,
|
||||
self.visit(right)).getResult(0)
|
||||
elif isinstance(op, ast.LtE):
|
||||
left = tfp.Tf_LessEqualOp.create(opb, opb.getUnknownLoc(), left,
|
||||
self.visit(right)).getResult(0)
|
||||
elif isinstance(op, ast.Gt):
|
||||
left = tfp.Tf_GreaterOp.create(opb, opb.getUnknownLoc(), left,
|
||||
self.visit(right)).getResult(0)
|
||||
elif isinstance(op, ast.GtE):
|
||||
left = tfp.Tf_GreaterEqualOp.create(opb, opb.getUnknownLoc(), left,
|
||||
self.visit(right)).getResult(0)
|
||||
elif isinstance(op, ast.NotEq):
|
||||
left = tfp.Tf_NotEqualOp.create(opb, opb.getUnknownLoc(), left,
|
||||
self.visit(right)).getResult(0)
|
||||
else:
|
||||
raise NotImplementedError('CompareOp operator not recognized')
|
||||
return left
|
||||
|
||||
def visit_Constant(self, node):
|
||||
opb = self.opbuilder
|
||||
value = None
|
||||
if isinstance(node.value, int):
|
||||
value = tfp.Tf_ConstOp.create(
|
||||
opb, opb.getUnknownLoc(),
|
||||
tfp.IntegerAttr.get(
|
||||
tfp.IntegerType.get(32, self.prog.ctx), node.value)).getResult(0)
|
||||
return value
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
# Cache the current builder
|
||||
cache_builder = self.opbuilder
|
||||
inputs, outputs = [], []
|
||||
|
||||
for arg in node.args.args:
|
||||
inputs.append(self.process_type(arg.annotation))
|
||||
|
||||
if node.returns:
|
||||
outputs = [self.process_type(node.returns)]
|
||||
|
||||
currfunc = self.prog.add_function(
|
||||
self.ctx.namer.new_symbol(node.name, []),
|
||||
self.prog.get_function_type(inputs, outputs))
|
||||
|
||||
# Add the function to symbol table and enter new scope
|
||||
self.symbol_table.insert_symbol(node.name, currfunc)
|
||||
self.symbol_table.enter_scope()
|
||||
|
||||
# Add arguments to symbol table
|
||||
for arg, value in zip(node.args.args, currfunc.getArguments()):
|
||||
self.symbol_table.insert_symbol(arg.id, value)
|
||||
self.opbuilder = tfp.OpBuilder(currfunc.getBody())
|
||||
|
||||
self.visit_block(node.body)
|
||||
self.symbol_table.exit_scope()
|
||||
self.opbuilder = cache_builder
|
||||
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
|
||||
# Create ifop
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
|
||||
modified_in_cond = list(body_scope.modified | orelse_scope.modified)
|
||||
outputs = [
|
||||
self.symbol_table.lookup_type(str(var)) for var in modified_in_cond
|
||||
]
|
||||
ifop = tfp.IfOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), cond,
|
||||
outputs)
|
||||
|
||||
# Cache the builder
|
||||
cache_builder = self.opbuilder
|
||||
|
||||
# Visit body
|
||||
self.opbuilder = tfp.OpBuilder(ifop.getRegion(0))
|
||||
# Enter scope to avoid values generated inside the region to come in symbol
|
||||
# table
|
||||
self.symbol_table.enter_scope()
|
||||
for stmt in node.body:
|
||||
self.visit(stmt)
|
||||
retvals = [
|
||||
self.symbol_table.lookup(str(varname)) for varname in modified_in_cond
|
||||
]
|
||||
tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals)
|
||||
self.symbol_table.exit_scope()
|
||||
|
||||
# Visit orelse
|
||||
self.opbuilder = tfp.OpBuilder(ifop.getRegion(1))
|
||||
self.symbol_table.enter_scope()
|
||||
for stmt in node.orelse:
|
||||
self.visit(stmt)
|
||||
retvals = [
|
||||
self.symbol_table.lookup(str(varname)) for varname in modified_in_cond
|
||||
]
|
||||
tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals)
|
||||
self.symbol_table.exit_scope()
|
||||
|
||||
# Reset builder and enter return values in symbol table
|
||||
self.opbuilder = cache_builder
|
||||
for idx, var in enumerate(modified_in_cond):
|
||||
self.symbol_table.insert_symbol(str(var), ifop.getResult(idx))
|
||||
|
||||
if ifop.getNumResults() == 1:
|
||||
return ifop.getResult(0)
|
||||
|
||||
return tuple(ifop.getResult(i) for i in range(ifop.getNumResults()))
|
||||
|
||||
def visit_Name(self, node):
|
||||
if self.symbol_table.lookup(node.id):
|
||||
return self.symbol_table.lookup(node.id)
|
||||
raise NotImplementedError('Symbol not found' + node.id)
|
||||
|
||||
def visit_Return(self, node):
|
||||
opb = self.opbuilder
|
||||
value = self.visit(node.value)
|
||||
if isinstance(value, tuple):
|
||||
# For more than one return values
|
||||
return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), list(value))
|
||||
return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), [value])
|
||||
|
||||
def visit_Tuple(self, node):
|
||||
return tuple(self.visit(elt) for elt in node.elts)
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
operand = self.visit(node.operand)
|
||||
if isinstance(node.op, ast.USub):
|
||||
return tfp.Tf_NegOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
|
||||
operand).getResult(0)
|
||||
|
||||
def _get_basic_loop_vars(self, modified, live_in, live_out):
|
||||
# [This is directly from
|
||||
# tensorflow/python/autograph/converters/control_flow.py]
|
||||
# The loop variables corresponding to simple symbols (e.g. `x`).
|
||||
basic_loop_vars = []
|
||||
for s in modified:
|
||||
if s.is_composite():
|
||||
# TODO: Raise an error when this happens for a TF loop.
|
||||
continue
|
||||
# Variables not live into or out of the loop are considered local to the
|
||||
# loop.
|
||||
if s not in live_in and s not in live_out:
|
||||
continue
|
||||
basic_loop_vars.append(s)
|
||||
return frozenset(basic_loop_vars)
|
||||
|
||||
def _get_composite_loop_vars(self, modified, live_in):
|
||||
# [This is directly from
|
||||
# tensorflow/python/autograph/converters/control_flow.py]
|
||||
# The loop variables corresponding to composite symbols (e.g. `self.x`).
|
||||
composite_loop_vars = []
|
||||
for s in modified:
|
||||
if not s.is_composite():
|
||||
continue
|
||||
# Mutations made to objects created inside the loop will appear as writes
|
||||
# to composite symbols. Because these mutations appear as modifications
|
||||
# made to composite symbols, we check whether the composite's parent is
|
||||
# actually live into the loop.
|
||||
# Example:
|
||||
# while cond:
|
||||
# x = Foo()
|
||||
# x.foo = 2 * x.foo # x.foo is live into the loop, but x is not.
|
||||
#
|
||||
# Note that some parents might not be symbols - for example, in x['foo'],
|
||||
# 'foo' is a parent, but it's a literal, not a symbol. We don't check the
|
||||
# liveness of literals.
|
||||
support_set_symbols = tuple(
|
||||
sss for sss in s.support_set if sss.is_symbol())
|
||||
if not all(sss in live_in for sss in support_set_symbols):
|
||||
continue
|
||||
composite_loop_vars.append(s)
|
||||
return frozenset(composite_loop_vars)
|
||||
|
||||
def _get_loop_vars(self, node, modified):
|
||||
# [This is directly from python/autograph/converters/control_flow.py]
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
|
||||
live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
|
||||
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
|
||||
reserved_symbols = body_scope.referenced
|
||||
|
||||
basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out)
|
||||
composite_loop_vars = self._get_composite_loop_vars(modified, live_in)
|
||||
loop_vars = tuple(basic_loop_vars | composite_loop_vars)
|
||||
|
||||
# Variable that are used or defined inside the loop, but not defined
|
||||
# before entering the loop. Only simple variables must be defined. The
|
||||
# composite ones will be implicitly checked at runtime.
|
||||
undefined_lives = basic_loop_vars - defined_in
|
||||
|
||||
return loop_vars, reserved_symbols, undefined_lives
|
||||
|
||||
def visit_While(self, node):
|
||||
|
||||
# Create a new WhileOp
|
||||
# `inputs` are initial values for loop variables
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
loop_vars, _, _ = self._get_loop_vars(node, body_scope.modified)
|
||||
inputs = [self.symbol_table.lookup(str(name)) for name in loop_vars]
|
||||
types = [input_.getType() for input_ in inputs]
|
||||
while_op = tfp.WhileOp.create(self.opbuilder,
|
||||
self.opbuilder.getUnknownLoc(), inputs, types)
|
||||
|
||||
# cache the current builder
|
||||
cache_builder = self.opbuilder
|
||||
|
||||
# Process cond
|
||||
self.symbol_table.enter_scope()
|
||||
for input_, type_ in zip(loop_vars, types):
|
||||
self.symbol_table.insert_symbol(
|
||||
str(input_),
|
||||
while_op.getRegion(0).front().addArgument(type_))
|
||||
self.opbuilder = tfp.OpBuilder(while_op.getRegion(0))
|
||||
tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
|
||||
[self.visit(node.test)])
|
||||
self.symbol_table.exit_scope()
|
||||
|
||||
# Process body
|
||||
self.symbol_table.enter_scope()
|
||||
for input_, type_ in zip(loop_vars, types):
|
||||
self.symbol_table.insert_symbol(
|
||||
str(input_),
|
||||
while_op.getRegion(1).front().addArgument(type_))
|
||||
self.opbuilder = tfp.OpBuilder(while_op.getRegion(1))
|
||||
self.visit_block(node.body)
|
||||
tfp.ReturnOp.create(
|
||||
self.opbuilder, self.opbuilder.getUnknownLoc(),
|
||||
[self.symbol_table.lookup(str(name)) for name in loop_vars])
|
||||
self.symbol_table.exit_scope()
|
||||
|
||||
# Enter new values as symbols
|
||||
for idx, var in enumerate(loop_vars):
|
||||
self.symbol_table.insert_symbol(str(var), while_op.getResult(idx))
|
||||
|
||||
# Restore builder
|
||||
self.opbuilder = cache_builder
|
||||
|
||||
|
||||
def mlir_gen_internal(node, entity_info):
|
||||
"""Returns mlir module for unprocessed node `node`."""
|
||||
namer = naming.Namer({})
|
||||
graphs = cfg.build(node)
|
||||
ctx = transformer.Context(entity_info, namer, None)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs)
|
||||
node = reaching_fndefs.resolve(node, ctx, graphs)
|
||||
node = liveness.resolve(node, ctx, graphs)
|
||||
mlir_generator = MLIRGen(ctx)
|
||||
mlir_generator.visit(node)
|
||||
return mlir_generator.prog
|
||||
|
||||
|
||||
def mlir_gen(func):
|
||||
"""Parse a function and return TFProgram."""
|
||||
node, source = parser.parse_entity(func, future_features=())
|
||||
entity_info = transformer.EntityInfo(
|
||||
name=func.__name__,
|
||||
source_code=source,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace=inspect_utils.getnamespace(func))
|
||||
return mlir_gen_internal(node, entity_info)
|
||||
|
||||
|
||||
def mlir_gen_from_source(source=None, src_file=None):
|
||||
"""Parse a function as either a string or from a supplied file path and return a TFProgram.
|
||||
"""
|
||||
if source is None:
|
||||
source = open(src_file).read()
|
||||
node = ast.parse(source)
|
||||
entity_info = transformer.EntityInfo(
|
||||
name='mlir_module',
|
||||
source_code=source,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace={})
|
||||
return mlir_gen_internal(node, entity_info)
|
159
tensorflow/python/tf_program/pywrap_tfd.py
Normal file
159
tensorflow/python/tf_program/pywrap_tfd.py
Normal file
@ -0,0 +1,159 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Intermediate between python bindings for MLIR and mlir generation for tensorflow program.
|
||||
|
||||
This passes most of the mlir classes as is, but adds a few new operations and
|
||||
the basic structure for a tensorflow program.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.mlir.python.mlir_wrapper import mlir_wrapper as mlir
|
||||
|
||||
# Class Definitions
|
||||
OpBuilder = mlir.OpBuilder
|
||||
Block = mlir.Block
|
||||
|
||||
# Types
|
||||
Type = mlir.Type
|
||||
IntegerType = mlir.IntegerType
|
||||
FloatType = mlir.FloatType
|
||||
RankedTensorType = mlir.RankedTensorType
|
||||
UnrankedTensorType = mlir.UnrankedTensorType
|
||||
IntegerAttr = mlir.IntegerAttr
|
||||
|
||||
# Standard Ops
|
||||
ReturnOp = mlir.ReturnOp
|
||||
|
||||
# TF Dialect Ops
|
||||
Tf_AnyOp = mlir.Tf_AnyOp
|
||||
Tf_AddV2Op = mlir.Tf_AddV2Op
|
||||
Tf_ConstOp = mlir.Tf_ConstOp
|
||||
Tf_EqualOp = mlir.Tf_EqualOp
|
||||
Tf_GreaterEqualOp = mlir.Tf_GreaterEqualOp
|
||||
Tf_GreaterOp = mlir.Tf_GreaterOp
|
||||
Tf_LegacyCallOp = mlir.Tf_LegacyCallOp
|
||||
Tf_LessEqualOp = mlir.Tf_LessEqualOp
|
||||
Tf_LessOp = mlir.Tf_LessOp
|
||||
Tf_NegOp = mlir.Tf_NegOp
|
||||
Tf_NotEqualOp = mlir.Tf_NotEqualOp
|
||||
Tf_SubOp = mlir.Tf_SubOp
|
||||
|
||||
|
||||
class IfOp(object):
|
||||
"""
|
||||
tfp.if(cond) ({body}, {orelse}) : type If `cond` is true, `body` is
|
||||
executed, otherwise `orelse` is executed.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(cls, opb, loc, cond, outputs):
|
||||
state = mlir.OperationState(loc, "tfp.If")
|
||||
state.addOperands([cond])
|
||||
state.addTypes(outputs)
|
||||
state.addRegion().push_back(Block.new()) # body region
|
||||
state.addRegion().push_back(Block.new()) # orelse region
|
||||
return opb.createOperation(state)
|
||||
|
||||
|
||||
class OrOp(object):
|
||||
"""
|
||||
tfp.Or(ops...) This is like tf.Any, except that the first dimension is opened
|
||||
into `ops`.
|
||||
|
||||
Returns a tensor of 1-bit integers which is "Logical OR" of the
|
||||
coressponding elements in ops...
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(cls, opb, loc, values):
|
||||
state = mlir.OperationState(loc, "tfp.Or")
|
||||
state.addTypes(
|
||||
[UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
|
||||
state.addOperands(values)
|
||||
return opb.createOperation(state)
|
||||
|
||||
|
||||
class AndOp(object):
|
||||
"""
|
||||
tfp.And(ops...) This is like tf.All, except that the first dimension is opened
|
||||
to `ops`.
|
||||
|
||||
Returns a tensor of 1-bit integers which is "Logical AND" of the
|
||||
coressponding elements in ops...
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(cls, opb, loc, values):
|
||||
state = mlir.OperationState(loc, "tfp.And")
|
||||
state.addTypes(
|
||||
[UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
|
||||
state.addOperands(values)
|
||||
return opb.createOperation(state)
|
||||
|
||||
|
||||
class WhileOp(object):
|
||||
"""tfp.While(init-vals, {
|
||||
|
||||
^bb1(cond-args):
|
||||
cond-region
|
||||
return cond
|
||||
}, {
|
||||
^bb1(body-args):
|
||||
body-region
|
||||
})
|
||||
As long as `cond-region` returns a "true"-like value, the body-region
|
||||
is executed and the arguments are replaced by its return values for the next
|
||||
iteration.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(cls, opb, loc, inputs, outputs):
|
||||
state = mlir.OperationState(loc, "tfp.While")
|
||||
state.addOperands(inputs)
|
||||
state.addTypes(outputs)
|
||||
state.addRegion().push_back(Block.new()) # cond region
|
||||
state.addRegion().push_back(Block.new()) # body region
|
||||
return opb.createOperation(state)
|
||||
|
||||
|
||||
class TFProgram(object):
|
||||
"""Python wrap for a Tensorflow Program (essentially an mlir Module)."""
|
||||
|
||||
def __init__(self):
|
||||
mlir.registerDialects()
|
||||
self.ctx = mlir.MLIRContext()
|
||||
self.builder = mlir.Builder(self.ctx)
|
||||
self.module = mlir.ModuleOp.create(mlir.UnknownLoc.get(self.ctx))
|
||||
self.curr_func = None
|
||||
|
||||
def add_function(self, name, func_type):
|
||||
self.curr_func = mlir.FuncOp.create(
|
||||
mlir.UnknownLoc.get(self.ctx), name, func_type)
|
||||
self.module.push_back(self.curr_func)
|
||||
return self.curr_func
|
||||
|
||||
def get_function_type(self, inputs, outputs):
|
||||
return self.builder.getFunctionType(inputs, outputs)
|
||||
|
||||
def dump(self):
|
||||
self.module.dump()
|
||||
|
||||
def __str__(self):
|
||||
return self.module.getAsStr()
|
20
tensorflow/python/tf_program/tests/BUILD
Normal file
20
tensorflow/python/tf_program/tests/BUILD
Normal file
@ -0,0 +1,20 @@
|
||||
package(licenses = ["notice"])
|
||||
|
||||
py_test(
|
||||
name = "mlir_gen_test",
|
||||
size = "small",
|
||||
testonly = True,
|
||||
srcs = ["mlir_gen_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
tags = [
|
||||
"no_oss_py2",
|
||||
"no_pip",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/tf_program:mlir_gen",
|
||||
"//tensorflow/python/types",
|
||||
],
|
||||
)
|
247
tensorflow/python/tf_program/tests/mlir_gen_test.py
Normal file
247
tensorflow/python/tf_program/tests/mlir_gen_test.py
Normal file
@ -0,0 +1,247 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `mlir_gen` module"""
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.types import core
|
||||
from tensorflow.python.tf_program.mlir_gen import mlir_gen
|
||||
|
||||
import tensorflow.compiler.mlir.python.mlir_wrapper.filecheck_wrapper as fw
|
||||
|
||||
|
||||
class MLIRGenTestBase(test.TestCase):
|
||||
|
||||
def _check_code(self, mlir_code, exp_mlir_code):
|
||||
return self.assertTrue(fw.check(str(mlir_code), exp_mlir_code))
|
||||
|
||||
|
||||
class MLIRGenTest(MLIRGenTestBase):
|
||||
"""MLIR Generation Tests for Tensorflow Program"""
|
||||
|
||||
def test_simple(self):
|
||||
|
||||
def test_fn():
|
||||
pass
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
mlir_code_exp = r"""
|
||||
CHECK-LABEL: @test_fn
|
||||
"""
|
||||
self._check_code(mlir_code, mlir_code_exp)
|
||||
|
||||
def test_argument(self):
|
||||
|
||||
def test_fn(x: core.Tensor) -> core.Tensor:
|
||||
return x
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
mlir_code_exp = r"""
|
||||
CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
CHECK-NEXT: return %arg0 : tensor<*xi32>
|
||||
"""
|
||||
self._check_code(mlir_code, mlir_code_exp)
|
||||
|
||||
def test_constant(self):
|
||||
|
||||
def test_fn() -> int:
|
||||
return 23
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
exp_mlir_code = r"""
|
||||
CHECK-LABEL: func @test_fn() -> i32
|
||||
CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor<i32>} : () -> tensor<i32>
|
||||
CHECK: return %[[r0]] : tensor<i32>
|
||||
"""
|
||||
self._check_code(mlir_code, exp_mlir_code)
|
||||
|
||||
def test_BoolOp(self):
|
||||
|
||||
def test_fn(x: bool, y: bool) -> bool:
|
||||
return x or y or x and x and y
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
exp_mlir_code = r"""
|
||||
CHECK-LABEL: func @test_fn(%arg0: i1, %arg1: i1) -> i1
|
||||
CHECK: %[[r0:[0-9]+]] = "tfp.And"(%arg0, %arg0, %arg1) : (i1, i1, i1) -> tensor<*xi1>
|
||||
CHECK: %[[r1:[0-9]+]] = "tfp.Or"(%arg0, %arg1, %[[r0]]) : (i1, i1, tensor<*xi1>) -> tensor<*xi1>
|
||||
return %[[r1]] : tensor<*xi1>
|
||||
"""
|
||||
self._check_code(mlir_code, exp_mlir_code)
|
||||
|
||||
def test_Call(self):
|
||||
|
||||
def test_fn():
|
||||
|
||||
def f1():
|
||||
return 23
|
||||
|
||||
def f2():
|
||||
return f1()
|
||||
|
||||
f2()
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
exp_mlir_code = r"""
|
||||
CHECK-LABEL: func @test_fn()
|
||||
CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f2} : () -> ()
|
||||
CHECK: }
|
||||
CHECK-LABEL: func @f1() {
|
||||
CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor<i32>} : () -> tensor<i32>
|
||||
CHECK: return %[[r0]] : tensor<i32>
|
||||
CHECK: }
|
||||
CHECK-LABEL: func @f2() {
|
||||
CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f1} : () -> ()
|
||||
}
|
||||
"""
|
||||
self._check_code(mlir_code, exp_mlir_code)
|
||||
|
||||
def test_Compare(self):
|
||||
|
||||
def test_fn(x: core.Tensor, y: core.Tensor, z: core.Tensor):
|
||||
return x > y < z
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
exp_mlir_code = r"""
|
||||
CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>)
|
||||
CHECK: %[[r0:[0-9]+]] = "tf.Greater"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
|
||||
CHECK: %[[r1:[0-9]+]] = "tf.Less"(%[[r0]], %arg2) : (tensor<*xi1>, tensor<*xi32>) -> tensor<*xi1>
|
||||
CHECK: return %[[r1]] : tensor<*xi1>
|
||||
"""
|
||||
self._check_code(mlir_code, exp_mlir_code)
|
||||
|
||||
def test_Assign_BinOp(self):
|
||||
|
||||
def test_fn() -> int:
|
||||
y = 12 + 23 - 24
|
||||
return y
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
exp_mlir_code = r"""
|
||||
CHECK-LABEL: func @test_fn() -> i32
|
||||
CHECK: %[[r0:[0-9]+]] = "tf.AddV2"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
CHECK: %[[r1:[0-9]+]] = "tf.Sub"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
CHECK: return %[[r1]] : tensor<i32>
|
||||
"""
|
||||
self._check_code(mlir_code, exp_mlir_code)
|
||||
|
||||
def test_if(self):
|
||||
|
||||
def test_fn(x: core.Tensor) -> int:
|
||||
res = 0
|
||||
if x > 0:
|
||||
res = 1
|
||||
elif x < 0:
|
||||
res = -1
|
||||
else:
|
||||
res = 0
|
||||
return res
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
exp_mlir_code = r"""
|
||||
CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> i32
|
||||
|
||||
CHECK: %[[r1:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
CHECK-NEXT: %[[r2:[0-9]+]] = "tfp.If"(%[[r1]]) ( {
|
||||
CHECK: return %{{[0-9]+}} : tensor<i32>
|
||||
CHECK-NEXT: }, {
|
||||
CHECK: %[[r3:[0-9]+]] = "tf.Less"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
CHECK: %[[r4:[0-9]+]] = "tfp.If"(%[[r3]]) ( {
|
||||
CHECK: %[[r5:[0-9]+]] = "tf.Neg"(%{{[0-9]+}}) : (tensor<i32>) -> tensor<i32>
|
||||
CHECK: return %[[r5]] : tensor<i32>
|
||||
CHECK-NEXT: }, {
|
||||
CHECK: return %{{[0-9]+}} : tensor<i32>
|
||||
CHECK-NEXT: }) : (tensor<*xi1>) -> tensor<i32>
|
||||
CHECK: return %[[r4]] : tensor<i32>
|
||||
CHECK-NEXT: }) : (tensor<*xi1>) -> tensor<i32>
|
||||
CHECK-NEXT: return %[[r2]] : tensor<i32>
|
||||
"""
|
||||
self._check_code(mlir_code, exp_mlir_code)
|
||||
|
||||
def test_while(self):
|
||||
|
||||
def test_fn(x: core.Tensor) -> core.Tensor:
|
||||
s = 0
|
||||
while x > 0:
|
||||
s = s + x
|
||||
return s
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
exp_mlir_code = r"""
|
||||
CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32>
|
||||
|
||||
CHECK: %[[r1:[0-9]+]] = "tfp.While"(%0) ( {
|
||||
CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor<i32>):
|
||||
CHECK: %[[r2:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
CHECK-NEXT: return %[[r2]] : tensor<*xi1>
|
||||
CHECK-NEXT: }, {
|
||||
CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor<i32>):
|
||||
CHECK: %[[r3:[0-9]+]] = "tf.AddV2"(%arg1, %arg0) : (tensor<i32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
CHECK-NEXT: return %[[r3]] : tensor<*xi32>
|
||||
CHECK-NEXT: }) : (tensor<i32>) -> tensor<i32>
|
||||
CHECK-NEXT: return %[[r1]] : tensor<i32>
|
||||
"""
|
||||
self._check_code(mlir_code, exp_mlir_code)
|
||||
|
||||
def test_fibonacci(self):
|
||||
|
||||
def test_fn(x: core.Tensor) -> core.Tensor:
|
||||
res, idx = 0, 2
|
||||
a, b = 0, 1
|
||||
if x == 0 or x == 1:
|
||||
res = x
|
||||
else:
|
||||
while idx <= x:
|
||||
res = a + b
|
||||
a = b
|
||||
b = res
|
||||
idx = idx + 1
|
||||
return res
|
||||
|
||||
mlir_code = mlir_gen(test_fn)
|
||||
exp_mlir_code = r"""
|
||||
CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32>
|
||||
CHECK: %[[r5:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
CHECK: %[[r7:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
CHECK: %[[r8:[0-9]+]] = "tfp.Or"(%[[r5]], %[[r7]]) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1>
|
||||
|
||||
CHECK: %[[r9:[0-9]+]]:4 = "tfp.If"(%[[r8]]) ( {
|
||||
CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>
|
||||
CHECK-NEXT: }, {
|
||||
CHECK-NEXT: %[[r10:[0-9]+]]:4 = "tfp.While"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||
CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
|
||||
CHECK-NEXT: %[[r11:[0-9]+]] = "tf.LessEqual"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>) -> tensor<*xi1>
|
||||
CHECK-NEXT: return %[[r11]] : tensor<*xi1>
|
||||
CHECK-NEXT: }, {
|
||||
CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
|
||||
CHECK-NEXT: %[[r12:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
CHECK: %[[r13:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
CHECK-NEXT: }) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
CHECK-NEXT: return %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}} : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
CHECK-NEXT: }) : (tensor<*xi1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
CHECK-NEXT: return %[[r9]]#{{[0-9]+}} : tensor<i32>
|
||||
"""
|
||||
self._check_code(mlir_code, exp_mlir_code)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user