From af9cb379b6728d11b4d929cb04dad90cf47fa408 Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Tue, 11 Aug 2020 23:03:14 -0700 Subject: [PATCH] Move TF Broadcast op legalization process to the prepare_tf stage This change is to get benefits from the constant folding logic from TF dialect. PiperOrigin-RevId: 326174654 Change-Id: Icb25f11a6ac0df9904a94831f4969f5b259723a7 --- tensorflow/compiler/mlir/lite/BUILD | 23 ++++ .../legalize-tf-no-runtime-verification.mlir | 7 +- .../compiler/mlir/lite/tests/legalize-tf.mlir | 22 ---- .../compiler/mlir/lite/tests/prepare-tf.mlir | 20 ++++ .../mlir/lite/transforms/legalize_tf.cc | 112 +----------------- .../mlir/lite/transforms/prepare_tf.cc | 45 ++++++- .../mlir/lite/utils/constant_utils.cc | 112 ++++++++++++++++++ .../compiler/mlir/lite/utils/constant_utils.h | 35 ++++++ 8 files changed, 239 insertions(+), 137 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/utils/constant_utils.cc create mode 100644 tensorflow/compiler/mlir/lite/utils/constant_utils.h diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 555c11779f5..bd1dcdf06ea 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -237,6 +237,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "constant_utils", + srcs = [ + "utils/constant_utils.cc", + ], + hdrs = [ + "utils/constant_utils.h", + ], + copts = ["-std=c++14"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:status", + "//tensorflow/stream_executor/lib", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "lstm_utils", srcs = [ @@ -347,6 +369,7 @@ cc_library( "transforms/passes.h", ], deps = [ + ":constant_utils", ":lstm_utils", ":stateful_ops_utils", ":tensorflow_lite", diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index 90266b4e78e..3c390df74b4 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -1,12 +1,11 @@ -// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s +// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> { %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16> return %0: tensor<3x3xbf16> // CHECK-LABEL: broadcast_to_bf16 -// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor -// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor) -> tensor<3x3xbf16> -// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> +// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xbf16> +// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[CST]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> // CHECK: return [[MUL]] : tensor<3x3xbf16> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 7cb9c4dd22c..d02e4e705f4 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1482,28 +1482,6 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) { // CHECK: return [[VAL_4]] : tensor<28x1x28xf32> // CHECK: } -func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { - %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> - return %0: tensor<3x3xf32> - -// CHECK-LABEL: broadcast_to_f32 -// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor -// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xf32> -// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK: return [[MUL]] : tensor<3x3xf32> -} - -func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { - %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> - return %0: tensor<3x3xi32> - -// CHECK-LABEL: broadcast_to_i32 -// CHECK: [[CST:%.*]] = constant dense<1> : tensor -// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xi32> -// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> -// CHECK: return [[MUL]] : tensor<3x3xi32> -} - func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> { %0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 066139e179b..6ee5b67d65e 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -595,4 +595,24 @@ func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { // CHECK: return %[[RES]] } +func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0: tensor<3x3xf32> + +// CHECK-LABEL: broadcast_to_f32 +// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32> +// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> +// CHECK: return [[MUL]] : tensor<3x3xf32> +} + +func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0: tensor<3x3xi32> + +// CHECK-LABEL: broadcast_to_i32 +// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32> +// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> +// CHECK: return [[MUL]] : tensor<3x3xi32> +} + } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 7a16e475ce3..297b1459fc5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" @@ -137,7 +138,6 @@ DECL_CONVERT_OP(StridedSlice); DECL_CONVERT_OP(Unpack); DECL_CONVERT_OP(Reciprocal); DECL_CONVERT_OP(RandomUniform); -DECL_CONVERT_OP(BroadcastTo); #undef DECL_CONVERT_OP @@ -483,89 +483,6 @@ LogicalResult ConvertTFAssertOp::matchAndRewrite( return success(); } -StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, - Location loc, - ShapedType shaped_type, - int value) { - Type element_type = shaped_type.getElementType(); - ShapedType scalar_type = RankedTensorType::get({}, element_type); - Attribute attr; - switch (element_type.getKind()) { - case mlir::StandardTypes::F16: { - auto floatType = mlir::FloatType::getF16(element_type.getContext()); - auto floatAttr = - mlir::FloatAttr::get(floatType, static_cast(value)); - std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); - break; - } - case mlir::StandardTypes::BF16: { - auto floatType = mlir::FloatType::getBF16(element_type.getContext()); - auto floatAttr = - mlir::FloatAttr::get(floatType, static_cast(value)); - std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); - break; - } - case mlir::StandardTypes::F32: { - attr = - DenseElementsAttr::get(scalar_type, static_cast(value)); - break; - } - case mlir::StandardTypes::Complex: { - auto etype = element_type.cast().getElementType(); - if (etype.isF32()) { - auto dialect = etype.getContext()->getRegisteredDialect("tf"); - tensorflow::TensorProto repr; - repr.set_dtype(tensorflow::DT_COMPLEX64); - - tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape(); - shape->set_unknown_rank(false); - shape->add_dim()->set_size(int64_t{1}); - std::string content; - auto complex_value = - std::complex(static_cast(value), 0.0f); - content.assign(reinterpret_cast(&complex_value), - sizeof(complex_value)); - repr.set_tensor_content(content); - std::string mangled = tensorflow::mangling_util::MangleTensor(repr); - - attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); - break; - } - return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); - } - case mlir::StandardTypes::Integer: { - const auto& itype = element_type.cast(); - switch (itype.getWidth()) { - case 8: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 16: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 32: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 64: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - default: - return Status(tensorflow::error::INVALID_ARGUMENT, - "Unsupported type"); - } - break; - } - default: - return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); - } - return rewriter->create(loc, scalar_type, attr); -} - LogicalResult ConvertTFReciprocalOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_reciprocal_op = cast(op); @@ -586,31 +503,6 @@ LogicalResult ConvertTFReciprocalOp::matchAndRewrite( return success(); } -LogicalResult ConvertTFBroadcastToOp::matchAndRewrite( - Operation* op, PatternRewriter& rewriter) const { - auto tf_broadcast_to_op = cast(op); - auto element_type = tf_broadcast_to_op.input().getType().cast(); - auto output_type = tf_broadcast_to_op.output().getType(); - - auto status_or_const_op = - CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - auto tfl_fill_op = rewriter.create( - op->getLoc(), output_type, tf_broadcast_to_op.shape(), - status_or_const_op.ValueOrDie()); - - StringAttr fused_activation_function = - StringAttr::get("NONE", rewriter.getContext()); - - rewriter.replaceOpWithNewOp( - op, output_type, tf_broadcast_to_op.input(), tfl_fill_op, - fused_activation_function); - return success(); -} - // Legalize unidirectional sequence lstm. struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context) @@ -751,7 +643,7 @@ void LegalizeTF::runOnFunction() { ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp, - ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context); + ConvertTFRandomUniformOp>(context); // Ophint python converter converted tf node pattern. patterns.insert(op); + auto input_type = tf_broadcast_to_op.input().getType().cast(); + auto output_type = tf_broadcast_to_op.output().getType().cast(); + auto shape_type = tf_broadcast_to_op.shape().getType().cast(); + Type element_type = input_type.getElementType(); + + // Allow lowering when low dimension inputs are given and its type is F32 or + // I32. + if (!((output_type.hasRank() && output_type.getRank() <= 5) || + (shape_type.hasStaticShape() && shape_type.getRank() == 1 && + shape_type.getDimSize(0) <= 5))) + return failure(); + + if (!((element_type.getKind() == mlir::StandardTypes::F32) || + (element_type.getKind() == mlir::StandardTypes::BF16) || + (element_type.getKind() == mlir::StandardTypes::Integer && + element_type.cast().getWidth() == 32))) + return failure(); + + auto status_or_const_op = + CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1); + if (!status_or_const_op.ok()) { + return failure(); + } + + auto tf_fill_op = rewriter.create( + op->getLoc(), output_type, tf_broadcast_to_op.shape(), + status_or_const_op.ValueOrDie()); + + auto mul_op = rewriter.create( + op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op); + rewriter.replaceOp(op, mul_op.getResult()); + return success(); + } +}; + #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" // Returns success if all the operations in the `op`'s regions including `op` @@ -767,7 +810,7 @@ void PrepareTFPass::runOnFunction() { patterns.insert, TF::ConvertTFBatchMatMulOp>(ctx); } - patterns.insert(ctx); applyPatternsAndFoldGreedily(func, patterns); } diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc new file mode 100644 index 00000000000..8562f623258 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -0,0 +1,112 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace mlir { +namespace TFL { + +stream_executor::port::StatusOr CreateConstOpWithSingleValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, + int value) { + Type element_type = shaped_type.getElementType(); + ShapedType scalar_type = RankedTensorType::get({}, element_type); + Attribute attr; + switch (element_type.getKind()) { + case mlir::StandardTypes::F16: { + auto floatType = mlir::FloatType::getF16(element_type.getContext()); + auto floatAttr = + mlir::FloatAttr::get(floatType, static_cast(value)); + std::vector floatValues({floatAttr}); + attr = DenseElementsAttr::get(scalar_type, floatValues); + break; + } + case mlir::StandardTypes::BF16: { + auto floatType = mlir::FloatType::getBF16(element_type.getContext()); + auto floatAttr = + mlir::FloatAttr::get(floatType, static_cast(value)); + std::vector floatValues({floatAttr}); + attr = DenseElementsAttr::get(scalar_type, floatValues); + break; + } + case mlir::StandardTypes::F32: { + attr = + DenseElementsAttr::get(scalar_type, static_cast(value)); + break; + } + case mlir::StandardTypes::Complex: { + auto etype = element_type.cast().getElementType(); + if (etype.isF32()) { + auto dialect = etype.getContext()->getRegisteredDialect("tf"); + tensorflow::TensorProto repr; + repr.set_dtype(tensorflow::DT_COMPLEX64); + + tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape(); + shape->set_unknown_rank(false); + shape->add_dim()->set_size(int64_t{1}); + std::string content; + auto complex_value = + std::complex(static_cast(value), 0.0f); + content.assign(reinterpret_cast(&complex_value), + sizeof(complex_value)); + repr.set_tensor_content(content); + std::string mangled = tensorflow::mangling_util::MangleTensor(repr); + + attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); + break; + } + return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); + } + case mlir::StandardTypes::Integer: { + const auto& itype = element_type.cast(); + switch (itype.getWidth()) { + case 8: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 16: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 32: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 64: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + default: + return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); + } + break; + } + default: + return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); + } + return rewriter->create(loc, scalar_type, attr); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.h b/tensorflow/compiler/mlir/lite/utils/constant_utils.h new file mode 100644 index 00000000000..5c348021b5e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace mlir { +namespace TFL { + +// Returns a Constant op with a single value. +stream_executor::port::StatusOr CreateConstOpWithSingleValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); + +} // namespace TFL +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_