Merge pull request from shraiysh:master

PiperOrigin-RevId: 311248675
Change-Id: Idf82ebbb155efec1624565eb13bd67573b68037a
This commit is contained in:
TensorFlower Gardener 2020-05-12 19:23:12 -07:00
commit 7bffd6c498
14 changed files with 1416 additions and 0 deletions

View File

@ -0,0 +1,41 @@
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
package(licenses = ["notice"])
tf_python_pybind_extension(
name = "mlir_wrapper",
srcs = [
"attrs.cc",
"basic_classes.cc",
"builders.cc",
"mlir_wrapper.cc",
"mlir_wrapper.h",
"ops.cc",
"types.cc",
],
module_name = "mlir_wrapper",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@pybind11",
],
)
tf_python_pybind_extension(
name = "filecheck_wrapper",
srcs = ["filecheck_wrapper.cc"],
module_name = "filecheck_wrapper",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support",
"@pybind11",
],
)

View File

@ -0,0 +1,25 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
void init_attrs(py::module& m) {
py::class_<mlir::Attribute>(m, "Attribute");
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "IntegerAttr")
.def("get",
py::overload_cast<mlir::Type, int64_t>(&mlir::IntegerAttr::get));
}

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/FileCheck.h"
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
void init_basic_classes(py::module& m) {
py::class_<mlir::MLIRContext>(m, "MLIRContext").def(py::init<>());
py::class_<mlir::Location>(m, "Location");
py::class_<mlir::UnknownLoc>(m, "UnknownLoc")
.def("get", &mlir::UnknownLoc::get);
py::class_<mlir::Region>(m, "Region")
.def("back", &mlir::Region::back, py::return_value_policy::reference)
.def("front", &mlir::Region::front, py::return_value_policy::reference)
.def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); })
.def("push_back", &mlir::Region::push_back)
.def("size", [](mlir::Region& r) { return r.getBlocks().size(); })
.def("front", &mlir::Region::front, py::return_value_policy::reference);
py::class_<mlir::Block::iterator>(m, "Block_Iterator");
py::class_<mlir::Block>(m, "Block")
.def("new", ([]() { return new mlir::Block; }),
py::return_value_policy::reference)
.def("end", &mlir::Block::end)
.def("addArgument", &mlir::Block::addArgument);
py::class_<mlir::Value>(m, "Value").def("getType", &mlir::Value::getType);
py::class_<mlir::OpResult, mlir::Value>(m, "OpResult");
py::class_<mlir::BlockArgument, mlir::Value>(m, "BlockArgument");
}

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
void init_builders(py::module& m) {
py::class_<mlir::Builder>(m, "Builder")
.def(py::init<mlir::MLIRContext*>())
.def("getFunctionType",
[](mlir::Builder& b, std::vector<mlir::Type> inputs,
std::vector<mlir::Type> outputs) {
return b.getFunctionType(llvm::ArrayRef<mlir::Type>(inputs),
llvm::ArrayRef<mlir::Type>(outputs));
});
py::class_<mlir::OpBuilder>(m, "OpBuilder")
.def(py::init<mlir::MLIRContext*>())
.def(py::init<mlir::Region&>())
.def(py::init<mlir::Operation*>())
.def(py::init<mlir::Block*, mlir::Block::iterator>())
.def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc)
.def("setInsertionPoint",
py::overload_cast<mlir::Block*, mlir::Block::iterator>(
&mlir::OpBuilder::setInsertionPoint))
.def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint)
.def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint)
.def(
"createOperation",
[](mlir::OpBuilder& opb, mlir::OperationState& state) {
return opb.createOperation(state);
},
py::return_value_policy::reference)
.def("getContext", &mlir::OpBuilder::getContext,
py::return_value_policy::reference);
py::class_<mlir::OpBuilder::InsertPoint>(m, "OpBuilder_InsertionPoint")
.def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock);
}

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/FileCheck.h"
#include "llvm/Support/SourceMgr.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
PYBIND11_MODULE(filecheck_wrapper, m) {
m.def("check", [](std::string input, std::string check) {
llvm::FileCheckRequest fcr;
llvm::FileCheck fc(fcr);
llvm::SourceMgr SM = llvm::SourceMgr();
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
llvm::SMLoc());
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check),
llvm::SMLoc());
llvm::Regex regex = fc.buildCheckPrefixRegex();
fc.readCheckFile(SM, llvm::StringRef(check), regex);
return fc.checkInput(SM, llvm::StringRef(input));
});
}

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
PYBIND11_MODULE(mlir_wrapper, m) {
m.def("registerDialects", []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
});
init_basic_classes(m);
init_types(m);
init_builders(m);
init_ops(m);
init_attrs(m);
}

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
void init_basic_classes(py::module& m);
void init_types(py::module& m);
void init_builders(py::module& m);
void init_ops(py::module& m);
void init_attrs(py::module& m);
#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H

