Move TF Broadcast op legalization process to the prepare_tf stage

This change is to get benefits from the constant folding logic from TF dialect.

PiperOrigin-RevId: 326174654
Change-Id: Icb25f11a6ac0df9904a94831f4969f5b259723a7
This commit is contained in:
Jaesung Chung 2020-08-11 23:03:14 -07:00 committed by TensorFlower Gardener
parent 57f61ed3fd
commit af9cb379b6
8 changed files with 239 additions and 137 deletions

View File

@ -237,6 +237,28 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "constant_utils",
srcs = [
"utils/constant_utils.cc",
],
hdrs = [
"utils/constant_utils.h",
],
copts = ["-std=c++14"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "lstm_utils",
srcs = [
@ -347,6 +369,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
":constant_utils",
":lstm_utils",
":stateful_ops_utils",
":tensorflow_lite",

View File

@ -1,12 +1,11 @@
// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
return %0: tensor<3x3xbf16>
// CHECK-LABEL: broadcast_to_bf16
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xbf16>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[CST]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
// CHECK: return [[MUL]] : tensor<3x3xbf16>
}

View File

@ -1482,28 +1482,6 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
// CHECK: }
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>

View File

@ -595,4 +595,24 @@ func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> {
// CHECK: return %[[RES]]
}
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
}

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
@ -137,7 +138,6 @@ DECL_CONVERT_OP(StridedSlice);
DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(Reciprocal);
DECL_CONVERT_OP(RandomUniform);
DECL_CONVERT_OP(BroadcastTo);
#undef DECL_CONVERT_OP
@ -483,89 +483,6 @@ LogicalResult ConvertTFAssertOp::matchAndRewrite(
return success();
}
StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
Location loc,
ShapedType shaped_type,
int value) {
Type element_type = shaped_type.getElementType();
ShapedType scalar_type = RankedTensorType::get({}, element_type);
Attribute attr;
switch (element_type.getKind()) {
case mlir::StandardTypes::F16: {
auto floatType = mlir::FloatType::getF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::BF16: {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::F32: {
attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
break;
}
case mlir::StandardTypes::Complex: {
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
if (etype.isF32()) {
auto dialect = etype.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
shape->set_unknown_rank(false);
shape->add_dim()->set_size(int64_t{1});
std::string content;
auto complex_value =
std::complex<float>(static_cast<float>(value), 0.0f);
content.assign(reinterpret_cast<const char*>(&complex_value),
sizeof(complex_value));
repr.set_tensor_content(content);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
break;
}
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
}
case mlir::StandardTypes::Integer: {
const auto& itype = element_type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 8:
attr = DenseElementsAttr::get<int8_t>(scalar_type,
static_cast<int8_t>(value));
break;
case 16:
attr = DenseElementsAttr::get<int16_t>(scalar_type,
static_cast<int16_t>(value));
break;
case 32:
attr = DenseElementsAttr::get<int32_t>(scalar_type,
static_cast<int32_t>(value));
break;
case 64:
attr = DenseElementsAttr::get<int64_t>(scalar_type,
static_cast<int64_t>(value));
break;
default:
return Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
break;
}
default:
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
}
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
}
LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
@ -586,31 +503,6 @@ LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
return success();
}
LogicalResult ConvertTFBroadcastToOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
auto element_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
auto output_type = tf_broadcast_to_op.output().getType();
auto status_or_const_op =
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1);
if (!status_or_const_op.ok()) {
return failure();
}
auto tfl_fill_op = rewriter.create<TFL::FillOp>(
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
status_or_const_op.ValueOrDie());
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<TFL::MulOp>(
op, output_type, tf_broadcast_to_op.input(), tfl_fill_op,
fused_activation_function);
return success();
}
// Legalize unidirectional sequence lstm.
struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
@ -751,7 +643,7 @@ void LegalizeTF::runOnFunction() {
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
ConvertTFRandomUniformOp>(context);
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,

View File

@ -57,6 +57,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
@ -686,6 +687,48 @@ struct ConvertTFStridedSlice : public RewritePattern {
}
};
struct ConvertTFBroadcastTo : public RewritePattern {
explicit ConvertTFBroadcastTo(MLIRContext *context)
: RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
Type element_type = input_type.getElementType();
// Allow lowering when low dimension inputs are given and its type is F32 or
// I32.
if (!((output_type.hasRank() && output_type.getRank() <= 5) ||
(shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
shape_type.getDimSize(0) <= 5)))
return failure();
if (!((element_type.getKind() == mlir::StandardTypes::F32) ||
(element_type.getKind() == mlir::StandardTypes::BF16) ||
(element_type.getKind() == mlir::StandardTypes::Integer &&
element_type.cast<mlir::IntegerType>().getWidth() == 32)))
return failure();
auto status_or_const_op =
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
if (!status_or_const_op.ok()) {
return failure();
}
auto tf_fill_op = rewriter.create<TF::FillOp>(
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
status_or_const_op.ValueOrDie());
auto mul_op = rewriter.create<TF::MulOp>(
op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
rewriter.replaceOp(op, mul_op.getResult());
return success();
}
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
// Returns success if all the operations in the `op`'s regions including `op`
@ -767,7 +810,7 @@ void PrepareTFPass::runOnFunction() {
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
}
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFConv2D,
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
applyPatternsAndFoldGreedily(func, patterns);
}

View File

@ -0,0 +1,112 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/status.h"
namespace mlir {
namespace TFL {
stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
int value) {
Type element_type = shaped_type.getElementType();
ShapedType scalar_type = RankedTensorType::get({}, element_type);
Attribute attr;
switch (element_type.getKind()) {
case mlir::StandardTypes::F16: {
auto floatType = mlir::FloatType::getF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::BF16: {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::F32: {
attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
break;
}
case mlir::StandardTypes::Complex: {
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
if (etype.isF32()) {
auto dialect = etype.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
shape->set_unknown_rank(false);
shape->add_dim()->set_size(int64_t{1});
std::string content;
auto complex_value =
std::complex<float>(static_cast<float>(value), 0.0f);
content.assign(reinterpret_cast<const char*>(&complex_value),
sizeof(complex_value));
repr.set_tensor_content(content);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
break;
}
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
case mlir::StandardTypes::Integer: {
const auto& itype = element_type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 8:
attr = DenseElementsAttr::get<int8_t>(scalar_type,
static_cast<int8_t>(value));
break;
case 16:
attr = DenseElementsAttr::get<int16_t>(scalar_type,
static_cast<int16_t>(value));
break;
case 32:
attr = DenseElementsAttr::get<int32_t>(scalar_type,
static_cast<int32_t>(value));
break;
case 64:
attr = DenseElementsAttr::get<int64_t>(scalar_type,
static_cast<int64_t>(value));
break;
default:
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
break;
}
default:
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/stream_executor/lib/statusor.h"
namespace mlir {
namespace TFL {
// Returns a Constant op with a single value.
stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value);
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_