From 42df993ba540eca147029396ac2f32567cf7ffe7 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Wed, 5 Jun 2019 19:54:52 -0700 Subject: [PATCH] Give out warning if model has place holder but not in the input_array specified by the user. PiperOrigin-RevId: 251771255 --- tensorflow/lite/python/lite_test.py | 11 +- tensorflow/lite/toco/import_tensorflow.cc | 172 ++++++++++-------- .../lite/toco/import_tensorflow_test.cc | 15 +- tensorflow/lite/toco/model.h | 11 ++ tensorflow/lite/toco/tflite/export.cc | 7 + tensorflow/lite/toco/toco_convert_test.cc | 5 + 6 files changed, 133 insertions(+), 88 deletions(-) diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 036fd0d818d..a753f61a4dc 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.lite.python import lite from tensorflow.lite.python import lite_constants +from tensorflow.lite.python.convert import ConverterError from tensorflow.lite.python.interpreter import Interpreter from tensorflow.python import keras from tensorflow.python.client import session @@ -1190,15 +1191,17 @@ class FromSavedModelTest(test_util.TensorFlowTestCase): input_arrays=['inputA'], input_shapes={'inputA': [1, 16, 16, 3]}) - tflite_model = converter.convert() - self.assertTrue(tflite_model) + # Since we only partially specify the input, this is not allowed. + with self.assertRaises(ConverterError): + _ = converter.convert() # Check case where input shape is None. converter = lite.TFLiteConverter.from_saved_model( saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None}) - tflite_model = converter.convert() - self.assertTrue(tflite_model) + # Since we only partially specify the input, this is not allowed. + with self.assertRaises(ConverterError): + _ = converter.convert() def testSimpleModelTocoConverter(self): """Test a SavedModel with deprecated TocoConverter.""" diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index d1efcaeafc4..d0b9edbcc06 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -564,7 +564,7 @@ void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) { tensorflow::Status ConvertConstOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Const"); const auto& tensor = GetTensorAttr(node, "value"); const auto dtype = GetDataTypeAttr(node, "dtype"); @@ -616,7 +616,7 @@ tensorflow::Status ConvertConstOperator( tensorflow::Status ConvertConvOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Conv2D"); TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2)); @@ -691,7 +691,7 @@ tensorflow::Status ConvertConvOperator( tensorflow::Status ConvertDepthwiseConvOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "DepthwiseConv2dNative"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -762,7 +762,7 @@ tensorflow::Status ConvertDepthwiseConvOperator( tensorflow::Status ConvertDepthToSpaceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "DepthToSpace"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); @@ -778,7 +778,7 @@ tensorflow::Status ConvertDepthToSpaceOperator( tensorflow::Status ConvertSpaceToDepthOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SpaceToDepth"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); @@ -801,7 +801,7 @@ tensorflow::Status ConvertSpaceToDepthOperator( tensorflow::Status ConvertBiasAddOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "BiasAdd"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -818,7 +818,7 @@ tensorflow::Status ConvertBiasAddOperator( tensorflow::Status ConvertRandomUniform( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "RandomUniform"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); @@ -836,7 +836,7 @@ tensorflow::Status ConvertRandomUniform( tensorflow::Status ConvertIdentityOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" || node.op() == "Snapshot"); @@ -859,7 +859,7 @@ tensorflow::Status ConvertIdentityOperator( tensorflow::Status ConvertFakeQuantWithMinMaxArgs( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new FakeQuantOperator; @@ -880,7 +880,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs( tensorflow::Status ConvertFakeQuantWithMinMaxVars( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); const int num_inputs = GetInputsCount(node, tf_import_flags); QCHECK(num_inputs == 3 || num_inputs == 4) @@ -902,7 +902,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars( tensorflow::Status ConvertSqueezeOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Squeeze"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new SqueezeOperator; @@ -923,7 +923,7 @@ tensorflow::Status ConvertSqueezeOperator( tensorflow::Status ConvertSplitOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Split"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSplitOperator; @@ -941,7 +941,7 @@ tensorflow::Status ConvertSplitOperator( tensorflow::Status ConvertSplitVOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SplitV"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new TensorFlowSplitVOperator; @@ -960,7 +960,7 @@ tensorflow::Status ConvertSplitVOperator( tensorflow::Status ConvertSwitchOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Switch"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSwitchOperator; @@ -975,7 +975,7 @@ tensorflow::Status ConvertSwitchOperator( tensorflow::Status ConvertSoftmaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Softmax"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); @@ -991,7 +991,7 @@ tensorflow::Status ConvertSoftmaxOperator( tensorflow::Status ConvertLRNOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "LRN"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); @@ -1008,7 +1008,7 @@ tensorflow::Status ConvertLRNOperator( tensorflow::Status ConvertMaxPoolOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "MaxPool"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); @@ -1051,7 +1051,7 @@ tensorflow::Status ConvertMaxPoolOperator( tensorflow::Status ConvertAvgPoolOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "AvgPool"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); @@ -1090,7 +1090,7 @@ tensorflow::Status ConvertAvgPoolOperator( tensorflow::Status ConvertBatchMatMulOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* batch_matmul = new BatchMatMulOperator; @@ -1113,7 +1113,7 @@ tensorflow::Status ConvertBatchMatMulOperator( tensorflow::Status ConvertMatMulOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); CHECK(!HasAttr(node, "adjoint_a") || @@ -1137,7 +1137,7 @@ tensorflow::Status ConvertMatMulOperator( tensorflow::Status ConvertConcatOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { Operator* op = nullptr; if (node.op() == "Concat") { op = new TensorFlowConcatOperator; @@ -1162,7 +1162,7 @@ tensorflow::Status ConvertConcatOperator( tensorflow::Status ConvertMirrorPadOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { if (node.op() != "MirrorPad") { LOG(FATAL) << "Expected MirrorPad."; } @@ -1197,7 +1197,7 @@ enum FlexSupport { kFlexOk, kFlexNotOk }; template tensorflow::Status ConvertSimpleOperatorGeneric( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { if (NumInputs != kAnyNumInputs) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs)); } @@ -1225,18 +1225,18 @@ tensorflow::Status ConvertSimpleOperatorGeneric( template tensorflow::Status ConvertSimpleOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { return ConvertSimpleOperatorGeneric( - node, tf_import_flags, model); + node, tf_import_flags, model_flags, model); } // Convert a simple operator which is valid as a flex op. template tensorflow::Status ConvertSimpleOperatorFlexOk( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { return ConvertSimpleOperatorGeneric( - node, tf_import_flags, model); + node, tf_import_flags, model_flags, model); } void GetOutputNamesFromNodeDef(const NodeDef& node, @@ -1325,7 +1325,7 @@ void GetOutputTypesFromNodeDef(const NodeDef& node, tensorflow::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { // Names of special attributes in TF graph that are used by Toco. static constexpr char kAttrOutputQuantized[] = "_output_quantized"; static constexpr char kAttrOutputTypes[] = "_output_types"; @@ -1416,14 +1416,15 @@ tensorflow::Status ConvertUnsupportedOperator( // expensive copies of the protocol buffers downstream in the flex delegate. tensorflow::Status ConditionallyConvertConstOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { // We avoid incomplete and zero shapes because the resulting arrays // are not completely compatible with Eager/TensorFlow. const auto& tensor = GetTensorAttr(node, "value"); const auto& shape = tensor.tensor_shape(); for (const auto& dim : shape.dim()) { if (dim.size() <= 0) { - return ConvertUnsupportedOperator(node, tf_import_flags, model); + return ConvertUnsupportedOperator(node, tf_import_flags, model_flags, + model); } } @@ -1435,15 +1436,16 @@ tensorflow::Status ConditionallyConvertConstOperator( case DT_STRING: case DT_BOOL: case DT_COMPLEX64: - return ConvertConstOperator(node, tf_import_flags, model); + return ConvertConstOperator(node, tf_import_flags, model_flags, model); default: - return ConvertUnsupportedOperator(node, tf_import_flags, model); + return ConvertUnsupportedOperator(node, tf_import_flags, model_flags, + model); } } tensorflow::Status ConvertStridedSliceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "StridedSlice"); // TODO(soroosh): The 4th input (strides) should be e optional, to be // consistent with TF. @@ -1472,11 +1474,24 @@ tensorflow::Status ConvertStridedSliceOperator( tensorflow::Status ConvertPlaceholderOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); if (node.op() == "Placeholder") { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0)); } + + bool inside_input_arrays = false; + for (const auto& input_array : model_flags.input_arrays()) { + if (node.name() == input_array.name()) { + inside_input_arrays = true; + break; + } + } + + if (!inside_input_arrays) { + model->AddInvalidInputArray(node.name()); + } + auto& array = model->GetOrCreateArray(node.name()); if (node.attr().count("dtype")) { array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype")); @@ -1499,13 +1514,13 @@ tensorflow::Status ConvertPlaceholderOperator( tensorflow::Status ConvertNoOpOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { return tensorflow::Status::OK(); } tensorflow::Status ConvertCastOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Cast"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); @@ -1521,7 +1536,7 @@ tensorflow::Status ConvertCastOperator( tensorflow::Status ConvertFloorOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Floor"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); @@ -1535,7 +1550,7 @@ tensorflow::Status ConvertFloorOperator( tensorflow::Status ConvertCeilOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Ceil"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); @@ -1549,7 +1564,7 @@ tensorflow::Status ConvertCeilOperator( tensorflow::Status ConvertRoundOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Round"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); @@ -1563,7 +1578,7 @@ tensorflow::Status ConvertRoundOperator( tensorflow::Status ConvertGatherOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK(node.op() == "Gather" || node.op() == "GatherV2"); if (node.op() == "Gather") TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -1592,7 +1607,7 @@ tensorflow::Status ConvertGatherOperator( tensorflow::Status ConvertGatherNdOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "GatherNd"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); @@ -1608,7 +1623,7 @@ tensorflow::Status ConvertGatherNdOperator( template tensorflow::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; @@ -1628,21 +1643,23 @@ tensorflow::Status ConvertArgMinMaxOperator( tensorflow::Status ConvertArgMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ArgMax"); - return ConvertArgMinMaxOperator(node, tf_import_flags, model); + return ConvertArgMinMaxOperator(node, tf_import_flags, + model_flags, model); } tensorflow::Status ConvertArgMinOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ArgMin"); - return ConvertArgMinMaxOperator(node, tf_import_flags, model); + return ConvertArgMinMaxOperator(node, tf_import_flags, + model_flags, model); } tensorflow::Status ConvertResizeBilinearOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ResizeBilinear"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new ResizeBilinearOperator; @@ -1661,7 +1678,7 @@ tensorflow::Status ConvertResizeBilinearOperator( tensorflow::Status ConvertResizeNearestNeighborOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ResizeNearestNeighbor"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new ResizeNearestNeighborOperator; @@ -1680,7 +1697,7 @@ tensorflow::Status ConvertResizeNearestNeighborOperator( tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); @@ -1730,7 +1747,7 @@ tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( tensorflow::Status ConvertFusedBatchNormOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "FusedBatchNorm"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); @@ -1783,7 +1800,7 @@ tensorflow::Status ConvertFusedBatchNormOperator( tensorflow::Status ConvertSpaceToBatchNDOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SpaceToBatchND"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); @@ -1799,7 +1816,7 @@ tensorflow::Status ConvertSpaceToBatchNDOperator( tensorflow::Status ConvertBatchToSpaceNDOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "BatchToSpaceND"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); @@ -1816,7 +1833,7 @@ tensorflow::Status ConvertBatchToSpaceNDOperator( template tensorflow::Status ConvertReduceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new T; op->inputs.push_back(node.input(0)); @@ -1833,7 +1850,7 @@ tensorflow::Status ConvertReduceOperator( tensorflow::Status ConvertSvdfOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Svdf"); const int input_size = GetInputsCount(node, tf_import_flags); QCHECK(input_size == 3 || input_size == 4) @@ -1862,7 +1879,7 @@ tensorflow::Status ConvertSvdfOperator( // This is just bare bones support to get the shapes to propagate. tensorflow::Status ConvertTransposeConvOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Conv2DBackpropInput"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new TransposeConvOperator; @@ -1933,7 +1950,7 @@ tensorflow::Status ConvertTransposeConvOperator( tensorflow::Status ConvertRangeOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Range"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new RangeOperator; @@ -1958,7 +1975,7 @@ tensorflow::Status ConvertRangeOperator( // not directly related to tf.stack() usage. tensorflow::Status ConvertPackOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Pack"); auto op = absl::make_unique(); const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1980,7 +1997,7 @@ tensorflow::Status ConvertPackOperator( tensorflow::Status ConvertUnpackOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Unpack"); auto op = absl::make_unique(); const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -2010,7 +2027,7 @@ tensorflow::Status ConvertUnpackOperator( // graph visualization. tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { // At the moment, the only type of operator special-cased in this way is // NextIteration, occurring only in control-flow cycles. CHECK_EQ(node.op(), "NextIteration"); @@ -2029,7 +2046,7 @@ tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( tensorflow::Status ConvertShapeOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Shape"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto out_type = @@ -2045,7 +2062,7 @@ tensorflow::Status ConvertShapeOperator( tensorflow::Status ConvertReverseSequenceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ReverseSequence"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto op = absl::make_unique(); @@ -2206,7 +2223,7 @@ bool InlineAllFunctions(GraphDef* graphdef) { tensorflow::Status ConvertTopKV2Operator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); auto op = absl::make_unique(); op->inputs.push_back(node.input(0)); @@ -2228,7 +2245,7 @@ tensorflow::Status ConvertTopKV2Operator( tensorflow::Status ConvertDynamicPartitionOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { auto op = absl::make_unique(); CHECK(HasAttr(node, "num_partitions")); op->num_partitions = GetIntAttr(node, "num_partitions"); @@ -2246,7 +2263,7 @@ tensorflow::Status ConvertDynamicPartitionOperator( tensorflow::Status ConvertDynamicStitchOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { // The parallel and non-parallel variants are the same besides whether they // have a parallel loop; there are no behavioral differences. CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch"); @@ -2265,7 +2282,7 @@ tensorflow::Status ConvertDynamicStitchOperator( tensorflow::Status ConvertSparseToDenseOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SparseToDense"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); @@ -2284,7 +2301,7 @@ tensorflow::Status ConvertSparseToDenseOperator( tensorflow::Status ConvertOneHotOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "OneHot"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); @@ -2305,7 +2322,7 @@ tensorflow::Status ConvertOneHotOperator( tensorflow::Status ConvertCTCBeamSearchDecoderOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "CTCBeamSearchDecoder"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -2335,7 +2352,7 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator( // with TfLite OpHint API. tensorflow::Status ConvertUnidirectionalSequenceLstm( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm"); auto* op = new UnidirectionalSequenceLstmOperator(); @@ -2375,7 +2392,7 @@ tensorflow::Status ConvertUnidirectionalSequenceLstm( tensorflow::Status ConvertLeakyReluOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "LeakyRelu"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); @@ -2390,7 +2407,7 @@ tensorflow::Status ConvertLeakyReluOperator( tensorflow::Status ConvertUnidirectionalSequenceRnn( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { + const ModelFlags& model_flags, Model* model) { DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn"); auto* op = new UnidirectionalSequenceRnnOperator(); @@ -2415,7 +2432,7 @@ namespace internal { using ConverterType = tensorflow::Status (*)( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model); + const ModelFlags& model_flags, Model* model); using ConverterMapType = std::unordered_map; ConverterMapType GetTensorFlowNodeConverterMapForFlex() { @@ -2568,13 +2585,14 @@ ConverterMapType GetTensorFlowNodeConverterMap() { tensorflow::Status ImportTensorFlowNode( const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, Model* model, - const ConverterMapType& converter_map) { + const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, + Model* model, const ConverterMapType& converter_map) { auto converter = converter_map.find(node.op()); if (converter == converter_map.end()) { - return ConvertUnsupportedOperator(node, tf_import_flags, model); + return ConvertUnsupportedOperator(node, tf_import_flags, model_flags, + model); } else { - return converter->second(node, tf_import_flags, model); + return converter->second(node, tf_import_flags, model_flags, model); } } } // namespace internal @@ -2614,8 +2632,8 @@ std::unique_ptr ImportTensorFlowGraphDef( for (auto node : inlined_graph.node()) { StripZeroOutputIndexFromInputs(&node); - auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model, - converter_map); + auto status = internal::ImportTensorFlowNode( + node, tf_import_flags, model_flags, model, converter_map); CHECK(status.ok()) << status.error_message(); } diff --git a/tensorflow/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc index b620ade756e..3e0c530290b 100644 --- a/tensorflow/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/lite/toco/import_tensorflow_test.cc @@ -44,28 +44,29 @@ using ::testing::ElementsAre; namespace internal { using ConverterType = tensorflow::Status (*)( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model); + const ModelFlags& model_flags, Model* model); using ConverterMapType = std::unordered_map; ConverterMapType GetTensorFlowNodeConverterMap(); ConverterMapType GetTensorFlowNodeConverterMapForFlex(); Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, - Model*, const ConverterMapType&); + const ModelFlags& model_flags, Model*, + const ConverterMapType&); } // namespace internal namespace { Status ImportNode(const NodeDef& node, Model* model) { const auto converter = internal::GetTensorFlowNodeConverterMap(); - return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model, - converter); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), + ModelFlags(), model, converter); } Status ImportFlexNode(const NodeDef& node, Model* model) { // Empty converter => all nodes are flex nodes. const auto converter = internal::ConverterMapType(); - return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model, - converter); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), + ModelFlags(), model, converter); } Status ImportNode(const NodeDef& node) { @@ -170,7 +171,7 @@ TEST(FlexImportTest, ConditionalConst) { const auto converter = internal::GetTensorFlowNodeConverterMapForFlex(); return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), - &model, converter); + ModelFlags(), &model, converter); }; EXPECT_TRUE(build_and_import_node("Known", {1, 2, 3}, DT_INT32, 6).ok()); diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 66558559120..77b8846f6e3 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -2335,6 +2335,14 @@ class Model { int64 ArithmeticOpsCount() const { return ops_count; } + void AddInvalidInputArray(string invalid_input_array) { + invalid_input_arrays_.insert(invalid_input_array); + } + + const std::unordered_set& GetInvalidInputArrays() const { + return invalid_input_arrays_; + } + // Optional arrays are used for optional tensors, // these tensors do not have data, but with reserved names as op inputs. std::set optional_arrays; @@ -2361,6 +2369,9 @@ class Model { // The Operator's refer to these Array's by their name strings, not by their // addresses. See Operator::inputs, Operator::outputs. std::unordered_map> arrays; + + // Invalid input arrays. + std::unordered_set invalid_input_arrays_; }; // OperatorSignature contains the information required to making versioning diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index e11544404c4..66319bd2ae1 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -459,6 +459,13 @@ tensorflow::Status Export( const Model& model, string* output_file_contents, const ExportParams& params, const std::map>& ops_by_type) { + for (const string& input_array : model.GetInvalidInputArrays()) { + if (model.HasArray(input_array)) { + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Placeholder ", input_array, " should be specied by input_arrays.")); + } + } + flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); details::TensorsMap tensors_map; diff --git a/tensorflow/lite/toco/toco_convert_test.cc b/tensorflow/lite/toco/toco_convert_test.cc index f430df6791c..270c7aadcab 100644 --- a/tensorflow/lite/toco/toco_convert_test.cc +++ b/tensorflow/lite/toco/toco_convert_test.cc @@ -133,6 +133,11 @@ TEST(TocoTest, TransientStringTensors) { // input array must have a shape. toco_flags.set_output_format(TFLITE); + toco::InputArray* input_1 = model_flags.add_input_arrays(); + input_1->set_name("input1"); + toco::InputArray* indices_1 = model_flags.add_input_arrays(); + indices_1->set_name("indices1"); + model_flags.add_output_arrays("output1"); string input = R"GraphDef( node {