Fix the bug that flatbuffer importer doesn't import dynamic shape tensor correctly.
PiperOrigin-RevId: 308198515 Change-Id: I734386cbb59dc0d57665c7d12c8e40530f905b56
This commit is contained in:
parent
b713165f1e
commit
ecfcb090c5
|
@ -599,23 +599,22 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||||
|
|
||||||
std::vector<int32_t> shape;
|
std::vector<int32_t> shape;
|
||||||
std::vector<int32_t> shape_signature;
|
std::vector<int32_t> shape_signature;
|
||||||
|
auto* inst = value.getDefiningOp();
|
||||||
if (type.hasStaticShape()) {
|
if (type.hasStaticShape()) {
|
||||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||||
|
|
||||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||||
} else if (auto* inst = value.getDefiningOp()) {
|
} else if (inst && IsConst(inst)) {
|
||||||
if (IsConst(inst)) {
|
// Const op can have a result of dynamic shaped type (e.g. due to constant
|
||||||
// Const op can have a result of dynamic shaped type (e.g. due to constant
|
// folding), but we can still derive the shape of a constant tensor for
|
||||||
// folding), but we can still derive the shape of a constant tensor for
|
// its attribute type.
|
||||||
// its attribute type.
|
mlir::Attribute tensor_attr = inst->getAttr("value");
|
||||||
mlir::Attribute tensor_attr = inst->getAttr("value");
|
llvm::ArrayRef<int64_t> shape_ref =
|
||||||
llvm::ArrayRef<int64_t> shape_ref =
|
tensor_attr.getType().cast<TensorType>().getShape();
|
||||||
tensor_attr.getType().cast<TensorType>().getShape();
|
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
|
||||||
|
|
||||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||||
}
|
|
||||||
} else if (type.hasRank()) {
|
} else if (type.hasRank()) {
|
||||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||||
|
@ -627,7 +626,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||||
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* inst = value.getDefiningOp()) {
|
if (inst) {
|
||||||
if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
|
if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
|
||||||
// CreateSparsityParameters(cst.s_param());
|
// CreateSparsityParameters(cst.s_param());
|
||||||
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
|
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
|
||||||
|
|
|
@ -183,6 +183,12 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
|
||||||
return RankedTensorType::get({}, elem_type);
|
return RankedTensorType::get({}, elem_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!tensor.shape_signature.empty()) {
|
||||||
|
llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
|
||||||
|
tensor.shape_signature.end());
|
||||||
|
return RankedTensorType::get(shape, elem_type);
|
||||||
|
}
|
||||||
|
|
||||||
if (!tensor.shape.empty()) {
|
if (!tensor.shape.empty()) {
|
||||||
llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
|
llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
|
||||||
tensor.shape.end());
|
tensor.shape.end());
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||||
|
|
||||||
|
// CHECK: func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32>
|
||||||
|
func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32> {
|
||||||
|
%cst = constant dense<1.0> : tensor<4xf32>
|
||||||
|
%cst_3 = constant dense<2.0> : tensor<4x3x3x3xf32>
|
||||||
|
%0 = "tfl.conv_2d"(%arg0, %cst_3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<?x19x19x3xf32>, tensor<4x3x3x3xf32>, tensor<4xf32>) -> tensor<?x9x9x4xf32>
|
||||||
|
return %0 : tensor<?x9x9x4xf32>
|
||||||
|
}
|
|
@ -462,7 +462,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||||
3] == input_details[0]['shape_signature']).all())
|
3] == input_details[0]['shape_signature']).all())
|
||||||
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
self.assertTrue(([1, 16, 16,
|
self.assertTrue(([1, -1, 16,
|
||||||
3] == output_details[0]['shape_signature']).all())
|
3] == output_details[0]['shape_signature']).all())
|
||||||
|
|
||||||
def testBatchSizeValid(self):
|
def testBatchSizeValid(self):
|
||||||
|
|
Loading…
Reference in New Issue