diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 8319942a226..4e4c26b3154 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -716,6 +716,48 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) + def testRankAndShapeWithResizeInput(self): + # This is a regression test to ensure `Rank` and `Shape` ops work well with + # the reize input tensor API. + input_tensor = array_ops.placeholder( + shape=[None, 4, 4, 3], dtype=dtypes.float32) + reshaped_tensor = array_ops.reshape(input_tensor, [1, -1]) + output_rank = array_ops.rank(reshaped_tensor, name='output_rank') + output_shape = array_ops.shape(reshaped_tensor, name='output_shape') + + sess = session.Session() + converter = lite.TFLiteConverter.from_session(sess, [input_tensor], + [output_rank, output_shape]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + interpreter = Interpreter(model_content=tflite_model) + + def verify_rank_and_shape(expected_shape): + expected_rank = len(expected_shape) + output_details = interpreter.get_output_details() + self.assertEqual(2, len(output_details)) + self.assertEqual('output_rank', output_details[0]['name']) + self.assertEqual(np.int32, output_details[0]['dtype']) + self.assertEqual('output_shape', output_details[1]['name']) + self.assertEqual(np.int32, output_details[1]['dtype']) + self.assertEqual( + expected_shape, + interpreter.get_tensor(output_details[1]['index']).tolist()) + self.assertEqual( + expected_rank, + interpreter.get_tensor(output_details[0]['index']).tolist()) + + interpreter.allocate_tensors() + interpreter.invoke() + verify_rank_and_shape([1, 48]) + + input_details = interpreter.get_input_details() + interpreter.resize_tensor_input(input_details[0]['index'], [1, 4, 4, 4]) + interpreter.allocate_tensors() + interpreter.invoke() + verify_rank_and_shape([1, 64]) + @test_util.run_v1_only('b/120545219') class FromFrozenGraphFile(test_util.TensorFlowTestCase): diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h index d92733ba3b5..7af23a65e1e 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h @@ -199,6 +199,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRankOnlyForConstantInput) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index f5704bf0b73..bc85e43bba3 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1867,7 +1867,9 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { CHECK_EQ(op->inputs.size(), 4); const Array& output_shape_array = model->GetArray(op->inputs[1]); + // Return if the shape array has unknown shape or has no data. if (!output_shape_array.has_shape()) return; + if (!output_shape_array.buffer) return; CHECK_EQ(output_shape_array.shape().dimensions_count(), 1); // Output should not go over four dimensions. diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc index f142719f8ca..c9cede60491 100644 --- a/tensorflow/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/lite/toco/graph_transformations/quantize.cc @@ -70,7 +70,8 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kReduceMin || type == OperatorType::kTransposeConv || type == OperatorType::kMatrixSetDiag || - type == OperatorType::kMatrixDiag; + type == OperatorType::kMatrixDiag || type == OperatorType::kRange || + type == OperatorType::kRank; } // The quantized op allows output arrays of type float using diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index 00ab8588279..86b7cbb2c1e 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -19,9 +19,11 @@ limitations under the License. namespace toco { -::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model, - std::size_t op_index, - bool* modified) { +namespace { + +::tensorflow::Status ResolveConstantShapeOrRankImpl( + Model* model, std::size_t op_index, bool only_for_constant_input, + bool* modified) { *modified = false; const auto it = model->operators.begin() + op_index; const auto* op = it->get(); @@ -42,6 +44,10 @@ namespace toco { return ::tensorflow::Status::OK(); } + if (only_for_constant_input && !input_array.buffer) { + return ::tensorflow::Status::OK(); + } + if (!output_array.has_shape()) { // Yield until the output shape has been resolved. return ::tensorflow::Status::OK(); @@ -72,4 +78,23 @@ namespace toco { return ::tensorflow::Status::OK(); } +} // namespace + +::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model, + std::size_t op_index, + bool* modified) { + return ResolveConstantShapeOrRankImpl(model, op_index, + + /*only_for_constant_input=*/false, + modified + + ); +} + +::tensorflow::Status ResolveConstantShapeOrRankOnlyForConstantInput::Run( + Model* model, std::size_t op_index, bool* modified) { + return ResolveConstantShapeOrRankImpl( + model, op_index, /*only_for_constant_input=*/true, modified); +} + } // namespace toco diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index c66ef1db915..6680a7510c7 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -115,7 +115,6 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveStridedSliceAttributes); transformations->Add(new ResolveSliceAttributes); transformations->Add(new ResolveReduceAttributes); - transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); transformations->Add(new UnpartitionEmbeddingLookup); transformations->Add(new ResolveGatherAttributes); @@ -264,6 +263,12 @@ tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags, GraphTransformationsSet transformations; MakeGeneralGraphTransformationsSet(&transformations); + + if (output_format == TFLITE) { + transformations.Add(new ResolveConstantShapeOrRankOnlyForConstantInput); + } else { + transformations.Add(new ResolveConstantShapeOrRank); + } auto* remove_trivial_reshape = new RemoveTrivialReshape; transformations.Add(remove_trivial_reshape); auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;