Fix the bug that flatbuffer importer doesn't import dynamic shape tensor correctly.
PiperOrigin-RevId: 308198515 Change-Id: I734386cbb59dc0d57665c7d12c8e40530f905b56
This commit is contained in:
parent
b713165f1e
commit
ecfcb090c5
|
@ -599,23 +599,22 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
|||
|
||||
std::vector<int32_t> shape;
|
||||
std::vector<int32_t> shape_signature;
|
||||
auto* inst = value.getDefiningOp();
|
||||
if (type.hasStaticShape()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
|
||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
} else if (auto* inst = value.getDefiningOp()) {
|
||||
if (IsConst(inst)) {
|
||||
// Const op can have a result of dynamic shaped type (e.g. due to constant
|
||||
// folding), but we can still derive the shape of a constant tensor for
|
||||
// its attribute type.
|
||||
mlir::Attribute tensor_attr = inst->getAttr("value");
|
||||
llvm::ArrayRef<int64_t> shape_ref =
|
||||
tensor_attr.getType().cast<TensorType>().getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
} else if (inst && IsConst(inst)) {
|
||||
// Const op can have a result of dynamic shaped type (e.g. due to constant
|
||||
// folding), but we can still derive the shape of a constant tensor for
|
||||
// its attribute type.
|
||||
mlir::Attribute tensor_attr = inst->getAttr("value");
|
||||
llvm::ArrayRef<int64_t> shape_ref =
|
||||
tensor_attr.getType().cast<TensorType>().getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
|
||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
} else if (type.hasRank()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
|
@ -627,7 +626,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
|||
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
|
||||
if (auto* inst = value.getDefiningOp()) {
|
||||
if (inst) {
|
||||
if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
|
||||
// CreateSparsityParameters(cst.s_param());
|
||||
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
|
||||
|
|
|
@ -183,6 +183,12 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
|
|||
return RankedTensorType::get({}, elem_type);
|
||||
}
|
||||
|
||||
if (!tensor.shape_signature.empty()) {
|
||||
llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
|
||||
tensor.shape_signature.end());
|
||||
return RankedTensorType::get(shape, elem_type);
|
||||
}
|
||||
|
||||
if (!tensor.shape.empty()) {
|
||||
llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
|
||||
tensor.shape.end());
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
// CHECK: func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32>
|
||||
func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32> {
|
||||
%cst = constant dense<1.0> : tensor<4xf32>
|
||||
%cst_3 = constant dense<2.0> : tensor<4x3x3x3xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %cst_3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<?x19x19x3xf32>, tensor<4x3x3x3xf32>, tensor<4xf32>) -> tensor<?x9x9x4xf32>
|
||||
return %0 : tensor<?x9x9x4xf32>
|
||||
}
|
|
@ -462,7 +462,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||
3] == input_details[0]['shape_signature']).all())
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertTrue(([1, 16, 16,
|
||||
self.assertTrue(([1, -1, 16,
|
||||
3] == output_details[0]['shape_signature']).all())
|
||||
|
||||
def testBatchSizeValid(self):
|
||||
|
|
Loading…
Reference in New Issue