Graduate the mlir implementation of tfr from experimental

PiperOrigin-RevId: 333369220
Change-Id: Iafce731c4d80f06eb1e95aee6ea0f67af8caeac5
This commit is contained in:
Feng Liu 2020-09-23 14:08:10 -07:00 committed by TensorFlower Gardener
parent 40ccaf67ca
commit 938cc7bf9c
21 changed files with 3596 additions and 0 deletions

View File

@ -0,0 +1,165 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load(
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(
default_visibility = [
":friends",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/tfr/...",
"//tensorflow/compiler/mlir/...",
],
)
filegroup(
name = "tfr_ops_td_files",
srcs = [
"ir/tfr_ops.td",
"//tensorflow/compiler/mlir/tensorflow:ir/tf_op_base.td",
"//tensorflow/compiler/mlir/tensorflow:ir/tf_op_interfaces.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeBase.td",
"@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeOps.td",
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
],
)
gentbl(
name = "tfr_ops_inc_gen",
tbl_outs = [
(
"-gen-op-decls",
"ir/tfr_ops.h.inc",
),
(
"-gen-op-defs",
"ir/tfr_ops.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfr_ops.td",
td_srcs = [
":tfr_ops_td_files",
],
)
cc_library(
name = "tfr",
srcs = [
"ir/tfr_ops.cc",
"ir/tfr_ops.cc.inc",
],
hdrs = [
"ir/tfr_ops.h",
"ir/tfr_ops.h.inc",
"ir/tfr_types.h",
],
deps = [
":tfr_ops_inc_gen",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
)
cc_library(
name = "utils",
srcs = [
"utils/utils.cc",
],
hdrs = [
"utils/utils.h",
],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "passes",
srcs = [
"passes/canonicalize.cc",
"passes/decompose.cc",
"passes/raise_to_tf.cc",
],
hdrs = [
"passes/passes.h",
],
deps = [
":tfr",
":utils",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToStandard",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
tf_cc_binary(
name = "tfr-opt",
srcs = ["passes/tfr_opt.cc"],
deps = [
":passes",
":tfr",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir:passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
],
)
glob_lit_tests(
data = [
":test_utilities",
],
driver = "//tensorflow/compiler/mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/tfr:tfr-opt",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
"@llvm-project//mlir:run_lit.sh",
],
)

View File

@ -0,0 +1,590 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
#include <algorithm>
#include <string>
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/FunctionImplementation.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
namespace mlir {
namespace TFR {
//===----------------------------------------------------------------------===//
// InlinerInterface
//===----------------------------------------------------------------------===//
namespace {
/// This class defines the interface for inlining within the TFR dialect.
struct TFRInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
// Returns true if the given region 'src' can be inlined into the region
// 'dest' that is attached to an operation registered to the current dialect.
bool isLegalToInline(Region *dest, Region *src,
BlockAndValueMapping &) const final {
return true;
}
// Returns true if the given operation 'op', that is registered to this
// dialect, can be inlined into the region 'dest' that is attached to an
// operation registered to the current dialect.
bool isLegalToInline(Operation *op, Region *dest,
BlockAndValueMapping &) const final {
return true;
}
// Handle the given inlined terminator by replacing it with a new operation
// as necessary. Required when the region has only one block.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
auto retValOp = dyn_cast<TFRReturnOp>(op);
if (!retValOp) return;
for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) {
std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value));
}
}
// Attempts to materialize a conversion for a type mismatch between a call
// from this dialect, and a callable region. This method should generate an
// operation that takes 'input' as the only operand, and produces a single
// result of 'resultType'. If a conversion can not be generated, nullptr
// should be returned.
Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type result_type,
Location conversion_loc) const final {
if (!result_type.isa<IntegerType>()) return nullptr;
return builder.create<TruncateIOp>(conversion_loc, result_type, input);
}
};
} // namespace
//===----------------------------------------------------------------------===//
// TFR Dialect
//===----------------------------------------------------------------------===//
TFRDialect::TFRDialect(MLIRContext *context)
: Dialect(/*name=*/"tfr", context, TypeID::get<TFRDialect>()) {
addTypes<TFRTensorType, TFRTensorListType, TFRAttrType>();
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
>();
addInterfaces<TFRInlinerInterface>();
}
bool TFRType::classof(Type type) {
return llvm::isa<TFRDialect>(type.getDialect());
}
//===----------------------------------------------------------------------===//
// Custom op methods
//===----------------------------------------------------------------------===//
static LogicalResult Verify(ConstantTensorOp op) {
auto input_type = op.arg().getType();
auto output_type = op.out().getType();
if (auto output_tensor_type = output_type.dyn_cast<TFRTensorType>()) {
return success();
}
auto output_tensor_type = output_type.dyn_cast<RankedTensorType>();
if (!output_tensor_type || !output_tensor_type.hasStaticShape()) {
op.emitError("output type should be static and ranked.");
return failure();
}
if (output_tensor_type.getRank() == 0) {
bool same_scalar = output_tensor_type.getElementType() == input_type;
if (!same_scalar) {
op.emitError("input and output should have the same scalar types.");
}
return success(same_scalar);
}
if (auto input_vector_type = input_type.dyn_cast<VectorType>()) {
bool same_element_type = output_tensor_type.getElementType() ==
input_vector_type.getElementType();
bool same_shape =
output_tensor_type.getShape() == input_vector_type.getShape();
if (!same_element_type || !same_shape) {
op.emitError("input and output should have same shape and element type.");
}
return success(same_element_type && same_shape);
}
op.emitError("input can not be converted to an output tensor.");
return failure();
}
static LogicalResult Verify(TFRFuncOp func) {
// Collect all attribute names used by the tensor and tensor list arguments
// and returns. Also, collect the names of all the attribute arguments as the
// defined list. Later on, the used attribute names will be verified to be in
// the defined list.
llvm::SmallVector<StringAttr, 4> used_attrs;
// While scanning the arguments, record the start/end indices of each argument
// type, so the order can be verified as well.
// TODO(fengliuai): the attribute arguments with default values need to be
// at the end?
int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
last_tensor_list = -1, first_attr = -1;
for (auto arg : llvm::enumerate(func.getType().getInputs())) {
Type arg_type = arg.value();
if (auto tensor = arg_type.dyn_cast<TFRTensorType>()) {
if (first_tensor == -1) {
first_tensor = arg.index();
}
last_tensor = arg.index();
auto used = tensor.getAttrKeys();
used_attrs.append(used.begin(), used.end());
continue;
}
if (auto tensor_list = arg_type.dyn_cast<TFRTensorListType>()) {
if (first_tensor_list == -1) {
first_tensor_list = arg.index();
}
last_tensor_list = arg.index();
auto used = tensor_list.getAttrKeys();
used_attrs.append(used.begin(), used.end());
continue;
}
if (!arg_type.isa<TensorType>()) {
if (first_attr == -1) {
first_attr = arg.index();
}
auto name =
func.getArgAttrOfType<StringAttr>(arg.index(), kAttrArgumentNameAttr);
if (!name) {
func.emitError(
llvm::Twine(arg.index()) +
" attribute argument doesn't have a tfr.name attribute.");
return failure();
}
continue;
}
func.emitError("Builtin TensorType isn't allowed as the argument.");
return failure();
}
// Verify the argument order: tensors, tensor list, attributes; and also
// verify there is at most one tensor list argument.
if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
func.emitError(
"tfr.tensor argument should be before tfr.tensor_list argument.");
return failure();
}
if (first_attr != -1 && first_attr < last_tensor_list) {
func.emitError(
"tfr.tensor_list argument should be before non tensor arguments.");
return failure();
}
if (first_tensor_list != last_tensor_list) {
func.emitError("More than one tfr.tensor_list argument isn't allowed.");
return failure();
}
// Verify the result order: tensor, tensor list, and also verify at most one
// tensor list result.
bool seen_tensor_list = false;
for (auto result_type : func.getType().getResults()) {
if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
if (seen_tensor_list) {
func.emitError(
"tfr.tensor result should be before tfr.tensor_list result.");
return failure();
}
auto used = tensor.getAttrKeys();
used_attrs.append(used.begin(), used.end());
continue;
}
if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
if (seen_tensor_list) {
func.emitError("More than one tfr.tensor_list result isn't allowed.");
return failure();
}
seen_tensor_list = true;
auto used = tensor_list.getAttrKeys();
used_attrs.append(used.begin(), used.end());
continue;
}
func.emitError(
"None tfr.tensor/tfr.tensor_list results aren't allowed as a "
"result.");
return failure();
}
// Verify that all the used attributes are in the attribute arguments.
llvm::SmallVector<StringAttr, 4> undefined_attrs;
for (auto attr : used_attrs) {
if (!func.getAttr(attr.getValue())) {
undefined_attrs.push_back(attr);
}
}
if (!undefined_attrs.empty()) {
llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
std::transform(undefined_attrs.begin(), undefined_attrs.end(),
attr_names.begin(),
[](StringAttr attr) { return attr.getValue().str(); });
func.emitError(llvm::Twine("Undefined attributes are used: ",
llvm::join(attr_names, ",")));
return failure();
}
return success();
}
static ParseResult ParseFuncOp(OpAsmParser &parser, OperationState *result) {
auto build_func_type = [](Builder &builder, ArrayRef<Type> arg_types,
ArrayRef<Type> results, impl::VariadicFlag,
std::string &) {
return builder.getFunctionType(arg_types, results);
};
return impl::parseFunctionLikeOp(parser, *result, /*allowVariadic=*/false,
build_func_type);
}
static void PrintFuncOp(OpAsmPrinter &p, TFRFuncOp op) {
FunctionType fn_type = op.getType();
impl::printFunctionLikeOp(p, op, fn_type.getInputs(), /*isVariadic=*/false,
fn_type.getResults());
}
} // namespace TFR
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
namespace mlir {
namespace TFR {
struct ConvertConstToTensorConst : public OpRewritePattern<ConstantTensorOp> {
using OpRewritePattern<ConstantTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op,
PatternRewriter &rewriter) const override {
Location loc = cst_tensor_op.getLoc();
Type out_type = cst_tensor_op.getType();
Operation *new_cst = nullptr;
ArrayAttr array;
if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
llvm::DenseSet<Type> all_types;
for (auto it : array) {
all_types.insert(it.getType());
}
if (all_types.size() != 1) return failure();
ShapedType new_out_type = RankedTensorType::get(
{static_cast<int64_t>(array.size())}, *all_types.begin());
DenseElementsAttr attr =
DenseElementsAttr::get(new_out_type, array.getValue());
new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, attr);
if (out_type.isa<TFRTensorType>()) {
new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
}
rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
return success();
}
Attribute scalar;
if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
Type new_out_type = RankedTensorType::get({}, scalar.getType());
new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
if (out_type.isa<TFRTensorType>()) {
new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
}
rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
return success();
}
return failure();
}
};
struct RemoveRedundantCast : public OpRewritePattern<CastOp> {
using OpRewritePattern<CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CastOp cast_op,
PatternRewriter &rewriter) const override {
auto preceding_cast =
llvm::dyn_cast_or_null<CastOp>(cast_op.arg().getDefiningOp());
if (!preceding_cast) {
return failure();
}
Value input = preceding_cast.arg();
Type input_type = input.getType();
Type output_type = cast_op.getType();
// If the two types are the same, the back-to-back tfr.cast ops can be
// removed.
if (input_type == output_type || output_type.isa<UnrankedTensorType>()) {
rewriter.replaceOp(cast_op, {input});
return success();
}
// If the rank of the input tensor isn't ranked, we replace the pair
// with tf.EnsureShape op so it can be removed after shape inference or
// confirmed at runtime.
if (input_type.isa<UnrankedTensorType>() && output_type.isa<ShapedType>()) {
auto shape = output_type.cast<ShapedType>().getShape();
auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape);
rewriter.replaceOpWithNewOp<TF::EnsureShapeOp>(cast_op, output_type,
input, shape_attr);
}
return success();
}
};
struct GetTensorShape : public OpRewritePattern<GetShapeOp> {
using OpRewritePattern<GetShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GetShapeOp shape_op,
PatternRewriter &rewriter) const override {
Operation *preceding_op = shape_op.arg().getDefiningOp();
if (auto cast_op = llvm::dyn_cast_or_null<CastOp>(preceding_op)) {
// replace this pair by shape.shape_of, so the folding works.
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(shape_op, cast_op.arg());
return success();
}
return failure();
}
};
struct RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
using OpRewritePattern<GetElementOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GetElementOp ge_op,
PatternRewriter &rewriter) const override {
IntegerAttr index;
if (!matchPattern(ge_op.index(), m_Constant(&index))) {
return failure();
}
auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
ge_op.tensor_list().getDefiningOp());
if (!preceding_build_list ||
preceding_build_list.getNumOperands() <= index.getInt()) {
return failure();
}
Value input = preceding_build_list.getOperand(index.getInt());
Type output_type = ge_op.getType();
if (input.getType() != output_type &&
!output_type.isa<UnrankedTensorType>()) {
return failure();
}
rewriter.replaceOp(ge_op, {input});
return success();
}
};
struct BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
using OpRewritePattern<BuildListOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BuildListOp bl_op,
PatternRewriter &rewriter) const override {
SmallVector<Attribute, 4> array_list;
array_list.reserve(bl_op.getNumOperands());
for (const auto &operand : bl_op.getOperands()) {
Attribute array_elt;
if (!matchPattern(operand, m_Constant(&array_elt))) {
return failure();
}
array_list.push_back(array_elt);
}
auto array_attr = rewriter.getArrayAttr(array_list);
rewriter.replaceOpWithNewOp<TFR::ConstOp>(bl_op, array_attr);
return success();
}
};
void ConstantTensorOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ConvertConstToTensorConst>(context);
}
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RemoveRedundantCast>(context);
}
void GetShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<GetTensorShape>(context);
}
void GetElementOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<RemoveRedundantGetElement>(context);
}
void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<BuildConstantListAsAttr>(context);
}
OpFoldResult TFR::EqualOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "equal op has two operands");
auto ctx = getContext();
if (operands[0] == operands[1]) return BoolAttr::get(/*value=*/true, ctx);
return BoolAttr::get(/*value=*/false, ctx);
}
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
// Return the held attribute value.
return value();
}
// CallableOpInterface
Region *TFRFuncOp::getCallableRegion() {
return isExternal() ? nullptr : &body().front();
}
// CallableOpInterface
ArrayRef<Type> TFRFuncOp::getCallableResults() {
return getType().getResults();
}
//===----------------------------------------------------------------------===//
// Dialect type definitions
//===----------------------------------------------------------------------===//
// Parses a TFR type.
// tfr_type ::= tensor_type | tensor_list_type | attr_type
// string_list ::= `[` string-literal (, string-literal)+ `]`
// tensor_type ::= `tensor`
// | `tensor<` (string-literal | string_list) '>'
// tensor_list_type ::= `tensor_list`
// | `tensor_list<` (string-literal | string_list) '>'
// attr_type ::= `attr`
Type TFRDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
MLIRContext *ctx = loc.getContext();
StringRef typeNameSpelling;
if (failed(parser.parseKeyword(&typeNameSpelling))) return {};
llvm::SmallVector<StringAttr, 4> attrs;
if (succeeded(parser.parseOptionalLess())) {
bool l_square_parsed = false;
if (succeeded(parser.parseOptionalLSquare())) {
l_square_parsed = true;
}
do {
StringRef attr;
if (failed(parser.parseKeyword(&attr))) return {};
attrs.push_back(StringAttr::get(attr, ctx));
} while (succeeded(parser.parseOptionalComma()));
if (l_square_parsed && failed(parser.parseRSquare())) {
parser.emitError(parser.getNameLoc(), "expected ']'");
}
if (failed(parser.parseGreater())) {
parser.emitError(parser.getNameLoc(), "expected '>'");
}
}
if (typeNameSpelling == "tensor") {
return TFRTensorType::getChecked(attrs, loc);
} else if (typeNameSpelling == "tensor_list") {
return TFRTensorListType::getChecked(attrs, loc);
} else if (typeNameSpelling == "attr") {
return TFRAttrType::getChecked(loc);
} else {
parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling);
return {};
}
}
void TFRDialect::printType(Type type, DialectAsmPrinter &os) const {
llvm::ArrayRef<StringAttr> attrs;
if (type.isa<TFRAttrType>()) {
os << "attr";
return;
}
if (auto tensor_ty = type.dyn_cast<TFRTensorType>()) {
attrs = tensor_ty.getAttrKeys();
os << "tensor";
} else if (auto tensor_list_ty = type.dyn_cast<TFRTensorListType>()) {
attrs = tensor_list_ty.getAttrKeys();
os << "tensor_list";
} else {
llvm_unreachable("Unhandled tfr type");
}
if (attrs.empty()) return;
os << "<";
if (attrs.size() > 1) {
os << "[";
}
llvm::interleaveComma(attrs, os,
[&](StringAttr attr) { os << attr.getValue(); });
if (attrs.size() > 1) {
os << "]";
}
os << ">";
}
} // namespace TFR
} // namespace mlir

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/FunctionSupport.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
namespace mlir {
namespace TFR {
constexpr char kAttrArgumentNameAttr[] = "tfr.name";
constexpr char kAttrArgumentDefaultAttr[] = "tfr.default";
class TFRDialect : public Dialect {
public:
explicit TFRDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "tfr"; }
// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
// Prints a type registered to this dialect.
void printType(Type ty, DialectAsmPrinter &os) const override;
};
} // namespace TFR
} // namespace mlir
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_

