Give out warning if model has place holder but not in the input_array specified by the user.

PiperOrigin-RevId: 251771255
This commit is contained in:
Renjie Liu 2019-06-05 19:54:52 -07:00 committed by TensorFlower Gardener
parent 9f49c5b09f
commit 42df993ba5
6 changed files with 133 additions and 88 deletions

View File

@ -25,6 +25,7 @@ import numpy as np
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python.convert import ConverterError
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.python import keras
from tensorflow.python.client import session
@ -1190,15 +1191,17 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
input_arrays=['inputA'],
input_shapes={'inputA': [1, 16, 16, 3]})
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Since we only partially specify the input, this is not allowed.
with self.assertRaises(ConverterError):
_ = converter.convert()
# Check case where input shape is None.
converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Since we only partially specify the input, this is not allowed.
with self.assertRaises(ConverterError):
_ = converter.convert()
def testSimpleModelTocoConverter(self):
"""Test a SavedModel with deprecated TocoConverter."""

View File

@ -564,7 +564,7 @@ void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
tensorflow::Status ConvertConstOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Const");
const auto& tensor = GetTensorAttr(node, "value");
const auto dtype = GetDataTypeAttr(node, "dtype");
@ -616,7 +616,7 @@ tensorflow::Status ConvertConstOperator(
tensorflow::Status ConvertConvOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Conv2D");
TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
@ -691,7 +691,7 @@ tensorflow::Status ConvertConvOperator(
tensorflow::Status ConvertDepthwiseConvOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "DepthwiseConv2dNative");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
@ -762,7 +762,7 @@ tensorflow::Status ConvertDepthwiseConvOperator(
tensorflow::Status ConvertDepthToSpaceOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "DepthToSpace");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
@ -778,7 +778,7 @@ tensorflow::Status ConvertDepthToSpaceOperator(
tensorflow::Status ConvertSpaceToDepthOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "SpaceToDepth");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
@ -801,7 +801,7 @@ tensorflow::Status ConvertSpaceToDepthOperator(
tensorflow::Status ConvertBiasAddOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "BiasAdd");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
@ -818,7 +818,7 @@ tensorflow::Status ConvertBiasAddOperator(
tensorflow::Status ConvertRandomUniform(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "RandomUniform");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
@ -836,7 +836,7 @@ tensorflow::Status ConvertRandomUniform(
tensorflow::Status ConvertIdentityOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
node.op() == "Snapshot");
@ -859,7 +859,7 @@ tensorflow::Status ConvertIdentityOperator(
tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
auto* op = new FakeQuantOperator;
@ -880,7 +880,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
tensorflow::Status ConvertFakeQuantWithMinMaxVars(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
const int num_inputs = GetInputsCount(node, tf_import_flags);
QCHECK(num_inputs == 3 || num_inputs == 4)
@ -902,7 +902,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars(
tensorflow::Status ConvertSqueezeOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Squeeze");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
auto* op = new SqueezeOperator;
@ -923,7 +923,7 @@ tensorflow::Status ConvertSqueezeOperator(
tensorflow::Status ConvertSplitOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Split");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new TensorFlowSplitOperator;
@ -941,7 +941,7 @@ tensorflow::Status ConvertSplitOperator(
tensorflow::Status ConvertSplitVOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "SplitV");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
auto* op = new TensorFlowSplitVOperator;
@ -960,7 +960,7 @@ tensorflow::Status ConvertSplitVOperator(
tensorflow::Status ConvertSwitchOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Switch");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new TensorFlowSwitchOperator;
@ -975,7 +975,7 @@ tensorflow::Status ConvertSwitchOperator(
tensorflow::Status ConvertSoftmaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Softmax");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0);
@ -991,7 +991,7 @@ tensorflow::Status ConvertSoftmaxOperator(
tensorflow::Status ConvertLRNOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "LRN");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0);
@ -1008,7 +1008,7 @@ tensorflow::Status ConvertLRNOperator(
tensorflow::Status ConvertMaxPoolOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "MaxPool");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0);
@ -1051,7 +1051,7 @@ tensorflow::Status ConvertMaxPoolOperator(
tensorflow::Status ConvertAvgPoolOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "AvgPool");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0);
@ -1090,7 +1090,7 @@ tensorflow::Status ConvertAvgPoolOperator(
tensorflow::Status ConvertBatchMatMulOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* batch_matmul = new BatchMatMulOperator;
@ -1113,7 +1113,7 @@ tensorflow::Status ConvertBatchMatMulOperator(
tensorflow::Status ConvertMatMulOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
CHECK(!HasAttr(node, "adjoint_a") ||
@ -1137,7 +1137,7 @@ tensorflow::Status ConvertMatMulOperator(
tensorflow::Status ConvertConcatOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
Operator* op = nullptr;
if (node.op() == "Concat") {
op = new TensorFlowConcatOperator;
@ -1162,7 +1162,7 @@ tensorflow::Status ConvertConcatOperator(
tensorflow::Status ConvertMirrorPadOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
if (node.op() != "MirrorPad") {
LOG(FATAL) << "Expected MirrorPad.";
}
@ -1197,7 +1197,7 @@ enum FlexSupport { kFlexOk, kFlexNotOk };
template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
tensorflow::Status ConvertSimpleOperatorGeneric(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
if (NumInputs != kAnyNumInputs) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
}
@ -1225,18 +1225,18 @@ tensorflow::Status ConvertSimpleOperatorGeneric(
template <typename Op, int NumInputs, int NumOutputs>
tensorflow::Status ConvertSimpleOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
node, tf_import_flags, model);
node, tf_import_flags, model_flags, model);
}
// Convert a simple operator which is valid as a flex op.
template <typename Op, int NumInputs, int NumOutputs>
tensorflow::Status ConvertSimpleOperatorFlexOk(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
node, tf_import_flags, model);
node, tf_import_flags, model_flags, model);
}
void GetOutputNamesFromNodeDef(const NodeDef& node,
@ -1325,7 +1325,7 @@ void GetOutputTypesFromNodeDef(const NodeDef& node,
tensorflow::Status ConvertUnsupportedOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
// Names of special attributes in TF graph that are used by Toco.
static constexpr char kAttrOutputQuantized[] = "_output_quantized";
static constexpr char kAttrOutputTypes[] = "_output_types";
@ -1416,14 +1416,15 @@ tensorflow::Status ConvertUnsupportedOperator(
// expensive copies of the protocol buffers downstream in the flex delegate.
tensorflow::Status ConditionallyConvertConstOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
// We avoid incomplete and zero shapes because the resulting arrays
// are not completely compatible with Eager/TensorFlow.
const auto& tensor = GetTensorAttr(node, "value");
const auto& shape = tensor.tensor_shape();
for (const auto& dim : shape.dim()) {
if (dim.size() <= 0) {
return ConvertUnsupportedOperator(node, tf_import_flags, model);
return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
model);
}
}
@ -1435,15 +1436,16 @@ tensorflow::Status ConditionallyConvertConstOperator(
case DT_STRING:
case DT_BOOL:
case DT_COMPLEX64:
return ConvertConstOperator(node, tf_import_flags, model);
return ConvertConstOperator(node, tf_import_flags, model_flags, model);
default:
return ConvertUnsupportedOperator(node, tf_import_flags, model);
return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
model);
}
}
tensorflow::Status ConvertStridedSliceOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "StridedSlice");
// TODO(soroosh): The 4th input (strides) should be e optional, to be
// consistent with TF.
@ -1472,11 +1474,24 @@ tensorflow::Status ConvertStridedSliceOperator(
tensorflow::Status ConvertPlaceholderOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
if (node.op() == "Placeholder") {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
}
bool inside_input_arrays = false;
for (const auto& input_array : model_flags.input_arrays()) {
if (node.name() == input_array.name()) {
inside_input_arrays = true;
break;
}
}
if (!inside_input_arrays) {
model->AddInvalidInputArray(node.name());
}
auto& array = model->GetOrCreateArray(node.name());
if (node.attr().count("dtype")) {
array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
@ -1499,13 +1514,13 @@ tensorflow::Status ConvertPlaceholderOperator(
tensorflow::Status ConvertNoOpOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
return tensorflow::Status::OK();
}
tensorflow::Status ConvertCastOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Cast");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
@ -1521,7 +1536,7 @@ tensorflow::Status ConvertCastOperator(
tensorflow::Status ConvertFloorOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Floor");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto data_type = GetDataTypeAttr(node, "T");
@ -1535,7 +1550,7 @@ tensorflow::Status ConvertFloorOperator(
tensorflow::Status ConvertCeilOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Ceil");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto data_type = GetDataTypeAttr(node, "T");
@ -1549,7 +1564,7 @@ tensorflow::Status ConvertCeilOperator(
tensorflow::Status ConvertRoundOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Round");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto data_type = GetDataTypeAttr(node, "T");
@ -1563,7 +1578,7 @@ tensorflow::Status ConvertRoundOperator(
tensorflow::Status ConvertGatherOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK(node.op() == "Gather" || node.op() == "GatherV2");
if (node.op() == "Gather")
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
@ -1592,7 +1607,7 @@ tensorflow::Status ConvertGatherOperator(
tensorflow::Status ConvertGatherNdOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "GatherNd");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
@ -1608,7 +1623,7 @@ tensorflow::Status ConvertGatherNdOperator(
template <typename Op>
tensorflow::Status ConvertArgMinMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto axis_data_type =
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
@ -1628,21 +1643,23 @@ tensorflow::Status ConvertArgMinMaxOperator(
tensorflow::Status ConvertArgMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ArgMax");
return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags, model);
return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags,
model_flags, model);
}
tensorflow::Status ConvertArgMinOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ArgMin");
return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags, model);
return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags,
model_flags, model);
}
tensorflow::Status ConvertResizeBilinearOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ResizeBilinear");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new ResizeBilinearOperator;
@ -1661,7 +1678,7 @@ tensorflow::Status ConvertResizeBilinearOperator(
tensorflow::Status ConvertResizeNearestNeighborOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ResizeNearestNeighbor");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new ResizeNearestNeighborOperator;
@ -1680,7 +1697,7 @@ tensorflow::Status ConvertResizeNearestNeighborOperator(
tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
@ -1730,7 +1747,7 @@ tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
tensorflow::Status ConvertFusedBatchNormOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "FusedBatchNorm");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
@ -1783,7 +1800,7 @@ tensorflow::Status ConvertFusedBatchNormOperator(
tensorflow::Status ConvertSpaceToBatchNDOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "SpaceToBatchND");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
@ -1799,7 +1816,7 @@ tensorflow::Status ConvertSpaceToBatchNDOperator(
tensorflow::Status ConvertBatchToSpaceNDOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "BatchToSpaceND");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
@ -1816,7 +1833,7 @@ tensorflow::Status ConvertBatchToSpaceNDOperator(
template <typename T>
tensorflow::Status ConvertReduceOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new T;
op->inputs.push_back(node.input(0));
@ -1833,7 +1850,7 @@ tensorflow::Status ConvertReduceOperator(
tensorflow::Status ConvertSvdfOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Svdf");
const int input_size = GetInputsCount(node, tf_import_flags);
QCHECK(input_size == 3 || input_size == 4)
@ -1862,7 +1879,7 @@ tensorflow::Status ConvertSvdfOperator(
// This is just bare bones support to get the shapes to propagate.
tensorflow::Status ConvertTransposeConvOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Conv2DBackpropInput");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
auto* op = new TransposeConvOperator;
@ -1933,7 +1950,7 @@ tensorflow::Status ConvertTransposeConvOperator(
tensorflow::Status ConvertRangeOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Range");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
auto* op = new RangeOperator;
@ -1958,7 +1975,7 @@ tensorflow::Status ConvertRangeOperator(
// not directly related to tf.stack() usage.
tensorflow::Status ConvertPackOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Pack");
auto op = absl::make_unique<PackOperator>();
const int num_inputs = GetInputsCount(node, tf_import_flags);
@ -1980,7 +1997,7 @@ tensorflow::Status ConvertPackOperator(
tensorflow::Status ConvertUnpackOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Unpack");
auto op = absl::make_unique<UnpackOperator>();
const int num_inputs = GetInputsCount(node, tf_import_flags);
@ -2010,7 +2027,7 @@ tensorflow::Status ConvertUnpackOperator(
// graph visualization.
tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
// At the moment, the only type of operator special-cased in this way is
// NextIteration, occurring only in control-flow cycles.
CHECK_EQ(node.op(), "NextIteration");
@ -2029,7 +2046,7 @@ tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
tensorflow::Status ConvertShapeOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Shape");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto out_type =
@ -2045,7 +2062,7 @@ tensorflow::Status ConvertShapeOperator(
tensorflow::Status ConvertReverseSequenceOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ReverseSequence");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto op = absl::make_unique<ReverseSequenceOperator>();
@ -2206,7 +2223,7 @@ bool InlineAllFunctions(GraphDef* graphdef) {
tensorflow::Status ConvertTopKV2Operator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
auto op = absl::make_unique<TopKV2Operator>();
op->inputs.push_back(node.input(0));
@ -2228,7 +2245,7 @@ tensorflow::Status ConvertTopKV2Operator(
tensorflow::Status ConvertDynamicPartitionOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
auto op = absl::make_unique<DynamicPartitionOperator>();
CHECK(HasAttr(node, "num_partitions"));
op->num_partitions = GetIntAttr(node, "num_partitions");
@ -2246,7 +2263,7 @@ tensorflow::Status ConvertDynamicPartitionOperator(
tensorflow::Status ConvertDynamicStitchOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
// The parallel and non-parallel variants are the same besides whether they
// have a parallel loop; there are no behavioral differences.
CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
@ -2265,7 +2282,7 @@ tensorflow::Status ConvertDynamicStitchOperator(
tensorflow::Status ConvertSparseToDenseOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "SparseToDense");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
@ -2284,7 +2301,7 @@ tensorflow::Status ConvertSparseToDenseOperator(
tensorflow::Status ConvertOneHotOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "OneHot");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
@ -2305,7 +2322,7 @@ tensorflow::Status ConvertOneHotOperator(
tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
@ -2335,7 +2352,7 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
// with TfLite OpHint API.
tensorflow::Status ConvertUnidirectionalSequenceLstm(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
auto* op = new UnidirectionalSequenceLstmOperator();
@ -2375,7 +2392,7 @@ tensorflow::Status ConvertUnidirectionalSequenceLstm(
tensorflow::Status ConvertLeakyReluOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "LeakyRelu");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
@ -2390,7 +2407,7 @@ tensorflow::Status ConvertLeakyReluOperator(
tensorflow::Status ConvertUnidirectionalSequenceRnn(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
const ModelFlags& model_flags, Model* model) {
DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
auto* op = new UnidirectionalSequenceRnnOperator();
@ -2415,7 +2432,7 @@ namespace internal {
using ConverterType = tensorflow::Status (*)(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model);
const ModelFlags& model_flags, Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
@ -2568,13 +2585,14 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
tensorflow::Status ImportTensorFlowNode(
const tensorflow::NodeDef& node,
const TensorFlowImportFlags& tf_import_flags, Model* model,
const ConverterMapType& converter_map) {
const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags,
Model* model, const ConverterMapType& converter_map) {
auto converter = converter_map.find(node.op());
if (converter == converter_map.end()) {
return ConvertUnsupportedOperator(node, tf_import_flags, model);
return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
model);
} else {
return converter->second(node, tf_import_flags, model);
return converter->second(node, tf_import_flags, model_flags, model);
}
}
} // namespace internal
@ -2614,8 +2632,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node);
auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model,
converter_map);
auto status = internal::ImportTensorFlowNode(
node, tf_import_flags, model_flags, model, converter_map);
CHECK(status.ok()) << status.error_message();
}

View File

@ -44,28 +44,29 @@ using ::testing::ElementsAre;
namespace internal {
using ConverterType = tensorflow::Status (*)(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model);
const ModelFlags& model_flags, Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
ConverterMapType GetTensorFlowNodeConverterMap();
ConverterMapType GetTensorFlowNodeConverterMapForFlex();
Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
Model*, const ConverterMapType&);
const ModelFlags& model_flags, Model*,
const ConverterMapType&);
} // namespace internal
namespace {
Status ImportNode(const NodeDef& node, Model* model) {
const auto converter = internal::GetTensorFlowNodeConverterMap();
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
converter);
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
ModelFlags(), model, converter);
}
Status ImportFlexNode(const NodeDef& node, Model* model) {
// Empty converter => all nodes are flex nodes.
const auto converter = internal::ConverterMapType();
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
converter);
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
ModelFlags(), model, converter);
}
Status ImportNode(const NodeDef& node) {
@ -170,7 +171,7 @@ TEST(FlexImportTest, ConditionalConst) {
const auto converter = internal::GetTensorFlowNodeConverterMapForFlex();
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
&model, converter);
ModelFlags(), &model, converter);
};
EXPECT_TRUE(build_and_import_node("Known", {1, 2, 3}, DT_INT32, 6).ok());

View File

@ -2335,6 +2335,14 @@ class Model {
int64 ArithmeticOpsCount() const { return ops_count; }
void AddInvalidInputArray(string invalid_input_array) {
invalid_input_arrays_.insert(invalid_input_array);
}
const std::unordered_set<string>& GetInvalidInputArrays() const {
return invalid_input_arrays_;
}
// Optional arrays are used for optional tensors,
// these tensors do not have data, but with reserved names as op inputs.
std::set<string> optional_arrays;
@ -2361,6 +2369,9 @@ class Model {
// The Operator's refer to these Array's by their name strings, not by their
// addresses. See Operator::inputs, Operator::outputs.
std::unordered_map<string, std::unique_ptr<Array>> arrays;
// Invalid input arrays.
std::unordered_set<string> invalid_input_arrays_;
};
// OperatorSignature contains the information required to making versioning

View File

@ -459,6 +459,13 @@ tensorflow::Status Export(
const Model& model, string* output_file_contents,
const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
for (const string& input_array : model.GetInvalidInputArrays()) {
if (model.HasArray(input_array)) {
return tensorflow::errors::InvalidArgument(absl::StrCat(
"Placeholder ", input_array, " should be specied by input_arrays."));
}
}
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
details::TensorsMap tensors_map;

View File

@ -133,6 +133,11 @@ TEST(TocoTest, TransientStringTensors) {
// input array must have a shape.
toco_flags.set_output_format(TFLITE);
toco::InputArray* input_1 = model_flags.add_input_arrays();
input_1->set_name("input1");
toco::InputArray* indices_1 = model_flags.add_input_arrays();
indices_1->set_name("indices1");
model_flags.add_output_arrays("output1");
string input = R"GraphDef(
node {