Graduate the mlir implementation of tfr from experimental
PiperOrigin-RevId: 333369220 Change-Id: Iafce731c4d80f06eb1e95aee6ea0f67af8caeac5
This commit is contained in:
parent
40ccaf67ca
commit
938cc7bf9c
165
tensorflow/compiler/mlir/tfr/BUILD
Normal file
165
tensorflow/compiler/mlir/tfr/BUILD
Normal file
@ -0,0 +1,165 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load(
|
||||
"//third_party/mlir:tblgen.bzl",
|
||||
"gentbl",
|
||||
)
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/tfr/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tfr_ops_td_files",
|
||||
srcs = [
|
||||
"ir/tfr_ops.td",
|
||||
"//tensorflow/compiler/mlir/tensorflow:ir/tf_op_base.td",
|
||||
"//tensorflow/compiler/mlir/tensorflow:ir/tf_op_interfaces.td",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeOps.td",
|
||||
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tfr_ops_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-decls",
|
||||
"ir/tfr_ops.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-defs",
|
||||
"ir/tfr_ops.cc.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfr_ops.td",
|
||||
td_srcs = [
|
||||
":tfr_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfr",
|
||||
srcs = [
|
||||
"ir/tfr_ops.cc",
|
||||
"ir/tfr_ops.cc.inc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfr_ops.h",
|
||||
"ir/tfr_ops.h.inc",
|
||||
"ir/tfr_types.h",
|
||||
],
|
||||
deps = [
|
||||
":tfr_ops_inc_gen",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
srcs = [
|
||||
"utils/utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/utils.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "passes",
|
||||
srcs = [
|
||||
"passes/canonicalize.cc",
|
||||
"passes/decompose.cc",
|
||||
"passes/raise_to_tf.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"passes/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":tfr",
|
||||
":utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:SCFToStandard",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tfr-opt",
|
||||
srcs = ["passes/tfr_opt.cc"],
|
||||
deps = [
|
||||
":passes",
|
||||
":tfr",
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/compiler/mlir:passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "//tensorflow/compiler/mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/tfr:tfr-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
"@llvm-project//mlir:run_lit.sh",
|
||||
],
|
||||
)
|
590
tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc
Normal file
590
tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc
Normal file
@ -0,0 +1,590 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/FunctionImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace TFR {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InlinerInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class defines the interface for inlining within the TFR dialect.
|
||||
struct TFRInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
// Returns true if the given region 'src' can be inlined into the region
|
||||
// 'dest' that is attached to an operation registered to the current dialect.
|
||||
bool isLegalToInline(Region *dest, Region *src,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if the given operation 'op', that is registered to this
|
||||
// dialect, can be inlined into the region 'dest' that is attached to an
|
||||
// operation registered to the current dialect.
|
||||
bool isLegalToInline(Operation *op, Region *dest,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Handle the given inlined terminator by replacing it with a new operation
|
||||
// as necessary. Required when the region has only one block.
|
||||
void handleTerminator(Operation *op,
|
||||
ArrayRef<Value> valuesToRepl) const final {
|
||||
auto retValOp = dyn_cast<TFRReturnOp>(op);
|
||||
if (!retValOp) return;
|
||||
|
||||
for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) {
|
||||
std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value));
|
||||
}
|
||||
}
|
||||
|
||||
// Attempts to materialize a conversion for a type mismatch between a call
|
||||
// from this dialect, and a callable region. This method should generate an
|
||||
// operation that takes 'input' as the only operand, and produces a single
|
||||
// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||
// should be returned.
|
||||
Operation *materializeCallConversion(OpBuilder &builder, Value input,
|
||||
Type result_type,
|
||||
Location conversion_loc) const final {
|
||||
if (!result_type.isa<IntegerType>()) return nullptr;
|
||||
return builder.create<TruncateIOp>(conversion_loc, result_type, input);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFR Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TFRDialect::TFRDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"tfr", context, TypeID::get<TFRDialect>()) {
|
||||
addTypes<TFRTensorType, TFRTensorListType, TFRAttrType>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
|
||||
>();
|
||||
|
||||
addInterfaces<TFRInlinerInterface>();
|
||||
}
|
||||
|
||||
bool TFRType::classof(Type type) {
|
||||
return llvm::isa<TFRDialect>(type.getDialect());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Custom op methods
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(ConstantTensorOp op) {
|
||||
auto input_type = op.arg().getType();
|
||||
auto output_type = op.out().getType();
|
||||
|
||||
if (auto output_tensor_type = output_type.dyn_cast<TFRTensorType>()) {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto output_tensor_type = output_type.dyn_cast<RankedTensorType>();
|
||||
if (!output_tensor_type || !output_tensor_type.hasStaticShape()) {
|
||||
op.emitError("output type should be static and ranked.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (output_tensor_type.getRank() == 0) {
|
||||
bool same_scalar = output_tensor_type.getElementType() == input_type;
|
||||
if (!same_scalar) {
|
||||
op.emitError("input and output should have the same scalar types.");
|
||||
}
|
||||
return success(same_scalar);
|
||||
}
|
||||
|
||||
if (auto input_vector_type = input_type.dyn_cast<VectorType>()) {
|
||||
bool same_element_type = output_tensor_type.getElementType() ==
|
||||
input_vector_type.getElementType();
|
||||
bool same_shape =
|
||||
output_tensor_type.getShape() == input_vector_type.getShape();
|
||||
if (!same_element_type || !same_shape) {
|
||||
op.emitError("input and output should have same shape and element type.");
|
||||
}
|
||||
return success(same_element_type && same_shape);
|
||||
}
|
||||
|
||||
op.emitError("input can not be converted to an output tensor.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(TFRFuncOp func) {
|
||||
// Collect all attribute names used by the tensor and tensor list arguments
|
||||
// and returns. Also, collect the names of all the attribute arguments as the
|
||||
// defined list. Later on, the used attribute names will be verified to be in
|
||||
// the defined list.
|
||||
llvm::SmallVector<StringAttr, 4> used_attrs;
|
||||
|
||||
// While scanning the arguments, record the start/end indices of each argument
|
||||
// type, so the order can be verified as well.
|
||||
// TODO(fengliuai): the attribute arguments with default values need to be
|
||||
// at the end?
|
||||
int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
|
||||
last_tensor_list = -1, first_attr = -1;
|
||||
|
||||
for (auto arg : llvm::enumerate(func.getType().getInputs())) {
|
||||
Type arg_type = arg.value();
|
||||
|
||||
if (auto tensor = arg_type.dyn_cast<TFRTensorType>()) {
|
||||
if (first_tensor == -1) {
|
||||
first_tensor = arg.index();
|
||||
}
|
||||
last_tensor = arg.index();
|
||||
auto used = tensor.getAttrKeys();
|
||||
used_attrs.append(used.begin(), used.end());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_list = arg_type.dyn_cast<TFRTensorListType>()) {
|
||||
if (first_tensor_list == -1) {
|
||||
first_tensor_list = arg.index();
|
||||
}
|
||||
last_tensor_list = arg.index();
|
||||
auto used = tensor_list.getAttrKeys();
|
||||
used_attrs.append(used.begin(), used.end());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!arg_type.isa<TensorType>()) {
|
||||
if (first_attr == -1) {
|
||||
first_attr = arg.index();
|
||||
}
|
||||
auto name =
|
||||
func.getArgAttrOfType<StringAttr>(arg.index(), kAttrArgumentNameAttr);
|
||||
if (!name) {
|
||||
func.emitError(
|
||||
llvm::Twine(arg.index()) +
|
||||
" attribute argument doesn't have a tfr.name attribute.");
|
||||
return failure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
func.emitError("Builtin TensorType isn't allowed as the argument.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Verify the argument order: tensors, tensor list, attributes; and also
|
||||
// verify there is at most one tensor list argument.
|
||||
if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
|
||||
func.emitError(
|
||||
"tfr.tensor argument should be before tfr.tensor_list argument.");
|
||||
return failure();
|
||||
}
|
||||
if (first_attr != -1 && first_attr < last_tensor_list) {
|
||||
func.emitError(
|
||||
"tfr.tensor_list argument should be before non tensor arguments.");
|
||||
return failure();
|
||||
}
|
||||
if (first_tensor_list != last_tensor_list) {
|
||||
func.emitError("More than one tfr.tensor_list argument isn't allowed.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Verify the result order: tensor, tensor list, and also verify at most one
|
||||
// tensor list result.
|
||||
bool seen_tensor_list = false;
|
||||
for (auto result_type : func.getType().getResults()) {
|
||||
if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
|
||||
if (seen_tensor_list) {
|
||||
func.emitError(
|
||||
"tfr.tensor result should be before tfr.tensor_list result.");
|
||||
return failure();
|
||||
}
|
||||
auto used = tensor.getAttrKeys();
|
||||
used_attrs.append(used.begin(), used.end());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
|
||||
if (seen_tensor_list) {
|
||||
func.emitError("More than one tfr.tensor_list result isn't allowed.");
|
||||
return failure();
|
||||
}
|
||||
seen_tensor_list = true;
|
||||
auto used = tensor_list.getAttrKeys();
|
||||
used_attrs.append(used.begin(), used.end());
|
||||
continue;
|
||||
}
|
||||
|
||||
func.emitError(
|
||||
"None tfr.tensor/tfr.tensor_list results aren't allowed as a "
|
||||
"result.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Verify that all the used attributes are in the attribute arguments.
|
||||
llvm::SmallVector<StringAttr, 4> undefined_attrs;
|
||||
for (auto attr : used_attrs) {
|
||||
if (!func.getAttr(attr.getValue())) {
|
||||
undefined_attrs.push_back(attr);
|
||||
}
|
||||
}
|
||||
if (!undefined_attrs.empty()) {
|
||||
llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
|
||||
std::transform(undefined_attrs.begin(), undefined_attrs.end(),
|
||||
attr_names.begin(),
|
||||
[](StringAttr attr) { return attr.getValue().str(); });
|
||||
func.emitError(llvm::Twine("Undefined attributes are used: ",
|
||||
llvm::join(attr_names, ",")));
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult ParseFuncOp(OpAsmParser &parser, OperationState *result) {
|
||||
auto build_func_type = [](Builder &builder, ArrayRef<Type> arg_types,
|
||||
ArrayRef<Type> results, impl::VariadicFlag,
|
||||
std::string &) {
|
||||
return builder.getFunctionType(arg_types, results);
|
||||
};
|
||||
return impl::parseFunctionLikeOp(parser, *result, /*allowVariadic=*/false,
|
||||
build_func_type);
|
||||
}
|
||||
|
||||
static void PrintFuncOp(OpAsmPrinter &p, TFRFuncOp op) {
|
||||
FunctionType fn_type = op.getType();
|
||||
impl::printFunctionLikeOp(p, op, fn_type.getInputs(), /*isVariadic=*/false,
|
||||
fn_type.getResults());
|
||||
}
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFR {
|
||||
struct ConvertConstToTensorConst : public OpRewritePattern<ConstantTensorOp> {
|
||||
using OpRewritePattern<ConstantTensorOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = cst_tensor_op.getLoc();
|
||||
Type out_type = cst_tensor_op.getType();
|
||||
Operation *new_cst = nullptr;
|
||||
|
||||
ArrayAttr array;
|
||||
if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
|
||||
llvm::DenseSet<Type> all_types;
|
||||
for (auto it : array) {
|
||||
all_types.insert(it.getType());
|
||||
}
|
||||
if (all_types.size() != 1) return failure();
|
||||
ShapedType new_out_type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(array.size())}, *all_types.begin());
|
||||
DenseElementsAttr attr =
|
||||
DenseElementsAttr::get(new_out_type, array.getValue());
|
||||
new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, attr);
|
||||
if (out_type.isa<TFRTensorType>()) {
|
||||
new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
|
||||
}
|
||||
rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
Attribute scalar;
|
||||
if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
|
||||
Type new_out_type = RankedTensorType::get({}, scalar.getType());
|
||||
new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
|
||||
if (out_type.isa<TFRTensorType>()) {
|
||||
new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
|
||||
}
|
||||
rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct RemoveRedundantCast : public OpRewritePattern<CastOp> {
|
||||
using OpRewritePattern<CastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CastOp cast_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto preceding_cast =
|
||||
llvm::dyn_cast_or_null<CastOp>(cast_op.arg().getDefiningOp());
|
||||
if (!preceding_cast) {
|
||||
return failure();
|
||||
}
|
||||
Value input = preceding_cast.arg();
|
||||
Type input_type = input.getType();
|
||||
Type output_type = cast_op.getType();
|
||||
|
||||
// If the two types are the same, the back-to-back tfr.cast ops can be
|
||||
// removed.
|
||||
if (input_type == output_type || output_type.isa<UnrankedTensorType>()) {
|
||||
rewriter.replaceOp(cast_op, {input});
|
||||
return success();
|
||||
}
|
||||
|
||||
// If the rank of the input tensor isn't ranked, we replace the pair
|
||||
// with tf.EnsureShape op so it can be removed after shape inference or
|
||||
// confirmed at runtime.
|
||||
if (input_type.isa<UnrankedTensorType>() && output_type.isa<ShapedType>()) {
|
||||
auto shape = output_type.cast<ShapedType>().getShape();
|
||||
auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape);
|
||||
rewriter.replaceOpWithNewOp<TF::EnsureShapeOp>(cast_op, output_type,
|
||||
input, shape_attr);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct GetTensorShape : public OpRewritePattern<GetShapeOp> {
|
||||
using OpRewritePattern<GetShapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GetShapeOp shape_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *preceding_op = shape_op.arg().getDefiningOp();
|
||||
if (auto cast_op = llvm::dyn_cast_or_null<CastOp>(preceding_op)) {
|
||||
// replace this pair by shape.shape_of, so the folding works.
|
||||
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(shape_op, cast_op.arg());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
|
||||
using OpRewritePattern<GetElementOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GetElementOp ge_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
IntegerAttr index;
|
||||
if (!matchPattern(ge_op.index(), m_Constant(&index))) {
|
||||
return failure();
|
||||
}
|
||||
auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
|
||||
ge_op.tensor_list().getDefiningOp());
|
||||
if (!preceding_build_list ||
|
||||
preceding_build_list.getNumOperands() <= index.getInt()) {
|
||||
return failure();
|
||||
}
|
||||
Value input = preceding_build_list.getOperand(index.getInt());
|
||||
Type output_type = ge_op.getType();
|
||||
if (input.getType() != output_type &&
|
||||
!output_type.isa<UnrankedTensorType>()) {
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOp(ge_op, {input});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
|
||||
using OpRewritePattern<BuildListOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(BuildListOp bl_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Attribute, 4> array_list;
|
||||
array_list.reserve(bl_op.getNumOperands());
|
||||
for (const auto &operand : bl_op.getOperands()) {
|
||||
Attribute array_elt;
|
||||
if (!matchPattern(operand, m_Constant(&array_elt))) {
|
||||
return failure();
|
||||
}
|
||||
array_list.push_back(array_elt);
|
||||
}
|
||||
auto array_attr = rewriter.getArrayAttr(array_list);
|
||||
rewriter.replaceOpWithNewOp<TFR::ConstOp>(bl_op, array_attr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void ConstantTensorOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ConvertConstToTensorConst>(context);
|
||||
}
|
||||
|
||||
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<RemoveRedundantCast>(context);
|
||||
}
|
||||
|
||||
void GetShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<GetTensorShape>(context);
|
||||
}
|
||||
|
||||
void GetElementOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<RemoveRedundantGetElement>(context);
|
||||
}
|
||||
|
||||
void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<BuildConstantListAsAttr>(context);
|
||||
}
|
||||
|
||||
OpFoldResult TFR::EqualOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "equal op has two operands");
|
||||
auto ctx = getContext();
|
||||
if (operands[0] == operands[1]) return BoolAttr::get(/*value=*/true, ctx);
|
||||
return BoolAttr::get(/*value=*/false, ctx);
|
||||
}
|
||||
|
||||
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
|
||||
// Return the held attribute value.
|
||||
return value();
|
||||
}
|
||||
|
||||
// CallableOpInterface
|
||||
Region *TFRFuncOp::getCallableRegion() {
|
||||
return isExternal() ? nullptr : &body().front();
|
||||
}
|
||||
|
||||
// CallableOpInterface
|
||||
ArrayRef<Type> TFRFuncOp::getCallableResults() {
|
||||
return getType().getResults();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Parses a TFR type.
|
||||
// tfr_type ::= tensor_type | tensor_list_type | attr_type
|
||||
// string_list ::= `[` string-literal (, string-literal)+ `]`
|
||||
// tensor_type ::= `tensor`
|
||||
// | `tensor<` (string-literal | string_list) '>'
|
||||
// tensor_list_type ::= `tensor_list`
|
||||
// | `tensor_list<` (string-literal | string_list) '>'
|
||||
// attr_type ::= `attr`
|
||||
Type TFRDialect::parseType(DialectAsmParser &parser) const {
|
||||
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
|
||||
MLIRContext *ctx = loc.getContext();
|
||||
|
||||
StringRef typeNameSpelling;
|
||||
if (failed(parser.parseKeyword(&typeNameSpelling))) return {};
|
||||
llvm::SmallVector<StringAttr, 4> attrs;
|
||||
if (succeeded(parser.parseOptionalLess())) {
|
||||
bool l_square_parsed = false;
|
||||
if (succeeded(parser.parseOptionalLSquare())) {
|
||||
l_square_parsed = true;
|
||||
}
|
||||
|
||||
do {
|
||||
StringRef attr;
|
||||
if (failed(parser.parseKeyword(&attr))) return {};
|
||||
attrs.push_back(StringAttr::get(attr, ctx));
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
|
||||
if (l_square_parsed && failed(parser.parseRSquare())) {
|
||||
parser.emitError(parser.getNameLoc(), "expected ']'");
|
||||
}
|
||||
|
||||
if (failed(parser.parseGreater())) {
|
||||
parser.emitError(parser.getNameLoc(), "expected '>'");
|
||||
}
|
||||
}
|
||||
|
||||
if (typeNameSpelling == "tensor") {
|
||||
return TFRTensorType::getChecked(attrs, loc);
|
||||
} else if (typeNameSpelling == "tensor_list") {
|
||||
return TFRTensorListType::getChecked(attrs, loc);
|
||||
} else if (typeNameSpelling == "attr") {
|
||||
return TFRAttrType::getChecked(loc);
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
void TFRDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
llvm::ArrayRef<StringAttr> attrs;
|
||||
|
||||
if (type.isa<TFRAttrType>()) {
|
||||
os << "attr";
|
||||
return;
|
||||
}
|
||||
if (auto tensor_ty = type.dyn_cast<TFRTensorType>()) {
|
||||
attrs = tensor_ty.getAttrKeys();
|
||||
os << "tensor";
|
||||
} else if (auto tensor_list_ty = type.dyn_cast<TFRTensorListType>()) {
|
||||
attrs = tensor_list_ty.getAttrKeys();
|
||||
os << "tensor_list";
|
||||
} else {
|
||||
llvm_unreachable("Unhandled tfr type");
|
||||
}
|
||||
|
||||
if (attrs.empty()) return;
|
||||
os << "<";
|
||||
|
||||
if (attrs.size() > 1) {
|
||||
os << "[";
|
||||
}
|
||||
|
||||
llvm::interleaveComma(attrs, os,
|
||||
[&](StringAttr attr) { os << attr.getValue(); });
|
||||
|
||||
if (attrs.size() > 1) {
|
||||
os << "]";
|
||||
}
|
||||
os << ">";
|
||||
}
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
55
tensorflow/compiler/mlir/tfr/ir/tfr_ops.h
Normal file
55
tensorflow/compiler/mlir/tfr/ir/tfr_ops.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/FunctionSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFR {
|
||||
|
||||
constexpr char kAttrArgumentNameAttr[] = "tfr.name";
|
||||
constexpr char kAttrArgumentDefaultAttr[] = "tfr.default";
|
||||
|
||||
class TFRDialect : public Dialect {
|
||||
public:
|
||||
explicit TFRDialect(MLIRContext *context);
|
||||
|
||||
static StringRef getDialectNamespace() { return "tfr"; }
|
||||
|
||||
// Parse a type registered to this dialect.
|
||||
Type parseType(DialectAsmParser &parser) const override;
|
||||
|
||||
// Prints a type registered to this dialect.
|
||||
void printType(Type ty, DialectAsmPrinter &os) const override;
|
||||
};
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h.inc"
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_
|
436
tensorflow/compiler/mlir/tfr/ir/tfr_ops.td
Normal file
436
tensorflow/compiler/mlir/tfr/ir/tfr_ops.td
Normal file
@ -0,0 +1,436 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the operation definition file for TFR
|
||||
|
||||
#ifndef DIALECT_TFR_OPS_
|
||||
#define DIALECT_TFR_OPS_
|
||||
|
||||
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TFR_Dialect : Dialect {
|
||||
let name = "tfr";
|
||||
|
||||
let description = [{
|
||||
The TensorFlow Composition dialect.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::TFR";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// tensor argument types
|
||||
class TFR_Type<string name> : DialectType<TFR_Dialect,
|
||||
CPred<"$_self.isa<mlir::TFR::" # name # "Type>()">,
|
||||
"TFR " # name #" type">,
|
||||
BuildableType<"$_builder.getType<mlir::TFR::" # name # "Type>()">;
|
||||
def TFR_TensorType : TFR_Type<"TFRTensor">;
|
||||
def TFR_TensorListType : TFR_Type<"TFRTensorList">;
|
||||
def TFR_AllTensorTypes : Type<Or<[
|
||||
TFR_TensorType.predicate,
|
||||
TFR_TensorListType.predicate]>, "all tensor related types">;
|
||||
|
||||
// attribute argument types
|
||||
def TFR_AttrType : TFR_Type<"TFRAttr">;
|
||||
def TFR_AttrScalarType: TypeAlias<TF_ElementType, "scalar attribute">;
|
||||
def TFR_AttrVectorType : VectorOf<[TF_ElementType, TFR_AttrType]>;
|
||||
def TFR_AllAttrTypes : Type<Or<[
|
||||
TFR_AttrType.predicate,
|
||||
Index.predicate,
|
||||
TFR_AttrScalarType.predicate,
|
||||
TFR_AttrVectorType.predicate]>, "all attribute related types">;
|
||||
|
||||
// all allowed arguments types
|
||||
def TFR_allowedArgType : Type<Or<[
|
||||
TFR_AllTensorTypes.predicate,
|
||||
TFR_AllAttrTypes.predicate]>, "allowed tfr.call operand types">;
|
||||
|
||||
def TFR_allowedConstValues : Attr<Or<[
|
||||
FlatSymbolRefAttr.predicate,
|
||||
TypeAttr.predicate,
|
||||
StrAttr.predicate,
|
||||
ArrayAttr.predicate]>, "allowed tfr.constant value"> {
|
||||
let storageType = "Attribute";
|
||||
let returnType = "Attribute";
|
||||
let convertFromStorage = "$_self";
|
||||
let constBuilderCall = "$0";
|
||||
}
|
||||
|
||||
// all allowed result types
|
||||
def TFR_allowedResultType : TypeAlias<TFR_AllTensorTypes,
|
||||
"allowed tfr.call result types">;
|
||||
|
||||
// standard tensor type and tfr.tensor types can be casted to each other.
|
||||
def TFR_singleTensorType : Type<Or<[
|
||||
TFR_TensorType.predicate,
|
||||
TF_Tensor.predicate]>, "single tensor or tfr.tensor type">;
|
||||
|
||||
// all allowed build list input types
|
||||
def TFR_allowedBuiltListType : Type<Or<[
|
||||
TFR_TensorType.predicate,
|
||||
TF_ElementType.predicate]>, "single tfr.tensor or tensor element type">;
|
||||
|
||||
// all allowed build list result types
|
||||
def TFR_allowedListResultType : Type<Or<[
|
||||
TFR_TensorListType.predicate,
|
||||
TFR_AttrType.predicate]>, "tfr.tensor_list or tfr.attr type">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TFR_Op<string mnemonic, list<OpTrait> traits> :
|
||||
Op<TFR_Dialect, mnemonic, traits>;
|
||||
|
||||
def TFR_CallOp : TFR_Op<"call", [CallOpInterface]> {
|
||||
let description = [{
|
||||
The `call` operation represents a direct call to a function that is within
|
||||
the same symbol scope as the callee. The operands and result types of the
|
||||
call must match the specified function type. The callee is encoded as a
|
||||
symbol reference attribute named "callee".
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%2 = tfr.call @my_add(%0, %1) : (tfr.tensor, f32) -> tfr.tensor_list
|
||||
```
|
||||
|
||||
Note that the operands of the `call` operation can only be with tfr.tensor,
|
||||
tfr.tensor_list, tfr.attr and mlir float and integer types. The results of
|
||||
the `call` operation can only be with tfr.tensor and tfr.tensor_list types.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$callee,
|
||||
Variadic<TFR_allowedArgType>:$args);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TFR_allowedResultType>:$outs);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
StringRef getCallee() { return callee(); }
|
||||
|
||||
// Get the argument operands to the called function.
|
||||
operand_range getArgOperands() { return args(); }
|
||||
|
||||
// Return the callee of this operation.
|
||||
CallInterfaceCallable getCallableForCallee() { return calleeAttr(); }
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$callee `(` $args `)` attr-dict `:` functional-type($args, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def TFR_CastOp : TFR_Op<"cast", [NoSideEffect]> {
|
||||
let description = [{
|
||||
The `cast` operation converts the operand with built-in tensor type to
|
||||
tfr.tensor type, or vice versa.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%1 = tfr.cast(%0) : tensor<f32> -> !tfr.tensor
|
||||
%3 = tfr.cast(%1) : !tfr.tensor -> tensor<f32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TFR_singleTensorType:$arg);
|
||||
|
||||
let results = (outs TFR_singleTensorType:$out);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Return element type of the input tensor type. Only available when the
|
||||
// input is a MLIR built-in tensor type.
|
||||
Attribute getInputElementType() {
|
||||
if (auto ty = arg().getType().dyn_cast<TensorType>()) {
|
||||
return TypeAttr::get(ty.getElementType());
|
||||
}
|
||||
return {};
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TFR_GetShapeOp : TFR_Op<"get_shape", [NoSideEffect]> {
|
||||
let description = [{
|
||||
The `get_shape` operation gets the shape of a tfr.tensor and returns
|
||||
!shape.shape type.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%1 = "tfr.get_shape"(%0) : !tfr.tensor -> !shape.shape
|
||||
%1 = tfr.get_shape %0 -> !shape.shape
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TFR_TensorType:$arg);
|
||||
|
||||
let results = (outs Shape_ShapeType:$out);
|
||||
|
||||
let assemblyFormat = "$arg attr-dict `->` type($out)";
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TFR_GetElementTypeOp : TFR_Op<"get_element_type", [NoSideEffect]> {
|
||||
let description = [{
|
||||
The `get_element_type` operation gets the element type of a tfr.tensor and
|
||||
returns !tfr.attr.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%1 = "tfr.get_element_type"(%0) : !tfr.tensor -> !tfr.attr
|
||||
%1 = tfr.get_element_type %0 -> !tfr.attr
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TFR_TensorType:$arg);
|
||||
|
||||
let results = (outs TFR_AttrType:$out);
|
||||
|
||||
let assemblyFormat = "$arg attr-dict `->` type($out)";
|
||||
}
|
||||
|
||||
def TFR_EqualOp : TFR_Op<"equal", [NoSideEffect, SameTypeOperands]> {
|
||||
let description = [{
|
||||
The `equal` operation compares the values of the tfr.attr type arguments.
|
||||
The operation returns an i1 boolean indicating if the two values are the
|
||||
same.
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%x = tfr.equal %lhs, %rhs -> i1
|
||||
%x = "tfr.equal"(%lhs, %rhs) : (!tfr.attr, !tfr.attr) -> i1
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFR_AttrType:$lhs,
|
||||
TFR_AttrType:$rhs
|
||||
);
|
||||
let results = (outs BoolLike:$result);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `->` type($result)";
|
||||
}
|
||||
|
||||
def TFR_ConstOp : TFR_Op<"constant", [ConstantLike, NoSideEffect]> {
|
||||
let description = [{
|
||||
The `attr` operation stores TF op's attribute, which doesn't support
|
||||
arithmetic operations.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%1 = "tfr.constant"() { value: i32 } : () -> !tfr.attr
|
||||
%2 = "tfr.constant"() { value: [i32, f32] } : () -> !tfr.attr
|
||||
%3 = tfr.constant [i32, f32] -> !tfr.attr
|
||||
%4 = tfr.constant f32 -> !tfr.attr
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TFR_allowedConstValues:$value);
|
||||
|
||||
let results = (outs TFR_AttrType:$out);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &state, Attribute value",
|
||||
[{
|
||||
auto* ctx = value.getContext();
|
||||
state.addAttribute("value", value);
|
||||
state.addTypes(TFRAttrType::get(ctx));
|
||||
}]>
|
||||
];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$value attr-dict `->` type($out)
|
||||
}];
|
||||
}
|
||||
|
||||
def TFR_ConstantTensorOp : TFR_Op<"constant_tensor", [NoSideEffect]> {
|
||||
let description = [{
|
||||
The `constant_tensor` operation converts the operand with non-built-in
|
||||
tensor type to built-in tensor type or tfr.tensor type. If it is built-in
|
||||
tensor type, the shape shouldn't be changed during the conversion.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%1 = tfr.contant_tensor(%0) : f32 -> tensor<f32>
|
||||
%3 = tfr.contant_tensor(%2) : vector<1xf32> -> tensor<1xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TFR_AllAttrTypes:$arg);
|
||||
|
||||
let results = (outs TFR_singleTensorType:$out);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TFR_GetElementOp : TFR_Op<"get_element", [NoSideEffect]> {
|
||||
let description = [{
|
||||
The `get_element` operation extracts one tfr.tensor element from a
|
||||
tfr.tensor_list.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%2 = tfr.get_element %1[%0] : (tfr.tensor, index) -> tfr.tensor
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFR_TensorListType:$tensor_list,
|
||||
Index:$index);
|
||||
|
||||
let results = (outs TFR_TensorType:$out);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$tensor_list `[` $index `]` attr-dict `:`
|
||||
`(` type($tensor_list) `,` type($index) `)` `->` type($out)
|
||||
}];
|
||||
}
|
||||
|
||||
def TFR_BuildListOp : TFR_Op<"build_list", [NoSideEffect]> {
|
||||
let description = [{
|
||||
The `build_list` operation builds a tensor list from a list of tensors, or
|
||||
an tfr.attr from a list of scalars.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%3 = tfr.build_list(%2, %1, %0) :
|
||||
(tfr.tensor, tfr.tensor, tfr.tensor) -> tfr.tensor_list
|
||||
%3 = tfr.build_list(%2, %1, %0) : (i32, i32, i32) -> tfr.attr
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<TFR_allowedBuiltListType>:$tensors);
|
||||
|
||||
let results = (outs TFR_allowedListResultType:$out);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Function related classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">,
|
||||
DeclareOpInterfaceMethods<CallableOpInterface>,
|
||||
FunctionLike,
|
||||
IsolatedFromAbove, Symbol]> {
|
||||
let summary = "TFR Function defines a composition of other ops";
|
||||
|
||||
let description = [{
|
||||
Defines a function that can be used to decompose an TF function call to
|
||||
the invocation of a set of other TF ops.
|
||||
|
||||
Syntax:
|
||||
|
||||
```
|
||||
op ::= `tfr.func` symbol-ref-id `(` argument-list `)` (`->`
|
||||
function-result-list)? function-attributes? region
|
||||
```
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
tfr.func @foo(%arg0: !tfr.tensor, %arg1: !tfr.tensor_list<T>,
|
||||
%arg2: int {tfr.name="T", tfr.default=1})
|
||||
attributes {qux: "quux"} {
|
||||
tfr.return
|
||||
}
|
||||
```
|
||||
|
||||
Note the arguments are ordered by the following rule:
|
||||
tfr.tensor > tfr.tensor_list > tfr.attr/i32/...,
|
||||
and only one trfr.tensor_list argument is allowed.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TypeAttr:$type,
|
||||
StrAttr:$sym_name
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
// When the regions is empty, the tfr.func is an external function and used
|
||||
// to model the element type constraints of the tf op. Otherwise, there is one
|
||||
// region containing the composition.
|
||||
let regions = (region VariadicRegion<AnyRegion>:$body);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, StringRef name, "
|
||||
"FunctionType type, ArrayRef<NamedAttribute> attrs = {}">
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// FunctionLike trait needs access to the functions below.
|
||||
friend class OpTrait::FunctionLike<TFRFuncOp>;
|
||||
|
||||
// Hooks for the input/output type enumeration in FunctionLike .
|
||||
unsigned getNumFuncArguments() { return getType().getNumInputs(); }
|
||||
unsigned getNumFuncResults() { return getType().getNumResults(); }
|
||||
}];
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
let parser = [{ return ParseFuncOp(parser, &result); }];
|
||||
let printer = [{ PrintFuncOp(p, *this); }];
|
||||
}
|
||||
|
||||
def TFR_TFRReturnOp : TFR_Op<"return", [HasParent<"TFRFuncOp">, NoSideEffect,
|
||||
ReturnLike, Terminator]> {
|
||||
let description = [{
|
||||
A terminator operation for regions that appear in the body of `tfr.func`
|
||||
functions. The operands to the `tfr.return` are the result values returned
|
||||
by an invocation of the `tfr.func`.
|
||||
|
||||
Note that only the tfr.tensor and tfr.tensor_list can be returned.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<TFR_allowedResultType>:$operands);
|
||||
|
||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
}
|
||||
|
||||
#endif // DIALECT_TFR_OPS_
|
115
tensorflow/compiler/mlir/tfr/ir/tfr_types.h
Normal file
115
tensorflow/compiler/mlir/tfr/ir/tfr_types.h
Normal file
@ -0,0 +1,115 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFR {
|
||||
|
||||
class TFRType : public Type {
|
||||
public:
|
||||
using Type::Type;
|
||||
|
||||
static bool classof(Type type);
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct TFRTypeStorage final
|
||||
: public TypeStorage,
|
||||
public llvm::TrailingObjects<TFRTypeStorage, StringAttr> {
|
||||
using KeyTy = ArrayRef<StringAttr>;
|
||||
|
||||
explicit TFRTypeStorage(unsigned num_attrs) : num_attrs(num_attrs) {}
|
||||
|
||||
static TFRTypeStorage* construct(TypeStorageAllocator& allocator, KeyTy key) {
|
||||
// Allocate a new storage instance.
|
||||
auto byteSize = TFRTypeStorage::totalSizeToAlloc<StringAttr>(key.size());
|
||||
auto rawMem = allocator.allocate(byteSize, alignof(TFRTypeStorage));
|
||||
auto result = ::new (rawMem) TFRTypeStorage(key.size());
|
||||
|
||||
// Copy in the string attributes into the trailing storage.
|
||||
std::uninitialized_copy(key.begin(), key.end(),
|
||||
result->getTrailingObjects<StringAttr>());
|
||||
return result;
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy& attrs) const { return attrs == GetAttrs(); }
|
||||
|
||||
KeyTy GetAttrs() const {
|
||||
return {getTrailingObjects<StringAttr>(), num_attrs};
|
||||
}
|
||||
|
||||
unsigned num_attrs;
|
||||
};
|
||||
|
||||
template <typename Derived>
|
||||
class TFRTypeImpl : public Type::TypeBase<Derived, TFRType, TFRTypeStorage> {
|
||||
public:
|
||||
using Base = Type::TypeBase<Derived, TFRType, TFRTypeStorage>;
|
||||
using TFRBase = TFRTypeImpl<Derived>;
|
||||
using Base::Base;
|
||||
|
||||
static Derived get(ArrayRef<StringAttr> attrs, MLIRContext* context) {
|
||||
return Base::get(context, attrs);
|
||||
}
|
||||
|
||||
static Derived getChecked(ArrayRef<StringAttr> attrs, Location loc) {
|
||||
return Base::getChecked(loc, attrs);
|
||||
}
|
||||
|
||||
static Derived get(MLIRContext* context) { return get({}, context); }
|
||||
|
||||
// TODO(fengliuai): fix the implementation
|
||||
static LogicalResult verifyConstructionInvariants(
|
||||
Location loc, ArrayRef<StringAttr> attrs) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ArrayRef<StringAttr> getAttrKeys() { return Base::getImpl()->GetAttrs(); }
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
class TFRTensorType : public detail::TFRTypeImpl<TFRTensorType> {
|
||||
public:
|
||||
using TFRBase::TFRBase;
|
||||
static std::string getTypeName() { return "TFRTensorType"; }
|
||||
};
|
||||
|
||||
class TFRTensorListType : public detail::TFRTypeImpl<TFRTensorListType> {
|
||||
public:
|
||||
using TFRBase::TFRBase;
|
||||
static std::string getTypeName() { return "TFRTensorListType"; }
|
||||
};
|
||||
|
||||
class TFRAttrType : public Type::TypeBase<TFRAttrType, TFRType, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static std::string getTypeName() { return "TFRAttrType"; }
|
||||
};
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_
|
160
tensorflow/compiler/mlir/tfr/passes/canonicalize.cc
Normal file
160
tensorflow/compiler/mlir/tfr/passes/canonicalize.cc
Normal file
@ -0,0 +1,160 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/Region.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
|
||||
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/passes/passes.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Canonicalization patterns for the scf.for and scf.if ops. They are used to
|
||||
// optimize the control flow in the tfr function. Technically, both patterns
|
||||
// should be upstreamed to be part of the op definition.
|
||||
// TODO(fengliuai): sync with the llvm upstream for both patterns.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace TFR {
|
||||
|
||||
namespace {
|
||||
|
||||
struct UnrollSCFForOp : public OpRewritePattern<scf::ForOp> {
|
||||
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(scf::ForOp for_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = for_op.getLoc();
|
||||
APInt lower_bound, upper_bound, step;
|
||||
if (!matchPattern(for_op.lowerBound(), m_ConstantInt(&lower_bound)) ||
|
||||
!matchPattern(for_op.upperBound(), m_ConstantInt(&upper_bound)) ||
|
||||
!matchPattern(for_op.step(), m_ConstantInt(&step))) {
|
||||
return failure();
|
||||
}
|
||||
uint64_t trip_count = (upper_bound - lower_bound).sdiv(step).getZExtValue();
|
||||
if (trip_count <= 0) return failure();
|
||||
|
||||
// TODO(fengliuai): use loopUnrollByFactor once the iter_arg is supported
|
||||
|
||||
Block *single_block = for_op.getBody();
|
||||
BlockAndValueMapping mapping;
|
||||
Value iv = for_op.getInductionVar();
|
||||
for (auto iter_op :
|
||||
llvm::zip(for_op.getRegionIterArgs(), for_op.initArgs())) {
|
||||
mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
|
||||
}
|
||||
mapping.map(iv, for_op.lowerBound());
|
||||
for (auto i = 0; i < trip_count; ++i) {
|
||||
if (!iv.use_empty()) {
|
||||
// iv' = iv + step * i;
|
||||
Value iter = rewriter.create<ConstantIndexOp>(loc, i);
|
||||
Value step_cst =
|
||||
rewriter.create<ConstantIndexOp>(loc, step.getSExtValue());
|
||||
Value stride = rewriter.create<MulIOp>(loc, step_cst, iter);
|
||||
Value iv_unroll =
|
||||
rewriter.create<AddIOp>(loc, mapping.lookup(iv), stride);
|
||||
mapping.map(iv, iv_unroll);
|
||||
}
|
||||
|
||||
Operation *terminator_op;
|
||||
for (auto it = single_block->begin(); it != single_block->end(); ++it) {
|
||||
terminator_op = rewriter.clone(*it, mapping);
|
||||
}
|
||||
// Map the block arguments to the yield results.
|
||||
for (auto iter_op : llvm::zip(for_op.getRegionIterArgs(),
|
||||
terminator_op->getOperands())) {
|
||||
mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
|
||||
}
|
||||
rewriter.eraseOp(terminator_op);
|
||||
}
|
||||
SmallVector<Value, 4> returned;
|
||||
for (Value arg : for_op.getRegionIterArgs()) {
|
||||
returned.push_back(mapping.lookup(arg));
|
||||
}
|
||||
rewriter.replaceOp(for_op, returned);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(fengliuai): up stream this pattern.
|
||||
struct SimplifySCFIfOp : public OpRewritePattern<scf::IfOp> {
|
||||
using OpRewritePattern<scf::IfOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(scf::IfOp if_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Then branch
|
||||
if (matchPattern(if_op.condition(), m_NonZero())) {
|
||||
return InlineRegion(if_op.getLoc(), rewriter, if_op, &if_op.thenRegion());
|
||||
}
|
||||
|
||||
// Else branch
|
||||
if (matchPattern(if_op.condition(), m_Zero())) {
|
||||
if (if_op.elseRegion().empty()) {
|
||||
// Remove the op
|
||||
rewriter.eraseOp(if_op);
|
||||
return success();
|
||||
} else {
|
||||
return InlineRegion(if_op.getLoc(), rewriter, if_op,
|
||||
&if_op.elseRegion());
|
||||
}
|
||||
}
|
||||
|
||||
// Not a constant condition
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
LogicalResult InlineRegion(Location loc, PatternRewriter &rewriter,
|
||||
Operation *inline_point, Region *region) const;
|
||||
};
|
||||
|
||||
LogicalResult SimplifySCFIfOp::InlineRegion(Location loc,
|
||||
PatternRewriter &rewriter,
|
||||
Operation *inline_point,
|
||||
Region *region) const {
|
||||
InlinerInterface interface(loc.getContext());
|
||||
if (failed(inlineRegion(interface, region, inline_point, {},
|
||||
inline_point->getResults(), loc,
|
||||
/*shouldCloneInlinedRegion=*/true))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// If the inlining was successful then erase the scf.if op.
|
||||
rewriter.eraseOp(inline_point);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<UnrollSCFForOp, SimplifySCFIfOp>(context);
|
||||
}
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
280
tensorflow/compiler/mlir/tfr/passes/decompose.cc
Normal file
280
tensorflow/compiler/mlir/tfr/passes/decompose.cc
Normal file
@ -0,0 +1,280 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/passes/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/utils/utils.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The pass to decompose unregistered TF ops with the TFR compose function.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace TFR {
|
||||
|
||||
namespace {
|
||||
|
||||
// Decompose the TF ops with the registered composition library.
|
||||
struct DecomposeTFOpsPass
|
||||
: public PassWrapper<DecomposeTFOpsPass, FunctionPass> {
|
||||
|
||||
explicit DecomposeTFOpsPass(llvm::Optional<ModuleOp> external_tfr_module)
|
||||
: external_tfr_module(external_tfr_module) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
// Apply canonicalization, mainly constant folding, on the function.
|
||||
void ApplyCanonicalization();
|
||||
|
||||
// Rewrite unregistered TF ops to TFR func call ops. Return failure if all the
|
||||
// ops are registered or the compose function doesn't exist.
|
||||
LogicalResult RewriteUnregisteredTFOps();
|
||||
|
||||
// Inline the TFR func call ops.
|
||||
LogicalResult InlineTFRFuncCalls();
|
||||
|
||||
// Optional external symbol table to look up the TFR function.
|
||||
llvm::Optional<ModuleOp> external_tfr_module;
|
||||
};
|
||||
|
||||
void DecomposeTFOpsPass::ApplyCanonicalization() {
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
auto* context = &getContext();
|
||||
for (auto* op : context->getRegisteredOperations()) {
|
||||
op->getCanonicalizationPatterns(patterns, context);
|
||||
}
|
||||
populateSCFOpsCanonicalizationPatterns(patterns, context);
|
||||
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
|
||||
FuncOp func = getFunction();
|
||||
SymbolTable table(external_tfr_module.hasValue()
|
||||
? *external_tfr_module
|
||||
: func.getParentOfType<ModuleOp>());
|
||||
OpBuilder builder(func);
|
||||
bool changed = false;
|
||||
func.walk([&table, &builder, &changed](Operation* op) {
|
||||
// Only the un-registered ops requires decomposition. The remaining ones
|
||||
// either will be constant folded or lowered by the rules defined in the
|
||||
// bridge.
|
||||
if (op->isRegistered()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find out the compose function
|
||||
auto compose_func_name = GetComposeFuncName(op->getName().getStringRef());
|
||||
auto compose_func = table.lookup<TFRFuncOp>(compose_func_name);
|
||||
if (!compose_func || compose_func.isExternal()) {
|
||||
// There are no decomposition methods defined for this op, skip.
|
||||
return;
|
||||
}
|
||||
|
||||
auto compose_func_type = compose_func.getType();
|
||||
builder.setInsertionPoint(op);
|
||||
TFRTensorType unconstrainted_tensor_type = builder.getType<TFRTensorType>();
|
||||
|
||||
// Create the new operands. This is mapping the operands from the target
|
||||
// TF ops to the TFR function arguments. If the TFR function argument is
|
||||
// a tensor_list, a "tfr.build_list" op is used to concat the available
|
||||
// TF op operands. If the TFR function argument isn't a tensor/tensor_list,
|
||||
// a constant is created by using the attribute stored in the TF op or the
|
||||
// default value in the argument attribute.
|
||||
llvm::SmallVector<Value, 4> new_operands;
|
||||
for (auto arg : llvm::enumerate(compose_func_type.getInputs())) {
|
||||
if (auto tensor_type = arg.value().dyn_cast<TFRTensorType>()) {
|
||||
auto casted = builder.create<CastOp>(op->getLoc(), tensor_type,
|
||||
op->getOperand(arg.index()));
|
||||
new_operands.push_back(casted);
|
||||
} else if (auto list_type = arg.value().dyn_cast<TFRTensorListType>()) {
|
||||
llvm::SmallVector<Value, 4> variadic_operands;
|
||||
for (int i = arg.index(); i < op->getNumOperands(); i++) {
|
||||
auto casted = builder.create<CastOp>(
|
||||
op->getLoc(), unconstrainted_tensor_type, op->getOperand(i));
|
||||
variadic_operands.push_back(casted);
|
||||
}
|
||||
auto build_list_op = builder.create<BuildListOp>(
|
||||
op->getLoc(), list_type, variadic_operands);
|
||||
new_operands.push_back(build_list_op.out());
|
||||
} else {
|
||||
auto attr_name = compose_func.getArgAttrOfType<StringAttr>(
|
||||
arg.index(), kAttrArgumentNameAttr);
|
||||
auto attribute = op->getAttr(attr_name.getValue());
|
||||
if (!attribute) {
|
||||
attribute =
|
||||
compose_func.getArgAttr(arg.index(), kAttrArgumentDefaultAttr);
|
||||
}
|
||||
Value attr_cst;
|
||||
// Wrap these special attributes as a special TFR constant, so the SSA
|
||||
// value has a valid type to be used as TFR function argument. These
|
||||
// attributes are not expected to be manipulated by the lowering passes.
|
||||
if (attribute.isa<TypeAttr>() || attribute.isa<ArrayAttr>() ||
|
||||
attribute.isa<StringAttr>() || attribute.isa<FlatSymbolRefAttr>()) {
|
||||
TFRAttrType output_type = TFRAttrType::get(builder.getContext());
|
||||
attr_cst =
|
||||
builder.create<ConstOp>(op->getLoc(), output_type, attribute);
|
||||
} else {
|
||||
attr_cst = builder.create<ConstantOp>(op->getLoc(), attribute);
|
||||
}
|
||||
new_operands.push_back(attr_cst);
|
||||
}
|
||||
}
|
||||
|
||||
// Create the TFR call op
|
||||
auto new_op = builder.create<CallOp>(
|
||||
op->getLoc(), compose_func_type.getResults(),
|
||||
builder.getSymbolRefAttr(compose_func.getName()), new_operands);
|
||||
|
||||
// Replace the use of the old op. This is mapping the results from the
|
||||
// target TF ops to the TFR function returns. If the TFR function return is
|
||||
// a tensor_list, "tfr.get_element" op is used to extract the required TF
|
||||
// op result.
|
||||
llvm::SmallVector<Value, 4> new_results;
|
||||
for (auto res : llvm::enumerate(compose_func_type.getResults())) {
|
||||
if (res.value().dyn_cast<TFRTensorType>()) {
|
||||
new_results.push_back(new_op.getResult(res.index()));
|
||||
} else if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
|
||||
for (int i = res.index(), j = 0; i < op->getNumResults(); i++, j++) {
|
||||
auto index =
|
||||
builder.create<ConstantOp>(op->getLoc(), builder.getIndexAttr(j));
|
||||
auto element_op = builder.create<GetElementOp>(
|
||||
op->getLoc(), unconstrainted_tensor_type,
|
||||
new_op.getResult(res.index()), index.getResult());
|
||||
new_results.push_back(element_op.out());
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto res : llvm::zip(op->getResults(), new_results)) {
|
||||
auto casted = builder.create<CastOp>(
|
||||
op->getLoc(), std::get<0>(res).getType(), std::get<1>(res));
|
||||
std::get<0>(res).replaceAllUsesWith(casted.out());
|
||||
}
|
||||
op->erase();
|
||||
changed |= true;
|
||||
});
|
||||
|
||||
// If `changed` is false, it is considered as a failure, so the recursive
|
||||
// rewrite will stop.
|
||||
return success(changed);
|
||||
}
|
||||
|
||||
LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() {
|
||||
// The Inliner will automatically use the registered dialect inliner.
|
||||
InlinerInterface inliner(&getContext());
|
||||
FuncOp func = getFunction();
|
||||
SymbolTable table(external_tfr_module.hasValue()
|
||||
? *external_tfr_module
|
||||
: func.getParentOfType<ModuleOp>());
|
||||
|
||||
// The inliner only inlines the TFR call op.
|
||||
bool changed = false;
|
||||
auto walk_result = func.walk([&](CallOp call_op) {
|
||||
auto callee = table.lookup<TFRFuncOp>(call_op.callee());
|
||||
if (!callee || callee.isExternal()) return WalkResult::advance();
|
||||
if (failed(inlineCall(inliner,
|
||||
cast<CallOpInterface>(call_op.getOperation()),
|
||||
cast<CallableOpInterface>(callee.getOperation()),
|
||||
callee.getCallableRegion(),
|
||||
/**shouldCloneInLinedRegion=*/true))) {
|
||||
// This failure is usually because the decompose function is not defined.
|
||||
// This call will be raised to TF ops.
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
call_op.erase();
|
||||
changed |= true;
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walk_result.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// If `changed` is false, it is considered as a failure, so the recursive
|
||||
// rewrite will stop.
|
||||
return success(changed);
|
||||
}
|
||||
|
||||
void DecomposeTFOpsPass::runOnFunction() {
|
||||
// Set a maximum iteration threshold in case there are infinite loops in the
|
||||
// call stack.
|
||||
int max_iterators = 10;
|
||||
do {
|
||||
// canonicalization
|
||||
ApplyCanonicalization();
|
||||
|
||||
// rewrite unregistered tf ops. Failed either because no ops can be
|
||||
// decomposed or the compose function isn't defined.
|
||||
auto rewrite_status = RewriteUnregisteredTFOps();
|
||||
// inline the tfr call op until there are no tfr.call op can be inlined.
|
||||
auto inline_status = InlineTFRFuncCalls();
|
||||
|
||||
if (failed(rewrite_status) && failed(inline_status)) {
|
||||
break;
|
||||
}
|
||||
} while (max_iterators-- >= 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the pass to decompose the TF ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeTFOpsPass(
|
||||
llvm::Optional<ModuleOp> tfr_module) {
|
||||
return std::make_unique<DecomposeTFOpsPass>(tfr_module);
|
||||
}
|
||||
|
||||
static PassRegistration<DecomposeTFOpsPass> pass(
|
||||
"tfr-decompose",
|
||||
"Decompose TF ops with the registered composition library.",
|
||||
[] { return CreateDecomposeTFOpsPass(); });
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
44
tensorflow/compiler/mlir/tfr/passes/passes.h
Normal file
44
tensorflow/compiler/mlir/tfr/passes/passes.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_
|
||||
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFR {
|
||||
|
||||
void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
|
||||
// Decompose ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeTFOpsPass(
|
||||
llvm::Optional<ModuleOp> tfr_module = llvm::None);
|
||||
|
||||
// Raise to TF ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseToTFOpsPass(
|
||||
llvm::Optional<ModuleOp> tfr_module = llvm::None,
|
||||
bool materialize_derived_attrs = false);
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_
|
474
tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc
Normal file
474
tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc
Normal file
@ -0,0 +1,474 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/passes/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/utils/utils.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The pass to rewrite the TFR function call ops by TF ops. The callee of the
|
||||
// TFR function call defines the signatures of the TF ops.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace TFR {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pattern is to rewrite the "tfr.call" op and the "tfr.cast" ops on the
|
||||
// operands by a TF op with "tfr.cast" ops on the results. The result type of
|
||||
// the new TF op is an unranked tensor with element type derived.
|
||||
class RewriteTFRCallOp : public OpRewritePattern<CallOp> {
|
||||
using OpRewritePattern<CallOp>::OpRewritePattern;
|
||||
|
||||
public:
|
||||
explicit RewriteTFRCallOp(MLIRContext* context, const SymbolTable& table,
|
||||
bool materialize_derived_attrs)
|
||||
: OpRewritePattern<CallOp>(context),
|
||||
symbol_table_(table),
|
||||
materialize_derived_attrs_(materialize_derived_attrs) {}
|
||||
|
||||
LogicalResult matchAndRewrite(CallOp call_op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
|
||||
private:
|
||||
// Derives the attribute values for the attributes attached to the
|
||||
// `input_tfr_type`. These attributes are only for the element type of the
|
||||
// inputs, and these type information has been collected in the `input_types`.
|
||||
// The result is stored in `derived_attrs` as the named attributes. Returns
|
||||
// failure if the attributes stored in the `input_tfr_type` violates the
|
||||
// assumptions.
|
||||
LogicalResult AddDerivedAttrs(
|
||||
PatternRewriter& rewriter, Type input_tfr_type,
|
||||
ArrayRef<Attribute> input_types,
|
||||
llvm::StringMap<Attribute>* derived_attrs) const;
|
||||
|
||||
// Collects the operands and attributes for the TF op. At the same time, it
|
||||
// collects all the derived attribute values to derive the output types of the
|
||||
// TF op.
|
||||
LogicalResult CollectInputsAndAttributes(
|
||||
PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op,
|
||||
SmallVectorImpl<Value>* inputs, NamedAttrList* arg_attrs,
|
||||
llvm::StringMap<Attribute>* derived_attrs) const;
|
||||
|
||||
// Uses the collected attribute values to derive all the output types.
|
||||
LogicalResult DeriveOutputTypes(FunctionType signature,
|
||||
const llvm::StringMap<Attribute>& attrs,
|
||||
SmallVectorImpl<Type>* output_types) const;
|
||||
|
||||
// Creates the TF op and also the necessary tfr.cast ops to replace the
|
||||
// original TFR call op.
|
||||
LogicalResult CreateAndReplaceOp(
|
||||
PatternRewriter& rewriter, CallOp call_op,
|
||||
const SmallVectorImpl<Type>& output_types,
|
||||
const SmallVectorImpl<Value>& inputs, const NamedAttrList& attr_list,
|
||||
const llvm::StringMap<Attribute>& derived_attrs) const;
|
||||
|
||||
// Adds a tf.Cast op if the tfr.tensor attribute indicated a fixed element
|
||||
// type.
|
||||
// TODO(fengliuai): This method is required when the operand types are not set
|
||||
// by the frontend correctly.
|
||||
Value CastToNonDerivedType(PatternRewriter& rewriter, Location loc,
|
||||
CastOp cast_op, Type input_tfr_type) const {
|
||||
auto tensor_type = input_tfr_type.dyn_cast<TFRTensorType>();
|
||||
if (!tensor_type) return cast_op.arg();
|
||||
|
||||
auto attr_names = tensor_type.getAttrKeys();
|
||||
if (attr_names.empty() || attr_names.size() > 1) return cast_op.arg();
|
||||
StringRef tfr_type_attr = attr_names[0].getValue();
|
||||
if (!fixed_elt_type_attrs_.contains(tfr_type_attr)) return cast_op.arg();
|
||||
|
||||
Type result_elt_type;
|
||||
if (tfr_type_attr == "i32_") {
|
||||
result_elt_type = rewriter.getI32Type();
|
||||
} else if (tfr_type_attr == "i64_") {
|
||||
result_elt_type = rewriter.getI64Type();
|
||||
} else if (tfr_type_attr == "f32_") {
|
||||
result_elt_type = rewriter.getF32Type();
|
||||
} else if (tfr_type_attr == "i1_") {
|
||||
result_elt_type = rewriter.getI1Type();
|
||||
} else {
|
||||
return cast_op.arg();
|
||||
}
|
||||
|
||||
Type original_input_type =
|
||||
cast_op.getInputElementType().cast<TypeAttr>().getValue();
|
||||
if (result_elt_type != original_input_type) {
|
||||
UnrankedTensorType result_type = UnrankedTensorType::get(result_elt_type);
|
||||
return rewriter.create<TF::CastOp>(loc, result_type, cast_op.arg());
|
||||
}
|
||||
return cast_op.arg();
|
||||
}
|
||||
|
||||
// For variadic operands, we have to enforce them to use the same types.
|
||||
// TODO(fengliuai): This method is required when the operand types are not set
|
||||
// by the frontend correctly.
|
||||
void CastValuesToSameType(PatternRewriter& rewriter, Location loc,
|
||||
const llvm::SmallVectorImpl<Attribute>& input_types,
|
||||
llvm::SmallVectorImpl<Value>& input_values) const {
|
||||
if (input_types.size() <= 1) return;
|
||||
|
||||
Type target_input_type = input_types[0].cast<TypeAttr>().getValue();
|
||||
auto result_type = UnrankedTensorType::get(target_input_type);
|
||||
for (auto i = 1; i < input_types.size(); ++i) {
|
||||
Type current_input_type = input_types[i].cast<TypeAttr>().getValue();
|
||||
if (current_input_type != target_input_type) {
|
||||
input_values[i] =
|
||||
rewriter.create<TF::CastOp>(loc, result_type, input_values[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const SymbolTable& symbol_table_;
|
||||
const bool materialize_derived_attrs_;
|
||||
const llvm::SmallDenseSet<StringRef, 4> fixed_elt_type_attrs_{"i32_", "i64_",
|
||||
"f32_", "i1_"};
|
||||
};
|
||||
|
||||
LogicalResult RewriteTFRCallOp::AddDerivedAttrs(
|
||||
PatternRewriter& rewriter, Type input_tfr_type,
|
||||
ArrayRef<Attribute> input_types,
|
||||
llvm::StringMap<Attribute>* derived_attrs) const {
|
||||
// If there is an attribute associated to the input in the signature, we
|
||||
// store it as an derived attribute.
|
||||
if (auto tensor_type = input_tfr_type.dyn_cast<TFRTensorType>()) {
|
||||
auto attr_names = tensor_type.getAttrKeys();
|
||||
if (attr_names.empty()) return success();
|
||||
|
||||
if (attr_names.size() == 1) {
|
||||
derived_attrs->insert({attr_names[0].getValue(), input_types[0]});
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
// If there is an attribute associated to the input in the signature,
|
||||
// we store it as an derived attribute.
|
||||
if (auto list_type = input_tfr_type.dyn_cast<TFRTensorListType>()) {
|
||||
auto attr_names = list_type.getAttrKeys();
|
||||
if (attr_names.empty()) return success();
|
||||
|
||||
// N*T case
|
||||
if (attr_names.size() == 2) {
|
||||
derived_attrs->insert({attr_names[0].getValue(),
|
||||
rewriter.getI32IntegerAttr(input_types.size())});
|
||||
// Note that this uses the first element of the list to infer the T value.
|
||||
// A tf.Cast is required to cast the other inputs to the same type.
|
||||
derived_attrs->insert({attr_names[1].getValue(), input_types[0]});
|
||||
return success();
|
||||
}
|
||||
|
||||
// list(dtype) case
|
||||
if (attr_names.size() == 1) {
|
||||
derived_attrs->insert(
|
||||
{attr_names[0].getValue(), rewriter.getArrayAttr(input_types)});
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult RewriteTFRCallOp::CollectInputsAndAttributes(
|
||||
PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op,
|
||||
SmallVectorImpl<Value>* inputs, NamedAttrList* arg_attrs,
|
||||
llvm::StringMap<Attribute>* derived_attrs) const {
|
||||
for (const auto& operand : llvm::enumerate(signature.getType().getInputs())) {
|
||||
// If the index is larger than the operand number of the call_op, the
|
||||
// default value of the operand needs to be used.
|
||||
if (operand.index() >= call_op.getNumOperands()) {
|
||||
auto attr_name = signature.getArgAttrOfType<StringAttr>(
|
||||
operand.index(), kAttrArgumentNameAttr);
|
||||
auto attr_value =
|
||||
signature.getArgAttr(operand.index(), kAttrArgumentDefaultAttr);
|
||||
arg_attrs->push_back(
|
||||
rewriter.getNamedAttr(attr_name.getValue(), attr_value));
|
||||
continue;
|
||||
}
|
||||
|
||||
// The index is valid for the call_op.
|
||||
Value input = call_op.getOperand(operand.index());
|
||||
Operation* input_op = input.getDefiningOp();
|
||||
auto input_tfr_type = signature.getType().getInputs()[operand.index()];
|
||||
|
||||
// There are three cases for the preceding input_op:
|
||||
|
||||
// 1. The preceding op can be a tfr.cast op, which will be fused to the
|
||||
// current op, so the result op has input with tensor type.
|
||||
if (auto cast_op = dyn_cast_or_null<CastOp>(input_op)) {
|
||||
Value input_to_cast = CastToNonDerivedType(rewriter, call_op.getLoc(),
|
||||
cast_op, input_tfr_type);
|
||||
inputs->push_back(input_to_cast);
|
||||
if (failed(AddDerivedAttrs(rewriter, input_tfr_type,
|
||||
{cast_op.getInputElementType()},
|
||||
derived_attrs))) {
|
||||
return failure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// 2. The preceding op is a tfr.build_list op, which collects multiple
|
||||
// values with tensor types via the tfr.cast ops. These ops will be fused
|
||||
// to the current op as well, so all the tfr.cast op inputs will be inputs
|
||||
// to the result op.
|
||||
if (auto list_op = dyn_cast_or_null<BuildListOp>(input_op)) {
|
||||
// Find out all the inputs to the build list op
|
||||
// TODO(fengliuai): make build_list op only take tensor argument
|
||||
llvm::SmallVector<Attribute, 4> list_input_types;
|
||||
llvm::SmallVector<Value, 4> list_inputs;
|
||||
for (auto list_input : list_op.getOperands()) {
|
||||
auto cast_op = dyn_cast_or_null<CastOp>(list_input.getDefiningOp());
|
||||
if (!cast_op) return failure();
|
||||
list_inputs.push_back(cast_op.arg());
|
||||
list_input_types.push_back(cast_op.getInputElementType());
|
||||
}
|
||||
CastValuesToSameType(rewriter, call_op.getLoc(), list_input_types,
|
||||
list_inputs);
|
||||
inputs->append(list_inputs.begin(), list_inputs.end());
|
||||
if (failed(AddDerivedAttrs(rewriter, input_tfr_type, list_input_types,
|
||||
derived_attrs))) {
|
||||
return failure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// 3. The preceding op is a constant, thus the value of this constant is
|
||||
// used to create an attribute of the result op, according to the signature.
|
||||
Attribute arg_value;
|
||||
// A failure indicates the argument isn't a constant value, so we should
|
||||
// not use it as an attribute.
|
||||
if (!matchPattern(input, m_Constant(&arg_value))) {
|
||||
return failure();
|
||||
}
|
||||
auto attr_name = signature.getArgAttrOfType<StringAttr>(
|
||||
operand.index(), kAttrArgumentNameAttr);
|
||||
arg_attrs->push_back(
|
||||
rewriter.getNamedAttr(attr_name.getValue(), arg_value));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// For each output, uses the attribute name associated to the tfr types to find
|
||||
// out the attribute value from the collected `attrs` and create the output type
|
||||
// of the result op by using the attribute value as the element type.
|
||||
LogicalResult RewriteTFRCallOp::DeriveOutputTypes(
|
||||
FunctionType signature, const llvm::StringMap<Attribute>& attrs,
|
||||
SmallVectorImpl<Type>* output_types) const {
|
||||
for (auto res : llvm::enumerate(signature.getResults())) {
|
||||
if (auto tensor_type = res.value().dyn_cast<TFRTensorType>()) {
|
||||
// tfr.tensor should only have one attribute attached.
|
||||
auto attr_key = tensor_type.getAttrKeys().front();
|
||||
output_types->push_back(UnrankedTensorType::get(
|
||||
attrs.lookup(attr_key.getValue()).cast<TypeAttr>().getValue()));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
|
||||
// There are two cases: N*T or list(dtype)
|
||||
auto attr_keys = list_type.getAttrKeys();
|
||||
// N*T case
|
||||
if (attr_keys.size() == 2) {
|
||||
// The first one is N, and the second one is T
|
||||
int list_size =
|
||||
attrs.lookup(attr_keys[0].getValue()).cast<IntegerAttr>().getInt();
|
||||
Type list_type =
|
||||
attrs.lookup(attr_keys[1].getValue()).cast<TypeAttr>().getValue();
|
||||
for (int i = 0; i < list_size; ++i) {
|
||||
output_types->push_back(UnrankedTensorType::get(list_type));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// TODO(fengliuai): list(dtype) case
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult RewriteTFRCallOp::CreateAndReplaceOp(
|
||||
PatternRewriter& rewriter, CallOp call_op,
|
||||
const SmallVectorImpl<Type>& output_types,
|
||||
const SmallVectorImpl<Value>& inputs, const NamedAttrList& attr_list,
|
||||
const llvm::StringMap<Attribute>& derived_attrs) const {
|
||||
// Create the new op
|
||||
Location loc = call_op.getLoc();
|
||||
rewriter.setInsertionPointAfter(call_op);
|
||||
std::string tf_op_name = GetTFOpName(call_op.callee());
|
||||
OperationState new_state(loc, tf_op_name, inputs, output_types, attr_list);
|
||||
Operation* new_op = rewriter.createOperation(new_state);
|
||||
if (materialize_derived_attrs_) {
|
||||
for (const auto& attr : derived_attrs) {
|
||||
// Add or update the derived attribute with the value. Skip the fixed
|
||||
// element type attributes, in case they are present in the NodeDef.
|
||||
if (!fixed_elt_type_attrs_.contains(attr.first())) {
|
||||
new_op->setAttr(attr.first(), attr.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the tfr.cast ops on the results and replace the uses of the
|
||||
// original call op.
|
||||
TFRTensorType unconstrainted_type = rewriter.getType<TFRTensorType>();
|
||||
SmallVector<Value, 4> new_results;
|
||||
for (auto res : llvm::enumerate(call_op.getResultTypes())) {
|
||||
Type res_type = res.value();
|
||||
if (res_type.dyn_cast<TFRTensorType>()) {
|
||||
Value new_res = new_op->getResult(res.index());
|
||||
auto casted = rewriter.create<CastOp>(loc, res_type, new_res);
|
||||
new_results.push_back(casted.out());
|
||||
} else if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
|
||||
SmallVector<Value, 4> tensor_list;
|
||||
for (int i = res.index(); i < new_op->getNumResults(); i++) {
|
||||
Value new_res = new_op->getResult(i);
|
||||
auto casted =
|
||||
rewriter.create<CastOp>(loc, unconstrainted_type, new_res);
|
||||
tensor_list.push_back(casted.out());
|
||||
}
|
||||
auto list_op = rewriter.create<BuildListOp>(loc, res_type, tensor_list);
|
||||
new_results.push_back(list_op.out());
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(call_op, new_results);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult RewriteTFRCallOp::matchAndRewrite(
|
||||
CallOp call_op, PatternRewriter& rewriter) const {
|
||||
// Get the func op and verify that it is external. The type of this external
|
||||
// func op is used as the signature of the corresponding TF ops. All the
|
||||
// external func ops have the trailing underscore.
|
||||
std::string external_callee_name = call_op.callee().str().append("_");
|
||||
TFRFuncOp func = symbol_table_.lookup<TFRFuncOp>(external_callee_name);
|
||||
if (!func || !func.isExternal()) return failure();
|
||||
// Get the inputs and attributes. The attributes include these from the
|
||||
// argument list and also these derived from the inputs.
|
||||
SmallVector<Value, 4> inputs;
|
||||
NamedAttrList argument_attrs;
|
||||
llvm::StringMap<Attribute> derived_attrs;
|
||||
if (failed(CollectInputsAndAttributes(rewriter, func, call_op, &inputs,
|
||||
&argument_attrs, &derived_attrs))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Derive the output types. The result type is derived by using the
|
||||
// attributes attched to the result type of the signature. The attribute
|
||||
// value should be either in the attribute argument list or the derived
|
||||
// attribute from the input tensors. All the result type
|
||||
// are unranked, and shape inference should be applied afterwards.
|
||||
SmallVector<Type, 4> output_types;
|
||||
|
||||
// Merge the attributes from the argument list to the derived ones.
|
||||
for (auto& attr : argument_attrs) {
|
||||
derived_attrs.insert({attr.first, attr.second});
|
||||
}
|
||||
|
||||
// Derive the output types by using the attributes attached to the tfr
|
||||
// types.
|
||||
if (failed(DeriveOutputTypes(func.getType(), derived_attrs, &output_types))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Create the new op and replace the old TFR call op.
|
||||
return CreateAndReplaceOp(rewriter, call_op, output_types, inputs,
|
||||
argument_attrs, derived_attrs);
|
||||
}
|
||||
|
||||
// Raise TFR call ops to the TF ops.
|
||||
struct RaiseToTFOpsPass : public PassWrapper<RaiseToTFOpsPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<TFRDialect, TF::TensorFlowDialect, scf::SCFDialect,
|
||||
StandardOpsDialect>();
|
||||
}
|
||||
|
||||
explicit RaiseToTFOpsPass(llvm::Optional<ModuleOp> tfr_module,
|
||||
bool materialize_derived_attrs)
|
||||
: external_tfr_module(tfr_module),
|
||||
materialize_derived_attrs(materialize_derived_attrs) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
llvm::Optional<ModuleOp> external_tfr_module;
|
||||
const bool materialize_derived_attrs;
|
||||
};
|
||||
|
||||
void RaiseToTFOpsPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext* ctx = &getContext();
|
||||
SymbolTable table(external_tfr_module.hasValue()
|
||||
? *external_tfr_module
|
||||
: func.getParentOfType<ModuleOp>());
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<RewriteTFRCallOp>(ctx, table, materialize_derived_attrs);
|
||||
for (auto* op : ctx->getRegisteredOperations()) {
|
||||
op->getCanonicalizationPatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the pass to raise TFR call ops to the TF ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseToTFOpsPass(
|
||||
llvm::Optional<ModuleOp> tfr_module, bool materialize_derived_attrs) {
|
||||
return std::make_unique<RaiseToTFOpsPass>(tfr_module,
|
||||
materialize_derived_attrs);
|
||||
}
|
||||
|
||||
static PassRegistration<RaiseToTFOpsPass> pass(
|
||||
"tfr-raise-to-tf", "Raise all the TFR call ops to TF ops.",
|
||||
[] { return CreateRaiseToTFOpsPass(); });
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
37
tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc
Normal file
37
tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||
#include "mlir/InitAllPasses.h" // from @llvm-project
|
||||
#include "mlir/Support/MlirOptMain.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
tensorflow::InitMlir y(&argc, &argv);
|
||||
|
||||
mlir::registerAllPasses();
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
|
||||
mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
|
||||
mlir::TFR::TFRDialect>();
|
||||
return failed(mlir::MlirOptMain(argc, argv, "TFR Pass Driver\n", registry));
|
||||
}
|
61
tensorflow/compiler/mlir/tfr/resources/BUILD
Normal file
61
tensorflow/compiler/mlir/tfr/resources/BUILD
Normal file
@ -0,0 +1,61 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/tfr/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "decomposition_lib",
|
||||
srcs = ["decomposition_lib.mlir"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "composite_ops_cc",
|
||||
srcs = ["composite_ops.cc"],
|
||||
copts = [
|
||||
"-Wno-unused-result",
|
||||
"-Wno-unused-variable",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "composite_ops",
|
||||
out = "composite_ops.py",
|
||||
deps = [
|
||||
":composite_ops_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_ops_cc",
|
||||
testonly = 1,
|
||||
srcs = ["test_ops.cc"],
|
||||
copts = [
|
||||
"-Wno-unused-result",
|
||||
"-Wno-unused-variable",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
39
tensorflow/compiler/mlir/tfr/resources/composite_ops.cc
Normal file
39
tensorflow/compiler/mlir/tfr/resources/composite_ops.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("MyAddN")
|
||||
.Input("inputs: N * T")
|
||||
.Output("sum: T")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: {numbertype, variant}")
|
||||
.SetIsCommutative()
|
||||
.SetIsAggregate();
|
||||
|
||||
REGISTER_OP("MyBiasedDense")
|
||||
.Input("input: T")
|
||||
.Input("weight: T")
|
||||
.Input("bias: T")
|
||||
.Output("out: T")
|
||||
.Attr("T: {float, int8}")
|
||||
.Attr("act: {'', 'relu', 'relu6'} = ''");
|
||||
|
||||
} // namespace tensorflow
|
109
tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir
Normal file
109
tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir
Normal file
@ -0,0 +1,109 @@
|
||||
// A test resource file which contains some pre-defined internal tfr.functions
|
||||
// for decomposition and external tfr.functions for raising the decomposition
|
||||
// result to the ops in the TF dialect.
|
||||
//
|
||||
// All the tfr.func functions are supposed to be translated from the Python
|
||||
// function with tf.composite annotation.
|
||||
// All the external tfr.func functions modeles the op signature defined by
|
||||
// OpDefs.
|
||||
|
||||
tfr.func @tf__my_add_n(%values: !tfr.tensor_list,
|
||||
%n: i64 {tfr.name="N"}) -> !tfr.tensor {
|
||||
%index = constant 0 : index
|
||||
%cst = constant 1 : i64
|
||||
%eq = cmpi "eq", %n, %cst : i64
|
||||
%v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%res = scf.if %eq -> !tfr.tensor {
|
||||
scf.yield %v1 : !tfr.tensor
|
||||
} else {
|
||||
%step = index_cast %cst : i64 to index
|
||||
%end = index_cast %n : i64 to index
|
||||
%reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor {
|
||||
%v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%reduce_next = tfr.call @tf__add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
|
||||
scf.yield %reduce_next : !tfr.tensor
|
||||
}
|
||||
scf.yield %reduce : !tfr.tensor
|
||||
}
|
||||
tfr.return %res : !tfr.tensor
|
||||
}
|
||||
|
||||
// Translated from tf.compose Python function.
|
||||
tfr.func @tf__my_biased_dense(%input: !tfr.tensor, %weight: !tfr.tensor,
|
||||
%bias: !tfr.tensor,
|
||||
%act: !tfr.attr{tfr.name="act", tfr.default=""}) -> !tfr.tensor {
|
||||
%dot = tfr.call @tf__mat_mul(%input, %weight) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
|
||||
%add = tfr.call @tf__add(%dot, %bias) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
|
||||
|
||||
%relu = tfr.constant "relu" -> !tfr.attr
|
||||
%relu6 = tfr.constant "relu6" -> !tfr.attr
|
||||
|
||||
%is_relu = tfr.equal %act, %relu -> i1
|
||||
%res = scf.if %is_relu -> !tfr.tensor {
|
||||
%applied_relu = tfr.call @tf__relu(%add) : (!tfr.tensor) -> !tfr.tensor
|
||||
scf.yield %applied_relu : !tfr.tensor
|
||||
} else {
|
||||
%is_relu6 = tfr.equal %act, %relu6 -> i1
|
||||
%res1 = scf.if %is_relu6 -> !tfr.tensor {
|
||||
%applied_relu6 = tfr.call @tf__relu6(%add) : (!tfr.tensor) -> !tfr.tensor
|
||||
scf.yield %applied_relu6 : !tfr.tensor
|
||||
} else {
|
||||
scf.yield %add : !tfr.tensor
|
||||
}
|
||||
scf.yield %res1 : !tfr.tensor
|
||||
}
|
||||
tfr.return %res : !tfr.tensor
|
||||
}
|
||||
|
||||
// This is a wong decomposition and used to verify that tf.Elu isn't decomposed
|
||||
// since its kernel has been registered.
|
||||
tfr.func @tf__elu_(%input: !tfr.tensor) -> !tfr.tensor {
|
||||
tfr.return %input : !tfr.tensor
|
||||
}
|
||||
|
||||
// Translated from:
|
||||
//
|
||||
// REGISTER_OP("Add")
|
||||
// .Input("x: T")
|
||||
// .Input("y: T")
|
||||
// .Output("z: T")
|
||||
// .Attr(
|
||||
// "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
|
||||
// "complex64, complex128, string}")
|
||||
tfr.func @tf__add_(!tfr.tensor<T>, !tfr.tensor<T>)
|
||||
-> !tfr.tensor<T> attributes{T}
|
||||
|
||||
// Translated from:
|
||||
//
|
||||
// REGISTER_OP("MatMul")
|
||||
// .Input("a: T")
|
||||
// .Input("b: T")
|
||||
// .Output("product: T")
|
||||
// .Attr("transpose_a: bool = false")
|
||||
// .Attr("transpose_b: bool = false")
|
||||
// .Attr("T: {bfloat16, half, float, double, int32, int64, complex64, complex128}")
|
||||
// T is a derived attribute.
|
||||
// transpose_a and transpose_b is materialized attributes.
|
||||
tfr.func @tf__mat_mul_(!tfr.tensor<T>, !tfr.tensor<T>,
|
||||
i1 {tfr.name="transpose_a", tfr.default=false},
|
||||
i1 {tfr.name="transpose_b", tfr.default=false})
|
||||
-> !tfr.tensor<T> attributes{T}
|
||||
|
||||
// Translated from:
|
||||
//
|
||||
// REGISTER_OP("Relu")
|
||||
// .Input("features: T")
|
||||
// .Output("activations: T")
|
||||
// .Attr("T: {realnumbertype, qint8}")
|
||||
// T is a derived attribute.
|
||||
tfr.func @tf__relu_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
|
||||
|
||||
|
||||
// Translated from:
|
||||
//
|
||||
// REGISTER_OP("Relu6")
|
||||
// .Input("features: T")
|
||||
// .Output("activations: T")
|
||||
// .Attr("T: {realnumbertype}")
|
||||
// T is a derived attribute.
|
||||
tfr.func @tf__relu6_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
|
78
tensorflow/compiler/mlir/tfr/resources/test_ops.cc
Normal file
78
tensorflow/compiler/mlir/tfr/resources/test_ops.cc
Normal file
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("TestNoOp");
|
||||
|
||||
REGISTER_OP("TestIdentityOp")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: numbertype");
|
||||
|
||||
REGISTER_OP("TestIdentityNOp")
|
||||
.Input("input: N * T")
|
||||
.Output("output: N * T")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: numbertype");
|
||||
|
||||
REGISTER_OP("TestInputNOp")
|
||||
.Input("input: N * T")
|
||||
.Output("output: T")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: numbertype");
|
||||
|
||||
REGISTER_OP("TestOutputNOp")
|
||||
.Input("input: T")
|
||||
.Output("output: N * T")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: numbertype");
|
||||
|
||||
REGISTER_OP("TestTwoInputsOp")
|
||||
.Input("lhs: T")
|
||||
.Input("rhs: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("pred: bool = false");
|
||||
|
||||
REGISTER_OP("TestNumAttrsOp")
|
||||
.Attr("x1: int = -10")
|
||||
.Attr("y1: int = 1")
|
||||
.Attr("x2: float = 0.0")
|
||||
.Attr("y2: float = -3.0");
|
||||
|
||||
REGISTER_OP("TestNonNumAttrsOp")
|
||||
.Attr("z: shape")
|
||||
.Attr("x: string = 'hello'")
|
||||
.Attr("y: type = DT_FLOAT");
|
||||
|
||||
REGISTER_OP("TestThreeInputsOp")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Input("z: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("act: {'x', 'y', 'z'} = 'z'");
|
||||
|
||||
REGISTER_OP("TestTwoOutputsOp")
|
||||
.Input("input: T")
|
||||
.Output("output1: T")
|
||||
.Output("output2: T")
|
||||
.Attr("T: numbertype");
|
||||
|
||||
} // namespace tensorflow
|
57
tensorflow/compiler/mlir/tfr/tests/control_flow.mlir
Normal file
57
tensorflow/compiler/mlir/tfr/tests/control_flow.mlir
Normal file
@ -0,0 +1,57 @@
|
||||
// RUN: tfr-opt %s -tfr-decompose -verify-diagnostics -split-input-file | FileCheck %s
|
||||
|
||||
tfr.func @tf__my_pack(%values: !tfr.tensor_list,
|
||||
%n: i32 {tfr.name="N"},
|
||||
%axis: i32 {tfr.name="axis"}) -> !tfr.tensor {
|
||||
%index = constant 0 : index
|
||||
%cst = constant 1 : i32
|
||||
%eq = cmpi "eq", %n, %cst : i32
|
||||
%v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%temp = tfr.call @tf__expand_dims(%v1, %axis) : (!tfr.tensor, i32) -> !tfr.tensor
|
||||
%res = scf.if %eq -> !tfr.tensor {
|
||||
scf.yield %temp : !tfr.tensor
|
||||
} else {
|
||||
%step = index_cast %cst : i32 to index
|
||||
%end = index_cast %n : i32 to index
|
||||
%reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%temp) -> !tfr.tensor {
|
||||
%v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%temp1 = tfr.call @tf__expand_dims(%v, %axis) : (!tfr.tensor, i32) -> !tfr.tensor
|
||||
%reduce_next = tfr.call @tf__risc_concat(%reduce_iter, %temp1, %axis) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor
|
||||
scf.yield %reduce_next : !tfr.tensor
|
||||
}
|
||||
scf.yield %reduce : !tfr.tensor
|
||||
}
|
||||
tfr.return %res : !tfr.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: pack_one
|
||||
func @pack_one(%arg0: tensor<2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
%0 = "tf.MyPack"(%arg0) {N=1:i32, axis=0:i32} : (tensor<2x3xf32>) -> tensor<1x2x3xf32>
|
||||
return %0 : tensor<1x2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[AXIS:.*]] = constant 0 : i32
|
||||
// CHECK-NEXT: %[[CAST:.*]] = "tfr.cast"(%arg0) : (tensor<2x3xf32>) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[ED:.*]] = tfr.call @tf__expand_dims(%[[CAST]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[BACK:.*]] = "tfr.cast"(%[[ED]]) : (!tfr.tensor) -> tensor<1x2x3xf32>
|
||||
// CHECK-NEXT: return %[[BACK]] : tensor<1x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: pack_multiple
|
||||
func @pack_multiple(%arg0: tensor<2x3xf32>,
|
||||
%arg1: tensor<2x3xf32>,
|
||||
%arg2: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
|
||||
%0 = "tf.MyPack"(%arg0, %arg1, %arg2) {N=3:i32, axis=0:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<3x2x3xf32>
|
||||
return %0 : tensor<3x2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[AXIS:.*]] = constant 0 : i32
|
||||
// CHECK-NEXT: %[[CAST0:.*]] = "tfr.cast"(%arg0) : (tensor<2x3xf32>) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[CAST1:.*]] = "tfr.cast"(%arg1) : (tensor<2x3xf32>) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[CAST2:.*]] = "tfr.cast"(%arg2) : (tensor<2x3xf32>) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[EX0:.*]] = tfr.call @tf__expand_dims(%[[CAST0]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[EX1:.*]] = tfr.call @tf__expand_dims(%[[CAST1]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[CONCAT1:.*]] = tfr.call @tf__risc_concat(%[[EX0]], %[[EX1]], %c0_i32) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[EX2:.*]] = tfr.call @tf__expand_dims(%[[CAST2]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[CONCAT2:.*]] = tfr.call @tf__risc_concat(%[[CONCAT1]], %[[EX2]], %[[AXIS]]) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[BACK:.*]] = "tfr.cast"(%[[CONCAT2]]) : (!tfr.tensor) -> tensor<3x2x3xf32>
|
||||
// CHECK-NEXT: return %[[BACK]] : tensor<3x2x3xf32>
|
||||
}
|
84
tensorflow/compiler/mlir/tfr/tests/decompose.mlir
Normal file
84
tensorflow/compiler/mlir/tfr/tests/decompose.mlir
Normal file
@ -0,0 +1,84 @@
|
||||
// RUN: tfr-opt %s -tfr-decompose -verify-diagnostics -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @tf__fake_no_op
|
||||
tfr.func @tf__fake_no_op(%arg0: !tfr.tensor) -> !tfr.tensor {
|
||||
tfr.return %arg0 : !tfr.tensor
|
||||
|
||||
// CHECK-NEXT: tfr.return %arg0 : !tfr.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tf__intermediate
|
||||
tfr.func @tf__intermediate(%arg0: !tfr.tensor) -> !tfr.tensor {
|
||||
%0 = tfr.call @tf__risc(%arg0) : (!tfr.tensor) -> !tfr.tensor
|
||||
tfr.return %0 : !tfr.tensor
|
||||
|
||||
// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%arg0) : (!tfr.tensor) -> !tfr.tensor
|
||||
// CHECK-NEXT: tfr.return %[[id]] : !tfr.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tf__fused_n
|
||||
tfr.func @tf__fused_n(
|
||||
%arg0: !tfr.tensor,
|
||||
%arg1: !tfr.tensor_list,
|
||||
%arg2: index {tfr.name="A",tfr.default=1:index})
|
||||
-> !tfr.tensor_list {
|
||||
%0 = tfr.call @tf__intermediate(%arg0) : (!tfr.tensor) -> !tfr.tensor
|
||||
%1 = tfr.get_element %arg1[%arg2] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%2 = tfr.call @tf__intermediate(%1) : (!tfr.tensor) -> !tfr.tensor
|
||||
%3 = "tfr.build_list"(%0, %2) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
|
||||
tfr.return %3 : !tfr.tensor_list
|
||||
|
||||
// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__intermediate(%arg0) : (!tfr.tensor) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[ge:.*]] = tfr.get_element %arg1[%arg2] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[id2:.*]] = tfr.call @tf__intermediate(%[[ge]]) : (!tfr.tensor) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[bl:.*]] = "tfr.build_list"(%[[id1]], %[[id2]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
|
||||
// CHECK-NEXT: tfr.return %[[bl]] : !tfr.tensor_list
|
||||
}
|
||||
|
||||
//------------------------
|
||||
|
||||
// CHECK-LABEL: decompose_tf_no_op
|
||||
func @decompose_tf_no_op(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> {
|
||||
%0 = "tf.FakeNoOp"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string>
|
||||
return %0 : tensor<1x2x3x4x!tf.string>
|
||||
|
||||
// CHECK-NEXT: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: decompose_tf_intermediate
|
||||
func @decompose_tf_intermediate(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> {
|
||||
%0 = "tf.Intermediate"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string>
|
||||
return %0 : tensor<1x2x3x4x!tf.string>
|
||||
|
||||
// CHECK-NEXT: %[[casted:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%[[casted]]) : (!tfr.tensor) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id]]) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string>
|
||||
// CHECK-NEXT: return %[[back]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: decompose_fused_n_default
|
||||
func @decompose_fused_n_default(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
|
||||
%0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) : (tensor<1x2x3x4x!tf.string>, tensor<f32>, tensor<f32>) -> (tensor<1x2x3x4x!tf.string>, tensor<f32>)
|
||||
return %0#1 : tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[in2:.*]] = "tfr.cast"(%arg2) : (tensor<f32>) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) : (!tfr.tensor) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[id2:.*]] = tfr.call @tf__risc(%[[in2]]) : (!tfr.tensor) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id2]]) : (!tfr.tensor) -> tensor<f32>
|
||||
// CHECK-NEXT: return %[[back]] : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: decompose_fused_n
|
||||
func @decompose_fused_n(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
|
||||
%0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) {A=0:index} : (tensor<1x2x3x4x!tf.string>, tensor<f32>, tensor<f32>) -> (tensor<1x2x3x4x!tf.string>, tensor<f32>)
|
||||
return %0#1 : tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[in1:.*]] = "tfr.cast"(%arg1) : (tensor<f32>) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) : (!tfr.tensor) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__risc(%[[in1]]) : (!tfr.tensor) -> !tfr.tensor
|
||||
// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id1]]) : (!tfr.tensor) -> tensor<f32>
|
||||
// CHECK-NEXT: return %[[back]] : tensor<f32>
|
||||
}
|
||||
|
235
tensorflow/compiler/mlir/tfr/tests/end2end.mlir
Normal file
235
tensorflow/compiler/mlir/tfr/tests/end2end.mlir
Normal file
@ -0,0 +1,235 @@
|
||||
// RUN: tfr-opt %s -tfr-decompose -tfr-raise-to-tf -canonicalize -verify-diagnostics -split-input-file | FileCheck %s
|
||||
|
||||
//=================> User models, from GraphDef <====================
|
||||
|
||||
// CHECK-LABEL: my_identity
|
||||
func @my_identity(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%0 = "tf.MyIdentity"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %0 : tensor<2x3xf32>
|
||||
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: my_rsqrt
|
||||
func @my_rsqrt(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
|
||||
%0 = "tf.MyRsqrt"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
|
||||
return %0 : tensor<3x2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[RE:.*]] = "tf.RiscReciprocal"(%arg0) : (tensor<2x3xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[SQRT:.*]] = "tf.RiscSqrt"(%[[RE]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[SQRT]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: my_leaky_relu
|
||||
func @my_leaky_relu(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
|
||||
%0 = "tf.MyLeakyRelu"(%arg0) {alpha=3.0 : f32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
|
||||
return %0 : tensor<3x2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32>
|
||||
// CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor<f32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: my_leaky_relu_with_default
|
||||
func @my_leaky_relu_with_default(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
|
||||
%0 = "tf.MyLeakyRelu"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
|
||||
return %0 : tensor<3x2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<2.000000e-01> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32>
|
||||
// CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor<f32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: my_cast
|
||||
func @my_cast(%arg0: tensor<2x3xf32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.MyCast"(%arg0) {Tout=i32} : (tensor<2x3xf32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-NEXT: %[[CAST:.*]] = "tf.RiscCast"(%arg0) {Tout = i32} : (tensor<2x3xf32>) -> tensor<*xi32>
|
||||
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CAST]]) {shape = #tf.shape<2x3>} : (tensor<*xi32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return %[[ES]] : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: my_pack_single_input
|
||||
func @my_pack_single_input(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
|
||||
%0 = "tf.MyPack"(%arg0) {N=1:i32, axis=0:i32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
|
||||
return %0 : tensor<3x2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[ED:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ED]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: my_pack_multiple_inputs
|
||||
func @my_pack_multiple_inputs(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf32>) -> tensor<3x2x3xf32> {
|
||||
%0 = "tf.MyPack"(%arg0, %arg1, %arg2) {N=3:i32, axis=0:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<3x2x3xf32>
|
||||
return %0 : tensor<3x2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[ED0:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ED1:.*]] = "tf.ExpandDims"(%arg1, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[CC0:.*]] = "tf.RiscConcat"(%[[ED0]], %[[ED1]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ED2:.*]] = "tf.ExpandDims"(%arg2, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[CC1:.*]] = "tf.RiscConcat"(%[[CC0]], %[[ED2]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CC1]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: my_add_n_single_input
|
||||
func @my_add_n_single_input(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%0 = "tf.MyAddN"(%arg0) {N=1:i32} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %0 : tensor<2x3xf32>
|
||||
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: my_add_n_multiple_inputs
|
||||
func @my_add_n_multiple_inputs(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%0 = "tf.MyAddN"(%arg0, %arg1, %arg2) {N=3:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %0 : tensor<2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[ADD0:.*]] = "tf.RiscAdd"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ADD1:.*]] = "tf.RiscAdd"(%[[ADD0]], %arg2) : (tensor<*xf32>, tensor<2x3xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ADD1]]) {shape = #tf.shape<2x3>} : (tensor<*xf32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: return %[[ES]] : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: my_map_and_batch_dataset
|
||||
func @my_map_and_batch_dataset(%input: tensor<*x!tf.variant>,
|
||||
%other1: tensor<*xf32>,
|
||||
%other2: tensor<*xi32>) -> tensor<*x!tf.variant> {
|
||||
%0 = "tf.MyMapAndBatchDataset"(%input, %other1, %other2)
|
||||
{batch_size=1000 : i64, num_parallel_calls = 8 : i64, drop_remainder = 0 : i1,
|
||||
func = @"__some_func", output_types = [f32], output_shapes = [#tf.shape<>], preserve_cardinality = true}
|
||||
: (tensor<*x!tf.variant>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf.variant>
|
||||
return %0 : tensor<*x!tf.variant>
|
||||
|
||||
// CHECK-NEXT: %[[BATCH:.*]] = "tf.Const"() {value = dense<1000> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-NEXT: %[[PARAL:.*]] = "tf.Const"() {value = dense<8> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-NEXT: %[[KEEP:.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||
// CHECK-NEXT: %[[CAST:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[RET:.*]] = "tf.MapAndBatchDatasetV0"(%arg0, %[[BATCH]], %[[PARAL]], %[[KEEP]], %arg1, %[[CAST]])
|
||||
// CHECK-SAME: {f = @__some_func, output_shapes = [#tf.shape<>], output_types = [f32], preserve_cardinality = true} : (tensor<*x!tf.variant>, tensor<i64>, tensor<i64>, tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*x!tf.variant>
|
||||
// CHECK-NEXT: return %[[RET]] : tensor<*x!tf.variant>
|
||||
}
|
||||
|
||||
//=================> decomposition functions, translated from tf.compose api <====================
|
||||
tfr.func @tf__my_identity(%value: !tfr.tensor) -> !tfr.tensor {
|
||||
tfr.return %value : !tfr.tensor
|
||||
}
|
||||
|
||||
tfr.func @tf__my_cast(%value: !tfr.tensor, %tout: !tfr.attr{tfr.name="Tout"}) -> !tfr.tensor {
|
||||
%0 = tfr.call @tf__risc_cast(%value, %tout) : (!tfr.tensor, !tfr.attr) -> !tfr.tensor
|
||||
tfr.return %0 : !tfr.tensor
|
||||
}
|
||||
|
||||
tfr.func @tf__my_rsqrt(%value: !tfr.tensor) -> !tfr.tensor {
|
||||
%1 = tfr.call @tf__risc_reciprocal(%value) : (!tfr.tensor) -> !tfr.tensor
|
||||
%2 = tfr.call @tf__risc_sqrt(%1) : (!tfr.tensor) -> !tfr.tensor
|
||||
tfr.return %2 : !tfr.tensor
|
||||
}
|
||||
|
||||
tfr.func @tf__my_leaky_relu(%value: !tfr.tensor, %alpha: f32 {tfr.name="alpha", tfr.default=0.2:f32}) -> !tfr.tensor {
|
||||
%1 = tfr.call @tf__risc_shape(%value) : (!tfr.tensor) -> !tfr.tensor
|
||||
%2 = "tfr.constant_tensor"(%alpha) : (f32) -> tensor<f32>
|
||||
%t = "tfr.cast"(%2) : (tensor<f32>) -> !tfr.tensor
|
||||
%3 = tfr.call @tf__risc_broadcast(%t, %1) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
|
||||
%4 = tfr.call @tf__risc_maximum(%value, %3) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
|
||||
tfr.return %4 : !tfr.tensor
|
||||
}
|
||||
|
||||
// TODO(fengliuai): use shape dialect to manipulate the shape then this can be decomposed further.
|
||||
tfr.func @tf__my_expand_dims(%value: !tfr.tensor, %axis: i32 {tfr.name="axis"}) -> !tfr.tensor {
|
||||
%axis_cst = "tfr.constant_tensor"(%axis) : (i32) -> tensor<i32>
|
||||
%dim = "tfr.cast"(%axis_cst) : (tensor<i32>) -> !tfr.tensor
|
||||
%0 = tfr.call @tf__expand_dims(%value, %dim) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
|
||||
tfr.return %0 : !tfr.tensor
|
||||
}
|
||||
|
||||
tfr.func @tf__my_pack(%values: !tfr.tensor_list,
|
||||
%n: i32 {tfr.name="N"},
|
||||
%axis: i32 {tfr.name="axis"}) -> !tfr.tensor {
|
||||
%index = constant 0 : index
|
||||
%cst = constant 1 : i32
|
||||
%eq = cmpi "eq", %n, %cst : i32
|
||||
%v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%temp = tfr.call @tf__my_expand_dims(%v1, %axis) : (!tfr.tensor, i32) -> !tfr.tensor
|
||||
%res = scf.if %eq -> !tfr.tensor {
|
||||
scf.yield %temp : !tfr.tensor
|
||||
} else {
|
||||
%step = index_cast %cst : i32 to index
|
||||
%end = index_cast %n : i32 to index
|
||||
%reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%temp) -> !tfr.tensor {
|
||||
%v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%temp1 = tfr.call @tf__my_expand_dims(%v, %axis) : (!tfr.tensor, i32) -> !tfr.tensor
|
||||
%reduce_next = tfr.call @tf__risc_concat(%reduce_iter, %temp1, %axis) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor
|
||||
scf.yield %reduce_next : !tfr.tensor
|
||||
}
|
||||
scf.yield %reduce : !tfr.tensor
|
||||
}
|
||||
tfr.return %res : !tfr.tensor
|
||||
}
|
||||
|
||||
tfr.func @tf__my_add_n(%values: !tfr.tensor_list,
|
||||
%n: i32 {tfr.name="N"}) -> !tfr.tensor {
|
||||
%index = constant 0 : index
|
||||
%cst = constant 1 : i32
|
||||
%eq = cmpi "eq", %n, %cst : i32
|
||||
%v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%res = scf.if %eq -> !tfr.tensor {
|
||||
scf.yield %v1 : !tfr.tensor
|
||||
} else {
|
||||
%step = index_cast %cst : i32 to index
|
||||
%end = index_cast %n : i32 to index
|
||||
%reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor {
|
||||
%v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%reduce_next = tfr.call @tf__risc_add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor
|
||||
scf.yield %reduce_next : !tfr.tensor
|
||||
}
|
||||
scf.yield %reduce : !tfr.tensor
|
||||
}
|
||||
tfr.return %res : !tfr.tensor
|
||||
}
|
||||
|
||||
tfr.func @tf__my_map_and_batch_dataset(
|
||||
%input_dataset: !tfr.tensor,
|
||||
%other_arguments: !tfr.tensor_list,
|
||||
%batch_size: i64 {tfr.name="batch_size"},
|
||||
%num_parallel_calls: i64 {tfr.name="num_parallel_calls"},
|
||||
%drop_remainder: i1 {tfr.name="drop_remainder"},
|
||||
%f: !tfr.attr {tfr.name="func"},
|
||||
%output_types: !tfr.attr {tfr.name="output_types"},
|
||||
%output_shapes: !tfr.attr {tfr.name="output_shapes"},
|
||||
%preserve_cardinality: i1 {tfr.name="preserve_cardinality", tfr.default=false}) -> !tfr.tensor {
|
||||
%batch = "tfr.constant_tensor"(%batch_size) : (i64) -> tensor<i64>
|
||||
%batch1 = "tfr.cast"(%batch) : (tensor<i64>) -> !tfr.tensor
|
||||
%calls = "tfr.constant_tensor"(%num_parallel_calls) : (i64) -> tensor<i64>
|
||||
%calls1 = "tfr.cast"(%calls) : (tensor<i64>) -> !tfr.tensor
|
||||
%drop = "tfr.constant_tensor"(%drop_remainder) : (i1) -> tensor<i1>
|
||||
%drop1 = "tfr.cast"(%drop) : (tensor<i1>) -> !tfr.tensor
|
||||
%ret = tfr.call @tf__map_and_batch_dataset_v0(%input_dataset, %batch1, %calls1, %drop1, %other_arguments, %f, %output_types, %output_shapes, %preserve_cardinality)
|
||||
: (!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor_list, !tfr.attr, !tfr.attr, !tfr.attr, i1) -> !tfr.tensor
|
||||
tfr.return %ret : !tfr.tensor
|
||||
}
|
||||
|
||||
//=================> signatures of the primitive ops with kernels, modeled as external TFR function <==
|
||||
tfr.func @tf__risc_cast_(!tfr.tensor, !tfr.attr{tfr.name="Tout"}) -> !tfr.tensor<Tout> attributes{Tout}
|
||||
tfr.func @tf__risc_add_(!tfr.tensor<T>, !tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
|
||||
tfr.func @tf__risc_concat_(!tfr.tensor<T>, !tfr.tensor<T>, i32{tfr.name="axis"}) -> !tfr.tensor<T> attributes{T}
|
||||
tfr.func @tf__risc_broadcast_(!tfr.tensor<T>, !tfr.tensor<Tidx>) -> !tfr.tensor<T> attributes{T, Tidx}
|
||||
tfr.func @tf__risc_reciprocal_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
|
||||
tfr.func @tf__risc_sqrt_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
|
||||
tfr.func @tf__risc_shape_(!tfr.tensor, !tfr.attr{tfr.name="T", tfr.default=i32}) -> !tfr.tensor<T> attributes{T}
|
||||
tfr.func @tf__risc_maximum_(!tfr.tensor<T>, !tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
|
||||
tfr.func @tf__expand_dims_(!tfr.tensor<T>, !tfr.tensor<Tdim>) -> !tfr.tensor<T> attributes{T, Tdim}
|
||||
tfr.func @tf__map_and_batch_dataset_v0_(!tfr.tensor<T>, !tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor_list<Targuments>,
|
||||
!tfr.attr{tfr.name="f"}, !tfr.attr{tfr.name="output_types"}, !tfr.attr{tfr.name="output_shapes"}, i1{tfr.name="preserve_cardinality"})
|
||||
-> !tfr.tensor<T> attributes{T, Targuments}
|
381
tensorflow/compiler/mlir/tfr/tests/ops.mlir
Normal file
381
tensorflow/compiler/mlir/tfr/tests/ops.mlir
Normal file
@ -0,0 +1,381 @@
|
||||
// RUN: tfr-opt %s -verify-diagnostics -split-input-file | tfr-opt | FileCheck %s
|
||||
// RUN: tfr-opt %s -canonicalize -verify-diagnostics -split-input-file | FileCheck %s -check-prefix=CANON
|
||||
|
||||
// Tests for types, ops with custom constraints, verifiers, printer or parser
|
||||
// methods.
|
||||
|
||||
// CHECK-LABEL: tensor_type_noconstraint
|
||||
func @tensor_type_noconstraint() -> !tfr.tensor
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tensor_type
|
||||
func @tensor_type() -> !tfr.tensor<T>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tensor_list_type_noconstraint
|
||||
func @tensor_list_type_noconstraint() -> !tfr.tensor_list
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tensor_list_type_array_like
|
||||
func @tensor_list_type_array_like() -> !tfr.tensor_list<[N, T]>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tensor_list_type_tuple_like
|
||||
func @tensor_list_type_tuple_like() -> !tfr.tensor_list<input_T>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{unbalanced '>' character in pretty dialect name}}
|
||||
func @tensor_invalid_1() -> !tfr.tensor<[N, T>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{unexpected nul or EOF in pretty dialect name}}
|
||||
func @tensor_invalid_2() -> !tfr.tensor<[N, T]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: call_op
|
||||
func @call_op(%arg0: !tfr.tensor<T>, %arg1: !tfr.tensor_list<TL>, %arg2: i32) -> !tfr.tensor<K> {
|
||||
%0 = tfr.call @Foo(%arg0, %arg1, %arg2) : (!tfr.tensor<T>, !tfr.tensor_list<TL>, i32) -> !tfr.tensor<K>
|
||||
return %0 : !tfr.tensor<K>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: call_op_arg_attr(%arg0: i32) -> !tfr.tensor<K>
|
||||
func @call_op_arg_attr(%arg0: i32) -> !tfr.tensor<K> {
|
||||
%0 = tfr.call @Bar(%arg0) : (i32) -> !tfr.tensor<K>
|
||||
return %0 : !tfr.tensor<K>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @call_op_invalid_1(%arg0: tensor<?xf32>) -> !tfr.tensor<K> {
|
||||
// expected-error@+1 {{got 'tensor<?xf32>'}}
|
||||
%0 = tfr.call @Huu(%arg0) : (tensor<?xf32>) -> !tfr.tensor<K>
|
||||
return %0 : !tfr.tensor<K>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: get_shape
|
||||
func @get_shape(%arg0: !tfr.tensor) -> (!shape.shape, !shape.shape) {
|
||||
%0 = tfr.get_shape %arg0 -> !shape.shape
|
||||
%1 = "tfr.get_shape"(%arg0) : (!tfr.tensor) -> !shape.shape
|
||||
return %0, %1 : !shape.shape, !shape.shape
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: get_real_shape
|
||||
// CANON-LABEL: get_real_shape
|
||||
func @get_real_shape(%arg0: tensor<1x2xf32>) -> tensor<1xindex> {
|
||||
%0 = "tfr.cast"(%arg0) : (tensor<1x2xf32>) -> !tfr.tensor
|
||||
%1 = tfr.get_shape %0 -> !shape.shape
|
||||
%2 = shape.to_extent_tensor %1 : !shape.shape -> tensor<1xindex>
|
||||
return %2 : tensor<1xindex>
|
||||
|
||||
// CANON-NEXT: %[[s:.*]] = shape.const_shape [1, 2] : tensor<?xindex>
|
||||
// CANON-NEXT: %[[e:.*]] = shape.to_extent_tensor %[[s]] : tensor<?xindex> -> tensor<1xindex>
|
||||
// CANON-NEXT: return %[[e]] : tensor<1xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @get_element_type(%arg0: !tfr.tensor) -> (!tfr.attr, !tfr.attr) {
|
||||
%0 = tfr.get_element_type %arg0 -> !tfr.attr
|
||||
%1 = "tfr.get_element_type"(%arg0) : (!tfr.tensor) -> !tfr.attr
|
||||
return %0, %1 : !tfr.attr, !tfr.attr
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: from_tf_tensor
|
||||
func @from_tf_tensor(%arg0: tensor<?xf32>) -> !tfr.tensor<K> {
|
||||
%0 = "tfr.cast"(%arg0) : (tensor<?xf32>) -> !tfr.tensor<K>
|
||||
return %0 : !tfr.tensor<K>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: to_tf_tensor
|
||||
func @to_tf_tensor(%arg0: !tfr.tensor<T>) -> tensor<?xi32> {
|
||||
%0 = "tfr.cast"(%arg0) : (!tfr.tensor<T>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: constant
|
||||
func @constant() -> (!tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr) {
|
||||
%0 = tfr.constant f32 -> !tfr.attr
|
||||
%1 = tfr.constant [f32, i32] -> !tfr.attr
|
||||
%2 = "tfr.constant"() {value = f32} : () -> !tfr.attr
|
||||
%3 = "tfr.constant"() {value = [f32, i32]} : () -> !tfr.attr
|
||||
return %0, %1, %2, %3 : !tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: equal
|
||||
// CANON-LABEL: equal
|
||||
func @equal() -> (i1, i1, i1, i1) {
|
||||
%0 = tfr.constant f32 -> !tfr.attr
|
||||
%1 = tfr.constant f32 -> !tfr.attr
|
||||
%2 = tfr.constant i32 -> !tfr.attr
|
||||
%same_type = tfr.equal %0,%1 -> i1
|
||||
%diff_type = tfr.equal %0,%2 -> i1
|
||||
|
||||
%3 = tfr.constant "hello" -> !tfr.attr
|
||||
%4 = tfr.constant "hello" -> !tfr.attr
|
||||
%5 = tfr.constant "how are you" -> !tfr.attr
|
||||
%same_str = tfr.equal %3,%4 -> i1
|
||||
%diff_str = tfr.equal %3,%5 -> i1
|
||||
return %same_type, %diff_type, %same_str, %diff_str : i1, i1, i1, i1
|
||||
|
||||
// CANON-NEXT: %true = constant true
|
||||
// CANON-NEXT: %false = constant false
|
||||
// CANON-NEXT: return %true, %false, %true, %false : i1, i1, i1, i1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: constant_tensor_scalar
|
||||
func @constant_tensor_scalar(%arg0: i32) -> tensor<i32> {
|
||||
%0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: constant_tensor_vector
|
||||
func @constant_tensor_vector(%arg0: vector<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "tfr.constant_tensor"(%arg0) : (vector<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: constant_tensor_array
|
||||
// CANON-LABEL: constant_tensor_array
|
||||
func @constant_tensor_array() -> !tfr.tensor {
|
||||
%0 = tfr.constant [1, -1, 3] -> !tfr.attr
|
||||
%1 = "tfr.constant_tensor"(%0) : (!tfr.attr) -> !tfr.tensor
|
||||
return %1 : !tfr.tensor
|
||||
|
||||
// CANON-NEXT: "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
|
||||
// CANON-NEXT: "tfr.cast"(%0) : (tensor<3xi64>) -> !tfr.tensor
|
||||
// CANON-NEXT: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: constant_tensor_scalar
|
||||
// CANON-LABEL: constant_tensor_scalar
|
||||
func @constant_tensor_scalar() -> !tfr.tensor {
|
||||
%0 = "std.constant"() {value = 42 : i32} : () -> i32
|
||||
%1 = "tfr.constant_tensor"(%0) : (i32) -> !tfr.tensor
|
||||
return %1 : !tfr.tensor
|
||||
|
||||
// CANON-NEXT: "tf.Const"() {value = dense<42> : tensor<i32>} : () -> tensor<i32>
|
||||
// CANON-NEXT: "tfr.cast"(%0) : (tensor<i32>) -> !tfr.tensor
|
||||
// CANON-NEXT: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @constant_tensor_invalid_0(%arg0: i32) -> tensor<f32> {
|
||||
// expected-error@+1 {{input and output should have the same scalar types.}}
|
||||
%0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @constant_tensor_invalid_1(%arg0: vector<1xi32>) -> tensor<?xi32> {
|
||||
// expected-error@+1 {{output type should be static and ranked}}
|
||||
%0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @constant_tensor_invalid_2(%arg0: vector<1xi32>) -> tensor<1xf32> {
|
||||
// expected-error@+1 {{input and output should have same shape and element type}}
|
||||
%0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @constant_tensor_invalid_3(%arg0: vector<1xi32>) -> tensor<1x1xi32> {
|
||||
// expected-error@+1 {{input and output should have same shape and element type}}
|
||||
%0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1x1xi32>
|
||||
return %0 : tensor<1x1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @constant_tensor_invalid_4(%arg0: i32) -> tensor<1x1xi32> {
|
||||
// expected-error@+1 {{input can not be converted to an output tensor}}
|
||||
%0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<1x1xi32>
|
||||
return %0 : tensor<1x1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: get_element
|
||||
func @get_element(%arg0: !tfr.tensor_list<T>) -> !tfr.tensor {
|
||||
%cst = "std.constant"() {value = 1 : index} : () -> index
|
||||
%0 = tfr.get_element %arg0[%cst] : (!tfr.tensor_list<T>, index) -> !tfr.tensor
|
||||
return %0 : !tfr.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: build_list
|
||||
func @build_list(%arg0: !tfr.tensor<A>, %arg1: !tfr.tensor<B>) -> !tfr.tensor_list {
|
||||
%0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor<A>, !tfr.tensor<B>) -> !tfr.tensor_list
|
||||
return %0 : !tfr.tensor_list
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: build_const_list
|
||||
// CANON-LABEL: build_const_list
|
||||
func @build_const_list() -> !tfr.attr {
|
||||
%0 = "std.constant"() {value = 42 : i32} : () -> i32
|
||||
%1 = "std.constant"() {value = 41 : i32} : () -> i32
|
||||
%2 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr
|
||||
return %2 : !tfr.attr
|
||||
|
||||
// CANON-NEXT: %[[c:.*]] = tfr.constant [42 : i32, 41 : i32] -> !tfr.attr
|
||||
// CANON-NEXT: return %[[c]] : !tfr.attr
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tfr.func
|
||||
tfr.func @External(%arg0: !tfr.tensor<A>,
|
||||
%arg1: !tfr.tensor_list<C>,
|
||||
%arg2: i32 {tfr.name = "A"},
|
||||
%arg3: !tfr.attr {tfr.name = "T"})
|
||||
-> (!tfr.tensor<A>, !tfr.tensor_list<C>)
|
||||
attributes {A, C}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tfr.func
|
||||
tfr.func @Foo(%arg0: !tfr.tensor<A>,
|
||||
%arg1: !tfr.tensor_list<C>,
|
||||
%arg2: i32 {tfr.name = "A"},
|
||||
%arg3: vector<1xi32> {tfr.name = "C"})
|
||||
-> (!tfr.tensor<A>, !tfr.tensor_list<C>)
|
||||
attributes {A, C} {
|
||||
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<C>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tfr.func
|
||||
tfr.func @Bar(%arg0: !tfr.tensor<A>,
|
||||
%arg2: i32 {tfr.name = "B"},
|
||||
%arg3: vector<1xi32> {tfr.name = "C"})
|
||||
-> (!tfr.tensor<A>, !tfr.tensor<A>)
|
||||
attributes {A} {
|
||||
tfr.return %arg0, %arg0 : !tfr.tensor<A>, !tfr.tensor<A>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{Undefined attributes are used: A}}
|
||||
tfr.func @Foo_undefined_attr(%arg0: !tfr.tensor<A>,
|
||||
%arg1: !tfr.tensor_list<A>,
|
||||
%arg2: i32 {tfr.name = "A"},
|
||||
%arg3: vector<1xi32> {tfr.name = "C"}) ->
|
||||
(!tfr.tensor<A>, !tfr.tensor_list<A>) {
|
||||
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{3 attribute argument doesn't have a tfr.name attribute}}
|
||||
tfr.func @Foo_unnamed_attr(%arg0: !tfr.tensor<A>,
|
||||
%arg1: !tfr.tensor_list<A>,
|
||||
%arg2: i32 {tfr.name = "A"},
|
||||
%arg3: vector<1xi32>) ->
|
||||
(!tfr.tensor<A>, !tfr.tensor_list<A>) {
|
||||
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{tfr.tensor_list argument should be before non tensor arguments}}
|
||||
tfr.func @Foo_invalid_arg_order(%arg0: !tfr.tensor<A>,
|
||||
%arg2: i32 {tfr.name = "A"},
|
||||
%arg1: !tfr.tensor_list<A>,
|
||||
%arg3: vector<1xi32> {tfr.name = "C"}) ->
|
||||
(!tfr.tensor<A>, !tfr.tensor_list<A>) {
|
||||
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{tfr.tensor argument should be before tfr.tensor_list argument.}}
|
||||
tfr.func @Foo_invalid_arg_order0(
|
||||
%arg1: !tfr.tensor_list,
|
||||
%arg0: !tfr.tensor,
|
||||
%arg2: i32 {tfr.name = "A"},
|
||||
%arg3: vector<1xi32> {tfr.name = "C"}) ->
|
||||
(!tfr.tensor, !tfr.tensor_list) {
|
||||
tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{tfr.tensor result should be before tfr.tensor_list result}}
|
||||
tfr.func @Foo_invalid_result_order(%arg0: !tfr.tensor<A>,
|
||||
%arg1: !tfr.tensor_list<A>,
|
||||
%arg2: i32 {tfr.name = "A"},
|
||||
%arg3: vector<1xi32> {tfr.name = "C"}) ->
|
||||
(!tfr.tensor_list<A>, !tfr.tensor<A>) {
|
||||
tfr.return %arg1, %arg0 : !tfr.tensor_list<A>, !tfr.tensor<A>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{More than one tfr.tensor_list argument isn't allowed}}
|
||||
tfr.func @Foo_multiple_tensor_list_args(%arg0: !tfr.tensor<A>,
|
||||
%arg1: !tfr.tensor_list<A>,
|
||||
%arg2: !tfr.tensor_list<A>,
|
||||
%arg3: i32 {tfr.name = "A"},
|
||||
%arg4: vector<1xi32> {tfr.name = "C"}) ->
|
||||
(!tfr.tensor<A>, !tfr.tensor_list<A>) {
|
||||
tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{More than one tfr.tensor_list result isn't allowed}}
|
||||
tfr.func @Foo_multiple_tensor_list_results(%arg0: !tfr.tensor<C>,
|
||||
%arg1: !tfr.tensor_list<A>,
|
||||
%arg2: i32 {tfr.name = "A"},
|
||||
%arg3: vector<1xi32> {tfr.name = "C"}) ->
|
||||
(!tfr.tensor_list<A>, !tfr.tensor_list<A>) {
|
||||
tfr.return %arg1, %arg1 : !tfr.tensor_list<A>, !tfr.tensor_list<A>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{None tfr.tensor/tfr.tensor_list results aren't allowed as a result}}
|
||||
tfr.func @Foo_return_attr(%arg0: !tfr.tensor<C>,
|
||||
%arg1: !tfr.tensor_list<A>,
|
||||
%arg2: i32 {tfr.name = "A"},
|
||||
%arg3: vector<1xi32> {tfr.name = "C"}) -> i32 {
|
||||
tfr.return %arg2 : i32
|
||||
}
|
76
tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir
Normal file
76
tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir
Normal file
@ -0,0 +1,76 @@
|
||||
// RUN: tfr-opt %s -tfr-raise-to-tf -verify-diagnostics -split-input-file | FileCheck %s
|
||||
|
||||
tfr.func @tf__risc_same_(!tfr.tensor<T>) -> !tfr.tensor<T> attributes {T}
|
||||
tfr.func @tf__risc_concat_(!tfr.tensor_list<N, T>) -> !tfr.tensor<T> attributes {T, N}
|
||||
tfr.func @tf__risc_split_(!tfr.tensor<T>, i32 {tfr.name="N"}) -> !tfr.tensor_list<N, T> attributes {T, N}
|
||||
tfr.func @tf__risc_cast_(!tfr.tensor, !tfr.attr {tfr.name="K"}) -> !tfr.tensor<K> attributes {T, K}
|
||||
|
||||
// CHECK-LABEL: decompose_tf_same
|
||||
func @decompose_tf_same(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> {
|
||||
%0 = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
|
||||
%1 = tfr.call @tf__risc_same(%0) : (!tfr.tensor) -> !tfr.tensor
|
||||
%2 = "tfr.cast"(%1) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string>
|
||||
return %2 : tensor<1x2x3x4x!tf.string>
|
||||
|
||||
// CHECK: %[[id:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<*x!tf.string>
|
||||
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id]]) {shape = #tf.shape<1x2x3x4>} : (tensor<*x!tf.string>) -> tensor<1x2x3x4x!tf.string>
|
||||
// CHECK: return %[[es]] : tensor<1x2x3x4x!tf.string>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: decompose_tf_consecutive
|
||||
func @decompose_tf_consecutive(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
|
||||
%1 = "tfr.cast"(%arg2) : (tensor<f32>) -> !tfr.tensor
|
||||
%2 = tfr.call @tf__risc_same(%0) : (!tfr.tensor) -> !tfr.tensor
|
||||
%3 = tfr.call @tf__risc_same(%1) : (!tfr.tensor) -> !tfr.tensor
|
||||
%4 = "tfr.cast"(%3) : (!tfr.tensor) -> tensor<f32>
|
||||
return %4 : tensor<f32>
|
||||
|
||||
// CHECK: %[[id0:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<*x!tf.string>
|
||||
// CHECK: %[[id2:.*]] = "tf.RiscSame"(%arg2) : (tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id2]]) {shape = #tf.shape<>} : (tensor<*xf32>) -> tensor<f32>
|
||||
// CHECK: return %[[es]] : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: decompose_tf_concat_n
|
||||
func @decompose_tf_concat_n(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<3xf32> {
|
||||
%0 = "tfr.cast"(%arg0) : (tensor<f32>) -> !tfr.tensor
|
||||
%1 = "tfr.cast"(%arg1) : (tensor<f32>) -> !tfr.tensor
|
||||
%2 = "tfr.cast"(%arg2) : (tensor<f32>) -> !tfr.tensor
|
||||
%3 = "tfr.build_list"(%0, %1, %2) : (!tfr.tensor, !tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
|
||||
%concat = tfr.call @tf__risc_concat(%3) : (!tfr.tensor_list) -> !tfr.tensor
|
||||
%4 = "tfr.cast"(%concat) : (!tfr.tensor) -> tensor<3xf32>
|
||||
return %4 : tensor<3xf32>
|
||||
|
||||
// CHECK: %[[concat:.*]] = "tf.RiscConcat"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[concat]]) {shape = #tf.shape<3>} : (tensor<*xf32>) -> tensor<3xf32>
|
||||
// CHECK: return %[[es]] : tensor<3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: decompose_tf_split
|
||||
func @decompose_tf_split(%arg0: tensor<3xf32>) -> (tensor<f32>) {
|
||||
%0 = "tfr.cast"(%arg0) : (tensor<3xf32>) -> !tfr.tensor
|
||||
%n = std.constant 3: i32
|
||||
%split = tfr.call @tf__risc_split(%0, %n) : (!tfr.tensor, i32) -> !tfr.tensor_list
|
||||
%i0 = std.constant 0: index
|
||||
%s0 = tfr.get_element %split[%i0] : (!tfr.tensor_list, index) -> !tfr.tensor
|
||||
%4 = "tfr.cast"(%s0) : (!tfr.tensor) -> tensor<f32>
|
||||
return %4 : tensor<f32>
|
||||
|
||||
// CHECK: %[[split:.*]]:3 = "tf.RiscSplit"(%arg0) {N = 3 : i32} : (tensor<3xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[split]]#0) {shape = #tf.shape<>} : (tensor<*xf32>) -> tensor<f32>
|
||||
// CHECK: return %[[es]] : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: decompose_tf_cast
|
||||
func @decompose_tf_cast(%arg0: tensor<f32>) -> tensor<i32> {
|
||||
%0 = "tfr.cast"(%arg0) : (tensor<f32>) -> !tfr.tensor
|
||||
%t = tfr.constant i32 -> !tfr.attr
|
||||
%concat = tfr.call @tf__risc_cast(%0, %t) : (!tfr.tensor, !tfr.attr) -> !tfr.tensor
|
||||
%4 = "tfr.cast"(%concat) : (!tfr.tensor) -> tensor<i32>
|
||||
return %4 : tensor<i32>
|
||||
|
||||
// CHECK: %[[tfcast:.*]] = "tf.RiscCast"(%arg0) {K = i32} : (tensor<f32>) -> tensor<*xi32>
|
||||
// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) {shape = #tf.shape<>} : (tensor<*xi32>) -> tensor<i32>
|
||||
// CHECK: return %[[es]] : tensor<i32>
|
||||
}
|
78
tensorflow/compiler/mlir/tfr/utils/utils.cc
Normal file
78
tensorflow/compiler/mlir/tfr/utils/utils.cc
Normal file
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfr/utils/utils.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFR {
|
||||
|
||||
std::string GetComposeFuncName(StringRef tf_op_name) {
|
||||
std::string compose_func_name;
|
||||
for (int i = 0; i < tf_op_name.size(); ++i) {
|
||||
if (tf_op_name[i] == '_') {
|
||||
// The field name must not contain "_"s. "_Arg" and "_RetVal" are special
|
||||
// op names and we can return empty string to skip the decomposition.
|
||||
return {};
|
||||
}
|
||||
if (tf_op_name[i] == '.') {
|
||||
compose_func_name.push_back('_');
|
||||
} else if (tf_op_name[i] >= 'A' && tf_op_name[i] <= 'Z') {
|
||||
compose_func_name.push_back('_');
|
||||
compose_func_name.push_back(tf_op_name[i] + 'a' - 'A');
|
||||
} else {
|
||||
compose_func_name.push_back(tf_op_name[i]);
|
||||
}
|
||||
}
|
||||
return compose_func_name;
|
||||
}
|
||||
|
||||
std::string GetTFOpName(StringRef compose_func_name) {
|
||||
std::string tf_op_name;
|
||||
bool after_underscore = false;
|
||||
for (int i = 0; i < compose_func_name.size(); ++i) {
|
||||
if (compose_func_name[i] >= 'A' && compose_func_name[i] <= 'Z') {
|
||||
// The field name must not contain uppercase letters.
|
||||
return {};
|
||||
}
|
||||
if (after_underscore) {
|
||||
if (compose_func_name[i] >= 'a' && compose_func_name[i] <= 'z') {
|
||||
tf_op_name.push_back(compose_func_name[i] + 'A' - 'a');
|
||||
after_underscore = false;
|
||||
} else {
|
||||
// The character after a "_" must be a lowercase letter.
|
||||
return {};
|
||||
}
|
||||
} else if (compose_func_name[i] == '_') { // first time visit '_'
|
||||
if (i + 1 < compose_func_name.size() && compose_func_name[i + 1] == '_') {
|
||||
tf_op_name.push_back('.');
|
||||
i++;
|
||||
}
|
||||
after_underscore = true;
|
||||
} else {
|
||||
tf_op_name.push_back(compose_func_name[i]);
|
||||
}
|
||||
}
|
||||
if (after_underscore) {
|
||||
// Trailing "_".
|
||||
return {};
|
||||
}
|
||||
return tf_op_name;
|
||||
}
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
42
tensorflow/compiler/mlir/tfr/utils/utils.h
Normal file
42
tensorflow/compiler/mlir/tfr/utils/utils.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFR {
|
||||
|
||||
// This is a hardcoded rule for mapping a TF op name to the corresponding
|
||||
// TFR function name. Examples:
|
||||
// tf.Pack => tf__pack
|
||||
// tf.ConcatV2 => tf__concat_v2
|
||||
// TODO(fengliuai): move to an util file.
|
||||
std::string GetComposeFuncName(StringRef tf_op_name);
|
||||
|
||||
// This is a hardcoded rule for mapping a TFR function op name to the
|
||||
// corresponding TF opname. Examples:
|
||||
// tf__pack -> tf.Pack
|
||||
// tf__concat_v2 => tf.ConcatV2
|
||||
std::string GetTFOpName(StringRef compose_func_name);
|
||||
|
||||
} // namespace TFR
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
|
Loading…
x
Reference in New Issue
Block a user