View File

@ -0,0 +1,436 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for TFR
#ifndef DIALECT_TFR_OPS_
#define DIALECT_TFR_OPS_
include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
//===----------------------------------------------------------------------===//
// Dialect
//===----------------------------------------------------------------------===//
def TFR_Dialect : Dialect {
let name = "tfr";
let description = [{
The TensorFlow Composition dialect.
}];
let cppNamespace = "::mlir::TFR";
}
//===----------------------------------------------------------------------===//
// Type classes
//===----------------------------------------------------------------------===//
// tensor argument types
class TFR_Type<string name> : DialectType<TFR_Dialect,
CPred<"$_self.isa<mlir::TFR::" # name # "Type>()">,
"TFR " # name #" type">,
BuildableType<"$_builder.getType<mlir::TFR::" # name # "Type>()">;
def TFR_TensorType : TFR_Type<"TFRTensor">;
def TFR_TensorListType : TFR_Type<"TFRTensorList">;
def TFR_AllTensorTypes : Type<Or<[
TFR_TensorType.predicate,
TFR_TensorListType.predicate]>, "all tensor related types">;
// attribute argument types
def TFR_AttrType : TFR_Type<"TFRAttr">;
def TFR_AttrScalarType: TypeAlias<TF_ElementType, "scalar attribute">;
def TFR_AttrVectorType : VectorOf<[TF_ElementType, TFR_AttrType]>;
def TFR_AllAttrTypes : Type<Or<[
TFR_AttrType.predicate,
Index.predicate,
TFR_AttrScalarType.predicate,
TFR_AttrVectorType.predicate]>, "all attribute related types">;
// all allowed arguments types
def TFR_allowedArgType : Type<Or<[
TFR_AllTensorTypes.predicate,
TFR_AllAttrTypes.predicate]>, "allowed tfr.call operand types">;
def TFR_allowedConstValues : Attr<Or<[
FlatSymbolRefAttr.predicate,
TypeAttr.predicate,
StrAttr.predicate,
ArrayAttr.predicate]>, "allowed tfr.constant value"> {
let storageType = "Attribute";
let returnType = "Attribute";
let convertFromStorage = "$_self";
let constBuilderCall = "$0";
}
// all allowed result types
def TFR_allowedResultType : TypeAlias<TFR_AllTensorTypes,
"allowed tfr.call result types">;
// standard tensor type and tfr.tensor types can be casted to each other.
def TFR_singleTensorType : Type<Or<[
TFR_TensorType.predicate,
TF_Tensor.predicate]>, "single tensor or tfr.tensor type">;
// all allowed build list input types
def TFR_allowedBuiltListType : Type<Or<[
TFR_TensorType.predicate,
TF_ElementType.predicate]>, "single tfr.tensor or tensor element type">;
// all allowed build list result types
def TFR_allowedListResultType : Type<Or<[
TFR_TensorListType.predicate,
TFR_AttrType.predicate]>, "tfr.tensor_list or tfr.attr type">;
//===----------------------------------------------------------------------===//
// Op classes
//===----------------------------------------------------------------------===//
class TFR_Op<string mnemonic, list<OpTrait> traits> :
Op<TFR_Dialect, mnemonic, traits>;
def TFR_CallOp : TFR_Op<"call", [CallOpInterface]> {
let description = [{
The `call` operation represents a direct call to a function that is within
the same symbol scope as the callee. The operands and result types of the
call must match the specified function type. The callee is encoded as a
symbol reference attribute named "callee".
Example:
```mlir
%2 = tfr.call @my_add(%0, %1) : (tfr.tensor, f32) -> tfr.tensor_list
```
Note that the operands of the `call` operation can only be with tfr.tensor,
tfr.tensor_list, tfr.attr and mlir float and integer types. The results of
the `call` operation can only be with tfr.tensor and tfr.tensor_list types.
}];
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<TFR_allowedArgType>:$args);
let results = (outs
Variadic<TFR_allowedResultType>:$outs);
let extraClassDeclaration = [{
StringRef getCallee() { return callee(); }
// Get the argument operands to the called function.
operand_range getArgOperands() { return args(); }
// Return the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return calleeAttr(); }
}];
let assemblyFormat = [{
$callee `(` $args `)` attr-dict `:` functional-type($args, results)
}];
}
def TFR_CastOp : TFR_Op<"cast", [NoSideEffect]> {
let description = [{
The `cast` operation converts the operand with built-in tensor type to
tfr.tensor type, or vice versa.
Example:
```mlir
%1 = tfr.cast(%0) : tensor<f32> -> !tfr.tensor
%3 = tfr.cast(%1) : !tfr.tensor -> tensor<f32>
```
}];
let arguments = (ins TFR_singleTensorType:$arg);
let results = (outs TFR_singleTensorType:$out);
let extraClassDeclaration = [{
// Return element type of the input tensor type. Only available when the
// input is a MLIR built-in tensor type.
Attribute getInputElementType() {
if (auto ty = arg().getType().dyn_cast<TensorType>()) {
return TypeAttr::get(ty.getElementType());
}
return {};
}
}];
let hasCanonicalizer = 1;
}
def TFR_GetShapeOp : TFR_Op<"get_shape", [NoSideEffect]> {
let description = [{
The `get_shape` operation gets the shape of a tfr.tensor and returns
!shape.shape type.
Example:
```mlir
%1 = "tfr.get_shape"(%0) : !tfr.tensor -> !shape.shape
%1 = tfr.get_shape %0 -> !shape.shape
```
}];
let arguments = (ins TFR_TensorType:$arg);
let results = (outs Shape_ShapeType:$out);
let assemblyFormat = "$arg attr-dict `->` type($out)";
let hasCanonicalizer = 1;
}
def TFR_GetElementTypeOp : TFR_Op<"get_element_type", [NoSideEffect]> {
let description = [{
The `get_element_type` operation gets the element type of a tfr.tensor and
returns !tfr.attr.
Example:
```mlir
%1 = "tfr.get_element_type"(%0) : !tfr.tensor -> !tfr.attr
%1 = tfr.get_element_type %0 -> !tfr.attr
```
}];
let arguments = (ins TFR_TensorType:$arg);
let results = (outs TFR_AttrType:$out);
let assemblyFormat = "$arg attr-dict `->` type($out)";
}
def TFR_EqualOp : TFR_Op<"equal", [NoSideEffect, SameTypeOperands]> {
let description = [{
The `equal` operation compares the values of the tfr.attr type arguments.
The operation returns an i1 boolean indicating if the two values are the
same.
Example:
```mlir
%x = tfr.equal %lhs, %rhs -> i1
%x = "tfr.equal"(%lhs, %rhs) : (!tfr.attr, !tfr.attr) -> i1
```
}];
let arguments = (ins
TFR_AttrType:$lhs,
TFR_AttrType:$rhs
);
let results = (outs BoolLike:$result);
let hasFolder = 1;
let assemblyFormat = "$lhs `,` $rhs attr-dict `->` type($result)";
}
def TFR_ConstOp : TFR_Op<"constant", [ConstantLike, NoSideEffect]> {
let description = [{
The `attr` operation stores TF op's attribute, which doesn't support
arithmetic operations.
Example:
```mlir
%1 = "tfr.constant"() { value: i32 } : () -> !tfr.attr
%2 = "tfr.constant"() { value: [i32, f32] } : () -> !tfr.attr
%3 = tfr.constant [i32, f32] -> !tfr.attr
%4 = tfr.constant f32 -> !tfr.attr
```
}];
let arguments = (ins TFR_allowedConstValues:$value);
let results = (outs TFR_AttrType:$out);
let hasFolder = 1;
let builders = [OpBuilder<
"OpBuilder &, OperationState &state, Attribute value",
[{
auto* ctx = value.getContext();
state.addAttribute("value", value);
state.addTypes(TFRAttrType::get(ctx));
}]>
];
let assemblyFormat = [{
$value attr-dict `->` type($out)
}];
}
def TFR_ConstantTensorOp : TFR_Op<"constant_tensor", [NoSideEffect]> {
let description = [{
The `constant_tensor` operation converts the operand with non-built-in
tensor type to built-in tensor type or tfr.tensor type. If it is built-in
tensor type, the shape shouldn't be changed during the conversion.
Example:
```mlir
%1 = tfr.contant_tensor(%0) : f32 -> tensor<f32>
%3 = tfr.contant_tensor(%2) : vector<1xf32> -> tensor<1xf32>
```
}];
let arguments = (ins TFR_AllAttrTypes:$arg);
let results = (outs TFR_singleTensorType:$out);
let hasCanonicalizer = 1;
let verifier = [{ return Verify(*this); }];
}
def TFR_GetElementOp : TFR_Op<"get_element", [NoSideEffect]> {
let description = [{
The `get_element` operation extracts one tfr.tensor element from a
tfr.tensor_list.
Example:
```mlir
%2 = tfr.get_element %1[%0] : (tfr.tensor, index) -> tfr.tensor
```
}];
let arguments = (ins
TFR_TensorListType:$tensor_list,
Index:$index);
let results = (outs TFR_TensorType:$out);
let hasCanonicalizer = 1;
let assemblyFormat = [{
$tensor_list `[` $index `]` attr-dict `:`
`(` type($tensor_list) `,` type($index) `)` `->` type($out)
}];
}
def TFR_BuildListOp : TFR_Op<"build_list", [NoSideEffect]> {
let description = [{
The `build_list` operation builds a tensor list from a list of tensors, or
an tfr.attr from a list of scalars.
Example:
```mlir
%3 = tfr.build_list(%2, %1, %0) :
(tfr.tensor, tfr.tensor, tfr.tensor) -> tfr.tensor_list
%3 = tfr.build_list(%2, %1, %0) : (i32, i32, i32) -> tfr.attr
```
}];
let arguments = (ins Variadic<TFR_allowedBuiltListType>:$tensors);
let results = (outs TFR_allowedListResultType:$out);
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// Function related classes
//===----------------------------------------------------------------------===//
def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">,
DeclareOpInterfaceMethods<CallableOpInterface>,
FunctionLike,
IsolatedFromAbove, Symbol]> {
let summary = "TFR Function defines a composition of other ops";
let description = [{
Defines a function that can be used to decompose an TF function call to
the invocation of a set of other TF ops.
Syntax:
```
op ::= `tfr.func` symbol-ref-id `(` argument-list `)` (`->`
function-result-list)? function-attributes? region
```
Example:
```mlir
tfr.func @foo(%arg0: !tfr.tensor, %arg1: !tfr.tensor_list<T>,
%arg2: int {tfr.name="T", tfr.default=1})
attributes {qux: "quux"} {
tfr.return
}
```
Note the arguments are ordered by the following rule:
tfr.tensor > tfr.tensor_list > tfr.attr/i32/...,
and only one trfr.tensor_list argument is allowed.
}];
let arguments = (ins
TypeAttr:$type,
StrAttr:$sym_name
);
let results = (outs);
// When the regions is empty, the tfr.func is an external function and used
// to model the element type constraints of the tf op. Otherwise, there is one
// region containing the composition.
let regions = (region VariadicRegion<AnyRegion>:$body);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, StringRef name, "
"FunctionType type, ArrayRef<NamedAttribute> attrs = {}">
];
let extraClassDeclaration = [{
// FunctionLike trait needs access to the functions below.
friend class OpTrait::FunctionLike<TFRFuncOp>;
// Hooks for the input/output type enumeration in FunctionLike .
unsigned getNumFuncArguments() { return getType().getNumInputs(); }
unsigned getNumFuncResults() { return getType().getNumResults(); }
}];
let verifier = [{ return Verify(*this); }];
let parser = [{ return ParseFuncOp(parser, &result); }];
let printer = [{ PrintFuncOp(p, *this); }];
}
def TFR_TFRReturnOp : TFR_Op<"return", [HasParent<"TFRFuncOp">, NoSideEffect,
ReturnLike, Terminator]> {
let description = [{
A terminator operation for regions that appear in the body of `tfr.func`
functions. The operands to the `tfr.return` are the result values returned
by an invocation of the `tfr.func`.
Note that only the tfr.tensor and tfr.tensor_list can be returned.
}];
let arguments = (ins Variadic<TFR_allowedResultType>:$operands);
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
#endif // DIALECT_TFR_OPS_

View File

@ -0,0 +1,115 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_
#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeSupport.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
namespace mlir {
namespace TFR {
class TFRType : public Type {
public:
using Type::Type;
static bool classof(Type type);
};
namespace detail {
struct TFRTypeStorage final
: public TypeStorage,
public llvm::TrailingObjects<TFRTypeStorage, StringAttr> {
using KeyTy = ArrayRef<StringAttr>;
explicit TFRTypeStorage(unsigned num_attrs) : num_attrs(num_attrs) {}
static TFRTypeStorage* construct(TypeStorageAllocator& allocator, KeyTy key) {
// Allocate a new storage instance.
auto byteSize = TFRTypeStorage::totalSizeToAlloc<StringAttr>(key.size());
auto rawMem = allocator.allocate(byteSize, alignof(TFRTypeStorage));
auto result = ::new (rawMem) TFRTypeStorage(key.size());
// Copy in the string attributes into the trailing storage.
std::uninitialized_copy(key.begin(), key.end(),
result->getTrailingObjects<StringAttr>());
return result;
}
bool operator==(const KeyTy& attrs) const { return attrs == GetAttrs(); }
KeyTy GetAttrs() const {
return {getTrailingObjects<StringAttr>(), num_attrs};
}
unsigned num_attrs;
};
template <typename Derived>
class TFRTypeImpl : public Type::TypeBase<Derived, TFRType, TFRTypeStorage> {
public:
using Base = Type::TypeBase<Derived, TFRType, TFRTypeStorage>;
using TFRBase = TFRTypeImpl<Derived>;
using Base::Base;
static Derived get(ArrayRef<StringAttr> attrs, MLIRContext* context) {
return Base::get(context, attrs);
}
static Derived getChecked(ArrayRef<StringAttr> attrs, Location loc) {
return Base::getChecked(loc, attrs);
}
static Derived get(MLIRContext* context) { return get({}, context); }
// TODO(fengliuai): fix the implementation
static LogicalResult verifyConstructionInvariants(
Location loc, ArrayRef<StringAttr> attrs) {
return success();
}
ArrayRef<StringAttr> getAttrKeys() { return Base::getImpl()->GetAttrs(); }
};
} // namespace detail
class TFRTensorType : public detail::TFRTypeImpl<TFRTensorType> {
public:
using TFRBase::TFRBase;
static std::string getTypeName() { return "TFRTensorType"; }
};
class TFRTensorListType : public detail::TFRTypeImpl<TFRTensorListType> {
public:
using TFRBase::TFRBase;
static std::string getTypeName() { return "TFRTensorListType"; }
};
class TFRAttrType : public Type::TypeBase<TFRAttrType, TFRType, TypeStorage> {
public:
using Base::Base;
static std::string getTypeName() { return "TFRAttrType"; }
};
} // namespace TFR
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_

View File

@ -0,0 +1,160 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <iterator>
#include <memory>
#include "llvm/Support/raw_ostream.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
#include "tensorflow/compiler/mlir/tfr/passes/passes.h"
//===----------------------------------------------------------------------===//
// Canonicalization patterns for the scf.for and scf.if ops. They are used to
// optimize the control flow in the tfr function. Technically, both patterns
// should be upstreamed to be part of the op definition.
// TODO(fengliuai): sync with the llvm upstream for both patterns.
//
namespace mlir {
namespace TFR {
namespace {
struct UnrollSCFForOp : public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
LogicalResult matchAndRewrite(scf::ForOp for_op,
PatternRewriter &rewriter) const override {
Location loc = for_op.getLoc();
APInt lower_bound, upper_bound, step;
if (!matchPattern(for_op.lowerBound(), m_ConstantInt(&lower_bound)) ||
!matchPattern(for_op.upperBound(), m_ConstantInt(&upper_bound)) ||
!matchPattern(for_op.step(), m_ConstantInt(&step))) {
return failure();
}
uint64_t trip_count = (upper_bound - lower_bound).sdiv(step).getZExtValue();
if (trip_count <= 0) return failure();
// TODO(fengliuai): use loopUnrollByFactor once the iter_arg is supported
Block *single_block = for_op.getBody();
BlockAndValueMapping mapping;
Value iv = for_op.getInductionVar();
for (auto iter_op :
llvm::zip(for_op.getRegionIterArgs(), for_op.initArgs())) {
mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
}
mapping.map(iv, for_op.lowerBound());
for (auto i = 0; i < trip_count; ++i) {
if (!iv.use_empty()) {
// iv' = iv + step * i;
Value iter = rewriter.create<ConstantIndexOp>(loc, i);
Value step_cst =
rewriter.create<ConstantIndexOp>(loc, step.getSExtValue());
Value stride = rewriter.create<MulIOp>(loc, step_cst, iter);
Value iv_unroll =
rewriter.create<AddIOp>(loc, mapping.lookup(iv), stride);
mapping.map(iv, iv_unroll);
}
Operation *terminator_op;
for (auto it = single_block->begin(); it != single_block->end(); ++it) {
terminator_op = rewriter.clone(*it, mapping);
}
// Map the block arguments to the yield results.
for (auto iter_op : llvm::zip(for_op.getRegionIterArgs(),
terminator_op->getOperands())) {
mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
}
rewriter.eraseOp(terminator_op);
}
SmallVector<Value, 4> returned;
for (Value arg : for_op.getRegionIterArgs()) {
returned.push_back(mapping.lookup(arg));
}
rewriter.replaceOp(for_op, returned);
return success();
}
};
// TODO(fengliuai): up stream this pattern.
struct SimplifySCFIfOp : public OpRewritePattern<scf::IfOp> {
using OpRewritePattern<scf::IfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(scf::IfOp if_op,
PatternRewriter &rewriter) const override {
// Then branch
if (matchPattern(if_op.condition(), m_NonZero())) {
return InlineRegion(if_op.getLoc(), rewriter, if_op, &if_op.thenRegion());
}
// Else branch
if (matchPattern(if_op.condition(), m_Zero())) {
if (if_op.elseRegion().empty()) {
// Remove the op
rewriter.eraseOp(if_op);
return success();
} else {
return InlineRegion(if_op.getLoc(), rewriter, if_op,
&if_op.elseRegion());
}
}
// Not a constant condition
return failure();
}
private:
LogicalResult InlineRegion(Location loc, PatternRewriter &rewriter,
Operation *inline_point, Region *region) const;
};
LogicalResult SimplifySCFIfOp::InlineRegion(Location loc,
PatternRewriter &rewriter,
Operation *inline_point,
Region *region) const {
InlinerInterface interface(loc.getContext());
if (failed(inlineRegion(interface, region, inline_point, {},
inline_point->getResults(), loc,
/*shouldCloneInlinedRegion=*/true))) {
return failure();
}
// If the inlining was successful then erase the scf.if op.
rewriter.eraseOp(inline_point);
return success();
}
} // namespace
void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<UnrollSCFForOp, SimplifySCFIfOp>(context);
}
} // namespace TFR
} // namespace mlir

