Give out warning if model has place holder but not in the input_array specified by the user.
PiperOrigin-RevId: 251771255
This commit is contained in:
parent
9f49c5b09f
commit
42df993ba5
@ -25,6 +25,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python.convert import ConverterError
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.client import session
|
||||
@ -1190,15 +1191,17 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
|
||||
input_arrays=['inputA'],
|
||||
input_shapes={'inputA': [1, 16, 16, 3]})
|
||||
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
# Since we only partially specify the input, this is not allowed.
|
||||
with self.assertRaises(ConverterError):
|
||||
_ = converter.convert()
|
||||
|
||||
# Check case where input shape is None.
|
||||
converter = lite.TFLiteConverter.from_saved_model(
|
||||
saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
|
||||
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
# Since we only partially specify the input, this is not allowed.
|
||||
with self.assertRaises(ConverterError):
|
||||
_ = converter.convert()
|
||||
|
||||
def testSimpleModelTocoConverter(self):
|
||||
"""Test a SavedModel with deprecated TocoConverter."""
|
||||
|
@ -564,7 +564,7 @@ void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
|
||||
|
||||
tensorflow::Status ConvertConstOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Const");
|
||||
const auto& tensor = GetTensorAttr(node, "value");
|
||||
const auto dtype = GetDataTypeAttr(node, "dtype");
|
||||
@ -616,7 +616,7 @@ tensorflow::Status ConvertConstOperator(
|
||||
|
||||
tensorflow::Status ConvertConvOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Conv2D");
|
||||
TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
|
||||
|
||||
@ -691,7 +691,7 @@ tensorflow::Status ConvertConvOperator(
|
||||
|
||||
tensorflow::Status ConvertDepthwiseConvOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "DepthwiseConv2dNative");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
|
||||
@ -762,7 +762,7 @@ tensorflow::Status ConvertDepthwiseConvOperator(
|
||||
|
||||
tensorflow::Status ConvertDepthToSpaceOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "DepthToSpace");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
|
||||
@ -778,7 +778,7 @@ tensorflow::Status ConvertDepthToSpaceOperator(
|
||||
|
||||
tensorflow::Status ConvertSpaceToDepthOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "SpaceToDepth");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
|
||||
@ -801,7 +801,7 @@ tensorflow::Status ConvertSpaceToDepthOperator(
|
||||
|
||||
tensorflow::Status ConvertBiasAddOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "BiasAdd");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
|
||||
@ -818,7 +818,7 @@ tensorflow::Status ConvertBiasAddOperator(
|
||||
|
||||
tensorflow::Status ConvertRandomUniform(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "RandomUniform");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
|
||||
@ -836,7 +836,7 @@ tensorflow::Status ConvertRandomUniform(
|
||||
|
||||
tensorflow::Status ConvertIdentityOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
|
||||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
|
||||
node.op() == "Snapshot");
|
||||
@ -859,7 +859,7 @@ tensorflow::Status ConvertIdentityOperator(
|
||||
|
||||
tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
auto* op = new FakeQuantOperator;
|
||||
@ -880,7 +880,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
|
||||
|
||||
tensorflow::Status ConvertFakeQuantWithMinMaxVars(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
|
||||
const int num_inputs = GetInputsCount(node, tf_import_flags);
|
||||
QCHECK(num_inputs == 3 || num_inputs == 4)
|
||||
@ -902,7 +902,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars(
|
||||
|
||||
tensorflow::Status ConvertSqueezeOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Squeeze");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
auto* op = new SqueezeOperator;
|
||||
@ -923,7 +923,7 @@ tensorflow::Status ConvertSqueezeOperator(
|
||||
|
||||
tensorflow::Status ConvertSplitOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Split");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
auto* op = new TensorFlowSplitOperator;
|
||||
@ -941,7 +941,7 @@ tensorflow::Status ConvertSplitOperator(
|
||||
|
||||
tensorflow::Status ConvertSplitVOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "SplitV");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
|
||||
auto* op = new TensorFlowSplitVOperator;
|
||||
@ -960,7 +960,7 @@ tensorflow::Status ConvertSplitVOperator(
|
||||
|
||||
tensorflow::Status ConvertSwitchOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Switch");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
auto* op = new TensorFlowSwitchOperator;
|
||||
@ -975,7 +975,7 @@ tensorflow::Status ConvertSwitchOperator(
|
||||
|
||||
tensorflow::Status ConvertSoftmaxOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Softmax");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto& input_name = node.input(0);
|
||||
@ -991,7 +991,7 @@ tensorflow::Status ConvertSoftmaxOperator(
|
||||
|
||||
tensorflow::Status ConvertLRNOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "LRN");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto& input_name = node.input(0);
|
||||
@ -1008,7 +1008,7 @@ tensorflow::Status ConvertLRNOperator(
|
||||
|
||||
tensorflow::Status ConvertMaxPoolOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "MaxPool");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto& input_name = node.input(0);
|
||||
@ -1051,7 +1051,7 @@ tensorflow::Status ConvertMaxPoolOperator(
|
||||
|
||||
tensorflow::Status ConvertAvgPoolOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "AvgPool");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto& input_name = node.input(0);
|
||||
@ -1090,7 +1090,7 @@ tensorflow::Status ConvertAvgPoolOperator(
|
||||
|
||||
tensorflow::Status ConvertBatchMatMulOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
|
||||
auto* batch_matmul = new BatchMatMulOperator;
|
||||
@ -1113,7 +1113,7 @@ tensorflow::Status ConvertBatchMatMulOperator(
|
||||
|
||||
tensorflow::Status ConvertMatMulOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
|
||||
CHECK(!HasAttr(node, "adjoint_a") ||
|
||||
@ -1137,7 +1137,7 @@ tensorflow::Status ConvertMatMulOperator(
|
||||
|
||||
tensorflow::Status ConvertConcatOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
Operator* op = nullptr;
|
||||
if (node.op() == "Concat") {
|
||||
op = new TensorFlowConcatOperator;
|
||||
@ -1162,7 +1162,7 @@ tensorflow::Status ConvertConcatOperator(
|
||||
|
||||
tensorflow::Status ConvertMirrorPadOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
if (node.op() != "MirrorPad") {
|
||||
LOG(FATAL) << "Expected MirrorPad.";
|
||||
}
|
||||
@ -1197,7 +1197,7 @@ enum FlexSupport { kFlexOk, kFlexNotOk };
|
||||
template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
|
||||
tensorflow::Status ConvertSimpleOperatorGeneric(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
if (NumInputs != kAnyNumInputs) {
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
|
||||
}
|
||||
@ -1225,18 +1225,18 @@ tensorflow::Status ConvertSimpleOperatorGeneric(
|
||||
template <typename Op, int NumInputs, int NumOutputs>
|
||||
tensorflow::Status ConvertSimpleOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
|
||||
node, tf_import_flags, model);
|
||||
node, tf_import_flags, model_flags, model);
|
||||
}
|
||||
|
||||
// Convert a simple operator which is valid as a flex op.
|
||||
template <typename Op, int NumInputs, int NumOutputs>
|
||||
tensorflow::Status ConvertSimpleOperatorFlexOk(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
|
||||
node, tf_import_flags, model);
|
||||
node, tf_import_flags, model_flags, model);
|
||||
}
|
||||
|
||||
void GetOutputNamesFromNodeDef(const NodeDef& node,
|
||||
@ -1325,7 +1325,7 @@ void GetOutputTypesFromNodeDef(const NodeDef& node,
|
||||
|
||||
tensorflow::Status ConvertUnsupportedOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
// Names of special attributes in TF graph that are used by Toco.
|
||||
static constexpr char kAttrOutputQuantized[] = "_output_quantized";
|
||||
static constexpr char kAttrOutputTypes[] = "_output_types";
|
||||
@ -1416,14 +1416,15 @@ tensorflow::Status ConvertUnsupportedOperator(
|
||||
// expensive copies of the protocol buffers downstream in the flex delegate.
|
||||
tensorflow::Status ConditionallyConvertConstOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
// We avoid incomplete and zero shapes because the resulting arrays
|
||||
// are not completely compatible with Eager/TensorFlow.
|
||||
const auto& tensor = GetTensorAttr(node, "value");
|
||||
const auto& shape = tensor.tensor_shape();
|
||||
for (const auto& dim : shape.dim()) {
|
||||
if (dim.size() <= 0) {
|
||||
return ConvertUnsupportedOperator(node, tf_import_flags, model);
|
||||
return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
|
||||
model);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1435,15 +1436,16 @@ tensorflow::Status ConditionallyConvertConstOperator(
|
||||
case DT_STRING:
|
||||
case DT_BOOL:
|
||||
case DT_COMPLEX64:
|
||||
return ConvertConstOperator(node, tf_import_flags, model);
|
||||
return ConvertConstOperator(node, tf_import_flags, model_flags, model);
|
||||
default:
|
||||
return ConvertUnsupportedOperator(node, tf_import_flags, model);
|
||||
return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
|
||||
model);
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertStridedSliceOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "StridedSlice");
|
||||
// TODO(soroosh): The 4th input (strides) should be e optional, to be
|
||||
// consistent with TF.
|
||||
@ -1472,11 +1474,24 @@ tensorflow::Status ConvertStridedSliceOperator(
|
||||
|
||||
tensorflow::Status ConvertPlaceholderOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
|
||||
if (node.op() == "Placeholder") {
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
|
||||
}
|
||||
|
||||
bool inside_input_arrays = false;
|
||||
for (const auto& input_array : model_flags.input_arrays()) {
|
||||
if (node.name() == input_array.name()) {
|
||||
inside_input_arrays = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!inside_input_arrays) {
|
||||
model->AddInvalidInputArray(node.name());
|
||||
}
|
||||
|
||||
auto& array = model->GetOrCreateArray(node.name());
|
||||
if (node.attr().count("dtype")) {
|
||||
array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
|
||||
@ -1499,13 +1514,13 @@ tensorflow::Status ConvertPlaceholderOperator(
|
||||
|
||||
tensorflow::Status ConvertNoOpOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertCastOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Cast");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
|
||||
@ -1521,7 +1536,7 @@ tensorflow::Status ConvertCastOperator(
|
||||
|
||||
tensorflow::Status ConvertFloorOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Floor");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto data_type = GetDataTypeAttr(node, "T");
|
||||
@ -1535,7 +1550,7 @@ tensorflow::Status ConvertFloorOperator(
|
||||
|
||||
tensorflow::Status ConvertCeilOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Ceil");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto data_type = GetDataTypeAttr(node, "T");
|
||||
@ -1549,7 +1564,7 @@ tensorflow::Status ConvertCeilOperator(
|
||||
|
||||
tensorflow::Status ConvertRoundOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Round");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto data_type = GetDataTypeAttr(node, "T");
|
||||
@ -1563,7 +1578,7 @@ tensorflow::Status ConvertRoundOperator(
|
||||
|
||||
tensorflow::Status ConvertGatherOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK(node.op() == "Gather" || node.op() == "GatherV2");
|
||||
if (node.op() == "Gather")
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
@ -1592,7 +1607,7 @@ tensorflow::Status ConvertGatherOperator(
|
||||
|
||||
tensorflow::Status ConvertGatherNdOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "GatherNd");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
|
||||
@ -1608,7 +1623,7 @@ tensorflow::Status ConvertGatherNdOperator(
|
||||
template <typename Op>
|
||||
tensorflow::Status ConvertArgMinMaxOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
const auto axis_data_type =
|
||||
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
|
||||
@ -1628,21 +1643,23 @@ tensorflow::Status ConvertArgMinMaxOperator(
|
||||
|
||||
tensorflow::Status ConvertArgMaxOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "ArgMax");
|
||||
return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags, model);
|
||||
return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags,
|
||||
model_flags, model);
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertArgMinOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "ArgMin");
|
||||
return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags, model);
|
||||
return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags,
|
||||
model_flags, model);
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertResizeBilinearOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "ResizeBilinear");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
auto* op = new ResizeBilinearOperator;
|
||||
@ -1661,7 +1678,7 @@ tensorflow::Status ConvertResizeBilinearOperator(
|
||||
|
||||
tensorflow::Status ConvertResizeNearestNeighborOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "ResizeNearestNeighbor");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
auto* op = new ResizeNearestNeighborOperator;
|
||||
@ -1680,7 +1697,7 @@ tensorflow::Status ConvertResizeNearestNeighborOperator(
|
||||
|
||||
tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
|
||||
|
||||
@ -1730,7 +1747,7 @@ tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
|
||||
|
||||
tensorflow::Status ConvertFusedBatchNormOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "FusedBatchNorm");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
|
||||
|
||||
@ -1783,7 +1800,7 @@ tensorflow::Status ConvertFusedBatchNormOperator(
|
||||
|
||||
tensorflow::Status ConvertSpaceToBatchNDOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "SpaceToBatchND");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
|
||||
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
|
||||
@ -1799,7 +1816,7 @@ tensorflow::Status ConvertSpaceToBatchNDOperator(
|
||||
|
||||
tensorflow::Status ConvertBatchToSpaceNDOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "BatchToSpaceND");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
|
||||
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
|
||||
@ -1816,7 +1833,7 @@ tensorflow::Status ConvertBatchToSpaceNDOperator(
|
||||
template <typename T>
|
||||
tensorflow::Status ConvertReduceOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
auto* op = new T;
|
||||
op->inputs.push_back(node.input(0));
|
||||
@ -1833,7 +1850,7 @@ tensorflow::Status ConvertReduceOperator(
|
||||
|
||||
tensorflow::Status ConvertSvdfOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Svdf");
|
||||
const int input_size = GetInputsCount(node, tf_import_flags);
|
||||
QCHECK(input_size == 3 || input_size == 4)
|
||||
@ -1862,7 +1879,7 @@ tensorflow::Status ConvertSvdfOperator(
|
||||
// This is just bare bones support to get the shapes to propagate.
|
||||
tensorflow::Status ConvertTransposeConvOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Conv2DBackpropInput");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
|
||||
auto* op = new TransposeConvOperator;
|
||||
@ -1933,7 +1950,7 @@ tensorflow::Status ConvertTransposeConvOperator(
|
||||
|
||||
tensorflow::Status ConvertRangeOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Range");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
|
||||
auto* op = new RangeOperator;
|
||||
@ -1958,7 +1975,7 @@ tensorflow::Status ConvertRangeOperator(
|
||||
// not directly related to tf.stack() usage.
|
||||
tensorflow::Status ConvertPackOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Pack");
|
||||
auto op = absl::make_unique<PackOperator>();
|
||||
const int num_inputs = GetInputsCount(node, tf_import_flags);
|
||||
@ -1980,7 +1997,7 @@ tensorflow::Status ConvertPackOperator(
|
||||
|
||||
tensorflow::Status ConvertUnpackOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Unpack");
|
||||
auto op = absl::make_unique<UnpackOperator>();
|
||||
const int num_inputs = GetInputsCount(node, tf_import_flags);
|
||||
@ -2010,7 +2027,7 @@ tensorflow::Status ConvertUnpackOperator(
|
||||
// graph visualization.
|
||||
tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
// At the moment, the only type of operator special-cased in this way is
|
||||
// NextIteration, occurring only in control-flow cycles.
|
||||
CHECK_EQ(node.op(), "NextIteration");
|
||||
@ -2029,7 +2046,7 @@ tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
|
||||
|
||||
tensorflow::Status ConvertShapeOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "Shape");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
const auto out_type =
|
||||
@ -2045,7 +2062,7 @@ tensorflow::Status ConvertShapeOperator(
|
||||
|
||||
tensorflow::Status ConvertReverseSequenceOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "ReverseSequence");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
auto op = absl::make_unique<ReverseSequenceOperator>();
|
||||
@ -2206,7 +2223,7 @@ bool InlineAllFunctions(GraphDef* graphdef) {
|
||||
|
||||
tensorflow::Status ConvertTopKV2Operator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
|
||||
auto op = absl::make_unique<TopKV2Operator>();
|
||||
op->inputs.push_back(node.input(0));
|
||||
@ -2228,7 +2245,7 @@ tensorflow::Status ConvertTopKV2Operator(
|
||||
|
||||
tensorflow::Status ConvertDynamicPartitionOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
auto op = absl::make_unique<DynamicPartitionOperator>();
|
||||
CHECK(HasAttr(node, "num_partitions"));
|
||||
op->num_partitions = GetIntAttr(node, "num_partitions");
|
||||
@ -2246,7 +2263,7 @@ tensorflow::Status ConvertDynamicPartitionOperator(
|
||||
|
||||
tensorflow::Status ConvertDynamicStitchOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
// The parallel and non-parallel variants are the same besides whether they
|
||||
// have a parallel loop; there are no behavioral differences.
|
||||
CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
|
||||
@ -2265,7 +2282,7 @@ tensorflow::Status ConvertDynamicStitchOperator(
|
||||
|
||||
tensorflow::Status ConvertSparseToDenseOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "SparseToDense");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
|
||||
|
||||
@ -2284,7 +2301,7 @@ tensorflow::Status ConvertSparseToDenseOperator(
|
||||
|
||||
tensorflow::Status ConvertOneHotOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "OneHot");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
|
||||
|
||||
@ -2305,7 +2322,7 @@ tensorflow::Status ConvertOneHotOperator(
|
||||
|
||||
tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
|
||||
|
||||
@ -2335,7 +2352,7 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
|
||||
// with TfLite OpHint API.
|
||||
tensorflow::Status ConvertUnidirectionalSequenceLstm(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
|
||||
|
||||
auto* op = new UnidirectionalSequenceLstmOperator();
|
||||
@ -2375,7 +2392,7 @@ tensorflow::Status ConvertUnidirectionalSequenceLstm(
|
||||
|
||||
tensorflow::Status ConvertLeakyReluOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK_EQ(node.op(), "LeakyRelu");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
|
||||
@ -2390,7 +2407,7 @@ tensorflow::Status ConvertLeakyReluOperator(
|
||||
|
||||
tensorflow::Status ConvertUnidirectionalSequenceRnn(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
|
||||
|
||||
auto* op = new UnidirectionalSequenceRnnOperator();
|
||||
@ -2415,7 +2432,7 @@ namespace internal {
|
||||
|
||||
using ConverterType = tensorflow::Status (*)(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model);
|
||||
const ModelFlags& model_flags, Model* model);
|
||||
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
|
||||
|
||||
ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
|
||||
@ -2568,13 +2585,14 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
|
||||
tensorflow::Status ImportTensorFlowNode(
|
||||
const tensorflow::NodeDef& node,
|
||||
const TensorFlowImportFlags& tf_import_flags, Model* model,
|
||||
const ConverterMapType& converter_map) {
|
||||
const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags,
|
||||
Model* model, const ConverterMapType& converter_map) {
|
||||
auto converter = converter_map.find(node.op());
|
||||
if (converter == converter_map.end()) {
|
||||
return ConvertUnsupportedOperator(node, tf_import_flags, model);
|
||||
return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
|
||||
model);
|
||||
} else {
|
||||
return converter->second(node, tf_import_flags, model);
|
||||
return converter->second(node, tf_import_flags, model_flags, model);
|
||||
}
|
||||
}
|
||||
} // namespace internal
|
||||
@ -2614,8 +2632,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
||||
|
||||
for (auto node : inlined_graph.node()) {
|
||||
StripZeroOutputIndexFromInputs(&node);
|
||||
auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model,
|
||||
converter_map);
|
||||
auto status = internal::ImportTensorFlowNode(
|
||||
node, tf_import_flags, model_flags, model, converter_map);
|
||||
CHECK(status.ok()) << status.error_message();
|
||||
}
|
||||
|
||||
|
@ -44,28 +44,29 @@ using ::testing::ElementsAre;
|
||||
namespace internal {
|
||||
using ConverterType = tensorflow::Status (*)(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model);
|
||||
const ModelFlags& model_flags, Model* model);
|
||||
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
|
||||
|
||||
ConverterMapType GetTensorFlowNodeConverterMap();
|
||||
ConverterMapType GetTensorFlowNodeConverterMapForFlex();
|
||||
Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
|
||||
Model*, const ConverterMapType&);
|
||||
const ModelFlags& model_flags, Model*,
|
||||
const ConverterMapType&);
|
||||
} // namespace internal
|
||||
|
||||
namespace {
|
||||
|
||||
Status ImportNode(const NodeDef& node, Model* model) {
|
||||
const auto converter = internal::GetTensorFlowNodeConverterMap();
|
||||
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
|
||||
converter);
|
||||
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
|
||||
ModelFlags(), model, converter);
|
||||
}
|
||||
|
||||
Status ImportFlexNode(const NodeDef& node, Model* model) {
|
||||
// Empty converter => all nodes are flex nodes.
|
||||
const auto converter = internal::ConverterMapType();
|
||||
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
|
||||
converter);
|
||||
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
|
||||
ModelFlags(), model, converter);
|
||||
}
|
||||
|
||||
Status ImportNode(const NodeDef& node) {
|
||||
@ -170,7 +171,7 @@ TEST(FlexImportTest, ConditionalConst) {
|
||||
|
||||
const auto converter = internal::GetTensorFlowNodeConverterMapForFlex();
|
||||
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
|
||||
&model, converter);
|
||||
ModelFlags(), &model, converter);
|
||||
};
|
||||
|
||||
EXPECT_TRUE(build_and_import_node("Known", {1, 2, 3}, DT_INT32, 6).ok());
|
||||
|
@ -2335,6 +2335,14 @@ class Model {
|
||||
|
||||
int64 ArithmeticOpsCount() const { return ops_count; }
|
||||
|
||||
void AddInvalidInputArray(string invalid_input_array) {
|
||||
invalid_input_arrays_.insert(invalid_input_array);
|
||||
}
|
||||
|
||||
const std::unordered_set<string>& GetInvalidInputArrays() const {
|
||||
return invalid_input_arrays_;
|
||||
}
|
||||
|
||||
// Optional arrays are used for optional tensors,
|
||||
// these tensors do not have data, but with reserved names as op inputs.
|
||||
std::set<string> optional_arrays;
|
||||
@ -2361,6 +2369,9 @@ class Model {
|
||||
// The Operator's refer to these Array's by their name strings, not by their
|
||||
// addresses. See Operator::inputs, Operator::outputs.
|
||||
std::unordered_map<string, std::unique_ptr<Array>> arrays;
|
||||
|
||||
// Invalid input arrays.
|
||||
std::unordered_set<string> invalid_input_arrays_;
|
||||
};
|
||||
|
||||
// OperatorSignature contains the information required to making versioning
|
||||
|
@ -459,6 +459,13 @@ tensorflow::Status Export(
|
||||
const Model& model, string* output_file_contents,
|
||||
const ExportParams& params,
|
||||
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
|
||||
for (const string& input_array : model.GetInvalidInputArrays()) {
|
||||
if (model.HasArray(input_array)) {
|
||||
return tensorflow::errors::InvalidArgument(absl::StrCat(
|
||||
"Placeholder ", input_array, " should be specied by input_arrays."));
|
||||
}
|
||||
}
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
|
||||
|
||||
details::TensorsMap tensors_map;
|
||||
|
@ -133,6 +133,11 @@ TEST(TocoTest, TransientStringTensors) {
|
||||
// input array must have a shape.
|
||||
toco_flags.set_output_format(TFLITE);
|
||||
|
||||
toco::InputArray* input_1 = model_flags.add_input_arrays();
|
||||
input_1->set_name("input1");
|
||||
toco::InputArray* indices_1 = model_flags.add_input_arrays();
|
||||
indices_1->set_name("indices1");
|
||||
|
||||
model_flags.add_output_arrays("output1");
|
||||
string input = R"GraphDef(
|
||||
node {
|
||||
|
Loading…
Reference in New Issue
Block a user