Move TF Broadcast op legalization process to the prepare_tf stage
This change is to get benefits from the constant folding logic from TF dialect. PiperOrigin-RevId: 326174654 Change-Id: Icb25f11a6ac0df9904a94831f4969f5b259723a7
This commit is contained in:
parent
57f61ed3fd
commit
af9cb379b6
@ -237,6 +237,28 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant_utils",
|
||||
srcs = [
|
||||
"utils/constant_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/constant_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lstm_utils",
|
||||
srcs = [
|
||||
@ -347,6 +369,7 @@ cc_library(
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":constant_utils",
|
||||
":lstm_utils",
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
|
@ -1,12 +1,11 @@
|
||||
// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
|
||||
// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
|
||||
|
||||
func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
|
||||
return %0: tensor<3x3xbf16>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_bf16
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xbf16>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[CST]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xbf16>
|
||||
}
|
||||
|
@ -1482,28 +1482,6 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
|
||||
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
|
||||
// CHECK: }
|
||||
|
||||
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
|
||||
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
|
||||
|
@ -595,4 +595,24 @@ func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> {
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
@ -137,7 +138,6 @@ DECL_CONVERT_OP(StridedSlice);
|
||||
DECL_CONVERT_OP(Unpack);
|
||||
DECL_CONVERT_OP(Reciprocal);
|
||||
DECL_CONVERT_OP(RandomUniform);
|
||||
DECL_CONVERT_OP(BroadcastTo);
|
||||
|
||||
#undef DECL_CONVERT_OP
|
||||
|
||||
@ -483,89 +483,6 @@ LogicalResult ConvertTFAssertOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
|
||||
Location loc,
|
||||
ShapedType shaped_type,
|
||||
int value) {
|
||||
Type element_type = shaped_type.getElementType();
|
||||
ShapedType scalar_type = RankedTensorType::get({}, element_type);
|
||||
Attribute attr;
|
||||
switch (element_type.getKind()) {
|
||||
case mlir::StandardTypes::F16: {
|
||||
auto floatType = mlir::FloatType::getF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::BF16: {
|
||||
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::F32: {
|
||||
attr =
|
||||
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::Complex: {
|
||||
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
|
||||
if (etype.isF32()) {
|
||||
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
||||
tensorflow::TensorProto repr;
|
||||
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
||||
|
||||
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
|
||||
shape->set_unknown_rank(false);
|
||||
shape->add_dim()->set_size(int64_t{1});
|
||||
std::string content;
|
||||
auto complex_value =
|
||||
std::complex<float>(static_cast<float>(value), 0.0f);
|
||||
content.assign(reinterpret_cast<const char*>(&complex_value),
|
||||
sizeof(complex_value));
|
||||
repr.set_tensor_content(content);
|
||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||
|
||||
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
|
||||
break;
|
||||
}
|
||||
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = element_type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 8:
|
||||
attr = DenseElementsAttr::get<int8_t>(scalar_type,
|
||||
static_cast<int8_t>(value));
|
||||
break;
|
||||
case 16:
|
||||
attr = DenseElementsAttr::get<int16_t>(scalar_type,
|
||||
static_cast<int16_t>(value));
|
||||
break;
|
||||
case 32:
|
||||
attr = DenseElementsAttr::get<int32_t>(scalar_type,
|
||||
static_cast<int32_t>(value));
|
||||
break;
|
||||
case 64:
|
||||
attr = DenseElementsAttr::get<int64_t>(scalar_type,
|
||||
static_cast<int64_t>(value));
|
||||
break;
|
||||
default:
|
||||
return Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
|
||||
}
|
||||
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
|
||||
@ -586,31 +503,6 @@ LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFBroadcastToOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
|
||||
auto element_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
|
||||
auto output_type = tf_broadcast_to_op.output().getType();
|
||||
|
||||
auto status_or_const_op =
|
||||
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1);
|
||||
if (!status_or_const_op.ok()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto tfl_fill_op = rewriter.create<TFL::FillOp>(
|
||||
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
|
||||
status_or_const_op.ValueOrDie());
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
|
||||
rewriter.replaceOpWithNewOp<TFL::MulOp>(
|
||||
op, output_type, tf_broadcast_to_op.input(), tfl_fill_op,
|
||||
fused_activation_function);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Legalize unidirectional sequence lstm.
|
||||
struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
||||
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
|
||||
@ -751,7 +643,7 @@ void LegalizeTF::runOnFunction() {
|
||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
|
||||
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
|
||||
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
|
||||
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
|
||||
ConvertTFRandomUniformOp>(context);
|
||||
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||
|
@ -57,6 +57,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
||||
@ -686,6 +687,48 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTFBroadcastTo : public RewritePattern {
|
||||
explicit ConvertTFBroadcastTo(MLIRContext *context)
|
||||
: RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
|
||||
auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
|
||||
auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
|
||||
auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
|
||||
Type element_type = input_type.getElementType();
|
||||
|
||||
// Allow lowering when low dimension inputs are given and its type is F32 or
|
||||
// I32.
|
||||
if (!((output_type.hasRank() && output_type.getRank() <= 5) ||
|
||||
(shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
|
||||
shape_type.getDimSize(0) <= 5)))
|
||||
return failure();
|
||||
|
||||
if (!((element_type.getKind() == mlir::StandardTypes::F32) ||
|
||||
(element_type.getKind() == mlir::StandardTypes::BF16) ||
|
||||
(element_type.getKind() == mlir::StandardTypes::Integer &&
|
||||
element_type.cast<mlir::IntegerType>().getWidth() == 32)))
|
||||
return failure();
|
||||
|
||||
auto status_or_const_op =
|
||||
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
|
||||
if (!status_or_const_op.ok()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto tf_fill_op = rewriter.create<TF::FillOp>(
|
||||
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
|
||||
status_or_const_op.ValueOrDie());
|
||||
|
||||
auto mul_op = rewriter.create<TF::MulOp>(
|
||||
op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
|
||||
rewriter.replaceOp(op, mul_op.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
|
||||
|
||||
// Returns success if all the operations in the `op`'s regions including `op`
|
||||
@ -767,7 +810,7 @@ void PrepareTFPass::runOnFunction() {
|
||||
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
|
||||
}
|
||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
|
||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFConv2D,
|
||||
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
112
tensorflow/compiler/mlir/lite/utils/constant_utils.cc
Normal file
112
tensorflow/compiler/mlir/lite/utils/constant_utils.cc
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
|
||||
PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
|
||||
int value) {
|
||||
Type element_type = shaped_type.getElementType();
|
||||
ShapedType scalar_type = RankedTensorType::get({}, element_type);
|
||||
Attribute attr;
|
||||
switch (element_type.getKind()) {
|
||||
case mlir::StandardTypes::F16: {
|
||||
auto floatType = mlir::FloatType::getF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::BF16: {
|
||||
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::F32: {
|
||||
attr =
|
||||
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::Complex: {
|
||||
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
|
||||
if (etype.isF32()) {
|
||||
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
||||
tensorflow::TensorProto repr;
|
||||
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
||||
|
||||
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
|
||||
shape->set_unknown_rank(false);
|
||||
shape->add_dim()->set_size(int64_t{1});
|
||||
std::string content;
|
||||
auto complex_value =
|
||||
std::complex<float>(static_cast<float>(value), 0.0f);
|
||||
content.assign(reinterpret_cast<const char*>(&complex_value),
|
||||
sizeof(complex_value));
|
||||
repr.set_tensor_content(content);
|
||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||
|
||||
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
|
||||
break;
|
||||
}
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = element_type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 8:
|
||||
attr = DenseElementsAttr::get<int8_t>(scalar_type,
|
||||
static_cast<int8_t>(value));
|
||||
break;
|
||||
case 16:
|
||||
attr = DenseElementsAttr::get<int16_t>(scalar_type,
|
||||
static_cast<int16_t>(value));
|
||||
break;
|
||||
case 32:
|
||||
attr = DenseElementsAttr::get<int32_t>(scalar_type,
|
||||
static_cast<int32_t>(value));
|
||||
break;
|
||||
case 64:
|
||||
attr = DenseElementsAttr::get<int64_t>(scalar_type,
|
||||
static_cast<int64_t>(value));
|
||||
break;
|
||||
default:
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
35
tensorflow/compiler/mlir/lite/utils/constant_utils.h
Normal file
35
tensorflow/compiler/mlir/lite/utils/constant_utils.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Returns a Constant op with a single value.
|
||||
stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
|
||||
PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value);
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
|
Loading…
x
Reference in New Issue
Block a user