View File

@ -0,0 +1,280 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <iterator>
#include <numeric>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
#include "tensorflow/compiler/mlir/tfr/passes/passes.h"
#include "tensorflow/compiler/mlir/tfr/utils/utils.h"
//===----------------------------------------------------------------------===//
// The pass to decompose unregistered TF ops with the TFR compose function.
//
namespace mlir {
namespace TFR {
namespace {
// Decompose the TF ops with the registered composition library.
struct DecomposeTFOpsPass
: public PassWrapper<DecomposeTFOpsPass, FunctionPass> {
explicit DecomposeTFOpsPass(llvm::Optional<ModuleOp> external_tfr_module)
: external_tfr_module(external_tfr_module) {}
void runOnFunction() override;
private:
// Apply canonicalization, mainly constant folding, on the function.
void ApplyCanonicalization();
// Rewrite unregistered TF ops to TFR func call ops. Return failure if all the
// ops are registered or the compose function doesn't exist.
LogicalResult RewriteUnregisteredTFOps();
// Inline the TFR func call ops.
LogicalResult InlineTFRFuncCalls();
// Optional external symbol table to look up the TFR function.
llvm::Optional<ModuleOp> external_tfr_module;
};
void DecomposeTFOpsPass::ApplyCanonicalization() {
OwningRewritePatternList patterns;
auto* context = &getContext();
for (auto* op : context->getRegisteredOperations()) {
op->getCanonicalizationPatterns(patterns, context);
}
populateSCFOpsCanonicalizationPatterns(patterns, context);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
FuncOp func = getFunction();
SymbolTable table(external_tfr_module.hasValue()
? *external_tfr_module
: func.getParentOfType<ModuleOp>());
OpBuilder builder(func);
bool changed = false;
func.walk([&table, &builder, &changed](Operation* op) {
// Only the un-registered ops requires decomposition. The remaining ones
// either will be constant folded or lowered by the rules defined in the
// bridge.
if (op->isRegistered()) {
return;
}
// Find out the compose function
auto compose_func_name = GetComposeFuncName(op->getName().getStringRef());
auto compose_func = table.lookup<TFRFuncOp>(compose_func_name);
if (!compose_func || compose_func.isExternal()) {
// There are no decomposition methods defined for this op, skip.
return;
}
auto compose_func_type = compose_func.getType();
builder.setInsertionPoint(op);
TFRTensorType unconstrainted_tensor_type = builder.getType<TFRTensorType>();
// Create the new operands. This is mapping the operands from the target
// TF ops to the TFR function arguments. If the TFR function argument is
// a tensor_list, a "tfr.build_list" op is used to concat the available
// TF op operands. If the TFR function argument isn't a tensor/tensor_list,
// a constant is created by using the attribute stored in the TF op or the
// default value in the argument attribute.
llvm::SmallVector<Value, 4> new_operands;
for (auto arg : llvm::enumerate(compose_func_type.getInputs())) {
if (auto tensor_type = arg.value().dyn_cast<TFRTensorType>()) {
auto casted = builder.create<CastOp>(op->getLoc(), tensor_type,
op->getOperand(arg.index()));
new_operands.push_back(casted);
} else if (auto list_type = arg.value().dyn_cast<TFRTensorListType>()) {
llvm::SmallVector<Value, 4> variadic_operands;
for (int i = arg.index(); i < op->getNumOperands(); i++) {
auto casted = builder.create<CastOp>(
op->getLoc(), unconstrainted_tensor_type, op->getOperand(i));
variadic_operands.push_back(casted);
}
auto build_list_op = builder.create<BuildListOp>(
op->getLoc(), list_type, variadic_operands);
new_operands.push_back(build_list_op.out());
} else {
auto attr_name = compose_func.getArgAttrOfType<StringAttr>(
arg.index(), kAttrArgumentNameAttr);
auto attribute = op->getAttr(attr_name.getValue());
if (!attribute) {
attribute =
compose_func.getArgAttr(arg.index(), kAttrArgumentDefaultAttr);
}
Value attr_cst;
// Wrap these special attributes as a special TFR constant, so the SSA
// value has a valid type to be used as TFR function argument. These
// attributes are not expected to be manipulated by the lowering passes.
if (attribute.isa<TypeAttr>() || attribute.isa<ArrayAttr>() ||
attribute.isa<StringAttr>() || attribute.isa<FlatSymbolRefAttr>()) {
TFRAttrType output_type = TFRAttrType::get(builder.getContext());
attr_cst =
builder.create<ConstOp>(op->getLoc(), output_type, attribute);
} else {
attr_cst = builder.create<ConstantOp>(op->getLoc(), attribute);
}
new_operands.push_back(attr_cst);
}
}
// Create the TFR call op
auto new_op = builder.create<CallOp>(
op->getLoc(), compose_func_type.getResults(),
builder.getSymbolRefAttr(compose_func.getName()), new_operands);
// Replace the use of the old op. This is mapping the results from the
// target TF ops to the TFR function returns. If the TFR function return is
// a tensor_list, "tfr.get_element" op is used to extract the required TF
// op result.
llvm::SmallVector<Value, 4> new_results;
for (auto res : llvm::enumerate(compose_func_type.getResults())) {
if (res.value().dyn_cast<TFRTensorType>()) {
new_results.push_back(new_op.getResult(res.index()));
} else if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
for (int i = res.index(), j = 0; i < op->getNumResults(); i++, j++) {
auto index =
builder.create<ConstantOp>(op->getLoc(), builder.getIndexAttr(j));
auto element_op = builder.create<GetElementOp>(
op->getLoc(), unconstrainted_tensor_type,
new_op.getResult(res.index()), index.getResult());
new_results.push_back(element_op.out());
}
}
}
for (auto res : llvm::zip(op->getResults(), new_results)) {
auto casted = builder.create<CastOp>(
op->getLoc(), std::get<0>(res).getType(), std::get<1>(res));
std::get<0>(res).replaceAllUsesWith(casted.out());
}
op->erase();
changed |= true;
});
// If `changed` is false, it is considered as a failure, so the recursive
// rewrite will stop.
return success(changed);
}
LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() {
// The Inliner will automatically use the registered dialect inliner.
InlinerInterface inliner(&getContext());
FuncOp func = getFunction();
SymbolTable table(external_tfr_module.hasValue()
? *external_tfr_module
: func.getParentOfType<ModuleOp>());
// The inliner only inlines the TFR call op.
bool changed = false;
auto walk_result = func.walk([&](CallOp call_op) {
auto callee = table.lookup<TFRFuncOp>(call_op.callee());
if (!callee || callee.isExternal()) return WalkResult::advance();
if (failed(inlineCall(inliner,
cast<CallOpInterface>(call_op.getOperation()),
cast<CallableOpInterface>(callee.getOperation()),
callee.getCallableRegion(),
/**shouldCloneInLinedRegion=*/true))) {
// This failure is usually because the decompose function is not defined.
// This call will be raised to TF ops.
return WalkResult::interrupt();
}
call_op.erase();
changed |= true;
return WalkResult::advance();
});
if (walk_result.wasInterrupted()) {
signalPassFailure();
return failure();
}
// If `changed` is false, it is considered as a failure, so the recursive
// rewrite will stop.
return success(changed);
}
void DecomposeTFOpsPass::runOnFunction() {
// Set a maximum iteration threshold in case there are infinite loops in the
// call stack.
int max_iterators = 10;
do {
// canonicalization
ApplyCanonicalization();
// rewrite unregistered tf ops. Failed either because no ops can be
// decomposed or the compose function isn't defined.
auto rewrite_status = RewriteUnregisteredTFOps();
// inline the tfr call op until there are no tfr.call op can be inlined.
auto inline_status = InlineTFRFuncCalls();
if (failed(rewrite_status) && failed(inline_status)) {
break;
}
} while (max_iterators-- >= 0);
}
} // namespace
// Creates an instance of the pass to decompose the TF ops.
std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeTFOpsPass(
llvm::Optional<ModuleOp> tfr_module) {
return std::make_unique<DecomposeTFOpsPass>(tfr_module);
}
static PassRegistration<DecomposeTFOpsPass> pass(
"tfr-decompose",
"Decompose TF ops with the registered composition library.",
[] { return CreateDecomposeTFOpsPass(); });
} // namespace TFR
} // namespace mlir

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
namespace mlir {
namespace TFR {
void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
// Decompose ops.
std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeTFOpsPass(
llvm::Optional<ModuleOp> tfr_module = llvm::None);
// Raise to TF ops.
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseToTFOpsPass(
llvm::Optional<ModuleOp> tfr_module = llvm::None,
bool materialize_derived_attrs = false);
} // namespace TFR
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_

View File

@ -0,0 +1,474 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <iterator>
#include <numeric>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
#include "tensorflow/compiler/mlir/tfr/passes/passes.h"
#include "tensorflow/compiler/mlir/tfr/utils/utils.h"
//===----------------------------------------------------------------------===//
// The pass to rewrite the TFR function call ops by TF ops. The callee of the
// TFR function call defines the signatures of the TF ops.
//
namespace mlir {
namespace TFR {
namespace {
// This pattern is to rewrite the "tfr.call" op and the "tfr.cast" ops on the
// operands by a TF op with "tfr.cast" ops on the results. The result type of
// the new TF op is an unranked tensor with element type derived.
class RewriteTFRCallOp : public OpRewritePattern<CallOp> {
using OpRewritePattern<CallOp>::OpRewritePattern;
public:
explicit RewriteTFRCallOp(MLIRContext* context, const SymbolTable& table,
bool materialize_derived_attrs)
: OpRewritePattern<CallOp>(context),
symbol_table_(table),
materialize_derived_attrs_(materialize_derived_attrs) {}
LogicalResult matchAndRewrite(CallOp call_op,
PatternRewriter& rewriter) const override;
private:
// Derives the attribute values for the attributes attached to the
// `input_tfr_type`. These attributes are only for the element type of the
// inputs, and these type information has been collected in the `input_types`.
// The result is stored in `derived_attrs` as the named attributes. Returns
// failure if the attributes stored in the `input_tfr_type` violates the
// assumptions.
LogicalResult AddDerivedAttrs(
PatternRewriter& rewriter, Type input_tfr_type,
ArrayRef<Attribute> input_types,
llvm::StringMap<Attribute>* derived_attrs) const;
// Collects the operands and attributes for the TF op. At the same time, it
// collects all the derived attribute values to derive the output types of the
// TF op.
LogicalResult CollectInputsAndAttributes(
PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op,
SmallVectorImpl<Value>* inputs, NamedAttrList* arg_attrs,
llvm::StringMap<Attribute>* derived_attrs) const;
// Uses the collected attribute values to derive all the output types.
LogicalResult DeriveOutputTypes(FunctionType signature,
const llvm::StringMap<Attribute>& attrs,
SmallVectorImpl<Type>* output_types) const;
// Creates the TF op and also the necessary tfr.cast ops to replace the
// original TFR call op.
LogicalResult CreateAndReplaceOp(
PatternRewriter& rewriter, CallOp call_op,
const SmallVectorImpl<Type>& output_types,
const SmallVectorImpl<Value>& inputs, const NamedAttrList& attr_list,
const llvm::StringMap<Attribute>& derived_attrs) const;
// Adds a tf.Cast op if the tfr.tensor attribute indicated a fixed element
// type.
// TODO(fengliuai): This method is required when the operand types are not set
// by the frontend correctly.
Value CastToNonDerivedType(PatternRewriter& rewriter, Location loc,
CastOp cast_op, Type input_tfr_type) const {
auto tensor_type = input_tfr_type.dyn_cast<TFRTensorType>();
if (!tensor_type) return cast_op.arg();
auto attr_names = tensor_type.getAttrKeys();
if (attr_names.empty() || attr_names.size() > 1) return cast_op.arg();
StringRef tfr_type_attr = attr_names[0].getValue();
if (!fixed_elt_type_attrs_.contains(tfr_type_attr)) return cast_op.arg();
Type result_elt_type;
if (tfr_type_attr == "i32_") {
result_elt_type = rewriter.getI32Type();
} else if (tfr_type_attr == "i64_") {
result_elt_type = rewriter.getI64Type();
} else if (tfr_type_attr == "f32_") {
result_elt_type = rewriter.getF32Type();
} else if (tfr_type_attr == "i1_") {
result_elt_type = rewriter.getI1Type();
} else {
return cast_op.arg();
}
Type original_input_type =
cast_op.getInputElementType().cast<TypeAttr>().getValue();
if (result_elt_type != original_input_type) {
UnrankedTensorType result_type = UnrankedTensorType::get(result_elt_type);
return rewriter.create<TF::CastOp>(loc, result_type, cast_op.arg());
}
return cast_op.arg();
}
// For variadic operands, we have to enforce them to use the same types.
// TODO(fengliuai): This method is required when the operand types are not set
// by the frontend correctly.
void CastValuesToSameType(PatternRewriter& rewriter, Location loc,
const llvm::SmallVectorImpl<Attribute>& input_types,
llvm::SmallVectorImpl<Value>& input_values) const {
if (input_types.size() <= 1) return;
Type target_input_type = input_types[0].cast<TypeAttr>().getValue();
auto result_type = UnrankedTensorType::get(target_input_type);
for (auto i = 1; i < input_types.size(); ++i) {
Type current_input_type = input_types[i].cast<TypeAttr>().getValue();
if (current_input_type != target_input_type) {
input_values[i] =
rewriter.create<TF::CastOp>(loc, result_type, input_values[i]);
}
}
}
const SymbolTable& symbol_table_;
const bool materialize_derived_attrs_;
const llvm::SmallDenseSet<StringRef, 4> fixed_elt_type_attrs_{"i32_", "i64_",
"f32_", "i1_"};
};
LogicalResult RewriteTFRCallOp::AddDerivedAttrs(
PatternRewriter& rewriter, Type input_tfr_type,
ArrayRef<Attribute> input_types,
llvm::StringMap<Attribute>* derived_attrs) const {
// If there is an attribute associated to the input in the signature, we
// store it as an derived attribute.
if (auto tensor_type = input_tfr_type.dyn_cast<TFRTensorType>()) {
auto attr_names = tensor_type.getAttrKeys();
if (attr_names.empty()) return success();
if (attr_names.size() == 1) {
derived_attrs->insert({attr_names[0].getValue(), input_types[0]});
return success();
}
}
// If there is an attribute associated to the input in the signature,
// we store it as an derived attribute.
if (auto list_type = input_tfr_type.dyn_cast<TFRTensorListType>()) {
auto attr_names = list_type.getAttrKeys();
if (attr_names.empty()) return success();
// N*T case
if (attr_names.size() == 2) {
derived_attrs->insert({attr_names[0].getValue(),
rewriter.getI32IntegerAttr(input_types.size())});
// Note that this uses the first element of the list to infer the T value.
// A tf.Cast is required to cast the other inputs to the same type.
derived_attrs->insert({attr_names[1].getValue(), input_types[0]});
return success();
}
// list(dtype) case
if (attr_names.size() == 1) {
derived_attrs->insert(
{attr_names[0].getValue(), rewriter.getArrayAttr(input_types)});
return success();
}
}
return failure();
}
LogicalResult RewriteTFRCallOp::CollectInputsAndAttributes(
PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op,
SmallVectorImpl<Value>* inputs, NamedAttrList* arg_attrs,
llvm::StringMap<Attribute>* derived_attrs) const {
for (const auto& operand : llvm::enumerate(signature.getType().getInputs())) {
// If the index is larger than the operand number of the call_op, the
// default value of the operand needs to be used.
if (operand.index() >= call_op.getNumOperands()) {
auto attr_name = signature.getArgAttrOfType<StringAttr>(
operand.index(), kAttrArgumentNameAttr);
auto attr_value =
signature.getArgAttr(operand.index(), kAttrArgumentDefaultAttr);
arg_attrs->push_back(
rewriter.getNamedAttr(attr_name.getValue(), attr_value));
continue;
}
// The index is valid for the call_op.
Value input = call_op.getOperand(operand.index());
Operation* input_op = input.getDefiningOp();
auto input_tfr_type = signature.getType().getInputs()[operand.index()];
// There are three cases for the preceding input_op:
// 1. The preceding op can be a tfr.cast op, which will be fused to the
// current op, so the result op has input with tensor type.
if (auto cast_op = dyn_cast_or_null<CastOp>(input_op)) {
Value input_to_cast = CastToNonDerivedType(rewriter, call_op.getLoc(),
cast_op, input_tfr_type);
inputs->push_back(input_to_cast);
if (failed(AddDerivedAttrs(rewriter, input_tfr_type,
{cast_op.getInputElementType()},
derived_attrs))) {
return failure();
}
continue;
}
// 2. The preceding op is a tfr.build_list op, which collects multiple
// values with tensor types via the tfr.cast ops. These ops will be fused
// to the current op as well, so all the tfr.cast op inputs will be inputs
// to the result op.
if (auto list_op = dyn_cast_or_null<BuildListOp>(input_op)) {
// Find out all the inputs to the build list op
// TODO(fengliuai): make build_list op only take tensor argument
llvm::SmallVector<Attribute, 4> list_input_types;
llvm::SmallVector<Value, 4> list_inputs;
for (auto list_input : list_op.getOperands()) {
auto cast_op = dyn_cast_or_null<CastOp>(list_input.getDefiningOp());
if (!cast_op) return failure();
list_inputs.push_back(cast_op.arg());
list_input_types.push_back(cast_op.getInputElementType());
}
CastValuesToSameType(rewriter, call_op.getLoc(), list_input_types,
list_inputs);
inputs->append(list_inputs.begin(), list_inputs.end());
if (failed(AddDerivedAttrs(rewriter, input_tfr_type, list_input_types,
derived_attrs))) {
return failure();
}
continue;
}
// 3. The preceding op is a constant, thus the value of this constant is
// used to create an attribute of the result op, according to the signature.
Attribute arg_value;
// A failure indicates the argument isn't a constant value, so we should
// not use it as an attribute.
if (!matchPattern(input, m_Constant(&arg_value))) {
return failure();
}
auto attr_name = signature.getArgAttrOfType<StringAttr>(
operand.index(), kAttrArgumentNameAttr);
arg_attrs->push_back(
rewriter.getNamedAttr(attr_name.getValue(), arg_value));
}
return success();
}
// For each output, uses the attribute name associated to the tfr types to find
// out the attribute value from the collected `attrs` and create the output type
// of the result op by using the attribute value as the element type.
LogicalResult RewriteTFRCallOp::DeriveOutputTypes(
FunctionType signature, const llvm::StringMap<Attribute>& attrs,
SmallVectorImpl<Type>* output_types) const {
for (auto res : llvm::enumerate(signature.getResults())) {
if (auto tensor_type = res.value().dyn_cast<TFRTensorType>()) {
// tfr.tensor should only have one attribute attached.
auto attr_key = tensor_type.getAttrKeys().front();
output_types->push_back(UnrankedTensorType::get(
attrs.lookup(attr_key.getValue()).cast<TypeAttr>().getValue()));
continue;
}
if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
// There are two cases: N*T or list(dtype)
auto attr_keys = list_type.getAttrKeys();
// N*T case
if (attr_keys.size() == 2) {
// The first one is N, and the second one is T
int list_size =
attrs.lookup(attr_keys[0].getValue()).cast<IntegerAttr>().getInt();
Type list_type =
attrs.lookup(attr_keys[1].getValue()).cast<TypeAttr>().getValue();
for (int i = 0; i < list_size; ++i) {
output_types->push_back(UnrankedTensorType::get(list_type));
}
continue;
}
// TODO(fengliuai): list(dtype) case
}
return failure();
}
return success();
}
LogicalResult RewriteTFRCallOp::CreateAndReplaceOp(
PatternRewriter& rewriter, CallOp call_op,
const SmallVectorImpl<Type>& output_types,
const SmallVectorImpl<Value>& inputs, const NamedAttrList& attr_list,
const llvm::StringMap<Attribute>& derived_attrs) const {
// Create the new op
Location loc = call_op.getLoc();
rewriter.setInsertionPointAfter(call_op);
std::string tf_op_name = GetTFOpName(call_op.callee());
OperationState new_state(loc, tf_op_name, inputs, output_types, attr_list);
Operation* new_op = rewriter.createOperation(new_state);
if (materialize_derived_attrs_) {
for (const auto& attr : derived_attrs) {
// Add or update the derived attribute with the value. Skip the fixed
// element type attributes, in case they are present in the NodeDef.
if (!fixed_elt_type_attrs_.contains(attr.first())) {
new_op->setAttr(attr.first(), attr.second);
}
}
}
// Create the tfr.cast ops on the results and replace the uses of the
// original call op.
TFRTensorType unconstrainted_type = rewriter.getType<TFRTensorType>();
SmallVector<Value, 4> new_results;
for (auto res : llvm::enumerate(call_op.getResultTypes())) {
Type res_type = res.value();
if (res_type.dyn_cast<TFRTensorType>()) {
Value new_res = new_op->getResult(res.index());
auto casted = rewriter.create<CastOp>(loc, res_type, new_res);
new_results.push_back(casted.out());
} else if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
SmallVector<Value, 4> tensor_list;
for (int i = res.index(); i < new_op->getNumResults(); i++) {
Value new_res = new_op->getResult(i);
auto casted =
rewriter.create<CastOp>(loc, unconstrainted_type, new_res);
tensor_list.push_back(casted.out());
}
auto list_op = rewriter.create<BuildListOp>(loc, res_type, tensor_list);
new_results.push_back(list_op.out());
}
}
rewriter.replaceOp(call_op, new_results);
return success();
}
LogicalResult RewriteTFRCallOp::matchAndRewrite(
CallOp call_op, PatternRewriter& rewriter) const {
// Get the func op and verify that it is external. The type of this external
// func op is used as the signature of the corresponding TF ops. All the
// external func ops have the trailing underscore.
std::string external_callee_name = call_op.callee().str().append("_");
TFRFuncOp func = symbol_table_.lookup<TFRFuncOp>(external_callee_name);
if (!func || !func.isExternal()) return failure();
// Get the inputs and attributes. The attributes include these from the
// argument list and also these derived from the inputs.
SmallVector<Value, 4> inputs;
NamedAttrList argument_attrs;
llvm::StringMap<Attribute> derived_attrs;
if (failed(CollectInputsAndAttributes(rewriter, func, call_op, &inputs,
&argument_attrs, &derived_attrs))) {
return failure();
}
// Derive the output types. The result type is derived by using the
// attributes attched to the result type of the signature. The attribute
// value should be either in the attribute argument list or the derived
// attribute from the input tensors. All the result type
// are unranked, and shape inference should be applied afterwards.
SmallVector<Type, 4> output_types;
// Merge the attributes from the argument list to the derived ones.
for (auto& attr : argument_attrs) {
derived_attrs.insert({attr.first, attr.second});
}
// Derive the output types by using the attributes attached to the tfr
// types.
if (failed(DeriveOutputTypes(func.getType(), derived_attrs, &output_types))) {
return failure();
}
// Create the new op and replace the old TFR call op.
return CreateAndReplaceOp(rewriter, call_op, output_types, inputs,
argument_attrs, derived_attrs);
}
// Raise TFR call ops to the TF ops.
struct RaiseToTFOpsPass : public PassWrapper<RaiseToTFOpsPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<TFRDialect, TF::TensorFlowDialect, scf::SCFDialect,
StandardOpsDialect>();
}
explicit RaiseToTFOpsPass(llvm::Optional<ModuleOp> tfr_module,
bool materialize_derived_attrs)
: external_tfr_module(tfr_module),
materialize_derived_attrs(materialize_derived_attrs) {}
void runOnFunction() override;
private:
llvm::Optional<ModuleOp> external_tfr_module;
const bool materialize_derived_attrs;
};
void RaiseToTFOpsPass::runOnFunction() {
FuncOp func = getFunction();
MLIRContext* ctx = &getContext();
SymbolTable table(external_tfr_module.hasValue()
? *external_tfr_module
: func.getParentOfType<ModuleOp>());
OwningRewritePatternList patterns;
patterns.insert<RewriteTFRCallOp>(ctx, table, materialize_derived_attrs);
for (auto* op : ctx->getRegisteredOperations()) {
op->getCanonicalizationPatterns(patterns, ctx);
}
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
// Creates an instance of the pass to raise TFR call ops to the TF ops.
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseToTFOpsPass(
llvm::Optional<ModuleOp> tfr_module, bool materialize_derived_attrs) {
return std::make_unique<RaiseToTFOpsPass>(tfr_module,
materialize_derived_attrs);
}
static PassRegistration<RaiseToTFOpsPass> pass(
"tfr-raise-to-tf", "Raise all the TFR call ops to TF ops.",
[] { return CreateRaiseToTFOpsPass(); });
} // namespace TFR
} // namespace mlir

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "mlir/InitAllPasses.h" // from @llvm-project
#include "mlir/Support/MlirOptMain.h" // from @llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
int main(int argc, char **argv) {
tensorflow::InitMlir y(&argc, &argv);
mlir::registerAllPasses();
mlir::DialectRegistry registry;
registry.insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
mlir::TFR::TFRDialect>();
return failed(mlir::MlirOptMain(argc, argv, "TFR Pass Driver\n", registry));
}

