Fix the bug that flatbuffer importer doesn't import dynamic shape tensor correctly.

PiperOrigin-RevId: 308198515
Change-Id: I734386cbb59dc0d57665c7d12c8e40530f905b56
This commit is contained in:
Chuan He 2020-04-23 22:46:32 -07:00 committed by TensorFlower Gardener
parent b713165f1e
commit ecfcb090c5
4 changed files with 27 additions and 13 deletions

View File

@ -599,13 +599,13 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
std::vector<int32_t> shape;
std::vector<int32_t> shape_signature;
auto* inst = value.getDefiningOp();
if (type.hasStaticShape()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} else if (auto* inst = value.getDefiningOp()) {
if (IsConst(inst)) {
} else if (inst && IsConst(inst)) {
// Const op can have a result of dynamic shaped type (e.g. due to constant
// folding), but we can still derive the shape of a constant tensor for
// its attribute type.
@ -615,7 +615,6 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
} else if (type.hasRank()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
@ -627,7 +626,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
if (auto* inst = value.getDefiningOp()) {
if (inst) {
if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
// CreateSparsityParameters(cst.s_param());
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {

View File

@ -183,6 +183,12 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
return RankedTensorType::get({}, elem_type);
}
if (!tensor.shape_signature.empty()) {
llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
tensor.shape_signature.end());
return RankedTensorType::get(shape, elem_type);
}
if (!tensor.shape.empty()) {
llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
tensor.shape.end());

View File

@ -0,0 +1,9 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// CHECK: func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32>
func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32> {
%cst = constant dense<1.0> : tensor<4xf32>
%cst_3 = constant dense<2.0> : tensor<4x3x3x3xf32>
%0 = "tfl.conv_2d"(%arg0, %cst_3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<?x19x19x3xf32>, tensor<4x3x3x3xf32>, tensor<4xf32>) -> tensor<?x9x9x4xf32>
return %0 : tensor<?x9x9x4xf32>
}

View File

@ -462,7 +462,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
3] == input_details[0]['shape_signature']).all())
output_details = interpreter.get_output_details()
self.assertTrue(([1, 16, 16,
self.assertTrue(([1, -1, 16,
3] == output_details[0]['shape_signature']).all())
def testBatchSizeValid(self):