Move TF Broadcast op legalization process to the prepare_tf stage
This change is to get benefits from the constant folding logic from TF dialect. PiperOrigin-RevId: 326174654 Change-Id: Icb25f11a6ac0df9904a94831f4969f5b259723a7
This commit is contained in:
parent
57f61ed3fd
commit
af9cb379b6
@ -237,6 +237,28 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "constant_utils",
|
||||||
|
srcs = [
|
||||||
|
"utils/constant_utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"utils/constant_utils.h",
|
||||||
|
],
|
||||||
|
copts = ["-std=c++14"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "lstm_utils",
|
name = "lstm_utils",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -347,6 +369,7 @@ cc_library(
|
|||||||
"transforms/passes.h",
|
"transforms/passes.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":constant_utils",
|
||||||
":lstm_utils",
|
":lstm_utils",
|
||||||
":stateful_ops_utils",
|
":stateful_ops_utils",
|
||||||
":tensorflow_lite",
|
":tensorflow_lite",
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
|
// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
|
||||||
|
|
||||||
func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> {
|
func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> {
|
||||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
|
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
|
||||||
return %0: tensor<3x3xbf16>
|
return %0: tensor<3x3xbf16>
|
||||||
|
|
||||||
// CHECK-LABEL: broadcast_to_bf16
|
// CHECK-LABEL: broadcast_to_bf16
|
||||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
|
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xbf16>
|
||||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
|
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[CST]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
|
||||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
|
|
||||||
// CHECK: return [[MUL]] : tensor<3x3xbf16>
|
// CHECK: return [[MUL]] : tensor<3x3xbf16>
|
||||||
}
|
}
|
||||||
|
@ -1482,28 +1482,6 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
|
|||||||
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
|
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
|
||||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
|
||||||
return %0: tensor<3x3xf32>
|
|
||||||
|
|
||||||
// CHECK-LABEL: broadcast_to_f32
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
|
|
||||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
|
|
||||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
|
||||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
|
||||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
|
||||||
return %0: tensor<3x3xi32>
|
|
||||||
|
|
||||||
// CHECK-LABEL: broadcast_to_i32
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
|
||||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
|
|
||||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
|
||||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
|
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
|
||||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
|
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
|
||||||
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
|
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
|
||||||
|
@ -595,4 +595,24 @@ func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> {
|
|||||||
// CHECK: return %[[RES]]
|
// CHECK: return %[[RES]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||||
|
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||||
|
return %0: tensor<3x3xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: broadcast_to_f32
|
||||||
|
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||||
|
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||||
|
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||||
|
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||||
|
return %0: tensor<3x3xi32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: broadcast_to_i32
|
||||||
|
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
|
||||||
|
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||||
|
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||||
@ -137,7 +138,6 @@ DECL_CONVERT_OP(StridedSlice);
|
|||||||
DECL_CONVERT_OP(Unpack);
|
DECL_CONVERT_OP(Unpack);
|
||||||
DECL_CONVERT_OP(Reciprocal);
|
DECL_CONVERT_OP(Reciprocal);
|
||||||
DECL_CONVERT_OP(RandomUniform);
|
DECL_CONVERT_OP(RandomUniform);
|
||||||
DECL_CONVERT_OP(BroadcastTo);
|
|
||||||
|
|
||||||
#undef DECL_CONVERT_OP
|
#undef DECL_CONVERT_OP
|
||||||
|
|
||||||
@ -483,89 +483,6 @@ LogicalResult ConvertTFAssertOp::matchAndRewrite(
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
|
|
||||||
Location loc,
|
|
||||||
ShapedType shaped_type,
|
|
||||||
int value) {
|
|
||||||
Type element_type = shaped_type.getElementType();
|
|
||||||
ShapedType scalar_type = RankedTensorType::get({}, element_type);
|
|
||||||
Attribute attr;
|
|
||||||
switch (element_type.getKind()) {
|
|
||||||
case mlir::StandardTypes::F16: {
|
|
||||||
auto floatType = mlir::FloatType::getF16(element_type.getContext());
|
|
||||||
auto floatAttr =
|
|
||||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
|
||||||
std::vector<Attribute> floatValues({floatAttr});
|
|
||||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mlir::StandardTypes::BF16: {
|
|
||||||
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
|
|
||||||
auto floatAttr =
|
|
||||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
|
||||||
std::vector<Attribute> floatValues({floatAttr});
|
|
||||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mlir::StandardTypes::F32: {
|
|
||||||
attr =
|
|
||||||
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mlir::StandardTypes::Complex: {
|
|
||||||
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
|
|
||||||
if (etype.isF32()) {
|
|
||||||
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
|
||||||
tensorflow::TensorProto repr;
|
|
||||||
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
|
||||||
|
|
||||||
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
|
|
||||||
shape->set_unknown_rank(false);
|
|
||||||
shape->add_dim()->set_size(int64_t{1});
|
|
||||||
std::string content;
|
|
||||||
auto complex_value =
|
|
||||||
std::complex<float>(static_cast<float>(value), 0.0f);
|
|
||||||
content.assign(reinterpret_cast<const char*>(&complex_value),
|
|
||||||
sizeof(complex_value));
|
|
||||||
repr.set_tensor_content(content);
|
|
||||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
|
||||||
|
|
||||||
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
|
|
||||||
}
|
|
||||||
case mlir::StandardTypes::Integer: {
|
|
||||||
const auto& itype = element_type.cast<mlir::IntegerType>();
|
|
||||||
switch (itype.getWidth()) {
|
|
||||||
case 8:
|
|
||||||
attr = DenseElementsAttr::get<int8_t>(scalar_type,
|
|
||||||
static_cast<int8_t>(value));
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
attr = DenseElementsAttr::get<int16_t>(scalar_type,
|
|
||||||
static_cast<int16_t>(value));
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
attr = DenseElementsAttr::get<int32_t>(scalar_type,
|
|
||||||
static_cast<int32_t>(value));
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
attr = DenseElementsAttr::get<int64_t>(scalar_type,
|
|
||||||
static_cast<int64_t>(value));
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return Status(tensorflow::error::INVALID_ARGUMENT,
|
|
||||||
"Unsupported type");
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
|
|
||||||
}
|
|
||||||
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
|
LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
|
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
|
||||||
@ -586,31 +503,6 @@ LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ConvertTFBroadcastToOp::matchAndRewrite(
|
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
|
||||||
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
|
|
||||||
auto element_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
|
|
||||||
auto output_type = tf_broadcast_to_op.output().getType();
|
|
||||||
|
|
||||||
auto status_or_const_op =
|
|
||||||
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1);
|
|
||||||
if (!status_or_const_op.ok()) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tfl_fill_op = rewriter.create<TFL::FillOp>(
|
|
||||||
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
|
|
||||||
status_or_const_op.ValueOrDie());
|
|
||||||
|
|
||||||
StringAttr fused_activation_function =
|
|
||||||
StringAttr::get("NONE", rewriter.getContext());
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<TFL::MulOp>(
|
|
||||||
op, output_type, tf_broadcast_to_op.input(), tfl_fill_op,
|
|
||||||
fused_activation_function);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Legalize unidirectional sequence lstm.
|
// Legalize unidirectional sequence lstm.
|
||||||
struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
||||||
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
|
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
|
||||||
@ -751,7 +643,7 @@ void LegalizeTF::runOnFunction() {
|
|||||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
|
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
|
||||||
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
|
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
|
||||||
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
|
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
|
||||||
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
|
ConvertTFRandomUniformOp>(context);
|
||||||
|
|
||||||
// Ophint python converter converted tf node pattern.
|
// Ophint python converter converted tf node pattern.
|
||||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||||
|
@ -57,6 +57,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
|
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
||||||
@ -686,6 +687,48 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ConvertTFBroadcastTo : public RewritePattern {
|
||||||
|
explicit ConvertTFBroadcastTo(MLIRContext *context)
|
||||||
|
: RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
|
||||||
|
auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
|
||||||
|
auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
|
||||||
|
auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
|
||||||
|
Type element_type = input_type.getElementType();
|
||||||
|
|
||||||
|
// Allow lowering when low dimension inputs are given and its type is F32 or
|
||||||
|
// I32.
|
||||||
|
if (!((output_type.hasRank() && output_type.getRank() <= 5) ||
|
||||||
|
(shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
|
||||||
|
shape_type.getDimSize(0) <= 5)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!((element_type.getKind() == mlir::StandardTypes::F32) ||
|
||||||
|
(element_type.getKind() == mlir::StandardTypes::BF16) ||
|
||||||
|
(element_type.getKind() == mlir::StandardTypes::Integer &&
|
||||||
|
element_type.cast<mlir::IntegerType>().getWidth() == 32)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto status_or_const_op =
|
||||||
|
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
|
||||||
|
if (!status_or_const_op.ok()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tf_fill_op = rewriter.create<TF::FillOp>(
|
||||||
|
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
|
||||||
|
status_or_const_op.ValueOrDie());
|
||||||
|
|
||||||
|
auto mul_op = rewriter.create<TF::MulOp>(
|
||||||
|
op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
|
||||||
|
rewriter.replaceOp(op, mul_op.getResult());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
|
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
|
||||||
|
|
||||||
// Returns success if all the operations in the `op`'s regions including `op`
|
// Returns success if all the operations in the `op`'s regions including `op`
|
||||||
@ -767,7 +810,7 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||||
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
|
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
|
||||||
}
|
}
|
||||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
|
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFConv2D,
|
||||||
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
||||||
applyPatternsAndFoldGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
112
tensorflow/compiler/mlir/lite/utils/constant_utils.cc
Normal file
112
tensorflow/compiler/mlir/lite/utils/constant_utils.cc
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TFL {
|
||||||
|
|
||||||
|
stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
|
||||||
|
PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
|
||||||
|
int value) {
|
||||||
|
Type element_type = shaped_type.getElementType();
|
||||||
|
ShapedType scalar_type = RankedTensorType::get({}, element_type);
|
||||||
|
Attribute attr;
|
||||||
|
switch (element_type.getKind()) {
|
||||||
|
case mlir::StandardTypes::F16: {
|
||||||
|
auto floatType = mlir::FloatType::getF16(element_type.getContext());
|
||||||
|
auto floatAttr =
|
||||||
|
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||||
|
std::vector<Attribute> floatValues({floatAttr});
|
||||||
|
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case mlir::StandardTypes::BF16: {
|
||||||
|
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
|
||||||
|
auto floatAttr =
|
||||||
|
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||||
|
std::vector<Attribute> floatValues({floatAttr});
|
||||||
|
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case mlir::StandardTypes::F32: {
|
||||||
|
attr =
|
||||||
|
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case mlir::StandardTypes::Complex: {
|
||||||
|
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
|
||||||
|
if (etype.isF32()) {
|
||||||
|
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
||||||
|
tensorflow::TensorProto repr;
|
||||||
|
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
||||||
|
|
||||||
|
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
|
||||||
|
shape->set_unknown_rank(false);
|
||||||
|
shape->add_dim()->set_size(int64_t{1});
|
||||||
|
std::string content;
|
||||||
|
auto complex_value =
|
||||||
|
std::complex<float>(static_cast<float>(value), 0.0f);
|
||||||
|
content.assign(reinterpret_cast<const char*>(&complex_value),
|
||||||
|
sizeof(complex_value));
|
||||||
|
repr.set_tensor_content(content);
|
||||||
|
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||||
|
|
||||||
|
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||||
|
"Unsupported type");
|
||||||
|
}
|
||||||
|
case mlir::StandardTypes::Integer: {
|
||||||
|
const auto& itype = element_type.cast<mlir::IntegerType>();
|
||||||
|
switch (itype.getWidth()) {
|
||||||
|
case 8:
|
||||||
|
attr = DenseElementsAttr::get<int8_t>(scalar_type,
|
||||||
|
static_cast<int8_t>(value));
|
||||||
|
break;
|
||||||
|
case 16:
|
||||||
|
attr = DenseElementsAttr::get<int16_t>(scalar_type,
|
||||||
|
static_cast<int16_t>(value));
|
||||||
|
break;
|
||||||
|
case 32:
|
||||||
|
attr = DenseElementsAttr::get<int32_t>(scalar_type,
|
||||||
|
static_cast<int32_t>(value));
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
attr = DenseElementsAttr::get<int64_t>(scalar_type,
|
||||||
|
static_cast<int64_t>(value));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||||
|
"Unsupported type");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||||
|
"Unsupported type");
|
||||||
|
}
|
||||||
|
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace TFL
|
||||||
|
} // namespace mlir
|
35
tensorflow/compiler/mlir/lite/utils/constant_utils.h
Normal file
35
tensorflow/compiler/mlir/lite/utils/constant_utils.h
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Location.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TFL {
|
||||||
|
|
||||||
|
// Returns a Constant op with a single value.
|
||||||
|
stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
|
||||||
|
PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value);
|
||||||
|
|
||||||
|
} // namespace TFL
|
||||||
|
} // namespace mlir
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
|
Loading…
x
Reference in New Issue
Block a user