View File

@ -0,0 +1,61 @@
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
package(
default_visibility = [
":friends",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/tfr/...",
"//tensorflow/compiler/mlir/...",
],
)
filegroup(
name = "decomposition_lib",
srcs = ["decomposition_lib.mlir"],
)
cc_library(
name = "composite_ops_cc",
srcs = ["composite_ops.cc"],
copts = [
"-Wno-unused-result",
"-Wno-unused-variable",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
tf_gen_op_wrapper_py(
name = "composite_ops",
out = "composite_ops.py",
deps = [
":composite_ops_cc",
],
)
cc_library(
name = "test_ops_cc",
testonly = 1,
srcs = ["test_ops.cc"],
copts = [
"-Wno-unused-result",
"-Wno-unused-variable",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("MyAddN")
.Input("inputs: N * T")
.Output("sum: T")
.Attr("N: int >= 1")
.Attr("T: {numbertype, variant}")
.SetIsCommutative()
.SetIsAggregate();
REGISTER_OP("MyBiasedDense")
.Input("input: T")
.Input("weight: T")
.Input("bias: T")
.Output("out: T")
.Attr("T: {float, int8}")
.Attr("act: {'', 'relu', 'relu6'} = ''");
} // namespace tensorflow

View File

@ -0,0 +1,109 @@
// A test resource file which contains some pre-defined internal tfr.functions
// for decomposition and external tfr.functions for raising the decomposition
// result to the ops in the TF dialect.
//
// All the tfr.func functions are supposed to be translated from the Python
// function with tf.composite annotation.
// All the external tfr.func functions modeles the op signature defined by
// OpDefs.
tfr.func @tf__my_add_n(%values: !tfr.tensor_list,
%n: i64 {tfr.name="N"}) -> !tfr.tensor {
%index = constant 0 : index
%cst = constant 1 : i64
%eq = cmpi "eq", %n, %cst : i64
%v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor
%res = scf.if %eq -> !tfr.tensor {
scf.yield %v1 : !tfr.tensor
} else {
%step = index_cast %cst : i64 to index
%end = index_cast %n : i64 to index
%reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor {
%v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor
%reduce_next = tfr.call @tf__add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
scf.yield %reduce_next : !tfr.tensor
}
scf.yield %reduce : !tfr.tensor
}
tfr.return %res : !tfr.tensor
}
// Translated from tf.compose Python function.
tfr.func @tf__my_biased_dense(%input: !tfr.tensor, %weight: !tfr.tensor,
%bias: !tfr.tensor,
%act: !tfr.attr{tfr.name="act", tfr.default=""}) -> !tfr.tensor {
%dot = tfr.call @tf__mat_mul(%input, %weight) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
%add = tfr.call @tf__add(%dot, %bias) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
%relu = tfr.constant "relu" -> !tfr.attr
%relu6 = tfr.constant "relu6" -> !tfr.attr
%is_relu = tfr.equal %act, %relu -> i1
%res = scf.if %is_relu -> !tfr.tensor {
%applied_relu = tfr.call @tf__relu(%add) : (!tfr.tensor) -> !tfr.tensor
scf.yield %applied_relu : !tfr.tensor
} else {
%is_relu6 = tfr.equal %act, %relu6 -> i1
%res1 = scf.if %is_relu6 -> !tfr.tensor {
%applied_relu6 = tfr.call @tf__relu6(%add) : (!tfr.tensor) -> !tfr.tensor
scf.yield %applied_relu6 : !tfr.tensor
} else {
scf.yield %add : !tfr.tensor
}
scf.yield %res1 : !tfr.tensor
}
tfr.return %res : !tfr.tensor
}
// This is a wong decomposition and used to verify that tf.Elu isn't decomposed
// since its kernel has been registered.
tfr.func @tf__elu_(%input: !tfr.tensor) -> !tfr.tensor {
tfr.return %input : !tfr.tensor
}
// Translated from:
//
// REGISTER_OP("Add")
// .Input("x: T")
// .Input("y: T")
// .Output("z: T")
// .Attr(
// "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
// "complex64, complex128, string}")
tfr.func @tf__add_(!tfr.tensor<T>, !tfr.tensor<T>)
-> !tfr.tensor<T> attributes{T}
// Translated from:
//
// REGISTER_OP("MatMul")
// .Input("a: T")
// .Input("b: T")
// .Output("product: T")
// .Attr("transpose_a: bool = false")
// .Attr("transpose_b: bool = false")
// .Attr("T: {bfloat16, half, float, double, int32, int64, complex64, complex128}")
// T is a derived attribute.
// transpose_a and transpose_b is materialized attributes.
tfr.func @tf__mat_mul_(!tfr.tensor<T>, !tfr.tensor<T>,
i1 {tfr.name="transpose_a", tfr.default=false},
i1 {tfr.name="transpose_b", tfr.default=false})
-> !tfr.tensor<T> attributes{T}
// Translated from:
//
// REGISTER_OP("Relu")
// .Input("features: T")
// .Output("activations: T")
// .Attr("T: {realnumbertype, qint8}")
// T is a derived attribute.
tfr.func @tf__relu_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
// Translated from:
//
// REGISTER_OP("Relu6")
// .Input("features: T")
// .Output("activations: T")
// .Attr("T: {realnumbertype}")
// T is a derived attribute.
tfr.func @tf__relu6_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("TestNoOp");
REGISTER_OP("TestIdentityOp")
.Input("input: T")
.Output("output: T")
.Attr("T: numbertype");
REGISTER_OP("TestIdentityNOp")
.Input("input: N * T")
.Output("output: N * T")
.Attr("N: int >= 1")
.Attr("T: numbertype");
REGISTER_OP("TestInputNOp")
.Input("input: N * T")
.Output("output: T")
.Attr("N: int >= 1")
.Attr("T: numbertype");
REGISTER_OP("TestOutputNOp")
.Input("input: T")
.Output("output: N * T")
.Attr("N: int >= 1")
.Attr("T: numbertype");
REGISTER_OP("TestTwoInputsOp")
.Input("lhs: T")
.Input("rhs: T")
.Output("output: T")
.Attr("T: numbertype")
.Attr("pred: bool = false");
REGISTER_OP("TestNumAttrsOp")
.Attr("x1: int = -10")
.Attr("y1: int = 1")
.Attr("x2: float = 0.0")
.Attr("y2: float = -3.0");
REGISTER_OP("TestNonNumAttrsOp")
.Attr("z: shape")
.Attr("x: string = 'hello'")
.Attr("y: type = DT_FLOAT");
REGISTER_OP("TestThreeInputsOp")
.Input("x: T")
.Input("y: T")
.Input("z: T")
.Output("output: T")
.Attr("T: numbertype")
.Attr("act: {'x', 'y', 'z'} = 'z'");
REGISTER_OP("TestTwoOutputsOp")
.Input("input: T")
.Output("output1: T")
.Output("output2: T")
.Attr("T: numbertype");
} // namespace tensorflow

View File

@ -0,0 +1,57 @@
// RUN: tfr-opt %s -tfr-decompose -verify-diagnostics -split-input-file | FileCheck %s
tfr.func @tf__my_pack(%values: !tfr.tensor_list,
%n: i32 {tfr.name="N"},
%axis: i32 {tfr.name="axis"}) -> !tfr.tensor {
%index = constant 0 : index
%cst = constant 1 : i32
%eq = cmpi "eq", %n, %cst : i32
%v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor
%temp = tfr.call @tf__expand_dims(%v1, %axis) : (!tfr.tensor, i32) -> !tfr.tensor
%res = scf.if %eq -> !tfr.tensor {
scf.yield %temp : !tfr.tensor
} else {
%step = index_cast %cst : i32 to index
%end = index_cast %n : i32 to index
%reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%temp) -> !tfr.tensor {
%v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor
%temp1 = tfr.call @tf__expand_dims(%v, %axis) : (!tfr.tensor, i32) -> !tfr.tensor
%reduce_next = tfr.call @tf__risc_concat(%reduce_iter, %temp1, %axis) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor
scf.yield %reduce_next : !tfr.tensor
}
scf.yield %reduce : !tfr.tensor
}
tfr.return %res : !tfr.tensor
}
// CHECK-LABEL: pack_one
func @pack_one(%arg0: tensor<2x3xf32>) -> tensor<1x2x3xf32> {
%0 = "tf.MyPack"(%arg0) {N=1:i32, axis=0:i32} : (tensor<2x3xf32>) -> tensor<1x2x3xf32>
return %0 : tensor<1x2x3xf32>
// CHECK-NEXT: %[[AXIS:.*]] = constant 0 : i32
// CHECK-NEXT: %[[CAST:.*]] = "tfr.cast"(%arg0) : (tensor<2x3xf32>) -> !tfr.tensor
// CHECK-NEXT: %[[ED:.*]] = tfr.call @tf__expand_dims(%[[CAST]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor
// CHECK-NEXT: %[[BACK:.*]] = "tfr.cast"(%[[ED]]) : (!tfr.tensor) -> tensor<1x2x3xf32>
// CHECK-NEXT: return %[[BACK]] : tensor<1x2x3xf32>
}
// CHECK-LABEL: pack_multiple
func @pack_multiple(%arg0: tensor<2x3xf32>,
%arg1: tensor<2x3xf32>,
%arg2: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
%0 = "tf.MyPack"(%arg0, %arg1, %arg2) {N=3:i32, axis=0:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<3x2x3xf32>
return %0 : tensor<3x2x3xf32>
// CHECK-NEXT: %[[AXIS:.*]] = constant 0 : i32
// CHECK-NEXT: %[[CAST0:.*]] = "tfr.cast"(%arg0) : (tensor<2x3xf32>) -> !tfr.tensor
// CHECK-NEXT: %[[CAST1:.*]] = "tfr.cast"(%arg1) : (tensor<2x3xf32>) -> !tfr.tensor
// CHECK-NEXT: %[[CAST2:.*]] = "tfr.cast"(%arg2) : (tensor<2x3xf32>) -> !tfr.tensor
// CHECK-NEXT: %[[EX0:.*]] = tfr.call @tf__expand_dims(%[[CAST0]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor
// CHECK-NEXT: %[[EX1:.*]] = tfr.call @tf__expand_dims(%[[CAST1]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor
// CHECK-NEXT: %[[CONCAT1:.*]] = tfr.call @tf__risc_concat(%[[EX0]], %[[EX1]], %c0_i32) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor
// CHECK-NEXT: %[[EX2:.*]] = tfr.call @tf__expand_dims(%[[CAST2]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor
// CHECK-NEXT: %[[CONCAT2:.*]] = tfr.call @tf__risc_concat(%[[CONCAT1]], %[[EX2]], %[[AXIS]]) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor
// CHECK-NEXT: %[[BACK:.*]] = "tfr.cast"(%[[CONCAT2]]) : (!tfr.tensor) -> tensor<3x2x3xf32>
// CHECK-NEXT: return %[[BACK]] : tensor<3x2x3xf32>
}

View File

@ -0,0 +1,84 @@
// RUN: tfr-opt %s -tfr-decompose -verify-diagnostics -split-input-file | FileCheck %s
// CHECK-LABEL: @tf__fake_no_op
tfr.func @tf__fake_no_op(%arg0: !tfr.tensor) -> !tfr.tensor {
tfr.return %arg0 : !tfr.tensor
// CHECK-NEXT: tfr.return %arg0 : !tfr.tensor
}
// CHECK-LABEL: @tf__intermediate
tfr.func @tf__intermediate(%arg0: !tfr.tensor) -> !tfr.tensor {
%0 = tfr.call @tf__risc(%arg0) : (!tfr.tensor) -> !tfr.tensor
tfr.return %0 : !tfr.tensor
// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%arg0) : (!tfr.tensor) -> !tfr.tensor
// CHECK-NEXT: tfr.return %[[id]] : !tfr.tensor
}
// CHECK-LABEL: @tf__fused_n
tfr.func @tf__fused_n(
%arg0: !tfr.tensor,
%arg1: !tfr.tensor_list,
%arg2: index {tfr.name="A",tfr.default=1:index})
-> !tfr.tensor_list {
%0 = tfr.call @tf__intermediate(%arg0) : (!tfr.tensor) -> !tfr.tensor
%1 = tfr.get_element %arg1[%arg2] : (!tfr.tensor_list, index) -> !tfr.tensor
%2 = tfr.call @tf__intermediate(%1) : (!tfr.tensor) -> !tfr.tensor
%3 = "tfr.build_list"(%0, %2) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
tfr.return %3 : !tfr.tensor_list
// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__intermediate(%arg0) : (!tfr.tensor) -> !tfr.tensor
// CHECK-NEXT: %[[ge:.*]] = tfr.get_element %arg1[%arg2] : (!tfr.tensor_list, index) -> !tfr.tensor
// CHECK-NEXT: %[[id2:.*]] = tfr.call @tf__intermediate(%[[ge]]) : (!tfr.tensor) -> !tfr.tensor
// CHECK-NEXT: %[[bl:.*]] = "tfr.build_list"(%[[id1]], %[[id2]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
// CHECK-NEXT: tfr.return %[[bl]] : !tfr.tensor_list
}
//------------------------
// CHECK-LABEL: decompose_tf_no_op
func @decompose_tf_no_op(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> {
%0 = "tf.FakeNoOp"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string>
return %0 : tensor<1x2x3x4x!tf.string>
// CHECK-NEXT: return %arg0
}
// CHECK-LABEL: decompose_tf_intermediate
func @decompose_tf_intermediate(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> {
%0 = "tf.Intermediate"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string>
return %0 : tensor<1x2x3x4x!tf.string>
// CHECK-NEXT: %[[casted:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%[[casted]]) : (!tfr.tensor) -> !tfr.tensor
// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id]]) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string>
// CHECK-NEXT: return %[[back]]
}
// CHECK-LABEL: decompose_fused_n_default
func @decompose_fused_n_default(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
%0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) : (tensor<1x2x3x4x!tf.string>, tensor<f32>, tensor<f32>) -> (tensor<1x2x3x4x!tf.string>, tensor<f32>)
return %0#1 : tensor<f32>
// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
// CHECK-NEXT: %[[in2:.*]] = "tfr.cast"(%arg2) : (tensor<f32>) -> !tfr.tensor
// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) : (!tfr.tensor) -> !tfr.tensor
// CHECK-NEXT: %[[id2:.*]] = tfr.call @tf__risc(%[[in2]]) : (!tfr.tensor) -> !tfr.tensor
// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id2]]) : (!tfr.tensor) -> tensor<f32>
// CHECK-NEXT: return %[[back]] : tensor<f32>
}
// CHECK-LABEL: decompose_fused_n
func @decompose_fused_n(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
%0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) {A=0:index} : (tensor<1x2x3x4x!tf.string>, tensor<f32>, tensor<f32>) -> (tensor<1x2x3x4x!tf.string>, tensor<f32>)
return %0#1 : tensor<f32>
// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
// CHECK-NEXT: %[[in1:.*]] = "tfr.cast"(%arg1) : (tensor<f32>) -> !tfr.tensor
// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) : (!tfr.tensor) -> !tfr.tensor
// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__risc(%[[in1]]) : (!tfr.tensor) -> !tfr.tensor
// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id1]]) : (!tfr.tensor) -> tensor<f32>
// CHECK-NEXT: return %[[back]] : tensor<f32>
}

View File

@ -0,0 +1,235 @@
// RUN: tfr-opt %s -tfr-decompose -tfr-raise-to-tf -canonicalize -verify-diagnostics -split-input-file | FileCheck %s
//=================> User models, from GraphDef <====================
// CHECK-LABEL: my_identity
func @my_identity(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = "tf.MyIdentity"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
// CHECK-NEXT: return %arg0 : tensor<2x3xf32>
}
// CHECK-LABEL: my_rsqrt
func @my_rsqrt(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
%0 = "tf.MyRsqrt"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
return %0 : tensor<3x2x3xf32>
// CHECK-NEXT: %[[RE:.*]] = "tf.RiscReciprocal"(%arg0) : (tensor<2x3xf32>) -> tensor<*xf32>
// CHECK-NEXT: %[[SQRT:.*]] = "tf.RiscSqrt"(%[[RE]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[SQRT]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
}
// CHECK-LABEL: my_leaky_relu
func @my_leaky_relu(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
%0 = "tf.MyLeakyRelu"(%arg0) {alpha=3.0 : f32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
return %0 : tensor<3x2x3xf32>
// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32>
// CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor<f32>, tensor<*xi32>) -> tensor<*xf32>
// CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
}
// CHECK-LABEL: my_leaky_relu_with_default
func @my_leaky_relu_with_default(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
%0 = "tf.MyLeakyRelu"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
return %0 : tensor<3x2x3xf32>
// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<2.000000e-01> : tensor<f32>} : () -> tensor<f32>
// CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32>
// CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor<f32>, tensor<*xi32>) -> tensor<*xf32>
// CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
}
// CHECK-LABEL: my_cast
func @my_cast(%arg0: tensor<2x3xf32>) -> tensor<2x3xi32> {
%0 = "tf.MyCast"(%arg0) {Tout=i32} : (tensor<2x3xf32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
// CHECK-NEXT: %[[CAST:.*]] = "tf.RiscCast"(%arg0) {Tout = i32} : (tensor<2x3xf32>) -> tensor<*xi32>
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CAST]]) {shape = #tf.shape<2x3>} : (tensor<*xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return %[[ES]] : tensor<2x3xi32>
}
// CHECK-LABEL: my_pack_single_input
func @my_pack_single_input(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
%0 = "tf.MyPack"(%arg0) {N=1:i32, axis=0:i32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
return %0 : tensor<3x2x3xf32>
// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[ED:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ED]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
}
// CHECK-LABEL: my_pack_multiple_inputs
func @my_pack_multiple_inputs(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
%0 = "tf.MyPack"(%arg0, %arg1, %arg2) {N=3:i32, axis=0:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<3x2x3xf32>
return %0 : tensor<3x2x3xf32>
// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[ED0:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ED1:.*]] = "tf.ExpandDims"(%arg1, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
// CHECK-NEXT: %[[CC0:.*]] = "tf.RiscConcat"(%[[ED0]], %[[ED1]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ED2:.*]] = "tf.ExpandDims"(%arg2, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
// CHECK-NEXT: %[[CC1:.*]] = "tf.RiscConcat"(%[[CC0]], %[[ED2]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CC1]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
}
// CHECK-LABEL: my_add_n_single_input
func @my_add_n_single_input(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = "tf.MyAddN"(%arg0) {N=1:i32} : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
// CHECK-NEXT: return %arg0 : tensor<2x3xf32>
}
// CHECK-LABEL: my_add_n_multiple_inputs
func @my_add_n_multiple_inputs(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = "tf.MyAddN"(%arg0, %arg1, %arg2) {N=3:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
// CHECK-NEXT: %[[ADD0:.*]] = "tf.RiscAdd"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ADD1:.*]] = "tf.RiscAdd"(%[[ADD0]], %arg2) : (tensor<*xf32>, tensor<2x3xf32>) -> tensor<*xf32>
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ADD1]]) {shape = #tf.shape<2x3>} : (tensor<*xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: return %[[ES]] : tensor<2x3xf32>
}
// CHECK-LABEL: my_map_and_batch_dataset
func @my_map_and_batch_dataset(%input: tensor<*x!tf.variant>,
%other1: tensor<*xf32>,
%other2: tensor<*xi32>) -> tensor<*x!tf.variant> {
%0 = "tf.MyMapAndBatchDataset"(%input, %other1, %other2)
{batch_size=1000 : i64, num_parallel_calls = 8 : i64, drop_remainder = 0 : i1,
func = @"__some_func", output_types = [f32], output_shapes = [#tf.shape<>], preserve_cardinality = true}
: (tensor<*x!tf.variant>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf.variant>
return %0 : tensor<*x!tf.variant>
// CHECK-NEXT: %[[BATCH:.*]] = "tf.Const"() {value = dense<1000> : tensor<i64>} : () -> tensor<i64>
// CHECK-NEXT: %[[PARAL:.*]] = "tf.Const"() {value = dense<8> : tensor<i64>} : () -> tensor<i64>
// CHECK-NEXT: %[[KEEP:.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
// CHECK-NEXT: %[[CAST:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32>
// CHECK-NEXT: %[[RET:.*]] = "tf.MapAndBatchDatasetV0"(%arg0, %[[BATCH]], %[[PARAL]], %[[KEEP]], %arg1, %[[CAST]])
// CHECK-SAME: {f = @__some_func, output_shapes = [#tf.shape<>], output_types = [f32], preserve_cardinality = true} : (tensor<*x!tf.variant>, tensor<i64>, tensor<i64>, tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*x!tf.variant>
// CHECK-NEXT: return %[[RET]] : tensor<*x!tf.variant>
}
//=================> decomposition functions, translated from tf.compose api <====================
tfr.func @tf__my_identity(%value: !tfr.tensor) -> !tfr.tensor {
tfr.return %value : !tfr.tensor
}
tfr.func @tf__my_cast(%value: !tfr.tensor, %tout: !tfr.attr{tfr.name="Tout"}) -> !tfr.tensor {
%0 = tfr.call @tf__risc_cast(%value, %tout) : (!tfr.tensor, !tfr.attr) -> !tfr.tensor
tfr.return %0 : !tfr.tensor
}
tfr.func @tf__my_rsqrt(%value: !tfr.tensor) -> !tfr.tensor {
%1 = tfr.call @tf__risc_reciprocal(%value) : (!tfr.tensor) -> !tfr.tensor
%2 = tfr.call @tf__risc_sqrt(%1) : (!tfr.tensor) -> !tfr.tensor
tfr.return %2 : !tfr.tensor
}
tfr.func @tf__my_leaky_relu(%value: !tfr.tensor, %alpha: f32 {tfr.name="alpha", tfr.default=0.2:f32}) -> !tfr.tensor {
%1 = tfr.call @tf__risc_shape(%value) : (!tfr.tensor) -> !tfr.tensor
%2 = "tfr.constant_tensor"(%alpha) : (f32) -> tensor<f32>
%t = "tfr.cast"(%2) : (tensor<f32>) -> !tfr.tensor
%3 = tfr.call @tf__risc_broadcast(%t, %1) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
%4 = tfr.call @tf__risc_maximum(%value, %3) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
tfr.return %4 : !tfr.tensor
}
// TODO(fengliuai): use shape dialect to manipulate the shape then this can be decomposed further.
tfr.func @tf__my_expand_dims(%value: !tfr.tensor, %axis: i32 {tfr.name="axis"}) -> !tfr.tensor {
%axis_cst = "tfr.constant_tensor"(%axis) : (i32) -> tensor<i32>
%dim = "tfr.cast"(%axis_cst) : (tensor<i32>) -> !tfr.tensor
%0 = tfr.call @tf__expand_dims(%value, %dim) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
tfr.return %0 : !tfr.tensor
}
tfr.func @tf__my_pack(%values: !tfr.tensor_list,
%n: i32 {tfr.name="N"},
%axis: i32 {tfr.name="axis"}) -> !tfr.tensor {
%index = constant 0 : index
%cst = constant 1 : i32
%eq = cmpi "eq", %n, %cst : i32
%v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor
%temp = tfr.call @tf__my_expand_dims(%v1, %axis) : (!tfr.tensor, i32) -> !tfr.tensor
%res = scf.if %eq -> !tfr.tensor {
scf.yield %temp : !tfr.tensor
} else {
%step = index_cast %cst : i32 to index
%end = index_cast %n : i32 to index
%reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%temp) -> !tfr.tensor {
%v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor
%temp1 = tfr.call @tf__my_expand_dims(%v, %axis) : (!tfr.tensor, i32) -> !tfr.tensor
%reduce_next = tfr.call @tf__risc_concat(%reduce_iter, %temp1, %axis) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor
scf.yield %reduce_next : !tfr.tensor
}
scf.yield %reduce : !tfr.tensor
}
tfr.return %res : !tfr.tensor
}
tfr.func @tf__my_add_n(%values: !tfr.tensor_list,
%n: i32 {tfr.name="N"}) -> !tfr.tensor {
%index = constant 0 : index
%cst = constant 1 : i32
%eq = cmpi "eq", %n, %cst : i32
%v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor
%res = scf.if %eq -> !tfr.tensor {
scf.yield %v1 : !tfr.tensor
} else {
%step = index_cast %cst : i32 to index
%end = index_cast %n : i32 to index
%reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor {
%v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor
%reduce_next = tfr.call @tf__risc_add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
scf.yield %reduce_next : !tfr.tensor
}
scf.yield %reduce : !tfr.tensor
}
tfr.return %res : !tfr.tensor
}
tfr.func @tf__my_map_and_batch_dataset(
%input_dataset: !tfr.tensor,
%other_arguments: !tfr.tensor_list,
%batch_size: i64 {tfr.name="batch_size"},
%num_parallel_calls: i64 {tfr.name="num_parallel_calls"},
%drop_remainder: i1 {tfr.name="drop_remainder"},
%f: !tfr.attr {tfr.name="func"},
%output_types: !tfr.attr {tfr.name="output_types"},
%output_shapes: !tfr.attr {tfr.name="output_shapes"},
%preserve_cardinality: i1 {tfr.name="preserve_cardinality", tfr.default=false}) -> !tfr.tensor {
%batch = "tfr.constant_tensor"(%batch_size) : (i64) -> tensor<i64>
%batch1 = "tfr.cast"(%batch) : (tensor<i64>) -> !tfr.tensor
%calls = "tfr.constant_tensor"(%num_parallel_calls) : (i64) -> tensor<i64>
%calls1 = "tfr.cast"(%calls) : (tensor<i64>) -> !tfr.tensor
%drop = "tfr.constant_tensor"(%drop_remainder) : (i1) -> tensor<i1>
%drop1 = "tfr.cast"(%drop) : (tensor<i1>) -> !tfr.tensor
%ret = tfr.call @tf__map_and_batch_dataset_v0(%input_dataset, %batch1, %calls1, %drop1, %other_arguments, %f, %output_types, %output_shapes, %preserve_cardinality)
: (!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor_list, !tfr.attr, !tfr.attr, !tfr.attr, i1) -> !tfr.tensor
tfr.return %ret : !tfr.tensor
}
//=================> signatures of the primitive ops with kernels, modeled as external TFR function <==
tfr.func @tf__risc_cast_(!tfr.tensor, !tfr.attr{tfr.name="Tout"}) -> !tfr.tensor<Tout> attributes{Tout}
tfr.func @tf__risc_add_(!tfr.tensor<T>, !tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
tfr.func @tf__risc_concat_(!tfr.tensor<T>, !tfr.tensor<T>, i32{tfr.name="axis"}) -> !tfr.tensor<T> attributes{T}
tfr.func @tf__risc_broadcast_(!tfr.tensor<T>, !tfr.tensor<Tidx>) -> !tfr.tensor<T> attributes{T, Tidx}
tfr.func @tf__risc_reciprocal_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
tfr.func @tf__risc_sqrt_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
tfr.func @tf__risc_shape_(!tfr.tensor, !tfr.attr{tfr.name="T", tfr.default=i32}) -> !tfr.tensor<T> attributes{T}
tfr.func @tf__risc_maximum_(!tfr.tensor<T>, !tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
tfr.func @tf__expand_dims_(!tfr.tensor<T>, !tfr.tensor<Tdim>) -> !tfr.tensor<T> attributes{T, Tdim}
tfr.func @tf__map_and_batch_dataset_v0_(!tfr.tensor<T>, !tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor_list<Targuments>,
!tfr.attr{tfr.name="f"}, !tfr.attr{tfr.name="output_types"}, !tfr.attr{tfr.name="output_shapes"}, i1{tfr.name="preserve_cardinality"})
-> !tfr.tensor<T> attributes{T, Targuments}

View File

@ -0,0 +1,381 @@
// RUN: tfr-opt %s -verify-diagnostics -split-input-file | tfr-opt | FileCheck %s
// RUN: tfr-opt %s -canonicalize -verify-diagnostics -split-input-file | FileCheck %s -check-prefix=CANON
// Tests for types, ops with custom constraints, verifiers, printer or parser
// methods.
// CHECK-LABEL: tensor_type_noconstraint
func @tensor_type_noconstraint() -> !tfr.tensor
// -----
// CHECK-LABEL: tensor_type
func @tensor_type() -> !tfr.tensor<T>
// -----
// CHECK-LABEL: tensor_list_type_noconstraint
func @tensor_list_type_noconstraint() -> !tfr.tensor_list
// -----
// CHECK-LABEL: tensor_list_type_array_like
func @tensor_list_type_array_like() -> !tfr.tensor_list<[N, T]>
// -----
// CHECK-LABEL: tensor_list_type_tuple_like
func @tensor_list_type_tuple_like() -> !tfr.tensor_list<input_T>
// -----
// expected-error@+1 {{unbalanced '>' character in pretty dialect name}}
func @tensor_invalid_1() -> !tfr.tensor<[N, T>
// -----
// expected-error@+1 {{unexpected nul or EOF in pretty dialect name}}
func @tensor_invalid_2() -> !tfr.tensor<[N, T]
// -----
// CHECK-LABEL: call_op
func @call_op(%arg0: !tfr.tensor<T>, %arg1: !tfr.tensor_list<TL>, %arg2: i32) -> !tfr.tensor<K> {
%0 = tfr.call @Foo(%arg0, %arg1, %arg2) : (!tfr.tensor<T>, !tfr.tensor_list<TL>, i32) -> !tfr.tensor<K>
return %0 : !tfr.tensor<K>
}
// -----
// CHECK-LABEL: call_op_arg_attr(%arg0: i32) -> !tfr.tensor<K>
func @call_op_arg_attr(%arg0: i32) -> !tfr.tensor<K> {
%0 = tfr.call @Bar(%arg0) : (i32) -> !tfr.tensor<K>
return %0 : !tfr.tensor<K>
}
// -----
func @call_op_invalid_1(%arg0: tensor<?xf32>) -> !tfr.tensor<K> {
// expected-error@+1 {{got 'tensor<?xf32>'}}
%0 = tfr.call @Huu(%arg0) : (tensor<?xf32>) -> !tfr.tensor<K>
return %0 : !tfr.tensor<K>
}
// -----
// CHECK-LABEL: get_shape
func @get_shape(%arg0: !tfr.tensor) -> (!shape.shape, !shape.shape) {
%0 = tfr.get_shape %arg0 -> !shape.shape
%1 = "tfr.get_shape"(%arg0) : (!tfr.tensor) -> !shape.shape
return %0, %1 : !shape.shape, !shape.shape
}
// -----
// CHECK-LABEL: get_real_shape
// CANON-LABEL: get_real_shape
func @get_real_shape(%arg0: tensor<1x2xf32>) -> tensor<1xindex> {
%0 = "tfr.cast"(%arg0) : (tensor<1x2xf32>) -> !tfr.tensor
%1 = tfr.get_shape %0 -> !shape.shape
%2 = shape.to_extent_tensor %1 : !shape.shape -> tensor<1xindex>
return %2 : tensor<1xindex>
// CANON-NEXT: %[[s:.*]] = shape.const_shape [1, 2] : tensor<?xindex>
// CANON-NEXT: %[[e:.*]] = shape.to_extent_tensor %[[s]] : tensor<?xindex> -> tensor<1xindex>
// CANON-NEXT: return %[[e]] : tensor<1xindex>
}
// -----
func @get_element_type(%arg0: !tfr.tensor) -> (!tfr.attr, !tfr.attr) {
%0 = tfr.get_element_type %arg0 -> !tfr.attr
%1 = "tfr.get_element_type"(%arg0) : (!tfr.tensor) -> !tfr.attr
return %0, %1 : !tfr.attr, !tfr.attr
}
// -----
// CHECK-LABEL: from_tf_tensor
func @from_tf_tensor(%arg0: tensor<?xf32>) -> !tfr.tensor<K> {
%0 = "tfr.cast"(%arg0) : (tensor<?xf32>) -> !tfr.tensor<K>
return %0 : !tfr.tensor<K>
}
// -----
// CHECK-LABEL: to_tf_tensor
func @to_tf_tensor(%arg0: !tfr.tensor<T>) -> tensor<?xi32> {
%0 = "tfr.cast"(%arg0) : (!tfr.tensor<T>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: constant
func @constant() -> (!tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr) {
%0 = tfr.constant f32 -> !tfr.attr
%1 = tfr.constant [f32, i32] -> !tfr.attr
%2 = "tfr.constant"() {value = f32} : () -> !tfr.attr
%3 = "tfr.constant"() {value = [f32, i32]} : () -> !tfr.attr
return %0, %1, %2, %3 : !tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr
}
// -----
// CHECK-LABEL: equal
// CANON-LABEL: equal
func @equal() -> (i1, i1, i1, i1) {
%0 = tfr.constant f32 -> !tfr.attr
%1 = tfr.constant f32 -> !tfr.attr
%2 = tfr.constant i32 -> !tfr.attr
%same_type = tfr.equal %0,%1 -> i1
%diff_type = tfr.equal %0,%2 -> i1
%3 = tfr.constant "hello" -> !tfr.attr
%4 = tfr.constant "hello" -> !tfr.attr
%5 = tfr.constant "how are you" -> !tfr.attr
%same_str = tfr.equal %3,%4 -> i1
%diff_str = tfr.equal %3,%5 -> i1
return %same_type, %diff_type, %same_str, %diff_str : i1, i1, i1, i1
// CANON-NEXT: %true = constant true
// CANON-NEXT: %false = constant false
// CANON-NEXT: return %true, %false, %true, %false : i1, i1, i1, i1
}
// -----
// CHECK-LABEL: constant_tensor_scalar
func @constant_tensor_scalar(%arg0: i32) -> tensor<i32> {
%0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<i32>
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: constant_tensor_vector
func @constant_tensor_vector(%arg0: vector<1x2xi32>) -> tensor<1x2xi32> {
%0 = "tfr.constant_tensor"(%arg0) : (vector<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
// -----
// CHECK-LABEL: constant_tensor_array
// CANON-LABEL: constant_tensor_array
func @constant_tensor_array() -> !tfr.tensor {
%0 = tfr.constant [1, -1, 3] -> !tfr.attr
%1 = "tfr.constant_tensor"(%0) : (!tfr.attr) -> !tfr.tensor
return %1 : !tfr.tensor
// CANON-NEXT: "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
// CANON-NEXT: "tfr.cast"(%0) : (tensor<3xi64>) -> !tfr.tensor
// CANON-NEXT: return
}
// -----
// CHECK-LABEL: constant_tensor_scalar
// CANON-LABEL: constant_tensor_scalar
func @constant_tensor_scalar() -> !tfr.tensor {
%0 = "std.constant"() {value = 42 : i32} : () -> i32
%1 = "tfr.constant_tensor"(%0) : (i32) -> !tfr.tensor
return %1 : !tfr.tensor
// CANON-NEXT: "tf.Const"() {value = dense<42> : tensor<i32>} : () -> tensor<i32>
// CANON-NEXT: "tfr.cast"(%0) : (tensor<i32>) -> !tfr.tensor
// CANON-NEXT: return
}
// -----
func @constant_tensor_invalid_0(%arg0: i32) -> tensor<f32> {
// expected-error@+1 {{input and output should have the same scalar types.}}
%0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
func @constant_tensor_invalid_1(%arg0: vector<1xi32>) -> tensor<?xi32> {
// expected-error@+1 {{output type should be static and ranked}}
%0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
func @constant_tensor_invalid_2(%arg0: vector<1xi32>) -> tensor<1xf32> {
// expected-error@+1 {{input and output should have same shape and element type}}
%0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}
// -----
func @constant_tensor_invalid_3(%arg0: vector<1xi32>) -> tensor<1x1xi32> {
// expected-error@+1 {{input and output should have same shape and element type}}
%0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1x1xi32>
return %0 : tensor<1x1xi32>
}
// -----
func @constant_tensor_invalid_4(%arg0: i32) -> tensor<1x1xi32> {
// expected-error@+1 {{input can not be converted to an output tensor}}
%0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<1x1xi32>
return %0 : tensor<1x1xi32>
}
// -----
// CHECK-LABEL: get_element
func @get_element(%arg0: !tfr.tensor_list<T>) -> !tfr.tensor {
%cst = "std.constant"() {value = 1 : index} : () -> index
%0 = tfr.get_element %arg0[%cst] : (!tfr.tensor_list<T>, index) -> !tfr.tensor
return %0 : !tfr.tensor
}
// -----
// CHECK-LABEL: build_list
func @build_list(%arg0: !tfr.tensor<A>, %arg1: !tfr.tensor<B>) -> !tfr.tensor_list {
%0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor<A>, !tfr.tensor<B>) -> !tfr.tensor_list
return %0 : !tfr.tensor_list
}
// -----
// CHECK-LABEL: build_const_list
// CANON-LABEL: build_const_list
func @build_const_list() -> !tfr.attr {
%0 = "std.constant"() {value = 42 : i32} : () -> i32
%1 = "std.constant"() {value = 41 : i32} : () -> i32
%2 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr
return %2 : !tfr.attr
// CANON-NEXT: %[[c:.*]] = tfr.constant [42 : i32, 41 : i32] -> !tfr.attr
// CANON-NEXT: return %[[c]] : !tfr.attr
}
// -----
// CHECK-LABEL: tfr.func
tfr.func @External(%arg0: !tfr.tensor<A>,
%arg1: !tfr.tensor_list<C>,
%arg2: i32 {tfr.name = "A"},
%arg3: !tfr.attr {tfr.name = "T"})
-> (!tfr.tensor<A>, !tfr.tensor_list<C>)
attributes {A, C}
// -----
// CHECK-LABEL: tfr.func
tfr.func @Foo(%arg0: !tfr.tensor<A>,
%arg1: !tfr.tensor_list<C>,
%arg2: i32 {tfr.name = "A"},
%arg3: vector<1xi32> {tfr.name = "C"})
-> (!tfr.tensor<A>, !tfr.tensor_list<C>)
attributes {A, C} {
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<C>
}
// -----
// CHECK-LABEL: tfr.func
tfr.func @Bar(%arg0: !tfr.tensor<A>,
%arg2: i32 {tfr.name = "B"},
%arg3: vector<1xi32> {tfr.name = "C"})
-> (!tfr.tensor<A>, !tfr.tensor<A>)
attributes {A} {
tfr.return %arg0, %arg0 : !tfr.tensor<A>, !tfr.tensor<A>
}
// -----
// expected-error@+1 {{Undefined attributes are used: A}}
tfr.func @Foo_undefined_attr(%arg0: !tfr.tensor<A>,
%arg1: !tfr.tensor_list<A>,
%arg2: i32 {tfr.name = "A"},
%arg3: vector<1xi32> {tfr.name = "C"}) ->
(!tfr.tensor<A>, !tfr.tensor_list<A>) {
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
}
// -----
// expected-error@+1 {{3 attribute argument doesn't have a tfr.name attribute}}
tfr.func @Foo_unnamed_attr(%arg0: !tfr.tensor<A>,
%arg1: !tfr.tensor_list<A>,
%arg2: i32 {tfr.name = "A"},
%arg3: vector<1xi32>) ->
(!tfr.tensor<A>, !tfr.tensor_list<A>) {
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
}
// -----
// expected-error@+1 {{tfr.tensor_list argument should be before non tensor arguments}}
tfr.func @Foo_invalid_arg_order(%arg0: !tfr.tensor<A>,
%arg2: i32 {tfr.name = "A"},
%arg1: !tfr.tensor_list<A>,
%arg3: vector<1xi32> {tfr.name = "C"}) ->
(!tfr.tensor<A>, !tfr.tensor_list<A>) {
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
}
// -----
// expected-error@+1 {{tfr.tensor argument should be before tfr.tensor_list argument.}}
tfr.func @Foo_invalid_arg_order0(
%arg1: !tfr.tensor_list,
%arg0: !tfr.tensor,
%arg2: i32 {tfr.name = "A"},
%arg3: vector<1xi32> {tfr.name = "C"}) ->
(!tfr.tensor, !tfr.tensor_list) {
tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list
}
// -----
// expected-error@+1 {{tfr.tensor result should be before tfr.tensor_list result}}
tfr.func @Foo_invalid_result_order(%arg0: !tfr.tensor<A>,
%arg1: !tfr.tensor_list<A>,
%arg2: i32 {tfr.name = "A"},
%arg3: vector<1xi32> {tfr.name = "C"}) ->
(!tfr.tensor_list<A>, !tfr.tensor<A>) {
tfr.return %arg1, %arg0 : !tfr.tensor_list<A>, !tfr.tensor<A>
}
// -----
// expected-error@+1 {{More than one tfr.tensor_list argument isn't allowed}}
tfr.func @Foo_multiple_tensor_list_args(%arg0: !tfr.tensor<A>,
%arg1: !tfr.tensor_list<A>,
%arg2: !tfr.tensor_list<A>,
%arg3: i32 {tfr.name = "A"},
%arg4: vector<1xi32> {tfr.name = "C"}) ->
(!tfr.tensor<A>, !tfr.tensor_list<A>) {
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
}
// -----
// expected-error@+1 {{More than one tfr.tensor_list result isn't allowed}}
tfr.func @Foo_multiple_tensor_list_results(%arg0: !tfr.tensor<C>,
%arg1: !tfr.tensor_list<A>,
%arg2: i32 {tfr.name = "A"},
%arg3: vector<1xi32> {tfr.name = "C"}) ->
(!tfr.tensor_list<A>, !tfr.tensor_list<A>) {
tfr.return %arg1, %arg1 : !tfr.tensor_list<A>, !tfr.tensor_list<A>
}
// -----
// expected-error@+1 {{None tfr.tensor/tfr.tensor_list results aren't allowed as a result}}
tfr.func @Foo_return_attr(%arg0: !tfr.tensor<C>,
%arg1: !tfr.tensor_list<A>,
%arg2: i32 {tfr.name = "A"},
%arg3: vector<1xi32> {tfr.name = "C"}) -> i32 {
tfr.return %arg2 : i32
}

View File

@ -0,0 +1,76 @@
// RUN: tfr-opt %s -tfr-raise-to-tf -verify-diagnostics -split-input-file | FileCheck %s
tfr.func @tf__risc_same_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes {T}
tfr.func @tf__risc_concat_(!tfr.tensor_list<N, T>) -> !tfr.tensor<T> attributes {T, N}
tfr.func @tf__risc_split_(!tfr.tensor<T>, i32 {tfr.name="N"}) -> !tfr.tensor_list<N, T> attributes {T, N}
tfr.func @tf__risc_cast_(!tfr.tensor, !tfr.attr {tfr.name="K"}) -> !tfr.tensor<K> attributes {T, K}
// CHECK-LABEL: decompose_tf_same
func @decompose_tf_same(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> {
%0 = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
%1 = tfr.call @tf__risc_same(%0) : (!tfr.tensor) -> !tfr.tensor
%2 = "tfr.cast"(%1) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string>
return %2 : tensor<1x2x3x4x!tf.string>
// CHECK: %[[id:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<*x!tf.string>
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id]]) {shape = #tf.shape<1x2x3x4>} : (tensor<*x!tf.string>) -> tensor<1x2x3x4x!tf.string>
// CHECK: return %[[es]] : tensor<1x2x3x4x!tf.string>
}
// CHECK-LABEL: decompose_tf_consecutive
func @decompose_tf_consecutive(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
%0 = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
%1 = "tfr.cast"(%arg2) : (tensor<f32>) -> !tfr.tensor
%2 = tfr.call @tf__risc_same(%0) : (!tfr.tensor) -> !tfr.tensor
%3 = tfr.call @tf__risc_same(%1) : (!tfr.tensor) -> !tfr.tensor
%4 = "tfr.cast"(%3) : (!tfr.tensor) -> tensor<f32>
return %4 : tensor<f32>
// CHECK: %[[id0:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<*x!tf.string>
// CHECK: %[[id2:.*]] = "tf.RiscSame"(%arg2) : (tensor<f32>) -> tensor<*xf32>
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id2]]) {shape = #tf.shape<>} : (tensor<*xf32>) -> tensor<f32>
// CHECK: return %[[es]] : tensor<f32>
}
// CHECK-LABEL: decompose_tf_concat_n
func @decompose_tf_concat_n(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<3xf32> {
%0 = "tfr.cast"(%arg0) : (tensor<f32>) -> !tfr.tensor
%1 = "tfr.cast"(%arg1) : (tensor<f32>) -> !tfr.tensor
%2 = "tfr.cast"(%arg2) : (tensor<f32>) -> !tfr.tensor
%3 = "tfr.build_list"(%0, %1, %2) : (!tfr.tensor, !tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
%concat = tfr.call @tf__risc_concat(%3) : (!tfr.tensor_list) -> !tfr.tensor
%4 = "tfr.cast"(%concat) : (!tfr.tensor) -> tensor<3xf32>
return %4 : tensor<3xf32>
// CHECK: %[[concat:.*]] = "tf.RiscConcat"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<*xf32>
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[concat]]) {shape = #tf.shape<3>} : (tensor<*xf32>) -> tensor<3xf32>
// CHECK: return %[[es]] : tensor<3xf32>
}
// CHECK-LABEL: decompose_tf_split
func @decompose_tf_split(%arg0: tensor<3xf32>) -> (tensor<f32>) {
%0 = "tfr.cast"(%arg0) : (tensor<3xf32>) -> !tfr.tensor
%n = std.constant 3: i32
%split = tfr.call @tf__risc_split(%0, %n) : (!tfr.tensor, i32) -> !tfr.tensor_list
%i0 = std.constant 0: index
%s0 = tfr.get_element %split[%i0] : (!tfr.tensor_list, index) -> !tfr.tensor
%4 = "tfr.cast"(%s0) : (!tfr.tensor) -> tensor<f32>
return %4 : tensor<f32>
// CHECK: %[[split:.*]]:3 = "tf.RiscSplit"(%arg0) {N = 3 : i32} : (tensor<3xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[split]]#0) {shape = #tf.shape<>} : (tensor<*xf32>) -> tensor<f32>
// CHECK: return %[[es]] : tensor<f32>
}
// CHECK-LABEL: decompose_tf_cast
func @decompose_tf_cast(%arg0: tensor<f32>) -> tensor<i32> {
%0 = "tfr.cast"(%arg0) : (tensor<f32>) -> !tfr.tensor
%t = tfr.constant i32 -> !tfr.attr
%concat = tfr.call @tf__risc_cast(%0, %t) : (!tfr.tensor, !tfr.attr) -> !tfr.tensor
%4 = "tfr.cast"(%concat) : (!tfr.tensor) -> tensor<i32>
return %4 : tensor<i32>
// CHECK: %[[tfcast:.*]] = "tf.RiscCast"(%arg0) {K = i32} : (tensor<f32>) -> tensor<*xi32>
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) {shape = #tf.shape<>} : (tensor<*xi32>) -> tensor<i32>
// CHECK: return %[[es]] : tensor<i32>
}

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfr/utils/utils.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace TFR {
std::string GetComposeFuncName(StringRef tf_op_name) {
std::string compose_func_name;
for (int i = 0; i < tf_op_name.size(); ++i) {
if (tf_op_name[i] == '_') {
// The field name must not contain "_"s. "_Arg" and "_RetVal" are special
// op names and we can return empty string to skip the decomposition.
return {};
}
if (tf_op_name[i] == '.') {
compose_func_name.push_back('_');
} else if (tf_op_name[i] >= 'A' && tf_op_name[i] <= 'Z') {
compose_func_name.push_back('_');
compose_func_name.push_back(tf_op_name[i] + 'a' - 'A');
} else {
compose_func_name.push_back(tf_op_name[i]);
}
}
return compose_func_name;
}
std::string GetTFOpName(StringRef compose_func_name) {
std::string tf_op_name;
bool after_underscore = false;
for (int i = 0; i < compose_func_name.size(); ++i) {
if (compose_func_name[i] >= 'A' && compose_func_name[i] <= 'Z') {
// The field name must not contain uppercase letters.
return {};
}
if (after_underscore) {
if (compose_func_name[i] >= 'a' && compose_func_name[i] <= 'z') {
tf_op_name.push_back(compose_func_name[i] + 'A' - 'a');
after_underscore = false;
} else {
// The character after a "_" must be a lowercase letter.
return {};
}
} else if (compose_func_name[i] == '_') { // first time visit '_'
if (i + 1 < compose_func_name.size() && compose_func_name[i + 1] == '_') {
tf_op_name.push_back('.');
i++;
}
after_underscore = true;
} else {
tf_op_name.push_back(compose_func_name[i]);
}
}
if (after_underscore) {
// Trailing "_".
return {};
}
return tf_op_name;
}
} // namespace TFR
} // namespace mlir

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
#include <string>
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace TFR {
// This is a hardcoded rule for mapping a TF op name to the corresponding
// TFR function name. Examples:
// tf.Pack => tf__pack
// tf.ConcatV2 => tf__concat_v2
// TODO(fengliuai): move to an util file.
std::string GetComposeFuncName(StringRef tf_op_name);
// This is a hardcoded rule for mapping a TFR function op name to the
// corresponding TF opname. Examples:
// tf__pack -> tf.Pack
// tf__concat_v2 => tf.ConcatV2
std::string GetTFOpName(StringRef compose_func_name);
} // namespace TFR
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_