View File

@ -0,0 +1,194 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
void init_ops(py::module& m) {
py::class_<mlir::Operation, std::unique_ptr<mlir::Operation, py::nodelete>>(
m, "Operation")
.def("getRegion", &mlir::Operation::getRegion,
py::return_value_policy::reference)
.def("getResult", &mlir::Operation::getResult)
.def("dump", &mlir::Operation::dump)
.def("getNumResults", &mlir::Operation::getNumResults);
py::class_<mlir::OperationState>(m, "OperationState")
.def(py::init([](mlir::Location loc, std::string name) {
return mlir::OperationState(loc, llvm::StringRef(name));
}))
.def("addTypes",
[](mlir::OperationState& state, std::vector<mlir::Type> tys) {
state.addTypes(mlir::ArrayRef<mlir::Type>(tys));
})
.def("addOperands",
[](mlir::OperationState& os, std::vector<mlir::Value> ops) {
os.addOperands(mlir::ArrayRef<mlir::Value>(ops));
})
.def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion),
py::return_value_policy::reference);
py::class_<mlir::ModuleOp>(m, "ModuleOp")
.def("create",
[](mlir::Location loc) { return mlir::ModuleOp::create(loc); })
.def("push_back",
[](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); })
.def("dump", &mlir::ModuleOp::dump)
.def("getAsStr", [](mlir::ModuleOp& m) {
std::string str;
llvm::raw_string_ostream os(str);
m.print(os);
return os.str();
});
py::class_<mlir::FuncOp>(m, "FuncOp")
.def("create",
[](mlir::Location location, std::string name,
mlir::FunctionType type) {
auto func = mlir::FuncOp::create(location, name, type);
func.addEntryBlock();
return func;
})
.def(
"getBody",
[](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); },
py::return_value_policy::reference)
.def("getArguments",
[](mlir::FuncOp& f) { return f.getArguments().vec(); })
.def("getName", [](mlir::FuncOp& f) { return f.getName().str(); })
.def("getType", &mlir::FuncOp::getType);
py::class_<mlir::ReturnOp>(m, "ReturnOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
std::vector<mlir::Value> values) -> mlir::Operation* {
return opb
.create<mlir::ReturnOp>(loc,
mlir::ArrayRef<mlir::Value>(values))
.getOperation();
});
// mlir::TF::AddOp
py::class_<mlir::TF::AddV2Op>(m, "Tf_AddV2Op")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::AddV2Op>(loc, x, y).getOperation();
});
py::class_<mlir::TF::AnyOp>(m, "Tf_AnyOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input,
mlir::Value reduction_indices,
bool keep_dims = false) -> mlir::Operation* {
return opb
.create<mlir::TF::AnyOp>(loc, opb.getI1Type(), input,
reduction_indices, keep_dims)
.getOperation();
});
// mlir::TF::ConstOp
py::class_<mlir::TF::ConstOp>(m, "Tf_ConstOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
mlir::Attribute value) -> mlir::Operation* {
return opb.create<mlir::TF::ConstOp>(loc, value).getOperation();
});
// mlir::TF::EqualOp
py::class_<mlir::TF::EqualOp>(m, "Tf_EqualOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb
.create<mlir::TF::EqualOp>(loc, x, y, opb.getBoolAttr(true))
.getOperation();
});
// mlir::TF::GreaterEqualOp
py::class_<mlir::TF::GreaterEqualOp>(m, "Tf_GreaterEqualOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::GreaterEqualOp>(loc, x, y)
.getOperation();
});
// mlir::TF::GreaterOp
py::class_<mlir::TF::GreaterOp>(m, "Tf_GreaterOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::GreaterOp>(loc, x, y).getOperation();
});
// mlir::TF::LegacyCallOp
py::class_<mlir::TF::LegacyCallOp>(m, "Tf_LegacyCallOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
std::vector<mlir::Type> output, std::vector<mlir::Value> args,
std::string f) -> mlir::Operation* {
return opb
.create<mlir::TF::LegacyCallOp>(
loc, mlir::ArrayRef<mlir::Type>(output),
mlir::ArrayRef<mlir::Value>(args), mlir::StringRef(f))
.getOperation();
});
// mlir::TF::LessEqualOp
py::class_<mlir::TF::LessEqualOp>(m, "Tf_LessEqualOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::LessEqualOp>(loc, x, y).getOperation();
});
// mlir::TF::LessOp
py::class_<mlir::TF::LessOp>(m, "Tf_LessOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::LessOp>(loc, x, y).getOperation();
});
// mlir::TF::NegOp
py::class_<mlir::TF::NegOp>(m, "Tf_NegOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
mlir::Value x) -> mlir::Operation* {
return opb.create<mlir::TF::NegOp>(loc, x).getOperation();
});
py::class_<mlir::TF::NotEqualOp>(m, "Tf_NotEqualOp")
.def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) {
return opb
.create<mlir::TF::NotEqualOp>(
loc, x, y, mlir::BoolAttr::get(true, opb.getContext()))
.getOperation();
});
// mlir::TF::SubOp
py::class_<mlir::TF::SubOp>(m, "Tf_SubOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::SubOp>(loc, x, y).getOperation();
});
}

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
void init_types(py::module& m) {
// Type
py::class_<mlir::Type> Type(m, "Type");
Type.def("getKind", &mlir::Type::getKind);
// Type Enums
py::enum_<mlir::StandardTypes::Kind>(Type, "StandardTypes_Kind")
.value("BF16", mlir::StandardTypes::BF16);
// Type Sub-classes
py::class_<mlir::FunctionType, mlir::Type>(m, "FunctionType")
.def("getResults",
[](mlir::FunctionType& ft) { return ft.getResults().vec(); });
py::class_<mlir::FloatType, mlir::Type>(m, "FloatType")
.def("get", &mlir::FloatType::get);
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
&mlir::IntegerType::get));
py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
.def("get", &mlir::UnrankedTensorType::get);
py::class_<mlir::RankedTensorType, mlir::Type>(m, "RankedTensorType")
.def("get", [](std::vector<int64_t> shape, mlir::Type ty) {
return mlir::RankedTensorType::get(mlir::ArrayRef<int64_t>(shape), ty);
});
}

