From 938cc7bf9c4f354361e18e3ba485af53e602d341 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 23 Sep 2020 14:08:10 -0700 Subject: [PATCH] Graduate the mlir implementation of tfr from experimental PiperOrigin-RevId: 333369220 Change-Id: Iafce731c4d80f06eb1e95aee6ea0f67af8caeac5 --- tensorflow/compiler/mlir/tfr/BUILD | 165 +++++ tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc | 590 ++++++++++++++++++ tensorflow/compiler/mlir/tfr/ir/tfr_ops.h | 55 ++ tensorflow/compiler/mlir/tfr/ir/tfr_ops.td | 436 +++++++++++++ tensorflow/compiler/mlir/tfr/ir/tfr_types.h | 115 ++++ .../compiler/mlir/tfr/passes/canonicalize.cc | 160 +++++ .../compiler/mlir/tfr/passes/decompose.cc | 280 +++++++++ tensorflow/compiler/mlir/tfr/passes/passes.h | 44 ++ .../compiler/mlir/tfr/passes/raise_to_tf.cc | 474 ++++++++++++++ .../compiler/mlir/tfr/passes/tfr_opt.cc | 37 ++ tensorflow/compiler/mlir/tfr/resources/BUILD | 61 ++ .../mlir/tfr/resources/composite_ops.cc | 39 ++ .../mlir/tfr/resources/decomposition_lib.mlir | 109 ++++ .../compiler/mlir/tfr/resources/test_ops.cc | 78 +++ .../compiler/mlir/tfr/tests/control_flow.mlir | 57 ++ .../compiler/mlir/tfr/tests/decompose.mlir | 84 +++ .../compiler/mlir/tfr/tests/end2end.mlir | 235 +++++++ tensorflow/compiler/mlir/tfr/tests/ops.mlir | 381 +++++++++++ .../compiler/mlir/tfr/tests/raise_to_tf.mlir | 76 +++ tensorflow/compiler/mlir/tfr/utils/utils.cc | 78 +++ tensorflow/compiler/mlir/tfr/utils/utils.h | 42 ++ 21 files changed, 3596 insertions(+) create mode 100644 tensorflow/compiler/mlir/tfr/BUILD create mode 100644 tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc create mode 100644 tensorflow/compiler/mlir/tfr/ir/tfr_ops.h create mode 100644 tensorflow/compiler/mlir/tfr/ir/tfr_ops.td create mode 100644 tensorflow/compiler/mlir/tfr/ir/tfr_types.h create mode 100644 tensorflow/compiler/mlir/tfr/passes/canonicalize.cc create mode 100644 tensorflow/compiler/mlir/tfr/passes/decompose.cc create mode 100644 tensorflow/compiler/mlir/tfr/passes/passes.h create mode 100644 tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc create mode 100644 tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc create mode 100644 tensorflow/compiler/mlir/tfr/resources/BUILD create mode 100644 tensorflow/compiler/mlir/tfr/resources/composite_ops.cc create mode 100644 tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir create mode 100644 tensorflow/compiler/mlir/tfr/resources/test_ops.cc create mode 100644 tensorflow/compiler/mlir/tfr/tests/control_flow.mlir create mode 100644 tensorflow/compiler/mlir/tfr/tests/decompose.mlir create mode 100644 tensorflow/compiler/mlir/tfr/tests/end2end.mlir create mode 100644 tensorflow/compiler/mlir/tfr/tests/ops.mlir create mode 100644 tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir create mode 100644 tensorflow/compiler/mlir/tfr/utils/utils.cc create mode 100644 tensorflow/compiler/mlir/tfr/utils/utils.h diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD new file mode 100644 index 00000000000..0a9305c8aea --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -0,0 +1,165 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//third_party/mlir:tblgen.bzl", + "gentbl", +) +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package( + default_visibility = [ + ":friends", + ], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["//third_party/mlir:subpackages"], + packages = [ + "//learning/brain/experimental/mlir/tfr/...", + "//tensorflow/compiler/mlir/...", + ], +) + +filegroup( + name = "tfr_ops_td_files", + srcs = [ + "ir/tfr_ops.td", + "//tensorflow/compiler/mlir/tensorflow:ir/tf_op_base.td", + "//tensorflow/compiler/mlir/tensorflow:ir/tf_op_interfaces.td", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeBase.td", + "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeOps.td", + "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", + ], +) + +gentbl( + name = "tfr_ops_inc_gen", + tbl_outs = [ + ( + "-gen-op-decls", + "ir/tfr_ops.h.inc", + ), + ( + "-gen-op-defs", + "ir/tfr_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/tfr_ops.td", + td_srcs = [ + ":tfr_ops_td_files", + ], +) + +cc_library( + name = "tfr", + srcs = [ + "ir/tfr_ops.cc", + "ir/tfr_ops.cc.inc", + ], + hdrs = [ + "ir/tfr_ops.h", + "ir/tfr_ops.h.inc", + "ir/tfr_types.h", + ], + deps = [ + ":tfr_ops_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "utils", + srcs = [ + "utils/utils.cc", + ], + hdrs = [ + "utils/utils.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "passes", + srcs = [ + "passes/canonicalize.cc", + "passes/decompose.cc", + "passes/raise_to_tf.cc", + ], + hdrs = [ + "passes/passes.h", + ], + deps = [ + ":tfr", + ":utils", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/memory", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToStandard", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], + alwayslink = 1, +) + +tf_cc_binary( + name = "tfr-opt", + srcs = ["passes/tfr_opt.cc"], + deps = [ + ":passes", + ":tfr", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir:passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + ], +) + +glob_lit_tests( + data = [ + ":test_utilities", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", + test_file_exts = ["mlir"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir/tfr:tfr-opt", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + "@llvm-project//mlir:run_lit.sh", + ], +) diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc new file mode 100644 index 00000000000..c0ef5c3b387 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -0,0 +1,590 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" + +#include +#include + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/FunctionImplementation.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" + +namespace mlir { + +namespace TFR { + +//===----------------------------------------------------------------------===// +// InlinerInterface +//===----------------------------------------------------------------------===// + +namespace { +/// This class defines the interface for inlining within the TFR dialect. +struct TFRInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + // Returns true if the given region 'src' can be inlined into the region + // 'dest' that is attached to an operation registered to the current dialect. + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &) const final { + return true; + } + + // Returns true if the given operation 'op', that is registered to this + // dialect, can be inlined into the region 'dest' that is attached to an + // operation registered to the current dialect. + bool isLegalToInline(Operation *op, Region *dest, + BlockAndValueMapping &) const final { + return true; + } + + // Handle the given inlined terminator by replacing it with a new operation + // as necessary. Required when the region has only one block. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + auto retValOp = dyn_cast(op); + if (!retValOp) return; + + for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) { + std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value)); + } + } + + // Attempts to materialize a conversion for a type mismatch between a call + // from this dialect, and a callable region. This method should generate an + // operation that takes 'input' as the only operand, and produces a single + // result of 'resultType'. If a conversion can not be generated, nullptr + // should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type result_type, + Location conversion_loc) const final { + if (!result_type.isa()) return nullptr; + return builder.create(conversion_loc, result_type, input); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// TFR Dialect +//===----------------------------------------------------------------------===// + +TFRDialect::TFRDialect(MLIRContext *context) + : Dialect(/*name=*/"tfr", context, TypeID::get()) { + addTypes(); + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc" + >(); + + addInterfaces(); +} + +bool TFRType::classof(Type type) { + return llvm::isa(type.getDialect()); +} + +//===----------------------------------------------------------------------===// +// Custom op methods +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ConstantTensorOp op) { + auto input_type = op.arg().getType(); + auto output_type = op.out().getType(); + + if (auto output_tensor_type = output_type.dyn_cast()) { + return success(); + } + + auto output_tensor_type = output_type.dyn_cast(); + if (!output_tensor_type || !output_tensor_type.hasStaticShape()) { + op.emitError("output type should be static and ranked."); + return failure(); + } + + if (output_tensor_type.getRank() == 0) { + bool same_scalar = output_tensor_type.getElementType() == input_type; + if (!same_scalar) { + op.emitError("input and output should have the same scalar types."); + } + return success(same_scalar); + } + + if (auto input_vector_type = input_type.dyn_cast()) { + bool same_element_type = output_tensor_type.getElementType() == + input_vector_type.getElementType(); + bool same_shape = + output_tensor_type.getShape() == input_vector_type.getShape(); + if (!same_element_type || !same_shape) { + op.emitError("input and output should have same shape and element type."); + } + return success(same_element_type && same_shape); + } + + op.emitError("input can not be converted to an output tensor."); + return failure(); +} + +static LogicalResult Verify(TFRFuncOp func) { + // Collect all attribute names used by the tensor and tensor list arguments + // and returns. Also, collect the names of all the attribute arguments as the + // defined list. Later on, the used attribute names will be verified to be in + // the defined list. + llvm::SmallVector used_attrs; + + // While scanning the arguments, record the start/end indices of each argument + // type, so the order can be verified as well. + // TODO(fengliuai): the attribute arguments with default values need to be + // at the end? + int first_tensor = -1, last_tensor = -1, first_tensor_list = -1, + last_tensor_list = -1, first_attr = -1; + + for (auto arg : llvm::enumerate(func.getType().getInputs())) { + Type arg_type = arg.value(); + + if (auto tensor = arg_type.dyn_cast()) { + if (first_tensor == -1) { + first_tensor = arg.index(); + } + last_tensor = arg.index(); + auto used = tensor.getAttrKeys(); + used_attrs.append(used.begin(), used.end()); + continue; + } + + if (auto tensor_list = arg_type.dyn_cast()) { + if (first_tensor_list == -1) { + first_tensor_list = arg.index(); + } + last_tensor_list = arg.index(); + auto used = tensor_list.getAttrKeys(); + used_attrs.append(used.begin(), used.end()); + continue; + } + + if (!arg_type.isa()) { + if (first_attr == -1) { + first_attr = arg.index(); + } + auto name = + func.getArgAttrOfType(arg.index(), kAttrArgumentNameAttr); + if (!name) { + func.emitError( + llvm::Twine(arg.index()) + + " attribute argument doesn't have a tfr.name attribute."); + return failure(); + } + continue; + } + + func.emitError("Builtin TensorType isn't allowed as the argument."); + return failure(); + } + + // Verify the argument order: tensors, tensor list, attributes; and also + // verify there is at most one tensor list argument. + if (first_tensor_list != -1 && first_tensor_list < last_tensor) { + func.emitError( + "tfr.tensor argument should be before tfr.tensor_list argument."); + return failure(); + } + if (first_attr != -1 && first_attr < last_tensor_list) { + func.emitError( + "tfr.tensor_list argument should be before non tensor arguments."); + return failure(); + } + if (first_tensor_list != last_tensor_list) { + func.emitError("More than one tfr.tensor_list argument isn't allowed."); + return failure(); + } + + // Verify the result order: tensor, tensor list, and also verify at most one + // tensor list result. + bool seen_tensor_list = false; + for (auto result_type : func.getType().getResults()) { + if (auto tensor = result_type.dyn_cast()) { + if (seen_tensor_list) { + func.emitError( + "tfr.tensor result should be before tfr.tensor_list result."); + return failure(); + } + auto used = tensor.getAttrKeys(); + used_attrs.append(used.begin(), used.end()); + continue; + } + + if (auto tensor_list = result_type.dyn_cast()) { + if (seen_tensor_list) { + func.emitError("More than one tfr.tensor_list result isn't allowed."); + return failure(); + } + seen_tensor_list = true; + auto used = tensor_list.getAttrKeys(); + used_attrs.append(used.begin(), used.end()); + continue; + } + + func.emitError( + "None tfr.tensor/tfr.tensor_list results aren't allowed as a " + "result."); + return failure(); + } + + // Verify that all the used attributes are in the attribute arguments. + llvm::SmallVector undefined_attrs; + for (auto attr : used_attrs) { + if (!func.getAttr(attr.getValue())) { + undefined_attrs.push_back(attr); + } + } + if (!undefined_attrs.empty()) { + llvm::SmallVector attr_names(undefined_attrs.size()); + std::transform(undefined_attrs.begin(), undefined_attrs.end(), + attr_names.begin(), + [](StringAttr attr) { return attr.getValue().str(); }); + func.emitError(llvm::Twine("Undefined attributes are used: ", + llvm::join(attr_names, ","))); + return failure(); + } + + return success(); +} + +static ParseResult ParseFuncOp(OpAsmParser &parser, OperationState *result) { + auto build_func_type = [](Builder &builder, ArrayRef arg_types, + ArrayRef results, impl::VariadicFlag, + std::string &) { + return builder.getFunctionType(arg_types, results); + }; + return impl::parseFunctionLikeOp(parser, *result, /*allowVariadic=*/false, + build_func_type); +} + +static void PrintFuncOp(OpAsmPrinter &p, TFRFuncOp op) { + FunctionType fn_type = op.getType(); + impl::printFunctionLikeOp(p, op, fn_type.getInputs(), /*isVariadic=*/false, + fn_type.getResults()); +} + +} // namespace TFR +} // namespace mlir + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc" + +namespace mlir { +namespace TFR { +struct ConvertConstToTensorConst : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op, + PatternRewriter &rewriter) const override { + Location loc = cst_tensor_op.getLoc(); + Type out_type = cst_tensor_op.getType(); + Operation *new_cst = nullptr; + + ArrayAttr array; + if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) { + llvm::DenseSet all_types; + for (auto it : array) { + all_types.insert(it.getType()); + } + if (all_types.size() != 1) return failure(); + ShapedType new_out_type = RankedTensorType::get( + {static_cast(array.size())}, *all_types.begin()); + DenseElementsAttr attr = + DenseElementsAttr::get(new_out_type, array.getValue()); + new_cst = rewriter.create(loc, new_out_type, attr); + if (out_type.isa()) { + new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); + } + rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); + return success(); + } + + Attribute scalar; + if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) { + Type new_out_type = RankedTensorType::get({}, scalar.getType()); + new_cst = rewriter.create(loc, new_out_type, scalar); + if (out_type.isa()) { + new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); + } + rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); + return success(); + } + return failure(); + } +}; + +struct RemoveRedundantCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CastOp cast_op, + PatternRewriter &rewriter) const override { + auto preceding_cast = + llvm::dyn_cast_or_null(cast_op.arg().getDefiningOp()); + if (!preceding_cast) { + return failure(); + } + Value input = preceding_cast.arg(); + Type input_type = input.getType(); + Type output_type = cast_op.getType(); + + // If the two types are the same, the back-to-back tfr.cast ops can be + // removed. + if (input_type == output_type || output_type.isa()) { + rewriter.replaceOp(cast_op, {input}); + return success(); + } + + // If the rank of the input tensor isn't ranked, we replace the pair + // with tf.EnsureShape op so it can be removed after shape inference or + // confirmed at runtime. + if (input_type.isa() && output_type.isa()) { + auto shape = output_type.cast().getShape(); + auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape); + rewriter.replaceOpWithNewOp(cast_op, output_type, + input, shape_attr); + } + + return success(); + } +}; + +struct GetTensorShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetShapeOp shape_op, + PatternRewriter &rewriter) const override { + Operation *preceding_op = shape_op.arg().getDefiningOp(); + if (auto cast_op = llvm::dyn_cast_or_null(preceding_op)) { + // replace this pair by shape.shape_of, so the folding works. + rewriter.replaceOpWithNewOp(shape_op, cast_op.arg()); + return success(); + } + return failure(); + } +}; + +struct RemoveRedundantGetElement : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetElementOp ge_op, + PatternRewriter &rewriter) const override { + IntegerAttr index; + if (!matchPattern(ge_op.index(), m_Constant(&index))) { + return failure(); + } + auto preceding_build_list = llvm::dyn_cast_or_null( + ge_op.tensor_list().getDefiningOp()); + if (!preceding_build_list || + preceding_build_list.getNumOperands() <= index.getInt()) { + return failure(); + } + Value input = preceding_build_list.getOperand(index.getInt()); + Type output_type = ge_op.getType(); + if (input.getType() != output_type && + !output_type.isa()) { + return failure(); + } + rewriter.replaceOp(ge_op, {input}); + return success(); + } +}; + +struct BuildConstantListAsAttr : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BuildListOp bl_op, + PatternRewriter &rewriter) const override { + SmallVector array_list; + array_list.reserve(bl_op.getNumOperands()); + for (const auto &operand : bl_op.getOperands()) { + Attribute array_elt; + if (!matchPattern(operand, m_Constant(&array_elt))) { + return failure(); + } + array_list.push_back(array_elt); + } + auto array_attr = rewriter.getArrayAttr(array_list); + rewriter.replaceOpWithNewOp(bl_op, array_attr); + return success(); + } +}; + +void ConstantTensorOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +void GetShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +void GetElementOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult TFR::EqualOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "equal op has two operands"); + auto ctx = getContext(); + if (operands[0] == operands[1]) return BoolAttr::get(/*value=*/true, ctx); + return BoolAttr::get(/*value=*/false, ctx); +} + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + + // Return the held attribute value. + return value(); +} + +// CallableOpInterface +Region *TFRFuncOp::getCallableRegion() { + return isExternal() ? nullptr : &body().front(); +} + +// CallableOpInterface +ArrayRef TFRFuncOp::getCallableResults() { + return getType().getResults(); +} + +//===----------------------------------------------------------------------===// +// Dialect type definitions +//===----------------------------------------------------------------------===// + +// Parses a TFR type. +// tfr_type ::= tensor_type | tensor_list_type | attr_type +// string_list ::= `[` string-literal (, string-literal)+ `]` +// tensor_type ::= `tensor` +// | `tensor<` (string-literal | string_list) '>' +// tensor_list_type ::= `tensor_list` +// | `tensor_list<` (string-literal | string_list) '>' +// attr_type ::= `attr` +Type TFRDialect::parseType(DialectAsmParser &parser) const { + Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); + MLIRContext *ctx = loc.getContext(); + + StringRef typeNameSpelling; + if (failed(parser.parseKeyword(&typeNameSpelling))) return {}; + llvm::SmallVector attrs; + if (succeeded(parser.parseOptionalLess())) { + bool l_square_parsed = false; + if (succeeded(parser.parseOptionalLSquare())) { + l_square_parsed = true; + } + + do { + StringRef attr; + if (failed(parser.parseKeyword(&attr))) return {}; + attrs.push_back(StringAttr::get(attr, ctx)); + } while (succeeded(parser.parseOptionalComma())); + + if (l_square_parsed && failed(parser.parseRSquare())) { + parser.emitError(parser.getNameLoc(), "expected ']'"); + } + + if (failed(parser.parseGreater())) { + parser.emitError(parser.getNameLoc(), "expected '>'"); + } + } + + if (typeNameSpelling == "tensor") { + return TFRTensorType::getChecked(attrs, loc); + } else if (typeNameSpelling == "tensor_list") { + return TFRTensorListType::getChecked(attrs, loc); + } else if (typeNameSpelling == "attr") { + return TFRAttrType::getChecked(loc); + } else { + parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling); + return {}; + } +} + +void TFRDialect::printType(Type type, DialectAsmPrinter &os) const { + llvm::ArrayRef attrs; + + if (type.isa()) { + os << "attr"; + return; + } + if (auto tensor_ty = type.dyn_cast()) { + attrs = tensor_ty.getAttrKeys(); + os << "tensor"; + } else if (auto tensor_list_ty = type.dyn_cast()) { + attrs = tensor_list_ty.getAttrKeys(); + os << "tensor_list"; + } else { + llvm_unreachable("Unhandled tfr type"); + } + + if (attrs.empty()) return; + os << "<"; + + if (attrs.size() > 1) { + os << "["; + } + + llvm::interleaveComma(attrs, os, + [&](StringAttr attr) { os << attr.getValue(); }); + + if (attrs.size() > 1) { + os << "]"; + } + os << ">"; +} + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h new file mode 100644 index 00000000000..cb36ee28351 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_ + +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/FunctionSupport.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +constexpr char kAttrArgumentNameAttr[] = "tfr.name"; +constexpr char kAttrArgumentDefaultAttr[] = "tfr.default"; + +class TFRDialect : public Dialect { + public: + explicit TFRDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "tfr"; } + + // Parse a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + // Prints a type registered to this dialect. + void printType(Type ty, DialectAsmPrinter &os) const override; +}; + +} // namespace TFR +} // namespace mlir + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td new file mode 100644 index 00000000000..2918336603b --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -0,0 +1,436 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for TFR + +#ifndef DIALECT_TFR_OPS_ +#define DIALECT_TFR_OPS_ + +include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" + +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + +def TFR_Dialect : Dialect { + let name = "tfr"; + + let description = [{ + The TensorFlow Composition dialect. + }]; + + let cppNamespace = "::mlir::TFR"; +} + +//===----------------------------------------------------------------------===// +// Type classes +//===----------------------------------------------------------------------===// + +// tensor argument types +class TFR_Type : DialectType()">, + "TFR " # name #" type">, + BuildableType<"$_builder.getType()">; +def TFR_TensorType : TFR_Type<"TFRTensor">; +def TFR_TensorListType : TFR_Type<"TFRTensorList">; +def TFR_AllTensorTypes : Type, "all tensor related types">; + +// attribute argument types +def TFR_AttrType : TFR_Type<"TFRAttr">; +def TFR_AttrScalarType: TypeAlias; +def TFR_AttrVectorType : VectorOf<[TF_ElementType, TFR_AttrType]>; +def TFR_AllAttrTypes : Type, "all attribute related types">; + +// all allowed arguments types +def TFR_allowedArgType : Type, "allowed tfr.call operand types">; + +def TFR_allowedConstValues : Attr, "allowed tfr.constant value"> { + let storageType = "Attribute"; + let returnType = "Attribute"; + let convertFromStorage = "$_self"; + let constBuilderCall = "$0"; +} + +// all allowed result types +def TFR_allowedResultType : TypeAlias; + +// standard tensor type and tfr.tensor types can be casted to each other. +def TFR_singleTensorType : Type, "single tensor or tfr.tensor type">; + +// all allowed build list input types +def TFR_allowedBuiltListType : Type, "single tfr.tensor or tensor element type">; + +// all allowed build list result types +def TFR_allowedListResultType : Type, "tfr.tensor_list or tfr.attr type">; + +//===----------------------------------------------------------------------===// +// Op classes +//===----------------------------------------------------------------------===// + +class TFR_Op traits> : + Op; + +def TFR_CallOp : TFR_Op<"call", [CallOpInterface]> { + let description = [{ + The `call` operation represents a direct call to a function that is within + the same symbol scope as the callee. The operands and result types of the + call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = tfr.call @my_add(%0, %1) : (tfr.tensor, f32) -> tfr.tensor_list + ``` + + Note that the operands of the `call` operation can only be with tfr.tensor, + tfr.tensor_list, tfr.attr and mlir float and integer types. The results of + the `call` operation can only be with tfr.tensor and tfr.tensor_list types. + }]; + + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$args); + + let results = (outs + Variadic:$outs); + + let extraClassDeclaration = [{ + StringRef getCallee() { return callee(); } + + // Get the argument operands to the called function. + operand_range getArgOperands() { return args(); } + + // Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { return calleeAttr(); } + }]; + + let assemblyFormat = [{ + $callee `(` $args `)` attr-dict `:` functional-type($args, results) + }]; +} + +def TFR_CastOp : TFR_Op<"cast", [NoSideEffect]> { + let description = [{ + The `cast` operation converts the operand with built-in tensor type to + tfr.tensor type, or vice versa. + + Example: + + ```mlir + %1 = tfr.cast(%0) : tensor -> !tfr.tensor + %3 = tfr.cast(%1) : !tfr.tensor -> tensor + ``` + }]; + + let arguments = (ins TFR_singleTensorType:$arg); + + let results = (outs TFR_singleTensorType:$out); + + let extraClassDeclaration = [{ + // Return element type of the input tensor type. Only available when the + // input is a MLIR built-in tensor type. + Attribute getInputElementType() { + if (auto ty = arg().getType().dyn_cast()) { + return TypeAttr::get(ty.getElementType()); + } + return {}; + } + }]; + + let hasCanonicalizer = 1; +} + +def TFR_GetShapeOp : TFR_Op<"get_shape", [NoSideEffect]> { + let description = [{ + The `get_shape` operation gets the shape of a tfr.tensor and returns + !shape.shape type. + + Example: + + ```mlir + %1 = "tfr.get_shape"(%0) : !tfr.tensor -> !shape.shape + %1 = tfr.get_shape %0 -> !shape.shape + ``` + }]; + + let arguments = (ins TFR_TensorType:$arg); + + let results = (outs Shape_ShapeType:$out); + + let assemblyFormat = "$arg attr-dict `->` type($out)"; + + let hasCanonicalizer = 1; +} + +def TFR_GetElementTypeOp : TFR_Op<"get_element_type", [NoSideEffect]> { + let description = [{ + The `get_element_type` operation gets the element type of a tfr.tensor and + returns !tfr.attr. + + Example: + + ```mlir + %1 = "tfr.get_element_type"(%0) : !tfr.tensor -> !tfr.attr + %1 = tfr.get_element_type %0 -> !tfr.attr + ``` + }]; + + let arguments = (ins TFR_TensorType:$arg); + + let results = (outs TFR_AttrType:$out); + + let assemblyFormat = "$arg attr-dict `->` type($out)"; +} + +def TFR_EqualOp : TFR_Op<"equal", [NoSideEffect, SameTypeOperands]> { + let description = [{ + The `equal` operation compares the values of the tfr.attr type arguments. + The operation returns an i1 boolean indicating if the two values are the + same. + Example: + + ```mlir + %x = tfr.equal %lhs, %rhs -> i1 + %x = "tfr.equal"(%lhs, %rhs) : (!tfr.attr, !tfr.attr) -> i1 + ``` + }]; + + let arguments = (ins + TFR_AttrType:$lhs, + TFR_AttrType:$rhs + ); + let results = (outs BoolLike:$result); + + let hasFolder = 1; + + let assemblyFormat = "$lhs `,` $rhs attr-dict `->` type($result)"; +} + +def TFR_ConstOp : TFR_Op<"constant", [ConstantLike, NoSideEffect]> { + let description = [{ + The `attr` operation stores TF op's attribute, which doesn't support + arithmetic operations. + + Example: + + ```mlir + %1 = "tfr.constant"() { value: i32 } : () -> !tfr.attr + %2 = "tfr.constant"() { value: [i32, f32] } : () -> !tfr.attr + %3 = tfr.constant [i32, f32] -> !tfr.attr + %4 = tfr.constant f32 -> !tfr.attr + ``` + }]; + + let arguments = (ins TFR_allowedConstValues:$value); + + let results = (outs TFR_AttrType:$out); + + let hasFolder = 1; + + let builders = [OpBuilder< + "OpBuilder &, OperationState &state, Attribute value", + [{ + auto* ctx = value.getContext(); + state.addAttribute("value", value); + state.addTypes(TFRAttrType::get(ctx)); + }]> + ]; + + let assemblyFormat = [{ + $value attr-dict `->` type($out) + }]; +} + +def TFR_ConstantTensorOp : TFR_Op<"constant_tensor", [NoSideEffect]> { + let description = [{ + The `constant_tensor` operation converts the operand with non-built-in + tensor type to built-in tensor type or tfr.tensor type. If it is built-in + tensor type, the shape shouldn't be changed during the conversion. + + Example: + + ```mlir + %1 = tfr.contant_tensor(%0) : f32 -> tensor + %3 = tfr.contant_tensor(%2) : vector<1xf32> -> tensor<1xf32> + ``` + }]; + + let arguments = (ins TFR_AllAttrTypes:$arg); + + let results = (outs TFR_singleTensorType:$out); + + let hasCanonicalizer = 1; + + let verifier = [{ return Verify(*this); }]; +} + +def TFR_GetElementOp : TFR_Op<"get_element", [NoSideEffect]> { + let description = [{ + The `get_element` operation extracts one tfr.tensor element from a + tfr.tensor_list. + + Example: + + ```mlir + %2 = tfr.get_element %1[%0] : (tfr.tensor, index) -> tfr.tensor + ``` + }]; + + let arguments = (ins + TFR_TensorListType:$tensor_list, + Index:$index); + + let results = (outs TFR_TensorType:$out); + + let hasCanonicalizer = 1; + + let assemblyFormat = [{ + $tensor_list `[` $index `]` attr-dict `:` + `(` type($tensor_list) `,` type($index) `)` `->` type($out) + }]; +} + +def TFR_BuildListOp : TFR_Op<"build_list", [NoSideEffect]> { + let description = [{ + The `build_list` operation builds a tensor list from a list of tensors, or + an tfr.attr from a list of scalars. + + Example: + + ```mlir + %3 = tfr.build_list(%2, %1, %0) : + (tfr.tensor, tfr.tensor, tfr.tensor) -> tfr.tensor_list + %3 = tfr.build_list(%2, %1, %0) : (i32, i32, i32) -> tfr.attr + ``` + }]; + + let arguments = (ins Variadic:$tensors); + + let results = (outs TFR_allowedListResultType:$out); + + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// Function related classes +//===----------------------------------------------------------------------===// + +def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">, + DeclareOpInterfaceMethods, + FunctionLike, + IsolatedFromAbove, Symbol]> { + let summary = "TFR Function defines a composition of other ops"; + + let description = [{ + Defines a function that can be used to decompose an TF function call to + the invocation of a set of other TF ops. + + Syntax: + + ``` + op ::= `tfr.func` symbol-ref-id `(` argument-list `)` (`->` + function-result-list)? function-attributes? region + ``` + + Example: + + ```mlir + tfr.func @foo(%arg0: !tfr.tensor, %arg1: !tfr.tensor_list, + %arg2: int {tfr.name="T", tfr.default=1}) + attributes {qux: "quux"} { + tfr.return + } + ``` + + Note the arguments are ordered by the following rule: + tfr.tensor > tfr.tensor_list > tfr.attr/i32/..., + and only one trfr.tensor_list argument is allowed. + }]; + + let arguments = (ins + TypeAttr:$type, + StrAttr:$sym_name + ); + + let results = (outs); + + // When the regions is empty, the tfr.func is an external function and used + // to model the element type constraints of the tf op. Otherwise, there is one + // region containing the composition. + let regions = (region VariadicRegion:$body); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, StringRef name, " + "FunctionType type, ArrayRef attrs = {}"> + ]; + + let extraClassDeclaration = [{ + // FunctionLike trait needs access to the functions below. + friend class OpTrait::FunctionLike; + + // Hooks for the input/output type enumeration in FunctionLike . + unsigned getNumFuncArguments() { return getType().getNumInputs(); } + unsigned getNumFuncResults() { return getType().getNumResults(); } + }]; + + let verifier = [{ return Verify(*this); }]; + let parser = [{ return ParseFuncOp(parser, &result); }]; + let printer = [{ PrintFuncOp(p, *this); }]; +} + +def TFR_TFRReturnOp : TFR_Op<"return", [HasParent<"TFRFuncOp">, NoSideEffect, + ReturnLike, Terminator]> { + let description = [{ + A terminator operation for regions that appear in the body of `tfr.func` + functions. The operands to the `tfr.return` are the result values returned + by an invocation of the `tfr.func`. + + Note that only the tfr.tensor and tfr.tensor_list can be returned. + }]; + + let arguments = (ins Variadic:$operands); + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} + +#endif // DIALECT_TFR_OPS_ diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_types.h b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h new file mode 100644 index 00000000000..4bda8f34658 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h @@ -0,0 +1,115 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeSupport.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +class TFRType : public Type { + public: + using Type::Type; + + static bool classof(Type type); +}; + +namespace detail { + +struct TFRTypeStorage final + : public TypeStorage, + public llvm::TrailingObjects { + using KeyTy = ArrayRef; + + explicit TFRTypeStorage(unsigned num_attrs) : num_attrs(num_attrs) {} + + static TFRTypeStorage* construct(TypeStorageAllocator& allocator, KeyTy key) { + // Allocate a new storage instance. + auto byteSize = TFRTypeStorage::totalSizeToAlloc(key.size()); + auto rawMem = allocator.allocate(byteSize, alignof(TFRTypeStorage)); + auto result = ::new (rawMem) TFRTypeStorage(key.size()); + + // Copy in the string attributes into the trailing storage. + std::uninitialized_copy(key.begin(), key.end(), + result->getTrailingObjects()); + return result; + } + + bool operator==(const KeyTy& attrs) const { return attrs == GetAttrs(); } + + KeyTy GetAttrs() const { + return {getTrailingObjects(), num_attrs}; + } + + unsigned num_attrs; +}; + +template +class TFRTypeImpl : public Type::TypeBase { + public: + using Base = Type::TypeBase; + using TFRBase = TFRTypeImpl; + using Base::Base; + + static Derived get(ArrayRef attrs, MLIRContext* context) { + return Base::get(context, attrs); + } + + static Derived getChecked(ArrayRef attrs, Location loc) { + return Base::getChecked(loc, attrs); + } + + static Derived get(MLIRContext* context) { return get({}, context); } + + // TODO(fengliuai): fix the implementation + static LogicalResult verifyConstructionInvariants( + Location loc, ArrayRef attrs) { + return success(); + } + + ArrayRef getAttrKeys() { return Base::getImpl()->GetAttrs(); } +}; +} // namespace detail + +class TFRTensorType : public detail::TFRTypeImpl { + public: + using TFRBase::TFRBase; + static std::string getTypeName() { return "TFRTensorType"; } +}; + +class TFRTensorListType : public detail::TFRTypeImpl { + public: + using TFRBase::TFRBase; + static std::string getTypeName() { return "TFRTensorListType"; } +}; + +class TFRAttrType : public Type::TypeBase { + public: + using Base::Base; + static std::string getTypeName() { return "TFRAttrType"; } +}; + +} // namespace TFR +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ diff --git a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc new file mode 100644 index 00000000000..d399a10a35e --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc @@ -0,0 +1,160 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "mlir/Transforms/LoopUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/compiler/mlir/tfr/passes/passes.h" + +//===----------------------------------------------------------------------===// +// Canonicalization patterns for the scf.for and scf.if ops. They are used to +// optimize the control flow in the tfr function. Technically, both patterns +// should be upstreamed to be part of the op definition. +// TODO(fengliuai): sync with the llvm upstream for both patterns. +// +namespace mlir { +namespace TFR { + +namespace { + +struct UnrollSCFForOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::ForOp for_op, + PatternRewriter &rewriter) const override { + Location loc = for_op.getLoc(); + APInt lower_bound, upper_bound, step; + if (!matchPattern(for_op.lowerBound(), m_ConstantInt(&lower_bound)) || + !matchPattern(for_op.upperBound(), m_ConstantInt(&upper_bound)) || + !matchPattern(for_op.step(), m_ConstantInt(&step))) { + return failure(); + } + uint64_t trip_count = (upper_bound - lower_bound).sdiv(step).getZExtValue(); + if (trip_count <= 0) return failure(); + + // TODO(fengliuai): use loopUnrollByFactor once the iter_arg is supported + + Block *single_block = for_op.getBody(); + BlockAndValueMapping mapping; + Value iv = for_op.getInductionVar(); + for (auto iter_op : + llvm::zip(for_op.getRegionIterArgs(), for_op.initArgs())) { + mapping.map(std::get<0>(iter_op), std::get<1>(iter_op)); + } + mapping.map(iv, for_op.lowerBound()); + for (auto i = 0; i < trip_count; ++i) { + if (!iv.use_empty()) { + // iv' = iv + step * i; + Value iter = rewriter.create(loc, i); + Value step_cst = + rewriter.create(loc, step.getSExtValue()); + Value stride = rewriter.create(loc, step_cst, iter); + Value iv_unroll = + rewriter.create(loc, mapping.lookup(iv), stride); + mapping.map(iv, iv_unroll); + } + + Operation *terminator_op; + for (auto it = single_block->begin(); it != single_block->end(); ++it) { + terminator_op = rewriter.clone(*it, mapping); + } + // Map the block arguments to the yield results. + for (auto iter_op : llvm::zip(for_op.getRegionIterArgs(), + terminator_op->getOperands())) { + mapping.map(std::get<0>(iter_op), std::get<1>(iter_op)); + } + rewriter.eraseOp(terminator_op); + } + SmallVector returned; + for (Value arg : for_op.getRegionIterArgs()) { + returned.push_back(mapping.lookup(arg)); + } + rewriter.replaceOp(for_op, returned); + return success(); + } +}; + +// TODO(fengliuai): up stream this pattern. +struct SimplifySCFIfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::IfOp if_op, + PatternRewriter &rewriter) const override { + // Then branch + if (matchPattern(if_op.condition(), m_NonZero())) { + return InlineRegion(if_op.getLoc(), rewriter, if_op, &if_op.thenRegion()); + } + + // Else branch + if (matchPattern(if_op.condition(), m_Zero())) { + if (if_op.elseRegion().empty()) { + // Remove the op + rewriter.eraseOp(if_op); + return success(); + } else { + return InlineRegion(if_op.getLoc(), rewriter, if_op, + &if_op.elseRegion()); + } + } + + // Not a constant condition + return failure(); + } + + private: + LogicalResult InlineRegion(Location loc, PatternRewriter &rewriter, + Operation *inline_point, Region *region) const; +}; + +LogicalResult SimplifySCFIfOp::InlineRegion(Location loc, + PatternRewriter &rewriter, + Operation *inline_point, + Region *region) const { + InlinerInterface interface(loc.getContext()); + if (failed(inlineRegion(interface, region, inline_point, {}, + inline_point->getResults(), loc, + /*shouldCloneInlinedRegion=*/true))) { + return failure(); + } + + // If the inlining was successful then erase the scf.if op. + rewriter.eraseOp(inline_point); + return success(); +} + +} // namespace + +void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc new file mode 100644 index 00000000000..9265437cca9 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -0,0 +1,280 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" +#include "tensorflow/compiler/mlir/tfr/passes/passes.h" +#include "tensorflow/compiler/mlir/tfr/utils/utils.h" + +//===----------------------------------------------------------------------===// +// The pass to decompose unregistered TF ops with the TFR compose function. +// +namespace mlir { +namespace TFR { + +namespace { + +// Decompose the TF ops with the registered composition library. +struct DecomposeTFOpsPass + : public PassWrapper { + + explicit DecomposeTFOpsPass(llvm::Optional external_tfr_module) + : external_tfr_module(external_tfr_module) {} + + void runOnFunction() override; + + private: + // Apply canonicalization, mainly constant folding, on the function. + void ApplyCanonicalization(); + + // Rewrite unregistered TF ops to TFR func call ops. Return failure if all the + // ops are registered or the compose function doesn't exist. + LogicalResult RewriteUnregisteredTFOps(); + + // Inline the TFR func call ops. + LogicalResult InlineTFRFuncCalls(); + + // Optional external symbol table to look up the TFR function. + llvm::Optional external_tfr_module; +}; + +void DecomposeTFOpsPass::ApplyCanonicalization() { + OwningRewritePatternList patterns; + + auto* context = &getContext(); + for (auto* op : context->getRegisteredOperations()) { + op->getCanonicalizationPatterns(patterns, context); + } + populateSCFOpsCanonicalizationPatterns(patterns, context); + + applyPatternsAndFoldGreedily(getFunction(), patterns); +} + +LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { + FuncOp func = getFunction(); + SymbolTable table(external_tfr_module.hasValue() + ? *external_tfr_module + : func.getParentOfType()); + OpBuilder builder(func); + bool changed = false; + func.walk([&table, &builder, &changed](Operation* op) { + // Only the un-registered ops requires decomposition. The remaining ones + // either will be constant folded or lowered by the rules defined in the + // bridge. + if (op->isRegistered()) { + return; + } + + // Find out the compose function + auto compose_func_name = GetComposeFuncName(op->getName().getStringRef()); + auto compose_func = table.lookup(compose_func_name); + if (!compose_func || compose_func.isExternal()) { + // There are no decomposition methods defined for this op, skip. + return; + } + + auto compose_func_type = compose_func.getType(); + builder.setInsertionPoint(op); + TFRTensorType unconstrainted_tensor_type = builder.getType(); + + // Create the new operands. This is mapping the operands from the target + // TF ops to the TFR function arguments. If the TFR function argument is + // a tensor_list, a "tfr.build_list" op is used to concat the available + // TF op operands. If the TFR function argument isn't a tensor/tensor_list, + // a constant is created by using the attribute stored in the TF op or the + // default value in the argument attribute. + llvm::SmallVector new_operands; + for (auto arg : llvm::enumerate(compose_func_type.getInputs())) { + if (auto tensor_type = arg.value().dyn_cast()) { + auto casted = builder.create(op->getLoc(), tensor_type, + op->getOperand(arg.index())); + new_operands.push_back(casted); + } else if (auto list_type = arg.value().dyn_cast()) { + llvm::SmallVector variadic_operands; + for (int i = arg.index(); i < op->getNumOperands(); i++) { + auto casted = builder.create( + op->getLoc(), unconstrainted_tensor_type, op->getOperand(i)); + variadic_operands.push_back(casted); + } + auto build_list_op = builder.create( + op->getLoc(), list_type, variadic_operands); + new_operands.push_back(build_list_op.out()); + } else { + auto attr_name = compose_func.getArgAttrOfType( + arg.index(), kAttrArgumentNameAttr); + auto attribute = op->getAttr(attr_name.getValue()); + if (!attribute) { + attribute = + compose_func.getArgAttr(arg.index(), kAttrArgumentDefaultAttr); + } + Value attr_cst; + // Wrap these special attributes as a special TFR constant, so the SSA + // value has a valid type to be used as TFR function argument. These + // attributes are not expected to be manipulated by the lowering passes. + if (attribute.isa() || attribute.isa() || + attribute.isa() || attribute.isa()) { + TFRAttrType output_type = TFRAttrType::get(builder.getContext()); + attr_cst = + builder.create(op->getLoc(), output_type, attribute); + } else { + attr_cst = builder.create(op->getLoc(), attribute); + } + new_operands.push_back(attr_cst); + } + } + + // Create the TFR call op + auto new_op = builder.create( + op->getLoc(), compose_func_type.getResults(), + builder.getSymbolRefAttr(compose_func.getName()), new_operands); + + // Replace the use of the old op. This is mapping the results from the + // target TF ops to the TFR function returns. If the TFR function return is + // a tensor_list, "tfr.get_element" op is used to extract the required TF + // op result. + llvm::SmallVector new_results; + for (auto res : llvm::enumerate(compose_func_type.getResults())) { + if (res.value().dyn_cast()) { + new_results.push_back(new_op.getResult(res.index())); + } else if (auto list_type = res.value().dyn_cast()) { + for (int i = res.index(), j = 0; i < op->getNumResults(); i++, j++) { + auto index = + builder.create(op->getLoc(), builder.getIndexAttr(j)); + auto element_op = builder.create( + op->getLoc(), unconstrainted_tensor_type, + new_op.getResult(res.index()), index.getResult()); + new_results.push_back(element_op.out()); + } + } + } + for (auto res : llvm::zip(op->getResults(), new_results)) { + auto casted = builder.create( + op->getLoc(), std::get<0>(res).getType(), std::get<1>(res)); + std::get<0>(res).replaceAllUsesWith(casted.out()); + } + op->erase(); + changed |= true; + }); + + // If `changed` is false, it is considered as a failure, so the recursive + // rewrite will stop. + return success(changed); +} + +LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { + // The Inliner will automatically use the registered dialect inliner. + InlinerInterface inliner(&getContext()); + FuncOp func = getFunction(); + SymbolTable table(external_tfr_module.hasValue() + ? *external_tfr_module + : func.getParentOfType()); + + // The inliner only inlines the TFR call op. + bool changed = false; + auto walk_result = func.walk([&](CallOp call_op) { + auto callee = table.lookup(call_op.callee()); + if (!callee || callee.isExternal()) return WalkResult::advance(); + if (failed(inlineCall(inliner, + cast(call_op.getOperation()), + cast(callee.getOperation()), + callee.getCallableRegion(), + /**shouldCloneInLinedRegion=*/true))) { + // This failure is usually because the decompose function is not defined. + // This call will be raised to TF ops. + return WalkResult::interrupt(); + } + call_op.erase(); + changed |= true; + return WalkResult::advance(); + }); + + if (walk_result.wasInterrupted()) { + signalPassFailure(); + return failure(); + } + + // If `changed` is false, it is considered as a failure, so the recursive + // rewrite will stop. + return success(changed); +} + +void DecomposeTFOpsPass::runOnFunction() { + // Set a maximum iteration threshold in case there are infinite loops in the + // call stack. + int max_iterators = 10; + do { + // canonicalization + ApplyCanonicalization(); + + // rewrite unregistered tf ops. Failed either because no ops can be + // decomposed or the compose function isn't defined. + auto rewrite_status = RewriteUnregisteredTFOps(); + // inline the tfr call op until there are no tfr.call op can be inlined. + auto inline_status = InlineTFRFuncCalls(); + + if (failed(rewrite_status) && failed(inline_status)) { + break; + } + } while (max_iterators-- >= 0); +} + +} // namespace + +// Creates an instance of the pass to decompose the TF ops. +std::unique_ptr> CreateDecomposeTFOpsPass( + llvm::Optional tfr_module) { + return std::make_unique(tfr_module); +} + +static PassRegistration pass( + "tfr-decompose", + "Decompose TF ops with the registered composition library.", + [] { return CreateDecomposeTFOpsPass(); }); + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/passes/passes.h b/tensorflow/compiler/mlir/tfr/passes/passes.h new file mode 100644 index 00000000000..5c27d81ace8 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/passes.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_ + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + +// Decompose ops. +std::unique_ptr> CreateDecomposeTFOpsPass( + llvm::Optional tfr_module = llvm::None); + +// Raise to TF ops. +std::unique_ptr> CreateRaiseToTFOpsPass( + llvm::Optional tfr_module = llvm::None, + bool materialize_derived_attrs = false); + +} // namespace TFR +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc new file mode 100644 index 00000000000..f3fe9618c62 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc @@ -0,0 +1,474 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" +#include "tensorflow/compiler/mlir/tfr/passes/passes.h" +#include "tensorflow/compiler/mlir/tfr/utils/utils.h" + +//===----------------------------------------------------------------------===// +// The pass to rewrite the TFR function call ops by TF ops. The callee of the +// TFR function call defines the signatures of the TF ops. +// +namespace mlir { +namespace TFR { + +namespace { + +// This pattern is to rewrite the "tfr.call" op and the "tfr.cast" ops on the +// operands by a TF op with "tfr.cast" ops on the results. The result type of +// the new TF op is an unranked tensor with element type derived. +class RewriteTFRCallOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + public: + explicit RewriteTFRCallOp(MLIRContext* context, const SymbolTable& table, + bool materialize_derived_attrs) + : OpRewritePattern(context), + symbol_table_(table), + materialize_derived_attrs_(materialize_derived_attrs) {} + + LogicalResult matchAndRewrite(CallOp call_op, + PatternRewriter& rewriter) const override; + + private: + // Derives the attribute values for the attributes attached to the + // `input_tfr_type`. These attributes are only for the element type of the + // inputs, and these type information has been collected in the `input_types`. + // The result is stored in `derived_attrs` as the named attributes. Returns + // failure if the attributes stored in the `input_tfr_type` violates the + // assumptions. + LogicalResult AddDerivedAttrs( + PatternRewriter& rewriter, Type input_tfr_type, + ArrayRef input_types, + llvm::StringMap* derived_attrs) const; + + // Collects the operands and attributes for the TF op. At the same time, it + // collects all the derived attribute values to derive the output types of the + // TF op. + LogicalResult CollectInputsAndAttributes( + PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op, + SmallVectorImpl* inputs, NamedAttrList* arg_attrs, + llvm::StringMap* derived_attrs) const; + + // Uses the collected attribute values to derive all the output types. + LogicalResult DeriveOutputTypes(FunctionType signature, + const llvm::StringMap& attrs, + SmallVectorImpl* output_types) const; + + // Creates the TF op and also the necessary tfr.cast ops to replace the + // original TFR call op. + LogicalResult CreateAndReplaceOp( + PatternRewriter& rewriter, CallOp call_op, + const SmallVectorImpl& output_types, + const SmallVectorImpl& inputs, const NamedAttrList& attr_list, + const llvm::StringMap& derived_attrs) const; + + // Adds a tf.Cast op if the tfr.tensor attribute indicated a fixed element + // type. + // TODO(fengliuai): This method is required when the operand types are not set + // by the frontend correctly. + Value CastToNonDerivedType(PatternRewriter& rewriter, Location loc, + CastOp cast_op, Type input_tfr_type) const { + auto tensor_type = input_tfr_type.dyn_cast(); + if (!tensor_type) return cast_op.arg(); + + auto attr_names = tensor_type.getAttrKeys(); + if (attr_names.empty() || attr_names.size() > 1) return cast_op.arg(); + StringRef tfr_type_attr = attr_names[0].getValue(); + if (!fixed_elt_type_attrs_.contains(tfr_type_attr)) return cast_op.arg(); + + Type result_elt_type; + if (tfr_type_attr == "i32_") { + result_elt_type = rewriter.getI32Type(); + } else if (tfr_type_attr == "i64_") { + result_elt_type = rewriter.getI64Type(); + } else if (tfr_type_attr == "f32_") { + result_elt_type = rewriter.getF32Type(); + } else if (tfr_type_attr == "i1_") { + result_elt_type = rewriter.getI1Type(); + } else { + return cast_op.arg(); + } + + Type original_input_type = + cast_op.getInputElementType().cast().getValue(); + if (result_elt_type != original_input_type) { + UnrankedTensorType result_type = UnrankedTensorType::get(result_elt_type); + return rewriter.create(loc, result_type, cast_op.arg()); + } + return cast_op.arg(); + } + + // For variadic operands, we have to enforce them to use the same types. + // TODO(fengliuai): This method is required when the operand types are not set + // by the frontend correctly. + void CastValuesToSameType(PatternRewriter& rewriter, Location loc, + const llvm::SmallVectorImpl& input_types, + llvm::SmallVectorImpl& input_values) const { + if (input_types.size() <= 1) return; + + Type target_input_type = input_types[0].cast().getValue(); + auto result_type = UnrankedTensorType::get(target_input_type); + for (auto i = 1; i < input_types.size(); ++i) { + Type current_input_type = input_types[i].cast().getValue(); + if (current_input_type != target_input_type) { + input_values[i] = + rewriter.create(loc, result_type, input_values[i]); + } + } + } + + const SymbolTable& symbol_table_; + const bool materialize_derived_attrs_; + const llvm::SmallDenseSet fixed_elt_type_attrs_{"i32_", "i64_", + "f32_", "i1_"}; +}; + +LogicalResult RewriteTFRCallOp::AddDerivedAttrs( + PatternRewriter& rewriter, Type input_tfr_type, + ArrayRef input_types, + llvm::StringMap* derived_attrs) const { + // If there is an attribute associated to the input in the signature, we + // store it as an derived attribute. + if (auto tensor_type = input_tfr_type.dyn_cast()) { + auto attr_names = tensor_type.getAttrKeys(); + if (attr_names.empty()) return success(); + + if (attr_names.size() == 1) { + derived_attrs->insert({attr_names[0].getValue(), input_types[0]}); + return success(); + } + } + + // If there is an attribute associated to the input in the signature, + // we store it as an derived attribute. + if (auto list_type = input_tfr_type.dyn_cast()) { + auto attr_names = list_type.getAttrKeys(); + if (attr_names.empty()) return success(); + + // N*T case + if (attr_names.size() == 2) { + derived_attrs->insert({attr_names[0].getValue(), + rewriter.getI32IntegerAttr(input_types.size())}); + // Note that this uses the first element of the list to infer the T value. + // A tf.Cast is required to cast the other inputs to the same type. + derived_attrs->insert({attr_names[1].getValue(), input_types[0]}); + return success(); + } + + // list(dtype) case + if (attr_names.size() == 1) { + derived_attrs->insert( + {attr_names[0].getValue(), rewriter.getArrayAttr(input_types)}); + return success(); + } + } + + return failure(); +} + +LogicalResult RewriteTFRCallOp::CollectInputsAndAttributes( + PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op, + SmallVectorImpl* inputs, NamedAttrList* arg_attrs, + llvm::StringMap* derived_attrs) const { + for (const auto& operand : llvm::enumerate(signature.getType().getInputs())) { + // If the index is larger than the operand number of the call_op, the + // default value of the operand needs to be used. + if (operand.index() >= call_op.getNumOperands()) { + auto attr_name = signature.getArgAttrOfType( + operand.index(), kAttrArgumentNameAttr); + auto attr_value = + signature.getArgAttr(operand.index(), kAttrArgumentDefaultAttr); + arg_attrs->push_back( + rewriter.getNamedAttr(attr_name.getValue(), attr_value)); + continue; + } + + // The index is valid for the call_op. + Value input = call_op.getOperand(operand.index()); + Operation* input_op = input.getDefiningOp(); + auto input_tfr_type = signature.getType().getInputs()[operand.index()]; + + // There are three cases for the preceding input_op: + + // 1. The preceding op can be a tfr.cast op, which will be fused to the + // current op, so the result op has input with tensor type. + if (auto cast_op = dyn_cast_or_null(input_op)) { + Value input_to_cast = CastToNonDerivedType(rewriter, call_op.getLoc(), + cast_op, input_tfr_type); + inputs->push_back(input_to_cast); + if (failed(AddDerivedAttrs(rewriter, input_tfr_type, + {cast_op.getInputElementType()}, + derived_attrs))) { + return failure(); + } + continue; + } + + // 2. The preceding op is a tfr.build_list op, which collects multiple + // values with tensor types via the tfr.cast ops. These ops will be fused + // to the current op as well, so all the tfr.cast op inputs will be inputs + // to the result op. + if (auto list_op = dyn_cast_or_null(input_op)) { + // Find out all the inputs to the build list op + // TODO(fengliuai): make build_list op only take tensor argument + llvm::SmallVector list_input_types; + llvm::SmallVector list_inputs; + for (auto list_input : list_op.getOperands()) { + auto cast_op = dyn_cast_or_null(list_input.getDefiningOp()); + if (!cast_op) return failure(); + list_inputs.push_back(cast_op.arg()); + list_input_types.push_back(cast_op.getInputElementType()); + } + CastValuesToSameType(rewriter, call_op.getLoc(), list_input_types, + list_inputs); + inputs->append(list_inputs.begin(), list_inputs.end()); + if (failed(AddDerivedAttrs(rewriter, input_tfr_type, list_input_types, + derived_attrs))) { + return failure(); + } + continue; + } + + // 3. The preceding op is a constant, thus the value of this constant is + // used to create an attribute of the result op, according to the signature. + Attribute arg_value; + // A failure indicates the argument isn't a constant value, so we should + // not use it as an attribute. + if (!matchPattern(input, m_Constant(&arg_value))) { + return failure(); + } + auto attr_name = signature.getArgAttrOfType( + operand.index(), kAttrArgumentNameAttr); + arg_attrs->push_back( + rewriter.getNamedAttr(attr_name.getValue(), arg_value)); + } + return success(); +} + +// For each output, uses the attribute name associated to the tfr types to find +// out the attribute value from the collected `attrs` and create the output type +// of the result op by using the attribute value as the element type. +LogicalResult RewriteTFRCallOp::DeriveOutputTypes( + FunctionType signature, const llvm::StringMap& attrs, + SmallVectorImpl* output_types) const { + for (auto res : llvm::enumerate(signature.getResults())) { + if (auto tensor_type = res.value().dyn_cast()) { + // tfr.tensor should only have one attribute attached. + auto attr_key = tensor_type.getAttrKeys().front(); + output_types->push_back(UnrankedTensorType::get( + attrs.lookup(attr_key.getValue()).cast().getValue())); + continue; + } + + if (auto list_type = res.value().dyn_cast()) { + // There are two cases: N*T or list(dtype) + auto attr_keys = list_type.getAttrKeys(); + // N*T case + if (attr_keys.size() == 2) { + // The first one is N, and the second one is T + int list_size = + attrs.lookup(attr_keys[0].getValue()).cast().getInt(); + Type list_type = + attrs.lookup(attr_keys[1].getValue()).cast().getValue(); + for (int i = 0; i < list_size; ++i) { + output_types->push_back(UnrankedTensorType::get(list_type)); + } + continue; + } + // TODO(fengliuai): list(dtype) case + } + return failure(); + } + return success(); +} + +LogicalResult RewriteTFRCallOp::CreateAndReplaceOp( + PatternRewriter& rewriter, CallOp call_op, + const SmallVectorImpl& output_types, + const SmallVectorImpl& inputs, const NamedAttrList& attr_list, + const llvm::StringMap& derived_attrs) const { + // Create the new op + Location loc = call_op.getLoc(); + rewriter.setInsertionPointAfter(call_op); + std::string tf_op_name = GetTFOpName(call_op.callee()); + OperationState new_state(loc, tf_op_name, inputs, output_types, attr_list); + Operation* new_op = rewriter.createOperation(new_state); + if (materialize_derived_attrs_) { + for (const auto& attr : derived_attrs) { + // Add or update the derived attribute with the value. Skip the fixed + // element type attributes, in case they are present in the NodeDef. + if (!fixed_elt_type_attrs_.contains(attr.first())) { + new_op->setAttr(attr.first(), attr.second); + } + } + } + + // Create the tfr.cast ops on the results and replace the uses of the + // original call op. + TFRTensorType unconstrainted_type = rewriter.getType(); + SmallVector new_results; + for (auto res : llvm::enumerate(call_op.getResultTypes())) { + Type res_type = res.value(); + if (res_type.dyn_cast()) { + Value new_res = new_op->getResult(res.index()); + auto casted = rewriter.create(loc, res_type, new_res); + new_results.push_back(casted.out()); + } else if (auto list_type = res.value().dyn_cast()) { + SmallVector tensor_list; + for (int i = res.index(); i < new_op->getNumResults(); i++) { + Value new_res = new_op->getResult(i); + auto casted = + rewriter.create(loc, unconstrainted_type, new_res); + tensor_list.push_back(casted.out()); + } + auto list_op = rewriter.create(loc, res_type, tensor_list); + new_results.push_back(list_op.out()); + } + } + rewriter.replaceOp(call_op, new_results); + return success(); +} + +LogicalResult RewriteTFRCallOp::matchAndRewrite( + CallOp call_op, PatternRewriter& rewriter) const { + // Get the func op and verify that it is external. The type of this external + // func op is used as the signature of the corresponding TF ops. All the + // external func ops have the trailing underscore. + std::string external_callee_name = call_op.callee().str().append("_"); + TFRFuncOp func = symbol_table_.lookup(external_callee_name); + if (!func || !func.isExternal()) return failure(); + // Get the inputs and attributes. The attributes include these from the + // argument list and also these derived from the inputs. + SmallVector inputs; + NamedAttrList argument_attrs; + llvm::StringMap derived_attrs; + if (failed(CollectInputsAndAttributes(rewriter, func, call_op, &inputs, + &argument_attrs, &derived_attrs))) { + return failure(); + } + + // Derive the output types. The result type is derived by using the + // attributes attched to the result type of the signature. The attribute + // value should be either in the attribute argument list or the derived + // attribute from the input tensors. All the result type + // are unranked, and shape inference should be applied afterwards. + SmallVector output_types; + + // Merge the attributes from the argument list to the derived ones. + for (auto& attr : argument_attrs) { + derived_attrs.insert({attr.first, attr.second}); + } + + // Derive the output types by using the attributes attached to the tfr + // types. + if (failed(DeriveOutputTypes(func.getType(), derived_attrs, &output_types))) { + return failure(); + } + + // Create the new op and replace the old TFR call op. + return CreateAndReplaceOp(rewriter, call_op, output_types, inputs, + argument_attrs, derived_attrs); +} + +// Raise TFR call ops to the TF ops. +struct RaiseToTFOpsPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + explicit RaiseToTFOpsPass(llvm::Optional tfr_module, + bool materialize_derived_attrs) + : external_tfr_module(tfr_module), + materialize_derived_attrs(materialize_derived_attrs) {} + + void runOnFunction() override; + + private: + llvm::Optional external_tfr_module; + const bool materialize_derived_attrs; +}; + +void RaiseToTFOpsPass::runOnFunction() { + FuncOp func = getFunction(); + MLIRContext* ctx = &getContext(); + SymbolTable table(external_tfr_module.hasValue() + ? *external_tfr_module + : func.getParentOfType()); + + OwningRewritePatternList patterns; + patterns.insert(ctx, table, materialize_derived_attrs); + for (auto* op : ctx->getRegisteredOperations()) { + op->getCanonicalizationPatterns(patterns, ctx); + } + + applyPatternsAndFoldGreedily(func, patterns); +} +} // namespace + +// Creates an instance of the pass to raise TFR call ops to the TF ops. +std::unique_ptr> CreateRaiseToTFOpsPass( + llvm::Optional tfr_module, bool materialize_derived_attrs) { + return std::make_unique(tfr_module, + materialize_derived_attrs); +} + +static PassRegistration pass( + "tfr-raise-to-tf", "Raise all the TFR call ops to TF ops.", + [] { return CreateRaiseToTFOpsPass(); }); + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc b/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc new file mode 100644 index 00000000000..8f06f278369 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" + +int main(int argc, char **argv) { + tensorflow::InitMlir y(&argc, &argv); + + mlir::registerAllPasses(); + + mlir::DialectRegistry registry; + registry.insert(); + return failed(mlir::MlirOptMain(argc, argv, "TFR Pass Driver\n", registry)); +} diff --git a/tensorflow/compiler/mlir/tfr/resources/BUILD b/tensorflow/compiler/mlir/tfr/resources/BUILD new file mode 100644 index 00000000000..6abb3917e57 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/resources/BUILD @@ -0,0 +1,61 @@ +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + +package( + default_visibility = [ + ":friends", + ], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["//third_party/mlir:subpackages"], + packages = [ + "//learning/brain/experimental/mlir/tfr/...", + "//tensorflow/compiler/mlir/...", + ], +) + +filegroup( + name = "decomposition_lib", + srcs = ["decomposition_lib.mlir"], +) + +cc_library( + name = "composite_ops_cc", + srcs = ["composite_ops.cc"], + copts = [ + "-Wno-unused-result", + "-Wno-unused-variable", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +tf_gen_op_wrapper_py( + name = "composite_ops", + out = "composite_ops.py", + deps = [ + ":composite_ops_cc", + ], +) + +cc_library( + name = "test_ops_cc", + testonly = 1, + srcs = ["test_ops.cc"], + copts = [ + "-Wno-unused-result", + "-Wno-unused-variable", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc b/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc new file mode 100644 index 00000000000..8120625bc89 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("MyAddN") + .Input("inputs: N * T") + .Output("sum: T") + .Attr("N: int >= 1") + .Attr("T: {numbertype, variant}") + .SetIsCommutative() + .SetIsAggregate(); + +REGISTER_OP("MyBiasedDense") + .Input("input: T") + .Input("weight: T") + .Input("bias: T") + .Output("out: T") + .Attr("T: {float, int8}") + .Attr("act: {'', 'relu', 'relu6'} = ''"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir b/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir new file mode 100644 index 00000000000..f67d24c9fec --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir @@ -0,0 +1,109 @@ +// A test resource file which contains some pre-defined internal tfr.functions +// for decomposition and external tfr.functions for raising the decomposition +// result to the ops in the TF dialect. +// +// All the tfr.func functions are supposed to be translated from the Python +// function with tf.composite annotation. +// All the external tfr.func functions modeles the op signature defined by +// OpDefs. + +tfr.func @tf__my_add_n(%values: !tfr.tensor_list, + %n: i64 {tfr.name="N"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i64 + %eq = cmpi "eq", %n, %cst : i64 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %v1 : !tfr.tensor + } else { + %step = index_cast %cst : i64 to index + %end = index_cast %n : i64 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %reduce_next = tfr.call @tf__add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +// Translated from tf.compose Python function. +tfr.func @tf__my_biased_dense(%input: !tfr.tensor, %weight: !tfr.tensor, + %bias: !tfr.tensor, + %act: !tfr.attr{tfr.name="act", tfr.default=""}) -> !tfr.tensor { + %dot = tfr.call @tf__mat_mul(%input, %weight) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + %add = tfr.call @tf__add(%dot, %bias) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + + %relu = tfr.constant "relu" -> !tfr.attr + %relu6 = tfr.constant "relu6" -> !tfr.attr + + %is_relu = tfr.equal %act, %relu -> i1 + %res = scf.if %is_relu -> !tfr.tensor { + %applied_relu = tfr.call @tf__relu(%add) : (!tfr.tensor) -> !tfr.tensor + scf.yield %applied_relu : !tfr.tensor + } else { + %is_relu6 = tfr.equal %act, %relu6 -> i1 + %res1 = scf.if %is_relu6 -> !tfr.tensor { + %applied_relu6 = tfr.call @tf__relu6(%add) : (!tfr.tensor) -> !tfr.tensor + scf.yield %applied_relu6 : !tfr.tensor + } else { + scf.yield %add : !tfr.tensor + } + scf.yield %res1 : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +// This is a wong decomposition and used to verify that tf.Elu isn't decomposed +// since its kernel has been registered. +tfr.func @tf__elu_(%input: !tfr.tensor) -> !tfr.tensor { + tfr.return %input : !tfr.tensor +} + +// Translated from: +// +// REGISTER_OP("Add") +// .Input("x: T") +// .Input("y: T") +// .Output("z: T") +// .Attr( +// "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " +// "complex64, complex128, string}") +tfr.func @tf__add_(!tfr.tensor, !tfr.tensor) + -> !tfr.tensor attributes{T} + +// Translated from: +// +// REGISTER_OP("MatMul") +// .Input("a: T") +// .Input("b: T") +// .Output("product: T") +// .Attr("transpose_a: bool = false") +// .Attr("transpose_b: bool = false") +// .Attr("T: {bfloat16, half, float, double, int32, int64, complex64, complex128}") +// T is a derived attribute. +// transpose_a and transpose_b is materialized attributes. +tfr.func @tf__mat_mul_(!tfr.tensor, !tfr.tensor, + i1 {tfr.name="transpose_a", tfr.default=false}, + i1 {tfr.name="transpose_b", tfr.default=false}) + -> !tfr.tensor attributes{T} + +// Translated from: +// +// REGISTER_OP("Relu") +// .Input("features: T") +// .Output("activations: T") +// .Attr("T: {realnumbertype, qint8}") +// T is a derived attribute. +tfr.func @tf__relu_(!tfr.tensor) -> !tfr.tensor attributes{T} + + +// Translated from: +// +// REGISTER_OP("Relu6") +// .Input("features: T") +// .Output("activations: T") +// .Attr("T: {realnumbertype}") +// T is a derived attribute. +tfr.func @tf__relu6_(!tfr.tensor) -> !tfr.tensor attributes{T} diff --git a/tensorflow/compiler/mlir/tfr/resources/test_ops.cc b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc new file mode 100644 index 00000000000..bff7fe0c18c --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("TestNoOp"); + +REGISTER_OP("TestIdentityOp") + .Input("input: T") + .Output("output: T") + .Attr("T: numbertype"); + +REGISTER_OP("TestIdentityNOp") + .Input("input: N * T") + .Output("output: N * T") + .Attr("N: int >= 1") + .Attr("T: numbertype"); + +REGISTER_OP("TestInputNOp") + .Input("input: N * T") + .Output("output: T") + .Attr("N: int >= 1") + .Attr("T: numbertype"); + +REGISTER_OP("TestOutputNOp") + .Input("input: T") + .Output("output: N * T") + .Attr("N: int >= 1") + .Attr("T: numbertype"); + +REGISTER_OP("TestTwoInputsOp") + .Input("lhs: T") + .Input("rhs: T") + .Output("output: T") + .Attr("T: numbertype") + .Attr("pred: bool = false"); + +REGISTER_OP("TestNumAttrsOp") + .Attr("x1: int = -10") + .Attr("y1: int = 1") + .Attr("x2: float = 0.0") + .Attr("y2: float = -3.0"); + +REGISTER_OP("TestNonNumAttrsOp") + .Attr("z: shape") + .Attr("x: string = 'hello'") + .Attr("y: type = DT_FLOAT"); + +REGISTER_OP("TestThreeInputsOp") + .Input("x: T") + .Input("y: T") + .Input("z: T") + .Output("output: T") + .Attr("T: numbertype") + .Attr("act: {'x', 'y', 'z'} = 'z'"); + +REGISTER_OP("TestTwoOutputsOp") + .Input("input: T") + .Output("output1: T") + .Output("output2: T") + .Attr("T: numbertype"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/tests/control_flow.mlir b/tensorflow/compiler/mlir/tfr/tests/control_flow.mlir new file mode 100644 index 00000000000..8dacd57653f --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/control_flow.mlir @@ -0,0 +1,57 @@ +// RUN: tfr-opt %s -tfr-decompose -verify-diagnostics -split-input-file | FileCheck %s + +tfr.func @tf__my_pack(%values: !tfr.tensor_list, + %n: i32 {tfr.name="N"}, + %axis: i32 {tfr.name="axis"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i32 + %eq = cmpi "eq", %n, %cst : i32 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %temp = tfr.call @tf__expand_dims(%v1, %axis) : (!tfr.tensor, i32) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %temp : !tfr.tensor + } else { + %step = index_cast %cst : i32 to index + %end = index_cast %n : i32 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%temp) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %temp1 = tfr.call @tf__expand_dims(%v, %axis) : (!tfr.tensor, i32) -> !tfr.tensor + %reduce_next = tfr.call @tf__risc_concat(%reduce_iter, %temp1, %axis) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +// CHECK-LABEL: pack_one +func @pack_one(%arg0: tensor<2x3xf32>) -> tensor<1x2x3xf32> { + %0 = "tf.MyPack"(%arg0) {N=1:i32, axis=0:i32} : (tensor<2x3xf32>) -> tensor<1x2x3xf32> + return %0 : tensor<1x2x3xf32> + +// CHECK-NEXT: %[[AXIS:.*]] = constant 0 : i32 +// CHECK-NEXT: %[[CAST:.*]] = "tfr.cast"(%arg0) : (tensor<2x3xf32>) -> !tfr.tensor +// CHECK-NEXT: %[[ED:.*]] = tfr.call @tf__expand_dims(%[[CAST]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[BACK:.*]] = "tfr.cast"(%[[ED]]) : (!tfr.tensor) -> tensor<1x2x3xf32> +// CHECK-NEXT: return %[[BACK]] : tensor<1x2x3xf32> +} + +// CHECK-LABEL: pack_multiple +func @pack_multiple(%arg0: tensor<2x3xf32>, + %arg1: tensor<2x3xf32>, + %arg2: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyPack"(%arg0, %arg1, %arg2) {N=3:i32, axis=0:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[AXIS:.*]] = constant 0 : i32 +// CHECK-NEXT: %[[CAST0:.*]] = "tfr.cast"(%arg0) : (tensor<2x3xf32>) -> !tfr.tensor +// CHECK-NEXT: %[[CAST1:.*]] = "tfr.cast"(%arg1) : (tensor<2x3xf32>) -> !tfr.tensor +// CHECK-NEXT: %[[CAST2:.*]] = "tfr.cast"(%arg2) : (tensor<2x3xf32>) -> !tfr.tensor +// CHECK-NEXT: %[[EX0:.*]] = tfr.call @tf__expand_dims(%[[CAST0]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[EX1:.*]] = tfr.call @tf__expand_dims(%[[CAST1]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[CONCAT1:.*]] = tfr.call @tf__risc_concat(%[[EX0]], %[[EX1]], %c0_i32) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[EX2:.*]] = tfr.call @tf__expand_dims(%[[CAST2]], %[[AXIS]]) : (!tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[CONCAT2:.*]] = tfr.call @tf__risc_concat(%[[CONCAT1]], %[[EX2]], %[[AXIS]]) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor +// CHECK-NEXT: %[[BACK:.*]] = "tfr.cast"(%[[CONCAT2]]) : (!tfr.tensor) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[BACK]] : tensor<3x2x3xf32> +} diff --git a/tensorflow/compiler/mlir/tfr/tests/decompose.mlir b/tensorflow/compiler/mlir/tfr/tests/decompose.mlir new file mode 100644 index 00000000000..97f12c9fedb --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/decompose.mlir @@ -0,0 +1,84 @@ +// RUN: tfr-opt %s -tfr-decompose -verify-diagnostics -split-input-file | FileCheck %s + +// CHECK-LABEL: @tf__fake_no_op +tfr.func @tf__fake_no_op(%arg0: !tfr.tensor) -> !tfr.tensor { + tfr.return %arg0 : !tfr.tensor + +// CHECK-NEXT: tfr.return %arg0 : !tfr.tensor +} + +// CHECK-LABEL: @tf__intermediate +tfr.func @tf__intermediate(%arg0: !tfr.tensor) -> !tfr.tensor { + %0 = tfr.call @tf__risc(%arg0) : (!tfr.tensor) -> !tfr.tensor + tfr.return %0 : !tfr.tensor + +// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%arg0) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: tfr.return %[[id]] : !tfr.tensor +} + +// CHECK-LABEL: @tf__fused_n +tfr.func @tf__fused_n( + %arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: index {tfr.name="A",tfr.default=1:index}) + -> !tfr.tensor_list { + %0 = tfr.call @tf__intermediate(%arg0) : (!tfr.tensor) -> !tfr.tensor + %1 = tfr.get_element %arg1[%arg2] : (!tfr.tensor_list, index) -> !tfr.tensor + %2 = tfr.call @tf__intermediate(%1) : (!tfr.tensor) -> !tfr.tensor + %3 = "tfr.build_list"(%0, %2) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + tfr.return %3 : !tfr.tensor_list + +// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__intermediate(%arg0) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[ge:.*]] = tfr.get_element %arg1[%arg2] : (!tfr.tensor_list, index) -> !tfr.tensor +// CHECK-NEXT: %[[id2:.*]] = tfr.call @tf__intermediate(%[[ge]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[bl:.*]] = "tfr.build_list"(%[[id1]], %[[id2]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list +// CHECK-NEXT: tfr.return %[[bl]] : !tfr.tensor_list +} + +//------------------------ + +// CHECK-LABEL: decompose_tf_no_op +func @decompose_tf_no_op(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> { + %0 = "tf.FakeNoOp"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> + return %0 : tensor<1x2x3x4x!tf.string> + +// CHECK-NEXT: return %arg0 +} + +// CHECK-LABEL: decompose_tf_intermediate +func @decompose_tf_intermediate(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> { + %0 = "tf.Intermediate"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> + return %0 : tensor<1x2x3x4x!tf.string> + +// CHECK-NEXT: %[[casted:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor +// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%[[casted]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id]]) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string> +// CHECK-NEXT: return %[[back]] +} + +// CHECK-LABEL: decompose_fused_n_default +func @decompose_fused_n_default(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor, %arg2: tensor) -> tensor { + %0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) : (tensor<1x2x3x4x!tf.string>, tensor, tensor) -> (tensor<1x2x3x4x!tf.string>, tensor) + return %0#1 : tensor + +// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor +// CHECK-NEXT: %[[in2:.*]] = "tfr.cast"(%arg2) : (tensor) -> !tfr.tensor +// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[id2:.*]] = tfr.call @tf__risc(%[[in2]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id2]]) : (!tfr.tensor) -> tensor +// CHECK-NEXT: return %[[back]] : tensor +} + +// CHECK-LABEL: decompose_fused_n +func @decompose_fused_n(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor, %arg2: tensor) -> tensor { + %0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) {A=0:index} : (tensor<1x2x3x4x!tf.string>, tensor, tensor) -> (tensor<1x2x3x4x!tf.string>, tensor) + return %0#1 : tensor + +// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor +// CHECK-NEXT: %[[in1:.*]] = "tfr.cast"(%arg1) : (tensor) -> !tfr.tensor +// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__risc(%[[in1]]) : (!tfr.tensor) -> !tfr.tensor +// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id1]]) : (!tfr.tensor) -> tensor +// CHECK-NEXT: return %[[back]] : tensor +} + diff --git a/tensorflow/compiler/mlir/tfr/tests/end2end.mlir b/tensorflow/compiler/mlir/tfr/tests/end2end.mlir new file mode 100644 index 00000000000..5738020ccdb --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/end2end.mlir @@ -0,0 +1,235 @@ +// RUN: tfr-opt %s -tfr-decompose -tfr-raise-to-tf -canonicalize -verify-diagnostics -split-input-file | FileCheck %s + +//=================> User models, from GraphDef <==================== + +// CHECK-LABEL: my_identity +func @my_identity(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %0 = "tf.MyIdentity"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %0 : tensor<2x3xf32> + +// CHECK-NEXT: return %arg0 : tensor<2x3xf32> +} + +// CHECK-LABEL: my_rsqrt +func @my_rsqrt(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyRsqrt"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[RE:.*]] = "tf.RiscReciprocal"(%arg0) : (tensor<2x3xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[SQRT:.*]] = "tf.RiscSqrt"(%[[RE]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[SQRT]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_leaky_relu +func @my_leaky_relu(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyLeakyRelu"(%arg0) {alpha=3.0 : f32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor} : () -> tensor +// CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32> +// CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor, tensor<*xi32>) -> tensor<*xf32> +// CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_leaky_relu_with_default +func @my_leaky_relu_with_default(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyLeakyRelu"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<2.000000e-01> : tensor} : () -> tensor +// CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32> +// CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor, tensor<*xi32>) -> tensor<*xf32> +// CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_cast +func @my_cast(%arg0: tensor<2x3xf32>) -> tensor<2x3xi32> { + %0 = "tf.MyCast"(%arg0) {Tout=i32} : (tensor<2x3xf32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> + +// CHECK-NEXT: %[[CAST:.*]] = "tf.RiscCast"(%arg0) {Tout = i32} : (tensor<2x3xf32>) -> tensor<*xi32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CAST]]) {shape = #tf.shape<2x3>} : (tensor<*xi32>) -> tensor<2x3xi32> +// CHECK-NEXT: return %[[ES]] : tensor<2x3xi32> +} + +// CHECK-LABEL: my_pack_single_input +func @my_pack_single_input(%arg0: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyPack"(%arg0) {N=1:i32, axis=0:i32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK-NEXT: %[[ED:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ED]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_pack_multiple_inputs +func @my_pack_multiple_inputs(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf32>) -> tensor<3x2x3xf32> { + %0 = "tf.MyPack"(%arg0, %arg1, %arg2) {N=3:i32, axis=0:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<3x2x3xf32> + return %0 : tensor<3x2x3xf32> + +// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK-NEXT: %[[ED0:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[ED1:.*]] = "tf.ExpandDims"(%arg1, %[[AXIS]]) : (tensor<2x3xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[CC0:.*]] = "tf.RiscConcat"(%[[ED0]], %[[ED1]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ED2:.*]] = "tf.ExpandDims"(%arg2, %[[AXIS]]) : (tensor<2x3xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[CC1:.*]] = "tf.RiscConcat"(%[[CC0]], %[[ED2]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CC1]]) {shape = #tf.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32> +} + +// CHECK-LABEL: my_add_n_single_input +func @my_add_n_single_input(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %0 = "tf.MyAddN"(%arg0) {N=1:i32} : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %0 : tensor<2x3xf32> + +// CHECK-NEXT: return %arg0 : tensor<2x3xf32> +} + +// CHECK-LABEL: my_add_n_multiple_inputs +func @my_add_n_multiple_inputs(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf32>) -> tensor<2x3xf32> { + %0 = "tf.MyAddN"(%arg0, %arg1, %arg2) {N=3:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + return %0 : tensor<2x3xf32> + +// CHECK-NEXT: %[[ADD0:.*]] = "tf.RiscAdd"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ADD1:.*]] = "tf.RiscAdd"(%[[ADD0]], %arg2) : (tensor<*xf32>, tensor<2x3xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ADD1]]) {shape = #tf.shape<2x3>} : (tensor<*xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: return %[[ES]] : tensor<2x3xf32> +} + +// CHECK-LABEL: my_map_and_batch_dataset +func @my_map_and_batch_dataset(%input: tensor<*x!tf.variant>, + %other1: tensor<*xf32>, + %other2: tensor<*xi32>) -> tensor<*x!tf.variant> { + %0 = "tf.MyMapAndBatchDataset"(%input, %other1, %other2) + {batch_size=1000 : i64, num_parallel_calls = 8 : i64, drop_remainder = 0 : i1, + func = @"__some_func", output_types = [f32], output_shapes = [#tf.shape<>], preserve_cardinality = true} + : (tensor<*x!tf.variant>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf.variant> + return %0 : tensor<*x!tf.variant> + +// CHECK-NEXT: %[[BATCH:.*]] = "tf.Const"() {value = dense<1000> : tensor} : () -> tensor +// CHECK-NEXT: %[[PARAL:.*]] = "tf.Const"() {value = dense<8> : tensor} : () -> tensor +// CHECK-NEXT: %[[KEEP:.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor +// CHECK-NEXT: %[[CAST:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> +// CHECK-NEXT: %[[RET:.*]] = "tf.MapAndBatchDatasetV0"(%arg0, %[[BATCH]], %[[PARAL]], %[[KEEP]], %arg1, %[[CAST]]) +// CHECK-SAME: {f = @__some_func, output_shapes = [#tf.shape<>], output_types = [f32], preserve_cardinality = true} : (tensor<*x!tf.variant>, tensor, tensor, tensor, tensor<*xf32>, tensor<*xf32>) -> tensor<*x!tf.variant> +// CHECK-NEXT: return %[[RET]] : tensor<*x!tf.variant> +} + +//=================> decomposition functions, translated from tf.compose api <==================== +tfr.func @tf__my_identity(%value: !tfr.tensor) -> !tfr.tensor { + tfr.return %value : !tfr.tensor +} + +tfr.func @tf__my_cast(%value: !tfr.tensor, %tout: !tfr.attr{tfr.name="Tout"}) -> !tfr.tensor { + %0 = tfr.call @tf__risc_cast(%value, %tout) : (!tfr.tensor, !tfr.attr) -> !tfr.tensor + tfr.return %0 : !tfr.tensor +} + +tfr.func @tf__my_rsqrt(%value: !tfr.tensor) -> !tfr.tensor { + %1 = tfr.call @tf__risc_reciprocal(%value) : (!tfr.tensor) -> !tfr.tensor + %2 = tfr.call @tf__risc_sqrt(%1) : (!tfr.tensor) -> !tfr.tensor + tfr.return %2 : !tfr.tensor +} + +tfr.func @tf__my_leaky_relu(%value: !tfr.tensor, %alpha: f32 {tfr.name="alpha", tfr.default=0.2:f32}) -> !tfr.tensor { + %1 = tfr.call @tf__risc_shape(%value) : (!tfr.tensor) -> !tfr.tensor + %2 = "tfr.constant_tensor"(%alpha) : (f32) -> tensor + %t = "tfr.cast"(%2) : (tensor) -> !tfr.tensor + %3 = tfr.call @tf__risc_broadcast(%t, %1) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + %4 = tfr.call @tf__risc_maximum(%value, %3) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + tfr.return %4 : !tfr.tensor +} + +// TODO(fengliuai): use shape dialect to manipulate the shape then this can be decomposed further. +tfr.func @tf__my_expand_dims(%value: !tfr.tensor, %axis: i32 {tfr.name="axis"}) -> !tfr.tensor { + %axis_cst = "tfr.constant_tensor"(%axis) : (i32) -> tensor + %dim = "tfr.cast"(%axis_cst) : (tensor) -> !tfr.tensor + %0 = tfr.call @tf__expand_dims(%value, %dim) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + tfr.return %0 : !tfr.tensor +} + +tfr.func @tf__my_pack(%values: !tfr.tensor_list, + %n: i32 {tfr.name="N"}, + %axis: i32 {tfr.name="axis"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i32 + %eq = cmpi "eq", %n, %cst : i32 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %temp = tfr.call @tf__my_expand_dims(%v1, %axis) : (!tfr.tensor, i32) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %temp : !tfr.tensor + } else { + %step = index_cast %cst : i32 to index + %end = index_cast %n : i32 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%temp) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %temp1 = tfr.call @tf__my_expand_dims(%v, %axis) : (!tfr.tensor, i32) -> !tfr.tensor + %reduce_next = tfr.call @tf__risc_concat(%reduce_iter, %temp1, %axis) : (!tfr.tensor, !tfr.tensor, i32) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +tfr.func @tf__my_add_n(%values: !tfr.tensor_list, + %n: i32 {tfr.name="N"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i32 + %eq = cmpi "eq", %n, %cst : i32 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %v1 : !tfr.tensor + } else { + %step = index_cast %cst : i32 to index + %end = index_cast %n : i32 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %reduce_next = tfr.call @tf__risc_add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +tfr.func @tf__my_map_and_batch_dataset( + %input_dataset: !tfr.tensor, + %other_arguments: !tfr.tensor_list, + %batch_size: i64 {tfr.name="batch_size"}, + %num_parallel_calls: i64 {tfr.name="num_parallel_calls"}, + %drop_remainder: i1 {tfr.name="drop_remainder"}, + %f: !tfr.attr {tfr.name="func"}, + %output_types: !tfr.attr {tfr.name="output_types"}, + %output_shapes: !tfr.attr {tfr.name="output_shapes"}, + %preserve_cardinality: i1 {tfr.name="preserve_cardinality", tfr.default=false}) -> !tfr.tensor { + %batch = "tfr.constant_tensor"(%batch_size) : (i64) -> tensor + %batch1 = "tfr.cast"(%batch) : (tensor) -> !tfr.tensor + %calls = "tfr.constant_tensor"(%num_parallel_calls) : (i64) -> tensor + %calls1 = "tfr.cast"(%calls) : (tensor) -> !tfr.tensor + %drop = "tfr.constant_tensor"(%drop_remainder) : (i1) -> tensor + %drop1 = "tfr.cast"(%drop) : (tensor) -> !tfr.tensor + %ret = tfr.call @tf__map_and_batch_dataset_v0(%input_dataset, %batch1, %calls1, %drop1, %other_arguments, %f, %output_types, %output_shapes, %preserve_cardinality) + : (!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor_list, !tfr.attr, !tfr.attr, !tfr.attr, i1) -> !tfr.tensor + tfr.return %ret : !tfr.tensor +} + +//=================> signatures of the primitive ops with kernels, modeled as external TFR function <== +tfr.func @tf__risc_cast_(!tfr.tensor, !tfr.attr{tfr.name="Tout"}) -> !tfr.tensor attributes{Tout} +tfr.func @tf__risc_add_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_concat_(!tfr.tensor, !tfr.tensor, i32{tfr.name="axis"}) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_broadcast_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T, Tidx} +tfr.func @tf__risc_reciprocal_(!tfr.tensor) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_sqrt_(!tfr.tensor) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_shape_(!tfr.tensor, !tfr.attr{tfr.name="T", tfr.default=i32}) -> !tfr.tensor attributes{T} +tfr.func @tf__risc_maximum_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T} +tfr.func @tf__expand_dims_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T, Tdim} +tfr.func @tf__map_and_batch_dataset_v0_(!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor_list, + !tfr.attr{tfr.name="f"}, !tfr.attr{tfr.name="output_types"}, !tfr.attr{tfr.name="output_shapes"}, i1{tfr.name="preserve_cardinality"}) + -> !tfr.tensor attributes{T, Targuments} diff --git a/tensorflow/compiler/mlir/tfr/tests/ops.mlir b/tensorflow/compiler/mlir/tfr/tests/ops.mlir new file mode 100644 index 00000000000..b074985c591 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/ops.mlir @@ -0,0 +1,381 @@ +// RUN: tfr-opt %s -verify-diagnostics -split-input-file | tfr-opt | FileCheck %s +// RUN: tfr-opt %s -canonicalize -verify-diagnostics -split-input-file | FileCheck %s -check-prefix=CANON + +// Tests for types, ops with custom constraints, verifiers, printer or parser +// methods. + +// CHECK-LABEL: tensor_type_noconstraint +func @tensor_type_noconstraint() -> !tfr.tensor + +// ----- + +// CHECK-LABEL: tensor_type +func @tensor_type() -> !tfr.tensor + +// ----- + +// CHECK-LABEL: tensor_list_type_noconstraint +func @tensor_list_type_noconstraint() -> !tfr.tensor_list + +// ----- + +// CHECK-LABEL: tensor_list_type_array_like +func @tensor_list_type_array_like() -> !tfr.tensor_list<[N, T]> + +// ----- + +// CHECK-LABEL: tensor_list_type_tuple_like +func @tensor_list_type_tuple_like() -> !tfr.tensor_list + +// ----- + +// expected-error@+1 {{unbalanced '>' character in pretty dialect name}} +func @tensor_invalid_1() -> !tfr.tensor<[N, T> + +// ----- + +// expected-error@+1 {{unexpected nul or EOF in pretty dialect name}} +func @tensor_invalid_2() -> !tfr.tensor<[N, T] + +// ----- + +// CHECK-LABEL: call_op +func @call_op(%arg0: !tfr.tensor, %arg1: !tfr.tensor_list, %arg2: i32) -> !tfr.tensor { + %0 = tfr.call @Foo(%arg0, %arg1, %arg2) : (!tfr.tensor, !tfr.tensor_list, i32) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +// CHECK-LABEL: call_op_arg_attr(%arg0: i32) -> !tfr.tensor +func @call_op_arg_attr(%arg0: i32) -> !tfr.tensor { + %0 = tfr.call @Bar(%arg0) : (i32) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +func @call_op_invalid_1(%arg0: tensor) -> !tfr.tensor { + // expected-error@+1 {{got 'tensor'}} + %0 = tfr.call @Huu(%arg0) : (tensor) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +// CHECK-LABEL: get_shape +func @get_shape(%arg0: !tfr.tensor) -> (!shape.shape, !shape.shape) { + %0 = tfr.get_shape %arg0 -> !shape.shape + %1 = "tfr.get_shape"(%arg0) : (!tfr.tensor) -> !shape.shape + return %0, %1 : !shape.shape, !shape.shape +} + +// ----- + +// CHECK-LABEL: get_real_shape +// CANON-LABEL: get_real_shape +func @get_real_shape(%arg0: tensor<1x2xf32>) -> tensor<1xindex> { + %0 = "tfr.cast"(%arg0) : (tensor<1x2xf32>) -> !tfr.tensor + %1 = tfr.get_shape %0 -> !shape.shape + %2 = shape.to_extent_tensor %1 : !shape.shape -> tensor<1xindex> + return %2 : tensor<1xindex> + +// CANON-NEXT: %[[s:.*]] = shape.const_shape [1, 2] : tensor +// CANON-NEXT: %[[e:.*]] = shape.to_extent_tensor %[[s]] : tensor -> tensor<1xindex> +// CANON-NEXT: return %[[e]] : tensor<1xindex> +} + +// ----- + +func @get_element_type(%arg0: !tfr.tensor) -> (!tfr.attr, !tfr.attr) { + %0 = tfr.get_element_type %arg0 -> !tfr.attr + %1 = "tfr.get_element_type"(%arg0) : (!tfr.tensor) -> !tfr.attr + return %0, %1 : !tfr.attr, !tfr.attr +} + +// ----- + +// CHECK-LABEL: from_tf_tensor +func @from_tf_tensor(%arg0: tensor) -> !tfr.tensor { + %0 = "tfr.cast"(%arg0) : (tensor) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +// CHECK-LABEL: to_tf_tensor +func @to_tf_tensor(%arg0: !tfr.tensor) -> tensor { + %0 = "tfr.cast"(%arg0) : (!tfr.tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: constant +func @constant() -> (!tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr) { + %0 = tfr.constant f32 -> !tfr.attr + %1 = tfr.constant [f32, i32] -> !tfr.attr + %2 = "tfr.constant"() {value = f32} : () -> !tfr.attr + %3 = "tfr.constant"() {value = [f32, i32]} : () -> !tfr.attr + return %0, %1, %2, %3 : !tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr +} + +// ----- + +// CHECK-LABEL: equal +// CANON-LABEL: equal +func @equal() -> (i1, i1, i1, i1) { + %0 = tfr.constant f32 -> !tfr.attr + %1 = tfr.constant f32 -> !tfr.attr + %2 = tfr.constant i32 -> !tfr.attr + %same_type = tfr.equal %0,%1 -> i1 + %diff_type = tfr.equal %0,%2 -> i1 + + %3 = tfr.constant "hello" -> !tfr.attr + %4 = tfr.constant "hello" -> !tfr.attr + %5 = tfr.constant "how are you" -> !tfr.attr + %same_str = tfr.equal %3,%4 -> i1 + %diff_str = tfr.equal %3,%5 -> i1 + return %same_type, %diff_type, %same_str, %diff_str : i1, i1, i1, i1 + +// CANON-NEXT: %true = constant true +// CANON-NEXT: %false = constant false +// CANON-NEXT: return %true, %false, %true, %false : i1, i1, i1, i1 +} + +// ----- + +// CHECK-LABEL: constant_tensor_scalar +func @constant_tensor_scalar(%arg0: i32) -> tensor { + %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: constant_tensor_vector +func @constant_tensor_vector(%arg0: vector<1x2xi32>) -> tensor<1x2xi32> { + %0 = "tfr.constant_tensor"(%arg0) : (vector<1x2xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> +} + +// ----- + +// CHECK-LABEL: constant_tensor_array +// CANON-LABEL: constant_tensor_array +func @constant_tensor_array() -> !tfr.tensor { + %0 = tfr.constant [1, -1, 3] -> !tfr.attr + %1 = "tfr.constant_tensor"(%0) : (!tfr.attr) -> !tfr.tensor + return %1 : !tfr.tensor + +// CANON-NEXT: "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi64>} : () -> tensor<3xi64> +// CANON-NEXT: "tfr.cast"(%0) : (tensor<3xi64>) -> !tfr.tensor +// CANON-NEXT: return +} + +// ----- + +// CHECK-LABEL: constant_tensor_scalar +// CANON-LABEL: constant_tensor_scalar +func @constant_tensor_scalar() -> !tfr.tensor { + %0 = "std.constant"() {value = 42 : i32} : () -> i32 + %1 = "tfr.constant_tensor"(%0) : (i32) -> !tfr.tensor + return %1 : !tfr.tensor + +// CANON-NEXT: "tf.Const"() {value = dense<42> : tensor} : () -> tensor +// CANON-NEXT: "tfr.cast"(%0) : (tensor) -> !tfr.tensor +// CANON-NEXT: return +} + +// ----- + +func @constant_tensor_invalid_0(%arg0: i32) -> tensor { + // expected-error@+1 {{input and output should have the same scalar types.}} + %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor + return %0 : tensor +} + +// ----- + +func @constant_tensor_invalid_1(%arg0: vector<1xi32>) -> tensor { + // expected-error@+1 {{output type should be static and ranked}} + %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @constant_tensor_invalid_2(%arg0: vector<1xi32>) -> tensor<1xf32> { + // expected-error@+1 {{input and output should have same shape and element type}} + %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +// ----- + +func @constant_tensor_invalid_3(%arg0: vector<1xi32>) -> tensor<1x1xi32> { + // expected-error@+1 {{input and output should have same shape and element type}} + %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1x1xi32> + return %0 : tensor<1x1xi32> +} + +// ----- + +func @constant_tensor_invalid_4(%arg0: i32) -> tensor<1x1xi32> { + // expected-error@+1 {{input can not be converted to an output tensor}} + %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<1x1xi32> + return %0 : tensor<1x1xi32> +} + +// ----- + +// CHECK-LABEL: get_element +func @get_element(%arg0: !tfr.tensor_list) -> !tfr.tensor { + %cst = "std.constant"() {value = 1 : index} : () -> index + %0 = tfr.get_element %arg0[%cst] : (!tfr.tensor_list, index) -> !tfr.tensor + return %0 : !tfr.tensor +} + +// ----- + +// CHECK-LABEL: build_list +func @build_list(%arg0: !tfr.tensor, %arg1: !tfr.tensor) -> !tfr.tensor_list { + %0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + return %0 : !tfr.tensor_list +} + +// ----- + +// CHECK-LABEL: build_const_list +// CANON-LABEL: build_const_list +func @build_const_list() -> !tfr.attr { + %0 = "std.constant"() {value = 42 : i32} : () -> i32 + %1 = "std.constant"() {value = 41 : i32} : () -> i32 + %2 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr + return %2 : !tfr.attr + +// CANON-NEXT: %[[c:.*]] = tfr.constant [42 : i32, 41 : i32] -> !tfr.attr +// CANON-NEXT: return %[[c]] : !tfr.attr +} + +// ----- + +// CHECK-LABEL: tfr.func +tfr.func @External(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: !tfr.attr {tfr.name = "T"}) + -> (!tfr.tensor, !tfr.tensor_list) + attributes {A, C} + +// ----- + +// CHECK-LABEL: tfr.func +tfr.func @Foo(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) + -> (!tfr.tensor, !tfr.tensor_list) + attributes {A, C} { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// CHECK-LABEL: tfr.func +tfr.func @Bar(%arg0: !tfr.tensor, + %arg2: i32 {tfr.name = "B"}, + %arg3: vector<1xi32> {tfr.name = "C"}) + -> (!tfr.tensor, !tfr.tensor) + attributes {A} { + tfr.return %arg0, %arg0 : !tfr.tensor, !tfr.tensor +} + +// ----- + +// expected-error@+1 {{Undefined attributes are used: A}} +tfr.func @Foo_undefined_attr(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{3 attribute argument doesn't have a tfr.name attribute}} +tfr.func @Foo_unnamed_attr(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32>) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{tfr.tensor_list argument should be before non tensor arguments}} +tfr.func @Foo_invalid_arg_order(%arg0: !tfr.tensor, + %arg2: i32 {tfr.name = "A"}, + %arg1: !tfr.tensor_list, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{tfr.tensor argument should be before tfr.tensor_list argument.}} +tfr.func @Foo_invalid_arg_order0( + %arg1: !tfr.tensor_list, + %arg0: !tfr.tensor, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{tfr.tensor result should be before tfr.tensor_list result}} +tfr.func @Foo_invalid_result_order(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor_list, !tfr.tensor) { + tfr.return %arg1, %arg0 : !tfr.tensor_list, !tfr.tensor +} + +// ----- + +// expected-error@+1 {{More than one tfr.tensor_list argument isn't allowed}} +tfr.func @Foo_multiple_tensor_list_args(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: !tfr.tensor_list, + %arg3: i32 {tfr.name = "A"}, + %arg4: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor, !tfr.tensor_list) { + tfr.return %arg0, %arg1 : !tfr.tensor, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{More than one tfr.tensor_list result isn't allowed}} +tfr.func @Foo_multiple_tensor_list_results(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> + (!tfr.tensor_list, !tfr.tensor_list) { + tfr.return %arg1, %arg1 : !tfr.tensor_list, !tfr.tensor_list +} + +// ----- + +// expected-error@+1 {{None tfr.tensor/tfr.tensor_list results aren't allowed as a result}} +tfr.func @Foo_return_attr(%arg0: !tfr.tensor, + %arg1: !tfr.tensor_list, + %arg2: i32 {tfr.name = "A"}, + %arg3: vector<1xi32> {tfr.name = "C"}) -> i32 { + tfr.return %arg2 : i32 +} diff --git a/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir b/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir new file mode 100644 index 00000000000..41d0ee6271d --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir @@ -0,0 +1,76 @@ +// RUN: tfr-opt %s -tfr-raise-to-tf -verify-diagnostics -split-input-file | FileCheck %s + +tfr.func @tf__risc_same_(!tfr.tensor) -> !tfr.tensor attributes {T} +tfr.func @tf__risc_concat_(!tfr.tensor_list) -> !tfr.tensor attributes {T, N} +tfr.func @tf__risc_split_(!tfr.tensor, i32 {tfr.name="N"}) -> !tfr.tensor_list attributes {T, N} +tfr.func @tf__risc_cast_(!tfr.tensor, !tfr.attr {tfr.name="K"}) -> !tfr.tensor attributes {T, K} + +// CHECK-LABEL: decompose_tf_same +func @decompose_tf_same(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> { + %0 = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor + %1 = tfr.call @tf__risc_same(%0) : (!tfr.tensor) -> !tfr.tensor + %2 = "tfr.cast"(%1) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string> + return %2 : tensor<1x2x3x4x!tf.string> + +// CHECK: %[[id:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<*x!tf.string> +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id]]) {shape = #tf.shape<1x2x3x4>} : (tensor<*x!tf.string>) -> tensor<1x2x3x4x!tf.string> +// CHECK: return %[[es]] : tensor<1x2x3x4x!tf.string> +} + +// CHECK-LABEL: decompose_tf_consecutive +func @decompose_tf_consecutive(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor + %1 = "tfr.cast"(%arg2) : (tensor) -> !tfr.tensor + %2 = tfr.call @tf__risc_same(%0) : (!tfr.tensor) -> !tfr.tensor + %3 = tfr.call @tf__risc_same(%1) : (!tfr.tensor) -> !tfr.tensor + %4 = "tfr.cast"(%3) : (!tfr.tensor) -> tensor + return %4 : tensor + +// CHECK: %[[id0:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> tensor<*x!tf.string> +// CHECK: %[[id2:.*]] = "tf.RiscSame"(%arg2) : (tensor) -> tensor<*xf32> +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id2]]) {shape = #tf.shape<>} : (tensor<*xf32>) -> tensor +// CHECK: return %[[es]] : tensor +} + +// CHECK-LABEL: decompose_tf_concat_n +func @decompose_tf_concat_n(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<3xf32> { + %0 = "tfr.cast"(%arg0) : (tensor) -> !tfr.tensor + %1 = "tfr.cast"(%arg1) : (tensor) -> !tfr.tensor + %2 = "tfr.cast"(%arg2) : (tensor) -> !tfr.tensor + %3 = "tfr.build_list"(%0, %1, %2) : (!tfr.tensor, !tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + %concat = tfr.call @tf__risc_concat(%3) : (!tfr.tensor_list) -> !tfr.tensor + %4 = "tfr.cast"(%concat) : (!tfr.tensor) -> tensor<3xf32> + return %4 : tensor<3xf32> + +// CHECK: %[[concat:.*]] = "tf.RiscConcat"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor<*xf32> +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[concat]]) {shape = #tf.shape<3>} : (tensor<*xf32>) -> tensor<3xf32> +// CHECK: return %[[es]] : tensor<3xf32> +} + +// CHECK-LABEL: decompose_tf_split +func @decompose_tf_split(%arg0: tensor<3xf32>) -> (tensor) { + %0 = "tfr.cast"(%arg0) : (tensor<3xf32>) -> !tfr.tensor + %n = std.constant 3: i32 + %split = tfr.call @tf__risc_split(%0, %n) : (!tfr.tensor, i32) -> !tfr.tensor_list + %i0 = std.constant 0: index + %s0 = tfr.get_element %split[%i0] : (!tfr.tensor_list, index) -> !tfr.tensor + %4 = "tfr.cast"(%s0) : (!tfr.tensor) -> tensor + return %4 : tensor + +// CHECK: %[[split:.*]]:3 = "tf.RiscSplit"(%arg0) {N = 3 : i32} : (tensor<3xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[split]]#0) {shape = #tf.shape<>} : (tensor<*xf32>) -> tensor +// CHECK: return %[[es]] : tensor +} + +// CHECK-LABEL: decompose_tf_cast +func @decompose_tf_cast(%arg0: tensor) -> tensor { + %0 = "tfr.cast"(%arg0) : (tensor) -> !tfr.tensor + %t = tfr.constant i32 -> !tfr.attr + %concat = tfr.call @tf__risc_cast(%0, %t) : (!tfr.tensor, !tfr.attr) -> !tfr.tensor + %4 = "tfr.cast"(%concat) : (!tfr.tensor) -> tensor + return %4 : tensor + +// CHECK: %[[tfcast:.*]] = "tf.RiscCast"(%arg0) {K = i32} : (tensor) -> tensor<*xi32> +// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) {shape = #tf.shape<>} : (tensor<*xi32>) -> tensor +// CHECK: return %[[es]] : tensor +} diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.cc b/tensorflow/compiler/mlir/tfr/utils/utils.cc new file mode 100644 index 00000000000..6c08b682cb0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/utils/utils.cc @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfr/utils/utils.h" + +#include "llvm/ADT/StringRef.h" +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +std::string GetComposeFuncName(StringRef tf_op_name) { + std::string compose_func_name; + for (int i = 0; i < tf_op_name.size(); ++i) { + if (tf_op_name[i] == '_') { + // The field name must not contain "_"s. "_Arg" and "_RetVal" are special + // op names and we can return empty string to skip the decomposition. + return {}; + } + if (tf_op_name[i] == '.') { + compose_func_name.push_back('_'); + } else if (tf_op_name[i] >= 'A' && tf_op_name[i] <= 'Z') { + compose_func_name.push_back('_'); + compose_func_name.push_back(tf_op_name[i] + 'a' - 'A'); + } else { + compose_func_name.push_back(tf_op_name[i]); + } + } + return compose_func_name; +} + +std::string GetTFOpName(StringRef compose_func_name) { + std::string tf_op_name; + bool after_underscore = false; + for (int i = 0; i < compose_func_name.size(); ++i) { + if (compose_func_name[i] >= 'A' && compose_func_name[i] <= 'Z') { + // The field name must not contain uppercase letters. + return {}; + } + if (after_underscore) { + if (compose_func_name[i] >= 'a' && compose_func_name[i] <= 'z') { + tf_op_name.push_back(compose_func_name[i] + 'A' - 'a'); + after_underscore = false; + } else { + // The character after a "_" must be a lowercase letter. + return {}; + } + } else if (compose_func_name[i] == '_') { // first time visit '_' + if (i + 1 < compose_func_name.size() && compose_func_name[i + 1] == '_') { + tf_op_name.push_back('.'); + i++; + } + after_underscore = true; + } else { + tf_op_name.push_back(compose_func_name[i]); + } + } + if (after_underscore) { + // Trailing "_". + return {}; + } + return tf_op_name; +} + +} // namespace TFR +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.h b/tensorflow/compiler/mlir/tfr/utils/utils.h new file mode 100644 index 00000000000..26c7250d95a --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/utils/utils.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_ + +#include + +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +// This is a hardcoded rule for mapping a TF op name to the corresponding +// TFR function name. Examples: +// tf.Pack => tf__pack +// tf.ConcatV2 => tf__concat_v2 +// TODO(fengliuai): move to an util file. +std::string GetComposeFuncName(StringRef tf_op_name); + +// This is a hardcoded rule for mapping a TFR function op name to the +// corresponding TF opname. Examples: +// tf__pack -> tf.Pack +// tf__concat_v2 => tf.ConcatV2 +std::string GetTFOpName(StringRef compose_func_name); + +} // namespace TFR +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_