diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD new file mode 100644 index 00000000000..78f4312da46 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -0,0 +1,41 @@ +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") + +package(licenses = ["notice"]) + +tf_python_pybind_extension( + name = "mlir_wrapper", + srcs = [ + "attrs.cc", + "basic_classes.cc", + "builders.cc", + "mlir_wrapper.cc", + "mlir_wrapper.h", + "ops.cc", + "types.cc", + ], + module_name = "mlir_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@pybind11", + ], +) + +tf_python_pybind_extension( + name = "filecheck_wrapper", + srcs = ["filecheck_wrapper.cc"], + module_name = "filecheck_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:support", + "@pybind11", + ], +) diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc new file mode 100644 index 00000000000..ca7faf2e1d3 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_attrs(py::module& m) { + py::class_(m, "Attribute"); + py::class_(m, "IntegerAttr") + .def("get", + py::overload_cast(&mlir::IntegerAttr::get)); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc new file mode 100644 index 00000000000..25adb44fe1d --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/FileCheck.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_basic_classes(py::module& m) { + py::class_(m, "MLIRContext").def(py::init<>()); + + py::class_(m, "Location"); + + py::class_(m, "UnknownLoc") + .def("get", &mlir::UnknownLoc::get); + + py::class_(m, "Region") + .def("back", &mlir::Region::back, py::return_value_policy::reference) + .def("front", &mlir::Region::front, py::return_value_policy::reference) + .def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); }) + .def("push_back", &mlir::Region::push_back) + .def("size", [](mlir::Region& r) { return r.getBlocks().size(); }) + .def("front", &mlir::Region::front, py::return_value_policy::reference); + py::class_(m, "Block_Iterator"); + py::class_(m, "Block") + .def("new", ([]() { return new mlir::Block; }), + py::return_value_policy::reference) + .def("end", &mlir::Block::end) + .def("addArgument", &mlir::Block::addArgument); + + py::class_(m, "Value").def("getType", &mlir::Value::getType); + py::class_(m, "OpResult"); + py::class_(m, "BlockArgument"); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc new file mode 100644 index 00000000000..338f17ed6df --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Builders.h" // from @llvm-project + +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_builders(py::module& m) { + py::class_(m, "Builder") + .def(py::init()) + .def("getFunctionType", + [](mlir::Builder& b, std::vector inputs, + std::vector outputs) { + return b.getFunctionType(llvm::ArrayRef(inputs), + llvm::ArrayRef(outputs)); + }); + py::class_(m, "OpBuilder") + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc) + .def("setInsertionPoint", + py::overload_cast( + &mlir::OpBuilder::setInsertionPoint)) + .def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint) + .def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint) + .def( + "createOperation", + [](mlir::OpBuilder& opb, mlir::OperationState& state) { + return opb.createOperation(state); + }, + py::return_value_policy::reference) + .def("getContext", &mlir::OpBuilder::getContext, + py::return_value_policy::reference); + + py::class_(m, "OpBuilder_InsertionPoint") + .def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc new file mode 100644 index 00000000000..8a841856b72 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/FileCheck.h" +#include "llvm/Support/SourceMgr.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +PYBIND11_MODULE(filecheck_wrapper, m) { + m.def("check", [](std::string input, std::string check) { + llvm::FileCheckRequest fcr; + llvm::FileCheck fc(fcr); + llvm::SourceMgr SM = llvm::SourceMgr(); + SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), + llvm::SMLoc()); + SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check), + llvm::SMLoc()); + llvm::Regex regex = fc.buildCheckPrefixRegex(); + fc.readCheckFile(SM, llvm::StringRef(check), regex); + return fc.checkInput(SM, llvm::StringRef(input)); + }); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc new file mode 100644 index 00000000000..6f468cd4267 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +PYBIND11_MODULE(mlir_wrapper, m) { + m.def("registerDialects", []() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + }); + + init_basic_classes(m); + init_types(m); + init_builders(m); + init_ops(m); + init_attrs(m); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h new file mode 100644 index 00000000000..562c59b43e1 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H +#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +void init_basic_classes(py::module& m); +void init_types(py::module& m); +void init_builders(py::module& m); +void init_ops(py::module& m); +void init_attrs(py::module& m); + +#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc new file mode 100644 index 00000000000..4432829653e --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc @@ -0,0 +1,194 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project + +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +void init_ops(py::module& m) { + py::class_>( + m, "Operation") + .def("getRegion", &mlir::Operation::getRegion, + py::return_value_policy::reference) + .def("getResult", &mlir::Operation::getResult) + .def("dump", &mlir::Operation::dump) + .def("getNumResults", &mlir::Operation::getNumResults); + + py::class_(m, "OperationState") + .def(py::init([](mlir::Location loc, std::string name) { + return mlir::OperationState(loc, llvm::StringRef(name)); + })) + .def("addTypes", + [](mlir::OperationState& state, std::vector tys) { + state.addTypes(mlir::ArrayRef(tys)); + }) + .def("addOperands", + [](mlir::OperationState& os, std::vector ops) { + os.addOperands(mlir::ArrayRef(ops)); + }) + .def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion), + py::return_value_policy::reference); + + py::class_(m, "ModuleOp") + .def("create", + [](mlir::Location loc) { return mlir::ModuleOp::create(loc); }) + .def("push_back", + [](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); }) + .def("dump", &mlir::ModuleOp::dump) + .def("getAsStr", [](mlir::ModuleOp& m) { + std::string str; + llvm::raw_string_ostream os(str); + m.print(os); + return os.str(); + }); + + py::class_(m, "FuncOp") + .def("create", + [](mlir::Location location, std::string name, + mlir::FunctionType type) { + auto func = mlir::FuncOp::create(location, name, type); + func.addEntryBlock(); + return func; + }) + .def( + "getBody", + [](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); }, + py::return_value_policy::reference) + .def("getArguments", + [](mlir::FuncOp& f) { return f.getArguments().vec(); }) + .def("getName", [](mlir::FuncOp& f) { return f.getName().str(); }) + .def("getType", &mlir::FuncOp::getType); + + py::class_(m, "ReturnOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + std::vector values) -> mlir::Operation* { + return opb + .create(loc, + mlir::ArrayRef(values)) + .getOperation(); + }); + + // mlir::TF::AddOp + py::class_(m, "Tf_AddV2Op") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + py::class_(m, "Tf_AnyOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input, + mlir::Value reduction_indices, + bool keep_dims = false) -> mlir::Operation* { + return opb + .create(loc, opb.getI1Type(), input, + reduction_indices, keep_dims) + .getOperation(); + }); + + // mlir::TF::ConstOp + py::class_(m, "Tf_ConstOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + mlir::Attribute value) -> mlir::Operation* { + return opb.create(loc, value).getOperation(); + }); + + // mlir::TF::EqualOp + py::class_(m, "Tf_EqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb + .create(loc, x, y, opb.getBoolAttr(true)) + .getOperation(); + }); + + // mlir::TF::GreaterEqualOp + py::class_(m, "Tf_GreaterEqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y) + .getOperation(); + }); + + // mlir::TF::GreaterOp + py::class_(m, "Tf_GreaterOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::LegacyCallOp + py::class_(m, "Tf_LegacyCallOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + std::vector output, std::vector args, + std::string f) -> mlir::Operation* { + return opb + .create( + loc, mlir::ArrayRef(output), + mlir::ArrayRef(args), mlir::StringRef(f)) + .getOperation(); + }); + + // mlir::TF::LessEqualOp + py::class_(m, "Tf_LessEqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::LessOp + py::class_(m, "Tf_LessOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::NegOp + py::class_(m, "Tf_NegOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + mlir::Value x) -> mlir::Operation* { + return opb.create(loc, x).getOperation(); + }); + + py::class_(m, "Tf_NotEqualOp") + .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) { + return opb + .create( + loc, x, y, mlir::BoolAttr::get(true, opb.getContext())) + .getOperation(); + }); + + // mlir::TF::SubOp + py::class_(m, "Tf_SubOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc new file mode 100644 index 00000000000..2be67f8e93e --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +void init_types(py::module& m) { + // Type + py::class_ Type(m, "Type"); + Type.def("getKind", &mlir::Type::getKind); + + // Type Enums + py::enum_(Type, "StandardTypes_Kind") + .value("BF16", mlir::StandardTypes::BF16); + + // Type Sub-classes + py::class_(m, "FunctionType") + .def("getResults", + [](mlir::FunctionType& ft) { return ft.getResults().vec(); }); + + py::class_(m, "FloatType") + .def("get", &mlir::FloatType::get); + + py::class_(m, "IntegerType") + .def("get", py::overload_cast( + &mlir::IntegerType::get)); + + py::class_(m, "UnrankedTensorType") + .def("get", &mlir::UnrankedTensorType::get); + + py::class_(m, "RankedTensorType") + .def("get", [](std::vector shape, mlir::Type ty) { + return mlir::RankedTensorType::get(mlir::ArrayRef(shape), ty); + }); +} diff --git a/tensorflow/python/tf_program/BUILD b/tensorflow/python/tf_program/BUILD new file mode 100644 index 00000000000..9dfb0df8a24 --- /dev/null +++ b/tensorflow/python/tf_program/BUILD @@ -0,0 +1,22 @@ +package(licenses = ["notice"]) + +py_library( + name = "pywrap_tfd", + srcs = ["pywrap_tfd.py"], + deps = [ + "//tensorflow/compiler/mlir/python/mlir_wrapper", + ], +) + +py_library( + name = "mlir_gen", + srcs = ["mlir_gen.py"], + visibility = ["//visibility:public"], + deps = [ + ":pywrap_tfd", + "//tensorflow/python/autograph/pyct", + "//tensorflow/python/autograph/pyct/static_analysis", + "//tensorflow/python/types", + "@gast_archive//:gast", + ], +) diff --git a/tensorflow/python/tf_program/mlir_gen.py b/tensorflow/python/tf_program/mlir_gen.py new file mode 100644 index 00000000000..8395848a53a --- /dev/null +++ b/tensorflow/python/tf_program/mlir_gen.py @@ -0,0 +1,456 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""mlir_gen: Generate mlir code from python code.""" + +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast as ast +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import cfg +from tensorflow.python.autograph.pyct import inspect_utils +from tensorflow.python.autograph.pyct import naming +from tensorflow.python.autograph.pyct import parser +from tensorflow.python.autograph.pyct import qual_names +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct.static_analysis import activity +from tensorflow.python.autograph.pyct.static_analysis import annos +from tensorflow.python.autograph.pyct.static_analysis import liveness +from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs +import tensorflow.python.tf_program.pywrap_tfd as tfp +from tensorflow.python.types import core + + +class SymbolTable(object): + """Symbol Table for python code.""" + + def __init__(self): + self.symbols = [] + self.enter_scope() + + def enter_scope(self): + """Enter a new scope - at function level.""" + self.symbols.append({'types': {}, 'symbols': {}}) + self.curr_table = self.symbols[len(self.symbols) - 1] + + def insert_symbol(self, name, value): + self.curr_table['symbols'][name] = value + self.curr_table['types'][name] = value.getType() + return value + + def insert_type(self, name, type_): + self.curr_table['types'][name] = type_ + + def exit_scope(self): + self.symbols.pop() + self.curr_table = self.symbols[len(self.symbols) - 1] + + def lookup(self, name): + curr_idx = len(self.symbols) - 1 + while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']): + curr_idx -= 1 + if curr_idx < 0: + return None + return self.symbols[curr_idx]['symbols'][name] + + def lookup_type(self, name): + curr_idx = len(self.symbols) - 1 + while curr_idx >= 0 and (name not in self.symbols[curr_idx]['types']): + curr_idx -= 1 + if curr_idx < 0: + return None + return self.symbols[curr_idx]['types'][name] + + def __repr__(self): + s = '\n'.join( + ' ' * idx * 2 + str(table) for idx, table in enumerate(self.symbols)) + return s + + +class ProcessType(ast.NodeVisitor): + """Visit a node and return processed type Currently only visits annotations and gives their type. + """ + + def __init__(self, prog, ctx): + self.prog = prog + self.ctx = ctx + + def visit_Attribute(self, node): + # Supported: core.Tensor + value = self.visit(node.value) + if value is None or not hasattr(value, node.attr): + raise AttributeError(str(type(value)) + ' has no attribute ' + node.attr) + attr = getattr(value, node.attr) + + if attr == core.Tensor: + return tfp.UnrankedTensorType.get(tfp.IntegerType.get(32, self.prog.ctx)) + return attr + + def visit_Name(self, node): + if node.id == 'int': + return tfp.IntegerType.get(32, self.prog.ctx) + if node.id == 'bool': + return tfp.IntegerType.get(1, self.prog.ctx) + if node.id in self.ctx.info.namespace: + return self.ctx.info.namespace[node.id] + + +class MLIRGen(ast.NodeVisitor): + """Visit the AST and generate MLIR code Requires liveness, reading_definitions. + """ + + def __init__(self, ctx): + self.ctx = ctx + self.symbol_table = SymbolTable() + self.prog = tfp.TFProgram() + self.opbuilder = None + + def visit_block(self, block): + return [self.visit(item) for item in block] + + def process_type(self, node): + return ProcessType(self.prog, self.ctx).visit(node) + + def visit_Assign(self, node): + value = self.visit(node.value) + if isinstance(value, tuple): + # If it is a tuple of values, assign one to each in targets + # TODO: This currently is assuming that all elts in targets[0] are Name + # objects. This might not be always True. + for key, val in zip(node.targets[0].elts, value): + self.symbol_table.insert_symbol(key.id, val) + else: + self.symbol_table.insert_symbol(node.targets[0].id, value) + + def visit_BinOp(self, node): + left = self.visit(node.left) + right = self.visit(node.right) + if isinstance(node.op, ast.Sub): + return tfp.Tf_SubOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), + left, right).getResult(0) + if isinstance(node.op, ast.Add): + return tfp.Tf_AddV2Op.create(self.opbuilder, + self.opbuilder.getUnknownLoc(), left, + right).getResult(0) + + def visit_BoolOp(self, node): + values = [self.visit(value) for value in node.values] + if isinstance(node.op, ast.Or): + return tfp.OrOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), + values).getResult(0) + if isinstance(node.op, ast.And): + return tfp.AndOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), + values).getResult(0) + + def visit_Call(self, node): + func = self.visit(node.func) + args = [self.visit(arg) for arg in node.args] + callop = tfp.Tf_LegacyCallOp.create(self.opbuilder, + self.opbuilder.getUnknownLoc(), + func.getType().getResults(), args, + func.getName()) + if callop.getNumResults() == 1: + return callop[0] + return tuple(callop.getResult(idx) for idx in range(callop.getNumResults())) + + def visit_Compare(self, node): + left = self.visit(node.left) + opb = self.opbuilder + for op, right in zip(node.ops, node.comparators): + if isinstance(op, ast.Eq): + left = tfp.Tf_EqualOp.create(opb, opb.getUnknownLoc(), left, + self.visit(right)).getResult(0) + elif isinstance(op, ast.Lt): + left = tfp.Tf_LessOp.create(opb, opb.getUnknownLoc(), left, + self.visit(right)).getResult(0) + elif isinstance(op, ast.LtE): + left = tfp.Tf_LessEqualOp.create(opb, opb.getUnknownLoc(), left, + self.visit(right)).getResult(0) + elif isinstance(op, ast.Gt): + left = tfp.Tf_GreaterOp.create(opb, opb.getUnknownLoc(), left, + self.visit(right)).getResult(0) + elif isinstance(op, ast.GtE): + left = tfp.Tf_GreaterEqualOp.create(opb, opb.getUnknownLoc(), left, + self.visit(right)).getResult(0) + elif isinstance(op, ast.NotEq): + left = tfp.Tf_NotEqualOp.create(opb, opb.getUnknownLoc(), left, + self.visit(right)).getResult(0) + else: + raise NotImplementedError('CompareOp operator not recognized') + return left + + def visit_Constant(self, node): + opb = self.opbuilder + value = None + if isinstance(node.value, int): + value = tfp.Tf_ConstOp.create( + opb, opb.getUnknownLoc(), + tfp.IntegerAttr.get( + tfp.IntegerType.get(32, self.prog.ctx), node.value)).getResult(0) + return value + + def visit_FunctionDef(self, node): + # Cache the current builder + cache_builder = self.opbuilder + inputs, outputs = [], [] + + for arg in node.args.args: + inputs.append(self.process_type(arg.annotation)) + + if node.returns: + outputs = [self.process_type(node.returns)] + + currfunc = self.prog.add_function( + self.ctx.namer.new_symbol(node.name, []), + self.prog.get_function_type(inputs, outputs)) + + # Add the function to symbol table and enter new scope + self.symbol_table.insert_symbol(node.name, currfunc) + self.symbol_table.enter_scope() + + # Add arguments to symbol table + for arg, value in zip(node.args.args, currfunc.getArguments()): + self.symbol_table.insert_symbol(arg.id, value) + self.opbuilder = tfp.OpBuilder(currfunc.getBody()) + + self.visit_block(node.body) + self.symbol_table.exit_scope() + self.opbuilder = cache_builder + + def visit_If(self, node): + cond = self.visit(node.test) + + # Create ifop + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) + modified_in_cond = list(body_scope.modified | orelse_scope.modified) + outputs = [ + self.symbol_table.lookup_type(str(var)) for var in modified_in_cond + ] + ifop = tfp.IfOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), cond, + outputs) + + # Cache the builder + cache_builder = self.opbuilder + + # Visit body + self.opbuilder = tfp.OpBuilder(ifop.getRegion(0)) + # Enter scope to avoid values generated inside the region to come in symbol + # table + self.symbol_table.enter_scope() + for stmt in node.body: + self.visit(stmt) + retvals = [ + self.symbol_table.lookup(str(varname)) for varname in modified_in_cond + ] + tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals) + self.symbol_table.exit_scope() + + # Visit orelse + self.opbuilder = tfp.OpBuilder(ifop.getRegion(1)) + self.symbol_table.enter_scope() + for stmt in node.orelse: + self.visit(stmt) + retvals = [ + self.symbol_table.lookup(str(varname)) for varname in modified_in_cond + ] + tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals) + self.symbol_table.exit_scope() + + # Reset builder and enter return values in symbol table + self.opbuilder = cache_builder + for idx, var in enumerate(modified_in_cond): + self.symbol_table.insert_symbol(str(var), ifop.getResult(idx)) + + if ifop.getNumResults() == 1: + return ifop.getResult(0) + + return tuple(ifop.getResult(i) for i in range(ifop.getNumResults())) + + def visit_Name(self, node): + if self.symbol_table.lookup(node.id): + return self.symbol_table.lookup(node.id) + raise NotImplementedError('Symbol not found' + node.id) + + def visit_Return(self, node): + opb = self.opbuilder + value = self.visit(node.value) + if isinstance(value, tuple): + # For more than one return values + return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), list(value)) + return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), [value]) + + def visit_Tuple(self, node): + return tuple(self.visit(elt) for elt in node.elts) + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + if isinstance(node.op, ast.USub): + return tfp.Tf_NegOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), + operand).getResult(0) + + def _get_basic_loop_vars(self, modified, live_in, live_out): + # [This is directly from + # tensorflow/python/autograph/converters/control_flow.py] + # The loop variables corresponding to simple symbols (e.g. `x`). + basic_loop_vars = [] + for s in modified: + if s.is_composite(): + # TODO: Raise an error when this happens for a TF loop. + continue + # Variables not live into or out of the loop are considered local to the + # loop. + if s not in live_in and s not in live_out: + continue + basic_loop_vars.append(s) + return frozenset(basic_loop_vars) + + def _get_composite_loop_vars(self, modified, live_in): + # [This is directly from + # tensorflow/python/autograph/converters/control_flow.py] + # The loop variables corresponding to composite symbols (e.g. `self.x`). + composite_loop_vars = [] + for s in modified: + if not s.is_composite(): + continue + # Mutations made to objects created inside the loop will appear as writes + # to composite symbols. Because these mutations appear as modifications + # made to composite symbols, we check whether the composite's parent is + # actually live into the loop. + # Example: + # while cond: + # x = Foo() + # x.foo = 2 * x.foo # x.foo is live into the loop, but x is not. + # + # Note that some parents might not be symbols - for example, in x['foo'], + # 'foo' is a parent, but it's a literal, not a symbol. We don't check the + # liveness of literals. + support_set_symbols = tuple( + sss for sss in s.support_set if sss.is_symbol()) + if not all(sss in live_in for sss in support_set_symbols): + continue + composite_loop_vars.append(s) + return frozenset(composite_loop_vars) + + def _get_loop_vars(self, node, modified): + # [This is directly from python/autograph/converters/control_flow.py] + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) + live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) + live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + reserved_symbols = body_scope.referenced + + basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out) + composite_loop_vars = self._get_composite_loop_vars(modified, live_in) + loop_vars = tuple(basic_loop_vars | composite_loop_vars) + + # Variable that are used or defined inside the loop, but not defined + # before entering the loop. Only simple variables must be defined. The + # composite ones will be implicitly checked at runtime. + undefined_lives = basic_loop_vars - defined_in + + return loop_vars, reserved_symbols, undefined_lives + + def visit_While(self, node): + + # Create a new WhileOp + # `inputs` are initial values for loop variables + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + loop_vars, _, _ = self._get_loop_vars(node, body_scope.modified) + inputs = [self.symbol_table.lookup(str(name)) for name in loop_vars] + types = [input_.getType() for input_ in inputs] + while_op = tfp.WhileOp.create(self.opbuilder, + self.opbuilder.getUnknownLoc(), inputs, types) + + # cache the current builder + cache_builder = self.opbuilder + + # Process cond + self.symbol_table.enter_scope() + for input_, type_ in zip(loop_vars, types): + self.symbol_table.insert_symbol( + str(input_), + while_op.getRegion(0).front().addArgument(type_)) + self.opbuilder = tfp.OpBuilder(while_op.getRegion(0)) + tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), + [self.visit(node.test)]) + self.symbol_table.exit_scope() + + # Process body + self.symbol_table.enter_scope() + for input_, type_ in zip(loop_vars, types): + self.symbol_table.insert_symbol( + str(input_), + while_op.getRegion(1).front().addArgument(type_)) + self.opbuilder = tfp.OpBuilder(while_op.getRegion(1)) + self.visit_block(node.body) + tfp.ReturnOp.create( + self.opbuilder, self.opbuilder.getUnknownLoc(), + [self.symbol_table.lookup(str(name)) for name in loop_vars]) + self.symbol_table.exit_scope() + + # Enter new values as symbols + for idx, var in enumerate(loop_vars): + self.symbol_table.insert_symbol(str(var), while_op.getResult(idx)) + + # Restore builder + self.opbuilder = cache_builder + + +def mlir_gen_internal(node, entity_info): + """Returns mlir module for unprocessed node `node`.""" + namer = naming.Namer({}) + graphs = cfg.build(node) + ctx = transformer.Context(entity_info, namer, None) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + node = reaching_definitions.resolve(node, ctx, graphs) + node = reaching_fndefs.resolve(node, ctx, graphs) + node = liveness.resolve(node, ctx, graphs) + mlir_generator = MLIRGen(ctx) + mlir_generator.visit(node) + return mlir_generator.prog + + +def mlir_gen(func): + """Parse a function and return TFProgram.""" + node, source = parser.parse_entity(func, future_features=()) + entity_info = transformer.EntityInfo( + name=func.__name__, + source_code=source, + source_file=None, + future_features=(), + namespace=inspect_utils.getnamespace(func)) + return mlir_gen_internal(node, entity_info) + + +def mlir_gen_from_source(source=None, src_file=None): + """Parse a function as either a string or from a supplied file path and return a TFProgram. + """ + if source is None: + source = open(src_file).read() + node = ast.parse(source) + entity_info = transformer.EntityInfo( + name='mlir_module', + source_code=source, + source_file=None, + future_features=(), + namespace={}) + return mlir_gen_internal(node, entity_info) diff --git a/tensorflow/python/tf_program/pywrap_tfd.py b/tensorflow/python/tf_program/pywrap_tfd.py new file mode 100644 index 00000000000..0d9a236f5d3 --- /dev/null +++ b/tensorflow/python/tf_program/pywrap_tfd.py @@ -0,0 +1,159 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Intermediate between python bindings for MLIR and mlir generation for tensorflow program. + +This passes most of the mlir classes as is, but adds a few new operations and +the basic structure for a tensorflow program. +""" + +# pylint: disable=invalid-name + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.mlir.python.mlir_wrapper import mlir_wrapper as mlir + +# Class Definitions +OpBuilder = mlir.OpBuilder +Block = mlir.Block + +# Types +Type = mlir.Type +IntegerType = mlir.IntegerType +FloatType = mlir.FloatType +RankedTensorType = mlir.RankedTensorType +UnrankedTensorType = mlir.UnrankedTensorType +IntegerAttr = mlir.IntegerAttr + +# Standard Ops +ReturnOp = mlir.ReturnOp + +# TF Dialect Ops +Tf_AnyOp = mlir.Tf_AnyOp +Tf_AddV2Op = mlir.Tf_AddV2Op +Tf_ConstOp = mlir.Tf_ConstOp +Tf_EqualOp = mlir.Tf_EqualOp +Tf_GreaterEqualOp = mlir.Tf_GreaterEqualOp +Tf_GreaterOp = mlir.Tf_GreaterOp +Tf_LegacyCallOp = mlir.Tf_LegacyCallOp +Tf_LessEqualOp = mlir.Tf_LessEqualOp +Tf_LessOp = mlir.Tf_LessOp +Tf_NegOp = mlir.Tf_NegOp +Tf_NotEqualOp = mlir.Tf_NotEqualOp +Tf_SubOp = mlir.Tf_SubOp + + +class IfOp(object): + """ + tfp.if(cond) ({body}, {orelse}) : type If `cond` is true, `body` is + executed, otherwise `orelse` is executed. + """ + + @classmethod + def create(cls, opb, loc, cond, outputs): + state = mlir.OperationState(loc, "tfp.If") + state.addOperands([cond]) + state.addTypes(outputs) + state.addRegion().push_back(Block.new()) # body region + state.addRegion().push_back(Block.new()) # orelse region + return opb.createOperation(state) + + +class OrOp(object): + """ + tfp.Or(ops...) This is like tf.Any, except that the first dimension is opened + into `ops`. + + Returns a tensor of 1-bit integers which is "Logical OR" of the + coressponding elements in ops... + """ + + @classmethod + def create(cls, opb, loc, values): + state = mlir.OperationState(loc, "tfp.Or") + state.addTypes( + [UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))]) + state.addOperands(values) + return opb.createOperation(state) + + +class AndOp(object): + """ + tfp.And(ops...) This is like tf.All, except that the first dimension is opened + to `ops`. + + Returns a tensor of 1-bit integers which is "Logical AND" of the + coressponding elements in ops... + """ + + @classmethod + def create(cls, opb, loc, values): + state = mlir.OperationState(loc, "tfp.And") + state.addTypes( + [UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))]) + state.addOperands(values) + return opb.createOperation(state) + + +class WhileOp(object): + """tfp.While(init-vals, { + + ^bb1(cond-args): + cond-region + return cond + }, { + ^bb1(body-args): + body-region + }) + As long as `cond-region` returns a "true"-like value, the body-region + is executed and the arguments are replaced by its return values for the next + iteration. + """ + + @classmethod + def create(cls, opb, loc, inputs, outputs): + state = mlir.OperationState(loc, "tfp.While") + state.addOperands(inputs) + state.addTypes(outputs) + state.addRegion().push_back(Block.new()) # cond region + state.addRegion().push_back(Block.new()) # body region + return opb.createOperation(state) + + +class TFProgram(object): + """Python wrap for a Tensorflow Program (essentially an mlir Module).""" + + def __init__(self): + mlir.registerDialects() + self.ctx = mlir.MLIRContext() + self.builder = mlir.Builder(self.ctx) + self.module = mlir.ModuleOp.create(mlir.UnknownLoc.get(self.ctx)) + self.curr_func = None + + def add_function(self, name, func_type): + self.curr_func = mlir.FuncOp.create( + mlir.UnknownLoc.get(self.ctx), name, func_type) + self.module.push_back(self.curr_func) + return self.curr_func + + def get_function_type(self, inputs, outputs): + return self.builder.getFunctionType(inputs, outputs) + + def dump(self): + self.module.dump() + + def __str__(self): + return self.module.getAsStr() diff --git a/tensorflow/python/tf_program/tests/BUILD b/tensorflow/python/tf_program/tests/BUILD new file mode 100644 index 00000000000..1cf0fad6c93 --- /dev/null +++ b/tensorflow/python/tf_program/tests/BUILD @@ -0,0 +1,20 @@ +package(licenses = ["notice"]) + +py_test( + name = "mlir_gen_test", + size = "small", + testonly = True, + srcs = ["mlir_gen_test.py"], + python_version = "PY3", + srcs_version = "PY3", + tags = [ + "no_oss_py2", + "no_pip", + ], + deps = [ + "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", + "//tensorflow/python:client_testlib", + "//tensorflow/python/tf_program:mlir_gen", + "//tensorflow/python/types", + ], +) diff --git a/tensorflow/python/tf_program/tests/mlir_gen_test.py b/tensorflow/python/tf_program/tests/mlir_gen_test.py new file mode 100644 index 00000000000..5e1ca5b36e0 --- /dev/null +++ b/tensorflow/python/tf_program/tests/mlir_gen_test.py @@ -0,0 +1,247 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `mlir_gen` module""" + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import test +from tensorflow.python.types import core +from tensorflow.python.tf_program.mlir_gen import mlir_gen + +import tensorflow.compiler.mlir.python.mlir_wrapper.filecheck_wrapper as fw + + +class MLIRGenTestBase(test.TestCase): + + def _check_code(self, mlir_code, exp_mlir_code): + return self.assertTrue(fw.check(str(mlir_code), exp_mlir_code)) + + +class MLIRGenTest(MLIRGenTestBase): + """MLIR Generation Tests for Tensorflow Program""" + + def test_simple(self): + + def test_fn(): + pass + + mlir_code = mlir_gen(test_fn) + mlir_code_exp = r""" + CHECK-LABEL: @test_fn + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_argument(self): + + def test_fn(x: core.Tensor) -> core.Tensor: + return x + + mlir_code = mlir_gen(test_fn) + mlir_code_exp = r""" + CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32> { + CHECK-NEXT: return %arg0 : tensor<*xi32> + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_constant(self): + + def test_fn() -> int: + return 23 + + mlir_code = mlir_gen(test_fn) + exp_mlir_code = r""" + CHECK-LABEL: func @test_fn() -> i32 + CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor} : () -> tensor + CHECK: return %[[r0]] : tensor + """ + self._check_code(mlir_code, exp_mlir_code) + + def test_BoolOp(self): + + def test_fn(x: bool, y: bool) -> bool: + return x or y or x and x and y + + mlir_code = mlir_gen(test_fn) + exp_mlir_code = r""" + CHECK-LABEL: func @test_fn(%arg0: i1, %arg1: i1) -> i1 + CHECK: %[[r0:[0-9]+]] = "tfp.And"(%arg0, %arg0, %arg1) : (i1, i1, i1) -> tensor<*xi1> + CHECK: %[[r1:[0-9]+]] = "tfp.Or"(%arg0, %arg1, %[[r0]]) : (i1, i1, tensor<*xi1>) -> tensor<*xi1> + return %[[r1]] : tensor<*xi1> + """ + self._check_code(mlir_code, exp_mlir_code) + + def test_Call(self): + + def test_fn(): + + def f1(): + return 23 + + def f2(): + return f1() + + f2() + + mlir_code = mlir_gen(test_fn) + exp_mlir_code = r""" + CHECK-LABEL: func @test_fn() + CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f2} : () -> () + CHECK: } + CHECK-LABEL: func @f1() { + CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor} : () -> tensor + CHECK: return %[[r0]] : tensor + CHECK: } + CHECK-LABEL: func @f2() { + CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f1} : () -> () + } + """ + self._check_code(mlir_code, exp_mlir_code) + + def test_Compare(self): + + def test_fn(x: core.Tensor, y: core.Tensor, z: core.Tensor): + return x > y < z + + mlir_code = mlir_gen(test_fn) + exp_mlir_code = r""" + CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) + CHECK: %[[r0:[0-9]+]] = "tf.Greater"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + CHECK: %[[r1:[0-9]+]] = "tf.Less"(%[[r0]], %arg2) : (tensor<*xi1>, tensor<*xi32>) -> tensor<*xi1> + CHECK: return %[[r1]] : tensor<*xi1> + """ + self._check_code(mlir_code, exp_mlir_code) + + def test_Assign_BinOp(self): + + def test_fn() -> int: + y = 12 + 23 - 24 + return y + + mlir_code = mlir_gen(test_fn) + exp_mlir_code = r""" + CHECK-LABEL: func @test_fn() -> i32 + CHECK: %[[r0:[0-9]+]] = "tf.AddV2"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor, tensor) -> tensor + CHECK: %[[r1:[0-9]+]] = "tf.Sub"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor, tensor) -> tensor + CHECK: return %[[r1]] : tensor + """ + self._check_code(mlir_code, exp_mlir_code) + + def test_if(self): + + def test_fn(x: core.Tensor) -> int: + res = 0 + if x > 0: + res = 1 + elif x < 0: + res = -1 + else: + res = 0 + return res + + mlir_code = mlir_gen(test_fn) + exp_mlir_code = r""" + CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> i32 + + CHECK: %[[r1:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor) -> tensor<*xi1> + CHECK-NEXT: %[[r2:[0-9]+]] = "tfp.If"(%[[r1]]) ( { + CHECK: return %{{[0-9]+}} : tensor + CHECK-NEXT: }, { + CHECK: %[[r3:[0-9]+]] = "tf.Less"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor) -> tensor<*xi1> + CHECK: %[[r4:[0-9]+]] = "tfp.If"(%[[r3]]) ( { + CHECK: %[[r5:[0-9]+]] = "tf.Neg"(%{{[0-9]+}}) : (tensor) -> tensor + CHECK: return %[[r5]] : tensor + CHECK-NEXT: }, { + CHECK: return %{{[0-9]+}} : tensor + CHECK-NEXT: }) : (tensor<*xi1>) -> tensor + CHECK: return %[[r4]] : tensor + CHECK-NEXT: }) : (tensor<*xi1>) -> tensor + CHECK-NEXT: return %[[r2]] : tensor + """ + self._check_code(mlir_code, exp_mlir_code) + + def test_while(self): + + def test_fn(x: core.Tensor) -> core.Tensor: + s = 0 + while x > 0: + s = s + x + return s + + mlir_code = mlir_gen(test_fn) + exp_mlir_code = r""" + CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32> + + CHECK: %[[r1:[0-9]+]] = "tfp.While"(%0) ( { + CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor): + CHECK: %[[r2:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor) -> tensor<*xi1> + CHECK-NEXT: return %[[r2]] : tensor<*xi1> + CHECK-NEXT: }, { + CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor): + CHECK: %[[r3:[0-9]+]] = "tf.AddV2"(%arg1, %arg0) : (tensor, tensor<*xi32>) -> tensor<*xi32> + CHECK-NEXT: return %[[r3]] : tensor<*xi32> + CHECK-NEXT: }) : (tensor) -> tensor + CHECK-NEXT: return %[[r1]] : tensor + """ + self._check_code(mlir_code, exp_mlir_code) + + def test_fibonacci(self): + + def test_fn(x: core.Tensor) -> core.Tensor: + res, idx = 0, 2 + a, b = 0, 1 + if x == 0 or x == 1: + res = x + else: + while idx <= x: + res = a + b + a = b + b = res + idx = idx + 1 + return res + + mlir_code = mlir_gen(test_fn) + exp_mlir_code = r""" + CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32> + CHECK: %[[r5:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor) -> tensor<*xi1> + CHECK: %[[r7:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor) -> tensor<*xi1> + CHECK: %[[r8:[0-9]+]] = "tfp.Or"(%[[r5]], %[[r7]]) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + + CHECK: %[[r9:[0-9]+]]:4 = "tfp.If"(%[[r8]]) ( { + CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32> + CHECK-NEXT: }, { + CHECK-NEXT: %[[r10:[0-9]+]]:4 = "tfp.While"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { + CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): + CHECK-NEXT: %[[r11:[0-9]+]] = "tf.LessEqual"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>) -> tensor<*xi1> + CHECK-NEXT: return %[[r11]] : tensor<*xi1> + CHECK-NEXT: }, { + CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): + CHECK-NEXT: %[[r12:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor, tensor) -> tensor + CHECK: %[[r13:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %{{[0-9]+}}) : (tensor, tensor) -> tensor + CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor, tensor, tensor, tensor + CHECK-NEXT: }) : (tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) + CHECK-NEXT: return %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}} : tensor, tensor, tensor, tensor + CHECK-NEXT: }) : (tensor<*xi1>) -> (tensor, tensor, tensor, tensor) + CHECK-NEXT: return %[[r9]]#{{[0-9]+}} : tensor + """ + self._check_code(mlir_code, exp_mlir_code) + + +if __name__ == '__main__': + test.main()