View File

@ -0,0 +1,22 @@
package(licenses = ["notice"])
py_library(
name = "pywrap_tfd",
srcs = ["pywrap_tfd.py"],
deps = [
"//tensorflow/compiler/mlir/python/mlir_wrapper",
],
)
py_library(
name = "mlir_gen",
srcs = ["mlir_gen.py"],
visibility = ["//visibility:public"],
deps = [
":pywrap_tfd",
"//tensorflow/python/autograph/pyct",
"//tensorflow/python/autograph/pyct/static_analysis",
"//tensorflow/python/types",
"@gast_archive//:gast",
],
)

View File

@ -0,0 +1,456 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""mlir_gen: Generate mlir code from python code."""
# pylint: disable=invalid-name
# pylint: disable=missing-function-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gast as ast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import annos
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
import tensorflow.python.tf_program.pywrap_tfd as tfp
from tensorflow.python.types import core
class SymbolTable(object):
"""Symbol Table for python code."""
def __init__(self):
self.symbols = []
self.enter_scope()
def enter_scope(self):
"""Enter a new scope - at function level."""
self.symbols.append({'types': {}, 'symbols': {}})
self.curr_table = self.symbols[len(self.symbols) - 1]
def insert_symbol(self, name, value):
self.curr_table['symbols'][name] = value
self.curr_table['types'][name] = value.getType()
return value
def insert_type(self, name, type_):
self.curr_table['types'][name] = type_
def exit_scope(self):
self.symbols.pop()
self.curr_table = self.symbols[len(self.symbols) - 1]
def lookup(self, name):
curr_idx = len(self.symbols) - 1
while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']):
curr_idx -= 1
if curr_idx < 0:
return None
return self.symbols[curr_idx]['symbols'][name]
def lookup_type(self, name):
curr_idx = len(self.symbols) - 1
while curr_idx >= 0 and (name not in self.symbols[curr_idx]['types']):
curr_idx -= 1
if curr_idx < 0:
return None
return self.symbols[curr_idx]['types'][name]
def __repr__(self):
s = '\n'.join(
' ' * idx * 2 + str(table) for idx, table in enumerate(self.symbols))
return s
class ProcessType(ast.NodeVisitor):
"""Visit a node and return processed type Currently only visits annotations and gives their type.
"""
def __init__(self, prog, ctx):
self.prog = prog
self.ctx = ctx
def visit_Attribute(self, node):
# Supported: core.Tensor
value = self.visit(node.value)
if value is None or not hasattr(value, node.attr):
raise AttributeError(str(type(value)) + ' has no attribute ' + node.attr)
attr = getattr(value, node.attr)
if attr == core.Tensor:
return tfp.UnrankedTensorType.get(tfp.IntegerType.get(32, self.prog.ctx))
return attr
def visit_Name(self, node):
if node.id == 'int':
return tfp.IntegerType.get(32, self.prog.ctx)
if node.id == 'bool':
return tfp.IntegerType.get(1, self.prog.ctx)
if node.id in self.ctx.info.namespace:
return self.ctx.info.namespace[node.id]
class MLIRGen(ast.NodeVisitor):
"""Visit the AST and generate MLIR code Requires liveness, reading_definitions.
"""
def __init__(self, ctx):
self.ctx = ctx
self.symbol_table = SymbolTable()
self.prog = tfp.TFProgram()
self.opbuilder = None
def visit_block(self, block):
return [self.visit(item) for item in block]
def process_type(self, node):
return ProcessType(self.prog, self.ctx).visit(node)
def visit_Assign(self, node):
value = self.visit(node.value)
if isinstance(value, tuple):
# If it is a tuple of values, assign one to each in targets
# TODO: This currently is assuming that all elts in targets[0] are Name
# objects. This might not be always True.
for key, val in zip(node.targets[0].elts, value):
self.symbol_table.insert_symbol(key.id, val)
else:
self.symbol_table.insert_symbol(node.targets[0].id, value)
def visit_BinOp(self, node):
left = self.visit(node.left)
right = self.visit(node.right)
if isinstance(node.op, ast.Sub):
return tfp.Tf_SubOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
left, right).getResult(0)
if isinstance(node.op, ast.Add):
return tfp.Tf_AddV2Op.create(self.opbuilder,
self.opbuilder.getUnknownLoc(), left,
right).getResult(0)
def visit_BoolOp(self, node):
values = [self.visit(value) for value in node.values]
if isinstance(node.op, ast.Or):
return tfp.OrOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
values).getResult(0)
if isinstance(node.op, ast.And):
return tfp.AndOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
values).getResult(0)
def visit_Call(self, node):
func = self.visit(node.func)
args = [self.visit(arg) for arg in node.args]
callop = tfp.Tf_LegacyCallOp.create(self.opbuilder,
self.opbuilder.getUnknownLoc(),
func.getType().getResults(), args,
func.getName())
if callop.getNumResults() == 1:
return callop[0]
return tuple(callop.getResult(idx) for idx in range(callop.getNumResults()))
def visit_Compare(self, node):
left = self.visit(node.left)
opb = self.opbuilder
for op, right in zip(node.ops, node.comparators):
if isinstance(op, ast.Eq):
left = tfp.Tf_EqualOp.create(opb, opb.getUnknownLoc(), left,
self.visit(right)).getResult(0)
elif isinstance(op, ast.Lt):
left = tfp.Tf_LessOp.create(opb, opb.getUnknownLoc(), left,
self.visit(right)).getResult(0)
elif isinstance(op, ast.LtE):
left = tfp.Tf_LessEqualOp.create(opb, opb.getUnknownLoc(), left,
self.visit(right)).getResult(0)
elif isinstance(op, ast.Gt):
left = tfp.Tf_GreaterOp.create(opb, opb.getUnknownLoc(), left,
self.visit(right)).getResult(0)
elif isinstance(op, ast.GtE):
left = tfp.Tf_GreaterEqualOp.create(opb, opb.getUnknownLoc(), left,
self.visit(right)).getResult(0)
elif isinstance(op, ast.NotEq):
left = tfp.Tf_NotEqualOp.create(opb, opb.getUnknownLoc(), left,
self.visit(right)).getResult(0)
else:
raise NotImplementedError('CompareOp operator not recognized')
return left
def visit_Constant(self, node):
opb = self.opbuilder
value = None
if isinstance(node.value, int):
value = tfp.Tf_ConstOp.create(
opb, opb.getUnknownLoc(),
tfp.IntegerAttr.get(
tfp.IntegerType.get(32, self.prog.ctx), node.value)).getResult(0)
return value
def visit_FunctionDef(self, node):
# Cache the current builder
cache_builder = self.opbuilder
inputs, outputs = [], []
for arg in node.args.args:
inputs.append(self.process_type(arg.annotation))
if node.returns:
outputs = [self.process_type(node.returns)]
currfunc = self.prog.add_function(
self.ctx.namer.new_symbol(node.name, []),
self.prog.get_function_type(inputs, outputs))
# Add the function to symbol table and enter new scope
self.symbol_table.insert_symbol(node.name, currfunc)
self.symbol_table.enter_scope()
# Add arguments to symbol table
for arg, value in zip(node.args.args, currfunc.getArguments()):
self.symbol_table.insert_symbol(arg.id, value)
self.opbuilder = tfp.OpBuilder(currfunc.getBody())
self.visit_block(node.body)
self.symbol_table.exit_scope()
self.opbuilder = cache_builder
def visit_If(self, node):
cond = self.visit(node.test)
# Create ifop
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
modified_in_cond = list(body_scope.modified | orelse_scope.modified)
outputs = [
self.symbol_table.lookup_type(str(var)) for var in modified_in_cond
]
ifop = tfp.IfOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), cond,
outputs)
# Cache the builder
cache_builder = self.opbuilder
# Visit body
self.opbuilder = tfp.OpBuilder(ifop.getRegion(0))
# Enter scope to avoid values generated inside the region to come in symbol
# table
self.symbol_table.enter_scope()
for stmt in node.body:
self.visit(stmt)
retvals = [
self.symbol_table.lookup(str(varname)) for varname in modified_in_cond
]
tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals)
self.symbol_table.exit_scope()
# Visit orelse
self.opbuilder = tfp.OpBuilder(ifop.getRegion(1))
self.symbol_table.enter_scope()
for stmt in node.orelse:
self.visit(stmt)
retvals = [
self.symbol_table.lookup(str(varname)) for varname in modified_in_cond
]
tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals)
self.symbol_table.exit_scope()
# Reset builder and enter return values in symbol table
self.opbuilder = cache_builder
for idx, var in enumerate(modified_in_cond):
self.symbol_table.insert_symbol(str(var), ifop.getResult(idx))
if ifop.getNumResults() == 1:
return ifop.getResult(0)
return tuple(ifop.getResult(i) for i in range(ifop.getNumResults()))
def visit_Name(self, node):
if self.symbol_table.lookup(node.id):
return self.symbol_table.lookup(node.id)
raise NotImplementedError('Symbol not found' + node.id)
def visit_Return(self, node):
opb = self.opbuilder
value = self.visit(node.value)
if isinstance(value, tuple):
# For more than one return values
return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), list(value))
return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), [value])
def visit_Tuple(self, node):
return tuple(self.visit(elt) for elt in node.elts)
def visit_UnaryOp(self, node):
operand = self.visit(node.operand)
if isinstance(node.op, ast.USub):
return tfp.Tf_NegOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
operand).getResult(0)
def _get_basic_loop_vars(self, modified, live_in, live_out):
# [This is directly from
# tensorflow/python/autograph/converters/control_flow.py]
# The loop variables corresponding to simple symbols (e.g. `x`).
basic_loop_vars = []
for s in modified:
if s.is_composite():
# TODO: Raise an error when this happens for a TF loop.
continue
# Variables not live into or out of the loop are considered local to the
# loop.
if s not in live_in and s not in live_out:
continue
basic_loop_vars.append(s)
return frozenset(basic_loop_vars)
def _get_composite_loop_vars(self, modified, live_in):
# [This is directly from
# tensorflow/python/autograph/converters/control_flow.py]
# The loop variables corresponding to composite symbols (e.g. `self.x`).
composite_loop_vars = []
for s in modified:
if not s.is_composite():
continue
# Mutations made to objects created inside the loop will appear as writes
# to composite symbols. Because these mutations appear as modifications
# made to composite symbols, we check whether the composite's parent is
# actually live into the loop.
# Example:
# while cond:
# x = Foo()
# x.foo = 2 * x.foo # x.foo is live into the loop, but x is not.
#
# Note that some parents might not be symbols - for example, in x['foo'],
# 'foo' is a parent, but it's a literal, not a symbol. We don't check the
# liveness of literals.
support_set_symbols = tuple(
sss for sss in s.support_set if sss.is_symbol())
if not all(sss in live_in for sss in support_set_symbols):
continue
composite_loop_vars.append(s)
return frozenset(composite_loop_vars)
def _get_loop_vars(self, node, modified):
# [This is directly from python/autograph/converters/control_flow.py]
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
reserved_symbols = body_scope.referenced
basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out)
composite_loop_vars = self._get_composite_loop_vars(modified, live_in)
loop_vars = tuple(basic_loop_vars | composite_loop_vars)
# Variable that are used or defined inside the loop, but not defined
# before entering the loop. Only simple variables must be defined. The
# composite ones will be implicitly checked at runtime.
undefined_lives = basic_loop_vars - defined_in
return loop_vars, reserved_symbols, undefined_lives
def visit_While(self, node):
# Create a new WhileOp
# `inputs` are initial values for loop variables
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
loop_vars, _, _ = self._get_loop_vars(node, body_scope.modified)
inputs = [self.symbol_table.lookup(str(name)) for name in loop_vars]
types = [input_.getType() for input_ in inputs]
while_op = tfp.WhileOp.create(self.opbuilder,
self.opbuilder.getUnknownLoc(), inputs, types)
# cache the current builder
cache_builder = self.opbuilder
# Process cond
self.symbol_table.enter_scope()
for input_, type_ in zip(loop_vars, types):
self.symbol_table.insert_symbol(
str(input_),
while_op.getRegion(0).front().addArgument(type_))
self.opbuilder = tfp.OpBuilder(while_op.getRegion(0))
tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
[self.visit(node.test)])
self.symbol_table.exit_scope()
# Process body
self.symbol_table.enter_scope()
for input_, type_ in zip(loop_vars, types):
self.symbol_table.insert_symbol(
str(input_),
while_op.getRegion(1).front().addArgument(type_))
self.opbuilder = tfp.OpBuilder(while_op.getRegion(1))
self.visit_block(node.body)
tfp.ReturnOp.create(
self.opbuilder, self.opbuilder.getUnknownLoc(),
[self.symbol_table.lookup(str(name)) for name in loop_vars])
self.symbol_table.exit_scope()
# Enter new values as symbols
for idx, var in enumerate(loop_vars):
self.symbol_table.insert_symbol(str(var), while_op.getResult(idx))
# Restore builder
self.opbuilder = cache_builder
def mlir_gen_internal(node, entity_info):
"""Returns mlir module for unprocessed node `node`."""
namer = naming.Namer({})
graphs = cfg.build(node)
ctx = transformer.Context(entity_info, namer, None)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx)
node = reaching_definitions.resolve(node, ctx, graphs)
node = reaching_fndefs.resolve(node, ctx, graphs)
node = liveness.resolve(node, ctx, graphs)
mlir_generator = MLIRGen(ctx)
mlir_generator.visit(node)
return mlir_generator.prog
def mlir_gen(func):
"""Parse a function and return TFProgram."""
node, source = parser.parse_entity(func, future_features=())
entity_info = transformer.EntityInfo(
name=func.__name__,
source_code=source,
source_file=None,
future_features=(),
namespace=inspect_utils.getnamespace(func))
return mlir_gen_internal(node, entity_info)
def mlir_gen_from_source(source=None, src_file=None):
"""Parse a function as either a string or from a supplied file path and return a TFProgram.
"""
if source is None:
source = open(src_file).read()
node = ast.parse(source)
entity_info = transformer.EntityInfo(
name='mlir_module',
source_code=source,
source_file=None,
future_features=(),
namespace={})
return mlir_gen_internal(node, entity_info)

View File

@ -0,0 +1,159 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Intermediate between python bindings for MLIR and mlir generation for tensorflow program.
This passes most of the mlir classes as is, but adds a few new operations and
the basic structure for a tensorflow program.
"""
# pylint: disable=invalid-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.compiler.mlir.python.mlir_wrapper import mlir_wrapper as mlir
# Class Definitions
OpBuilder = mlir.OpBuilder
Block = mlir.Block
# Types
Type = mlir.Type
IntegerType = mlir.IntegerType
FloatType = mlir.FloatType
RankedTensorType = mlir.RankedTensorType
UnrankedTensorType = mlir.UnrankedTensorType
IntegerAttr = mlir.IntegerAttr
# Standard Ops
ReturnOp = mlir.ReturnOp
# TF Dialect Ops
Tf_AnyOp = mlir.Tf_AnyOp
Tf_AddV2Op = mlir.Tf_AddV2Op
Tf_ConstOp = mlir.Tf_ConstOp
Tf_EqualOp = mlir.Tf_EqualOp
Tf_GreaterEqualOp = mlir.Tf_GreaterEqualOp
Tf_GreaterOp = mlir.Tf_GreaterOp
Tf_LegacyCallOp = mlir.Tf_LegacyCallOp
Tf_LessEqualOp = mlir.Tf_LessEqualOp
Tf_LessOp = mlir.Tf_LessOp
Tf_NegOp = mlir.Tf_NegOp
Tf_NotEqualOp = mlir.Tf_NotEqualOp
Tf_SubOp = mlir.Tf_SubOp
class IfOp(object):
"""
tfp.if(cond) ({body}, {orelse}) : type If `cond` is true, `body` is
executed, otherwise `orelse` is executed.
"""
@classmethod
def create(cls, opb, loc, cond, outputs):
state = mlir.OperationState(loc, "tfp.If")
state.addOperands([cond])
state.addTypes(outputs)
state.addRegion().push_back(Block.new()) # body region
state.addRegion().push_back(Block.new()) # orelse region
return opb.createOperation(state)
class OrOp(object):
"""
tfp.Or(ops...) This is like tf.Any, except that the first dimension is opened
into `ops`.
Returns a tensor of 1-bit integers which is "Logical OR" of the
coressponding elements in ops...
"""
@classmethod
def create(cls, opb, loc, values):
state = mlir.OperationState(loc, "tfp.Or")
state.addTypes(
[UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
state.addOperands(values)
return opb.createOperation(state)
class AndOp(object):
"""
tfp.And(ops...) This is like tf.All, except that the first dimension is opened
to `ops`.
Returns a tensor of 1-bit integers which is "Logical AND" of the
coressponding elements in ops...
"""
@classmethod
def create(cls, opb, loc, values):
state = mlir.OperationState(loc, "tfp.And")
state.addTypes(
[UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
state.addOperands(values)
return opb.createOperation(state)
class WhileOp(object):
"""tfp.While(init-vals, {
^bb1(cond-args):
cond-region
return cond
}, {
^bb1(body-args):
body-region
})
As long as `cond-region` returns a "true"-like value, the body-region
is executed and the arguments are replaced by its return values for the next
iteration.
"""
@classmethod
def create(cls, opb, loc, inputs, outputs):
state = mlir.OperationState(loc, "tfp.While")
state.addOperands(inputs)
state.addTypes(outputs)
state.addRegion().push_back(Block.new()) # cond region
state.addRegion().push_back(Block.new()) # body region
return opb.createOperation(state)
class TFProgram(object):
"""Python wrap for a Tensorflow Program (essentially an mlir Module)."""
def __init__(self):
mlir.registerDialects()
self.ctx = mlir.MLIRContext()
self.builder = mlir.Builder(self.ctx)
self.module = mlir.ModuleOp.create(mlir.UnknownLoc.get(self.ctx))
self.curr_func = None
def add_function(self, name, func_type):
self.curr_func = mlir.FuncOp.create(
mlir.UnknownLoc.get(self.ctx), name, func_type)
self.module.push_back(self.curr_func)
return self.curr_func
def get_function_type(self, inputs, outputs):
return self.builder.getFunctionType(inputs, outputs)
def dump(self):
self.module.dump()
def __str__(self):
return self.module.getAsStr()

View File

@ -0,0 +1,20 @@
package(licenses = ["notice"])
py_test(
name = "mlir_gen_test",
size = "small",
testonly = True,
srcs = ["mlir_gen_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = [
"no_oss_py2",
"no_pip",
],
deps = [
"//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper",
"//tensorflow/python:client_testlib",
"//tensorflow/python/tf_program:mlir_gen",
"//tensorflow/python/types",
],
)

View File

@ -0,0 +1,247 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `mlir_gen` module"""
# pylint: disable=missing-function-docstring
# pylint: disable=invalid-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.platform import test
from tensorflow.python.types import core
from tensorflow.python.tf_program.mlir_gen import mlir_gen
import tensorflow.compiler.mlir.python.mlir_wrapper.filecheck_wrapper as fw
class MLIRGenTestBase(test.TestCase):
def _check_code(self, mlir_code, exp_mlir_code):
return self.assertTrue(fw.check(str(mlir_code), exp_mlir_code))
class MLIRGenTest(MLIRGenTestBase):
"""MLIR Generation Tests for Tensorflow Program"""
def test_simple(self):
def test_fn():
pass
mlir_code = mlir_gen(test_fn)
mlir_code_exp = r"""
CHECK-LABEL: @test_fn
"""
self._check_code(mlir_code, mlir_code_exp)
def test_argument(self):
def test_fn(x: core.Tensor) -> core.Tensor:
return x
mlir_code = mlir_gen(test_fn)
mlir_code_exp = r"""
CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32> {
CHECK-NEXT: return %arg0 : tensor<*xi32>
"""
self._check_code(mlir_code, mlir_code_exp)
def test_constant(self):
def test_fn() -> int:
return 23
mlir_code = mlir_gen(test_fn)
exp_mlir_code = r"""
CHECK-LABEL: func @test_fn() -> i32
CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor<i32>} : () -> tensor<i32>
CHECK: return %[[r0]] : tensor<i32>
"""
self._check_code(mlir_code, exp_mlir_code)
def test_BoolOp(self):
def test_fn(x: bool, y: bool) -> bool:
return x or y or x and x and y
mlir_code = mlir_gen(test_fn)
exp_mlir_code = r"""
CHECK-LABEL: func @test_fn(%arg0: i1, %arg1: i1) -> i1
CHECK: %[[r0:[0-9]+]] = "tfp.And"(%arg0, %arg0, %arg1) : (i1, i1, i1) -> tensor<*xi1>
CHECK: %[[r1:[0-9]+]] = "tfp.Or"(%arg0, %arg1, %[[r0]]) : (i1, i1, tensor<*xi1>) -> tensor<*xi1>
return %[[r1]] : tensor<*xi1>
"""
self._check_code(mlir_code, exp_mlir_code)
def test_Call(self):
def test_fn():
def f1():
return 23
def f2():
return f1()
f2()
mlir_code = mlir_gen(test_fn)
exp_mlir_code = r"""
CHECK-LABEL: func @test_fn()
CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f2} : () -> ()
CHECK: }
CHECK-LABEL: func @f1() {
CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor<i32>} : () -> tensor<i32>
CHECK: return %[[r0]] : tensor<i32>
CHECK: }
CHECK-LABEL: func @f2() {
CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f1} : () -> ()
}
"""
self._check_code(mlir_code, exp_mlir_code)
def test_Compare(self):
def test_fn(x: core.Tensor, y: core.Tensor, z: core.Tensor):
return x > y < z
mlir_code = mlir_gen(test_fn)
exp_mlir_code = r"""
CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>)
CHECK: %[[r0:[0-9]+]] = "tf.Greater"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
CHECK: %[[r1:[0-9]+]] = "tf.Less"(%[[r0]], %arg2) : (tensor<*xi1>, tensor<*xi32>) -> tensor<*xi1>
CHECK: return %[[r1]] : tensor<*xi1>
"""
self._check_code(mlir_code, exp_mlir_code)
def test_Assign_BinOp(self):
def test_fn() -> int:
y = 12 + 23 - 24
return y
mlir_code = mlir_gen(test_fn)
exp_mlir_code = r"""
CHECK-LABEL: func @test_fn() -> i32
CHECK: %[[r0:[0-9]+]] = "tf.AddV2"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
CHECK: %[[r1:[0-9]+]] = "tf.Sub"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
CHECK: return %[[r1]] : tensor<i32>
"""
self._check_code(mlir_code, exp_mlir_code)
def test_if(self):
def test_fn(x: core.Tensor) -> int:
res = 0
if x > 0:
res = 1
elif x < 0:
res = -1
else:
res = 0
return res
mlir_code = mlir_gen(test_fn)
exp_mlir_code = r"""
CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> i32
CHECK: %[[r1:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
CHECK-NEXT: %[[r2:[0-9]+]] = "tfp.If"(%[[r1]]) ( {
CHECK: return %{{[0-9]+}} : tensor<i32>
CHECK-NEXT: }, {
CHECK: %[[r3:[0-9]+]] = "tf.Less"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
CHECK: %[[r4:[0-9]+]] = "tfp.If"(%[[r3]]) ( {
CHECK: %[[r5:[0-9]+]] = "tf.Neg"(%{{[0-9]+}}) : (tensor<i32>) -> tensor<i32>
CHECK: return %[[r5]] : tensor<i32>
CHECK-NEXT: }, {
CHECK: return %{{[0-9]+}} : tensor<i32>
CHECK-NEXT: }) : (tensor<*xi1>) -> tensor<i32>
CHECK: return %[[r4]] : tensor<i32>
CHECK-NEXT: }) : (tensor<*xi1>) -> tensor<i32>
CHECK-NEXT: return %[[r2]] : tensor<i32>
"""
self._check_code(mlir_code, exp_mlir_code)
def test_while(self):
def test_fn(x: core.Tensor) -> core.Tensor:
s = 0
while x > 0:
s = s + x
return s
mlir_code = mlir_gen(test_fn)
exp_mlir_code = r"""
CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32>
CHECK: %[[r1:[0-9]+]] = "tfp.While"(%0) ( {
CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor<i32>):
CHECK: %[[r2:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
CHECK-NEXT: return %[[r2]] : tensor<*xi1>
CHECK-NEXT: }, {
CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor<i32>):
CHECK: %[[r3:[0-9]+]] = "tf.AddV2"(%arg1, %arg0) : (tensor<i32>, tensor<*xi32>) -> tensor<*xi32>
CHECK-NEXT: return %[[r3]] : tensor<*xi32>
CHECK-NEXT: }) : (tensor<i32>) -> tensor<i32>
CHECK-NEXT: return %[[r1]] : tensor<i32>
"""
self._check_code(mlir_code, exp_mlir_code)
def test_fibonacci(self):
def test_fn(x: core.Tensor) -> core.Tensor:
res, idx = 0, 2
a, b = 0, 1
if x == 0 or x == 1:
res = x
else:
while idx <= x:
res = a + b
a = b
b = res
idx = idx + 1
return res
mlir_code = mlir_gen(test_fn)
exp_mlir_code = r"""
CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32>
CHECK: %[[r5:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
CHECK: %[[r7:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
CHECK: %[[r8:[0-9]+]] = "tfp.Or"(%[[r5]], %[[r7]]) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1>
CHECK: %[[r9:[0-9]+]]:4 = "tfp.If"(%[[r8]]) ( {
CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>
CHECK-NEXT: }, {
CHECK-NEXT: %[[r10:[0-9]+]]:4 = "tfp.While"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
CHECK-NEXT: %[[r11:[0-9]+]] = "tf.LessEqual"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>) -> tensor<*xi1>
CHECK-NEXT: return %[[r11]] : tensor<*xi1>
CHECK-NEXT: }, {
CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
CHECK-NEXT: %[[r12:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
CHECK: %[[r13:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
CHECK-NEXT: }) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
CHECK-NEXT: return %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}} : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
CHECK-NEXT: }) : (tensor<*xi1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
CHECK-NEXT: return %[[r9]]#{{[0-9]+}} : tensor<i32>
"""
self._check_code(mlir_code, exp_mlir_code)
if __name__ == '__main__':
test.main()