Fix the bug that flatbuffer importer doesn't import dynamic shape tensor correctly.

PiperOrigin-RevId: 308198515
Change-Id: I734386cbb59dc0d57665c7d12c8e40530f905b56
This commit is contained in:
Chuan He 2020-04-23 22:46:32 -07:00 committed by TensorFlower Gardener
parent b713165f1e
commit ecfcb090c5
4 changed files with 27 additions and 13 deletions

View File

@ -599,23 +599,22 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
std::vector<int32_t> shape; std::vector<int32_t> shape;
std::vector<int32_t> shape_signature; std::vector<int32_t> shape_signature;
auto* inst = value.getDefiningOp();
if (type.hasStaticShape()) { if (type.hasStaticShape()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape(); llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None; if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} else if (auto* inst = value.getDefiningOp()) { } else if (inst && IsConst(inst)) {
if (IsConst(inst)) { // Const op can have a result of dynamic shaped type (e.g. due to constant
// Const op can have a result of dynamic shaped type (e.g. due to constant // folding), but we can still derive the shape of a constant tensor for
// folding), but we can still derive the shape of a constant tensor for // its attribute type.
// its attribute type. mlir::Attribute tensor_attr = inst->getAttr("value");
mlir::Attribute tensor_attr = inst->getAttr("value"); llvm::ArrayRef<int64_t> shape_ref =
llvm::ArrayRef<int64_t> shape_ref = tensor_attr.getType().cast<TensorType>().getShape();
tensor_attr.getType().cast<TensorType>().getShape(); if (mlir::failed(check_shape(shape_ref))) return llvm::None;
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
} else if (type.hasRank()) { } else if (type.hasRank()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape(); llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None; if (mlir::failed(check_shape(shape_ref))) return llvm::None;
@ -627,7 +626,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} }
if (auto* inst = value.getDefiningOp()) { if (inst) {
if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) { if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
// CreateSparsityParameters(cst.s_param()); // CreateSparsityParameters(cst.s_param());
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) { } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {

View File

@ -183,6 +183,12 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
return RankedTensorType::get({}, elem_type); return RankedTensorType::get({}, elem_type);
} }
if (!tensor.shape_signature.empty()) {
llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
tensor.shape_signature.end());
return RankedTensorType::get(shape, elem_type);
}
if (!tensor.shape.empty()) { if (!tensor.shape.empty()) {
llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(), llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
tensor.shape.end()); tensor.shape.end());

View File

@ -0,0 +1,9 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// CHECK: func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32>
func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32> {
%cst = constant dense<1.0> : tensor<4xf32>
%cst_3 = constant dense<2.0> : tensor<4x3x3x3xf32>
%0 = "tfl.conv_2d"(%arg0, %cst_3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<?x19x19x3xf32>, tensor<4x3x3x3xf32>, tensor<4xf32>) -> tensor<?x9x9x4xf32>
return %0 : tensor<?x9x9x4xf32>
}

View File

@ -462,7 +462,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
3] == input_details[0]['shape_signature']).all()) 3] == input_details[0]['shape_signature']).all())
output_details = interpreter.get_output_details() output_details = interpreter.get_output_details()
self.assertTrue(([1, 16, 16, self.assertTrue(([1, -1, 16,
3] == output_details[0]['shape_signature']).all()) 3] == output_details[0]['shape_signature']).all())
def testBatchSizeValid(self): def testBatchSizeValid(self):