diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index f9739bfa626..7ca83a636a7 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -599,23 +599,22 @@ Optional> Translator::BuildTensor( std::vector shape; std::vector shape_signature; + auto* inst = value.getDefiningOp(); if (type.hasStaticShape()) { llvm::ArrayRef shape_ref = type.getShape(); if (mlir::failed(check_shape(shape_ref))) return llvm::None; shape = std::vector(shape_ref.begin(), shape_ref.end()); - } else if (auto* inst = value.getDefiningOp()) { - if (IsConst(inst)) { - // Const op can have a result of dynamic shaped type (e.g. due to constant - // folding), but we can still derive the shape of a constant tensor for - // its attribute type. - mlir::Attribute tensor_attr = inst->getAttr("value"); - llvm::ArrayRef shape_ref = - tensor_attr.getType().cast().getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; + } else if (inst && IsConst(inst)) { + // Const op can have a result of dynamic shaped type (e.g. due to constant + // folding), but we can still derive the shape of a constant tensor for + // its attribute type. + mlir::Attribute tensor_attr = inst->getAttr("value"); + llvm::ArrayRef shape_ref = + tensor_attr.getType().cast().getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } + shape = std::vector(shape_ref.begin(), shape_ref.end()); } else if (type.hasRank()) { llvm::ArrayRef shape_ref = type.getShape(); if (mlir::failed(check_shape(shape_ref))) return llvm::None; @@ -627,7 +626,7 @@ Optional> Translator::BuildTensor( shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); } - if (auto* inst = value.getDefiningOp()) { + if (inst) { if (auto cst = dyn_cast(inst)) { // CreateSparsityParameters(cst.s_param()); } else if (auto cst = dyn_cast(inst)) { diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 71e73411309..a13caf185ee 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -183,6 +183,12 @@ StatusOr GetTensorType(const TensorT& tensor, Builder builder, return RankedTensorType::get({}, elem_type); } + if (!tensor.shape_signature.empty()) { + llvm::SmallVector shape(tensor.shape_signature.begin(), + tensor.shape_signature.end()); + return RankedTensorType::get(shape, elem_type); + } + if (!tensor.shape.empty()) { llvm::SmallVector shape(tensor.shape.begin(), tensor.shape.end()); diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dynamic_shape.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dynamic_shape.mlir new file mode 100644 index 00000000000..76e277eddcf --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dynamic_shape.mlir @@ -0,0 +1,9 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s + +// CHECK: func @main(%arg0: tensor) -> tensor +func @main(%arg0: tensor) -> tensor { + %cst = constant dense<1.0> : tensor<4xf32> + %cst_3 = constant dense<2.0> : tensor<4x3x3x3xf32> + %0 = "tfl.conv_2d"(%arg0, %cst_3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor, tensor<4x3x3x3xf32>, tensor<4xf32>) -> tensor + return %0 : tensor +} diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 1fa2c6e17d6..e327aee1376 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -462,7 +462,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): 3] == input_details[0]['shape_signature']).all()) output_details = interpreter.get_output_details() - self.assertTrue(([1, 16, 16, + self.assertTrue(([1, -1, 16, 3] == output_details[0]['shape_signature']).all()) def testBatchSizeValid(self):