Give out warning if model has place holder but not in the input_array specified by the user.

PiperOrigin-RevId: 251771255
This commit is contained in:
Renjie Liu 2019-06-05 19:54:52 -07:00 committed by TensorFlower Gardener
parent 9f49c5b09f
commit 42df993ba5
6 changed files with 133 additions and 88 deletions

View File

@ -25,6 +25,7 @@ import numpy as np
from tensorflow.lite.python import lite from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants from tensorflow.lite.python import lite_constants
from tensorflow.lite.python.convert import ConverterError
from tensorflow.lite.python.interpreter import Interpreter from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.client import session from tensorflow.python.client import session
@ -1190,15 +1191,17 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
input_arrays=['inputA'], input_arrays=['inputA'],
input_shapes={'inputA': [1, 16, 16, 3]}) input_shapes={'inputA': [1, 16, 16, 3]})
tflite_model = converter.convert() # Since we only partially specify the input, this is not allowed.
self.assertTrue(tflite_model) with self.assertRaises(ConverterError):
_ = converter.convert()
# Check case where input shape is None. # Check case where input shape is None.
converter = lite.TFLiteConverter.from_saved_model( converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None}) saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
tflite_model = converter.convert() # Since we only partially specify the input, this is not allowed.
self.assertTrue(tflite_model) with self.assertRaises(ConverterError):
_ = converter.convert()
def testSimpleModelTocoConverter(self): def testSimpleModelTocoConverter(self):
"""Test a SavedModel with deprecated TocoConverter.""" """Test a SavedModel with deprecated TocoConverter."""

View File

@ -564,7 +564,7 @@ void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
tensorflow::Status ConvertConstOperator( tensorflow::Status ConvertConstOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Const"); CHECK_EQ(node.op(), "Const");
const auto& tensor = GetTensorAttr(node, "value"); const auto& tensor = GetTensorAttr(node, "value");
const auto dtype = GetDataTypeAttr(node, "dtype"); const auto dtype = GetDataTypeAttr(node, "dtype");
@ -616,7 +616,7 @@ tensorflow::Status ConvertConstOperator(
tensorflow::Status ConvertConvOperator( tensorflow::Status ConvertConvOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Conv2D"); CHECK_EQ(node.op(), "Conv2D");
TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2)); TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
@ -691,7 +691,7 @@ tensorflow::Status ConvertConvOperator(
tensorflow::Status ConvertDepthwiseConvOperator( tensorflow::Status ConvertDepthwiseConvOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "DepthwiseConv2dNative"); CHECK_EQ(node.op(), "DepthwiseConv2dNative");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
@ -762,7 +762,7 @@ tensorflow::Status ConvertDepthwiseConvOperator(
tensorflow::Status ConvertDepthToSpaceOperator( tensorflow::Status ConvertDepthToSpaceOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "DepthToSpace"); CHECK_EQ(node.op(), "DepthToSpace");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
@ -778,7 +778,7 @@ tensorflow::Status ConvertDepthToSpaceOperator(
tensorflow::Status ConvertSpaceToDepthOperator( tensorflow::Status ConvertSpaceToDepthOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "SpaceToDepth"); CHECK_EQ(node.op(), "SpaceToDepth");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
@ -801,7 +801,7 @@ tensorflow::Status ConvertSpaceToDepthOperator(
tensorflow::Status ConvertBiasAddOperator( tensorflow::Status ConvertBiasAddOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "BiasAdd"); CHECK_EQ(node.op(), "BiasAdd");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
@ -818,7 +818,7 @@ tensorflow::Status ConvertBiasAddOperator(
tensorflow::Status ConvertRandomUniform( tensorflow::Status ConvertRandomUniform(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "RandomUniform"); CHECK_EQ(node.op(), "RandomUniform");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
@ -836,7 +836,7 @@ tensorflow::Status ConvertRandomUniform(
tensorflow::Status ConvertIdentityOperator( tensorflow::Status ConvertIdentityOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" || node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
node.op() == "Snapshot"); node.op() == "Snapshot");
@ -859,7 +859,7 @@ tensorflow::Status ConvertIdentityOperator(
tensorflow::Status ConvertFakeQuantWithMinMaxArgs( tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
auto* op = new FakeQuantOperator; auto* op = new FakeQuantOperator;
@ -880,7 +880,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
tensorflow::Status ConvertFakeQuantWithMinMaxVars( tensorflow::Status ConvertFakeQuantWithMinMaxVars(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
const int num_inputs = GetInputsCount(node, tf_import_flags); const int num_inputs = GetInputsCount(node, tf_import_flags);
QCHECK(num_inputs == 3 || num_inputs == 4) QCHECK(num_inputs == 3 || num_inputs == 4)
@ -902,7 +902,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars(
tensorflow::Status ConvertSqueezeOperator( tensorflow::Status ConvertSqueezeOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Squeeze"); CHECK_EQ(node.op(), "Squeeze");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
auto* op = new SqueezeOperator; auto* op = new SqueezeOperator;
@ -923,7 +923,7 @@ tensorflow::Status ConvertSqueezeOperator(
tensorflow::Status ConvertSplitOperator( tensorflow::Status ConvertSplitOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Split"); CHECK_EQ(node.op(), "Split");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new TensorFlowSplitOperator; auto* op = new TensorFlowSplitOperator;
@ -941,7 +941,7 @@ tensorflow::Status ConvertSplitOperator(
tensorflow::Status ConvertSplitVOperator( tensorflow::Status ConvertSplitVOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "SplitV"); CHECK_EQ(node.op(), "SplitV");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
auto* op = new TensorFlowSplitVOperator; auto* op = new TensorFlowSplitVOperator;
@ -960,7 +960,7 @@ tensorflow::Status ConvertSplitVOperator(
tensorflow::Status ConvertSwitchOperator( tensorflow::Status ConvertSwitchOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Switch"); CHECK_EQ(node.op(), "Switch");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new TensorFlowSwitchOperator; auto* op = new TensorFlowSwitchOperator;
@ -975,7 +975,7 @@ tensorflow::Status ConvertSwitchOperator(
tensorflow::Status ConvertSoftmaxOperator( tensorflow::Status ConvertSoftmaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Softmax"); CHECK_EQ(node.op(), "Softmax");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0); const auto& input_name = node.input(0);
@ -991,7 +991,7 @@ tensorflow::Status ConvertSoftmaxOperator(
tensorflow::Status ConvertLRNOperator( tensorflow::Status ConvertLRNOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "LRN"); CHECK_EQ(node.op(), "LRN");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0); const auto& input_name = node.input(0);
@ -1008,7 +1008,7 @@ tensorflow::Status ConvertLRNOperator(
tensorflow::Status ConvertMaxPoolOperator( tensorflow::Status ConvertMaxPoolOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "MaxPool"); CHECK_EQ(node.op(), "MaxPool");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0); const auto& input_name = node.input(0);
@ -1051,7 +1051,7 @@ tensorflow::Status ConvertMaxPoolOperator(
tensorflow::Status ConvertAvgPoolOperator( tensorflow::Status ConvertAvgPoolOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "AvgPool"); CHECK_EQ(node.op(), "AvgPool");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0); const auto& input_name = node.input(0);
@ -1090,7 +1090,7 @@ tensorflow::Status ConvertAvgPoolOperator(
tensorflow::Status ConvertBatchMatMulOperator( tensorflow::Status ConvertBatchMatMulOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* batch_matmul = new BatchMatMulOperator; auto* batch_matmul = new BatchMatMulOperator;
@ -1113,7 +1113,7 @@ tensorflow::Status ConvertBatchMatMulOperator(
tensorflow::Status ConvertMatMulOperator( tensorflow::Status ConvertMatMulOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
CHECK(!HasAttr(node, "adjoint_a") || CHECK(!HasAttr(node, "adjoint_a") ||
@ -1137,7 +1137,7 @@ tensorflow::Status ConvertMatMulOperator(
tensorflow::Status ConvertConcatOperator( tensorflow::Status ConvertConcatOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
Operator* op = nullptr; Operator* op = nullptr;
if (node.op() == "Concat") { if (node.op() == "Concat") {
op = new TensorFlowConcatOperator; op = new TensorFlowConcatOperator;
@ -1162,7 +1162,7 @@ tensorflow::Status ConvertConcatOperator(
tensorflow::Status ConvertMirrorPadOperator( tensorflow::Status ConvertMirrorPadOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
if (node.op() != "MirrorPad") { if (node.op() != "MirrorPad") {
LOG(FATAL) << "Expected MirrorPad."; LOG(FATAL) << "Expected MirrorPad.";
} }
@ -1197,7 +1197,7 @@ enum FlexSupport { kFlexOk, kFlexNotOk };
template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex> template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
tensorflow::Status ConvertSimpleOperatorGeneric( tensorflow::Status ConvertSimpleOperatorGeneric(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
if (NumInputs != kAnyNumInputs) { if (NumInputs != kAnyNumInputs) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
} }
@ -1225,18 +1225,18 @@ tensorflow::Status ConvertSimpleOperatorGeneric(
template <typename Op, int NumInputs, int NumOutputs> template <typename Op, int NumInputs, int NumOutputs>
tensorflow::Status ConvertSimpleOperator( tensorflow::Status ConvertSimpleOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>( return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
node, tf_import_flags, model); node, tf_import_flags, model_flags, model);
} }
// Convert a simple operator which is valid as a flex op. // Convert a simple operator which is valid as a flex op.
template <typename Op, int NumInputs, int NumOutputs> template <typename Op, int NumInputs, int NumOutputs>
tensorflow::Status ConvertSimpleOperatorFlexOk( tensorflow::Status ConvertSimpleOperatorFlexOk(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>( return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
node, tf_import_flags, model); node, tf_import_flags, model_flags, model);
} }
void GetOutputNamesFromNodeDef(const NodeDef& node, void GetOutputNamesFromNodeDef(const NodeDef& node,
@ -1325,7 +1325,7 @@ void GetOutputTypesFromNodeDef(const NodeDef& node,
tensorflow::Status ConvertUnsupportedOperator( tensorflow::Status ConvertUnsupportedOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
// Names of special attributes in TF graph that are used by Toco. // Names of special attributes in TF graph that are used by Toco.
static constexpr char kAttrOutputQuantized[] = "_output_quantized"; static constexpr char kAttrOutputQuantized[] = "_output_quantized";
static constexpr char kAttrOutputTypes[] = "_output_types"; static constexpr char kAttrOutputTypes[] = "_output_types";
@ -1416,14 +1416,15 @@ tensorflow::Status ConvertUnsupportedOperator(
// expensive copies of the protocol buffers downstream in the flex delegate. // expensive copies of the protocol buffers downstream in the flex delegate.
tensorflow::Status ConditionallyConvertConstOperator( tensorflow::Status ConditionallyConvertConstOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
// We avoid incomplete and zero shapes because the resulting arrays // We avoid incomplete and zero shapes because the resulting arrays
// are not completely compatible with Eager/TensorFlow. // are not completely compatible with Eager/TensorFlow.
const auto& tensor = GetTensorAttr(node, "value"); const auto& tensor = GetTensorAttr(node, "value");
const auto& shape = tensor.tensor_shape(); const auto& shape = tensor.tensor_shape();
for (const auto& dim : shape.dim()) { for (const auto& dim : shape.dim()) {
if (dim.size() <= 0) { if (dim.size() <= 0) {
return ConvertUnsupportedOperator(node, tf_import_flags, model); return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
model);
} }
} }
@ -1435,15 +1436,16 @@ tensorflow::Status ConditionallyConvertConstOperator(
case DT_STRING: case DT_STRING:
case DT_BOOL: case DT_BOOL:
case DT_COMPLEX64: case DT_COMPLEX64:
return ConvertConstOperator(node, tf_import_flags, model); return ConvertConstOperator(node, tf_import_flags, model_flags, model);
default: default:
return ConvertUnsupportedOperator(node, tf_import_flags, model); return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
model);
} }
} }
tensorflow::Status ConvertStridedSliceOperator( tensorflow::Status ConvertStridedSliceOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "StridedSlice"); CHECK_EQ(node.op(), "StridedSlice");
// TODO(soroosh): The 4th input (strides) should be e optional, to be // TODO(soroosh): The 4th input (strides) should be e optional, to be
// consistent with TF. // consistent with TF.
@ -1472,11 +1474,24 @@ tensorflow::Status ConvertStridedSliceOperator(
tensorflow::Status ConvertPlaceholderOperator( tensorflow::Status ConvertPlaceholderOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
if (node.op() == "Placeholder") { if (node.op() == "Placeholder") {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
} }
bool inside_input_arrays = false;
for (const auto& input_array : model_flags.input_arrays()) {
if (node.name() == input_array.name()) {
inside_input_arrays = true;
break;
}
}
if (!inside_input_arrays) {
model->AddInvalidInputArray(node.name());
}
auto& array = model->GetOrCreateArray(node.name()); auto& array = model->GetOrCreateArray(node.name());
if (node.attr().count("dtype")) { if (node.attr().count("dtype")) {
array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype")); array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
@ -1499,13 +1514,13 @@ tensorflow::Status ConvertPlaceholderOperator(
tensorflow::Status ConvertNoOpOperator( tensorflow::Status ConvertNoOpOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
tensorflow::Status ConvertCastOperator( tensorflow::Status ConvertCastOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Cast"); CHECK_EQ(node.op(), "Cast");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
@ -1521,7 +1536,7 @@ tensorflow::Status ConvertCastOperator(
tensorflow::Status ConvertFloorOperator( tensorflow::Status ConvertFloorOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Floor"); CHECK_EQ(node.op(), "Floor");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto data_type = GetDataTypeAttr(node, "T"); const auto data_type = GetDataTypeAttr(node, "T");
@ -1535,7 +1550,7 @@ tensorflow::Status ConvertFloorOperator(
tensorflow::Status ConvertCeilOperator( tensorflow::Status ConvertCeilOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Ceil"); CHECK_EQ(node.op(), "Ceil");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto data_type = GetDataTypeAttr(node, "T"); const auto data_type = GetDataTypeAttr(node, "T");
@ -1549,7 +1564,7 @@ tensorflow::Status ConvertCeilOperator(
tensorflow::Status ConvertRoundOperator( tensorflow::Status ConvertRoundOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Round"); CHECK_EQ(node.op(), "Round");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto data_type = GetDataTypeAttr(node, "T"); const auto data_type = GetDataTypeAttr(node, "T");
@ -1563,7 +1578,7 @@ tensorflow::Status ConvertRoundOperator(
tensorflow::Status ConvertGatherOperator( tensorflow::Status ConvertGatherOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK(node.op() == "Gather" || node.op() == "GatherV2"); CHECK(node.op() == "Gather" || node.op() == "GatherV2");
if (node.op() == "Gather") if (node.op() == "Gather")
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
@ -1592,7 +1607,7 @@ tensorflow::Status ConvertGatherOperator(
tensorflow::Status ConvertGatherNdOperator( tensorflow::Status ConvertGatherNdOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "GatherNd"); CHECK_EQ(node.op(), "GatherNd");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
@ -1608,7 +1623,7 @@ tensorflow::Status ConvertGatherNdOperator(
template <typename Op> template <typename Op>
tensorflow::Status ConvertArgMinMaxOperator( tensorflow::Status ConvertArgMinMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto axis_data_type = const auto axis_data_type =
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
@ -1628,21 +1643,23 @@ tensorflow::Status ConvertArgMinMaxOperator(
tensorflow::Status ConvertArgMaxOperator( tensorflow::Status ConvertArgMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ArgMax"); CHECK_EQ(node.op(), "ArgMax");
return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags, model); return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags,
model_flags, model);
} }
tensorflow::Status ConvertArgMinOperator( tensorflow::Status ConvertArgMinOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ArgMin"); CHECK_EQ(node.op(), "ArgMin");
return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags, model); return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags,
model_flags, model);
} }
tensorflow::Status ConvertResizeBilinearOperator( tensorflow::Status ConvertResizeBilinearOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ResizeBilinear"); CHECK_EQ(node.op(), "ResizeBilinear");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new ResizeBilinearOperator; auto* op = new ResizeBilinearOperator;
@ -1661,7 +1678,7 @@ tensorflow::Status ConvertResizeBilinearOperator(
tensorflow::Status ConvertResizeNearestNeighborOperator( tensorflow::Status ConvertResizeNearestNeighborOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ResizeNearestNeighbor"); CHECK_EQ(node.op(), "ResizeNearestNeighbor");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new ResizeNearestNeighborOperator; auto* op = new ResizeNearestNeighborOperator;
@ -1680,7 +1697,7 @@ tensorflow::Status ConvertResizeNearestNeighborOperator(
tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
@ -1730,7 +1747,7 @@ tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
tensorflow::Status ConvertFusedBatchNormOperator( tensorflow::Status ConvertFusedBatchNormOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "FusedBatchNorm"); CHECK_EQ(node.op(), "FusedBatchNorm");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
@ -1783,7 +1800,7 @@ tensorflow::Status ConvertFusedBatchNormOperator(
tensorflow::Status ConvertSpaceToBatchNDOperator( tensorflow::Status ConvertSpaceToBatchNDOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "SpaceToBatchND"); CHECK_EQ(node.op(), "SpaceToBatchND");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
@ -1799,7 +1816,7 @@ tensorflow::Status ConvertSpaceToBatchNDOperator(
tensorflow::Status ConvertBatchToSpaceNDOperator( tensorflow::Status ConvertBatchToSpaceNDOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "BatchToSpaceND"); CHECK_EQ(node.op(), "BatchToSpaceND");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
@ -1816,7 +1833,7 @@ tensorflow::Status ConvertBatchToSpaceNDOperator(
template <typename T> template <typename T>
tensorflow::Status ConvertReduceOperator( tensorflow::Status ConvertReduceOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new T; auto* op = new T;
op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(0));
@ -1833,7 +1850,7 @@ tensorflow::Status ConvertReduceOperator(
tensorflow::Status ConvertSvdfOperator( tensorflow::Status ConvertSvdfOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Svdf"); CHECK_EQ(node.op(), "Svdf");
const int input_size = GetInputsCount(node, tf_import_flags); const int input_size = GetInputsCount(node, tf_import_flags);
QCHECK(input_size == 3 || input_size == 4) QCHECK(input_size == 3 || input_size == 4)
@ -1862,7 +1879,7 @@ tensorflow::Status ConvertSvdfOperator(
// This is just bare bones support to get the shapes to propagate. // This is just bare bones support to get the shapes to propagate.
tensorflow::Status ConvertTransposeConvOperator( tensorflow::Status ConvertTransposeConvOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Conv2DBackpropInput"); CHECK_EQ(node.op(), "Conv2DBackpropInput");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
auto* op = new TransposeConvOperator; auto* op = new TransposeConvOperator;
@ -1933,7 +1950,7 @@ tensorflow::Status ConvertTransposeConvOperator(
tensorflow::Status ConvertRangeOperator( tensorflow::Status ConvertRangeOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Range"); CHECK_EQ(node.op(), "Range");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
auto* op = new RangeOperator; auto* op = new RangeOperator;
@ -1958,7 +1975,7 @@ tensorflow::Status ConvertRangeOperator(
// not directly related to tf.stack() usage. // not directly related to tf.stack() usage.
tensorflow::Status ConvertPackOperator( tensorflow::Status ConvertPackOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Pack"); CHECK_EQ(node.op(), "Pack");
auto op = absl::make_unique<PackOperator>(); auto op = absl::make_unique<PackOperator>();
const int num_inputs = GetInputsCount(node, tf_import_flags); const int num_inputs = GetInputsCount(node, tf_import_flags);
@ -1980,7 +1997,7 @@ tensorflow::Status ConvertPackOperator(
tensorflow::Status ConvertUnpackOperator( tensorflow::Status ConvertUnpackOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Unpack"); CHECK_EQ(node.op(), "Unpack");
auto op = absl::make_unique<UnpackOperator>(); auto op = absl::make_unique<UnpackOperator>();
const int num_inputs = GetInputsCount(node, tf_import_flags); const int num_inputs = GetInputsCount(node, tf_import_flags);
@ -2010,7 +2027,7 @@ tensorflow::Status ConvertUnpackOperator(
// graph visualization. // graph visualization.
tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
// At the moment, the only type of operator special-cased in this way is // At the moment, the only type of operator special-cased in this way is
// NextIteration, occurring only in control-flow cycles. // NextIteration, occurring only in control-flow cycles.
CHECK_EQ(node.op(), "NextIteration"); CHECK_EQ(node.op(), "NextIteration");
@ -2029,7 +2046,7 @@ tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
tensorflow::Status ConvertShapeOperator( tensorflow::Status ConvertShapeOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "Shape"); CHECK_EQ(node.op(), "Shape");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto out_type = const auto out_type =
@ -2045,7 +2062,7 @@ tensorflow::Status ConvertShapeOperator(
tensorflow::Status ConvertReverseSequenceOperator( tensorflow::Status ConvertReverseSequenceOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "ReverseSequence"); CHECK_EQ(node.op(), "ReverseSequence");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto op = absl::make_unique<ReverseSequenceOperator>(); auto op = absl::make_unique<ReverseSequenceOperator>();
@ -2206,7 +2223,7 @@ bool InlineAllFunctions(GraphDef* graphdef) {
tensorflow::Status ConvertTopKV2Operator( tensorflow::Status ConvertTopKV2Operator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
auto op = absl::make_unique<TopKV2Operator>(); auto op = absl::make_unique<TopKV2Operator>();
op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(0));
@ -2228,7 +2245,7 @@ tensorflow::Status ConvertTopKV2Operator(
tensorflow::Status ConvertDynamicPartitionOperator( tensorflow::Status ConvertDynamicPartitionOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
auto op = absl::make_unique<DynamicPartitionOperator>(); auto op = absl::make_unique<DynamicPartitionOperator>();
CHECK(HasAttr(node, "num_partitions")); CHECK(HasAttr(node, "num_partitions"));
op->num_partitions = GetIntAttr(node, "num_partitions"); op->num_partitions = GetIntAttr(node, "num_partitions");
@ -2246,7 +2263,7 @@ tensorflow::Status ConvertDynamicPartitionOperator(
tensorflow::Status ConvertDynamicStitchOperator( tensorflow::Status ConvertDynamicStitchOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
// The parallel and non-parallel variants are the same besides whether they // The parallel and non-parallel variants are the same besides whether they
// have a parallel loop; there are no behavioral differences. // have a parallel loop; there are no behavioral differences.
CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch"); CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
@ -2265,7 +2282,7 @@ tensorflow::Status ConvertDynamicStitchOperator(
tensorflow::Status ConvertSparseToDenseOperator( tensorflow::Status ConvertSparseToDenseOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "SparseToDense"); CHECK_EQ(node.op(), "SparseToDense");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
@ -2284,7 +2301,7 @@ tensorflow::Status ConvertSparseToDenseOperator(
tensorflow::Status ConvertOneHotOperator( tensorflow::Status ConvertOneHotOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "OneHot"); CHECK_EQ(node.op(), "OneHot");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
@ -2305,7 +2322,7 @@ tensorflow::Status ConvertOneHotOperator(
tensorflow::Status ConvertCTCBeamSearchDecoderOperator( tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "CTCBeamSearchDecoder"); CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
@ -2335,7 +2352,7 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
// with TfLite OpHint API. // with TfLite OpHint API.
tensorflow::Status ConvertUnidirectionalSequenceLstm( tensorflow::Status ConvertUnidirectionalSequenceLstm(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm"); DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
auto* op = new UnidirectionalSequenceLstmOperator(); auto* op = new UnidirectionalSequenceLstmOperator();
@ -2375,7 +2392,7 @@ tensorflow::Status ConvertUnidirectionalSequenceLstm(
tensorflow::Status ConvertLeakyReluOperator( tensorflow::Status ConvertLeakyReluOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
CHECK_EQ(node.op(), "LeakyRelu"); CHECK_EQ(node.op(), "LeakyRelu");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
@ -2390,7 +2407,7 @@ tensorflow::Status ConvertLeakyReluOperator(
tensorflow::Status ConvertUnidirectionalSequenceRnn( tensorflow::Status ConvertUnidirectionalSequenceRnn(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { const ModelFlags& model_flags, Model* model) {
DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn"); DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
auto* op = new UnidirectionalSequenceRnnOperator(); auto* op = new UnidirectionalSequenceRnnOperator();
@ -2415,7 +2432,7 @@ namespace internal {
using ConverterType = tensorflow::Status (*)( using ConverterType = tensorflow::Status (*)(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model); const ModelFlags& model_flags, Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>; using ConverterMapType = std::unordered_map<std::string, ConverterType>;
ConverterMapType GetTensorFlowNodeConverterMapForFlex() { ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
@ -2568,13 +2585,14 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
tensorflow::Status ImportTensorFlowNode( tensorflow::Status ImportTensorFlowNode(
const tensorflow::NodeDef& node, const tensorflow::NodeDef& node,
const TensorFlowImportFlags& tf_import_flags, Model* model, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags,
const ConverterMapType& converter_map) { Model* model, const ConverterMapType& converter_map) {
auto converter = converter_map.find(node.op()); auto converter = converter_map.find(node.op());
if (converter == converter_map.end()) { if (converter == converter_map.end()) {
return ConvertUnsupportedOperator(node, tf_import_flags, model); return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
model);
} else { } else {
return converter->second(node, tf_import_flags, model); return converter->second(node, tf_import_flags, model_flags, model);
} }
} }
} // namespace internal } // namespace internal
@ -2614,8 +2632,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
for (auto node : inlined_graph.node()) { for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node); StripZeroOutputIndexFromInputs(&node);
auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model, auto status = internal::ImportTensorFlowNode(
converter_map); node, tf_import_flags, model_flags, model, converter_map);
CHECK(status.ok()) << status.error_message(); CHECK(status.ok()) << status.error_message();
} }

View File

@ -44,28 +44,29 @@ using ::testing::ElementsAre;
namespace internal { namespace internal {
using ConverterType = tensorflow::Status (*)( using ConverterType = tensorflow::Status (*)(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model); const ModelFlags& model_flags, Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>; using ConverterMapType = std::unordered_map<std::string, ConverterType>;
ConverterMapType GetTensorFlowNodeConverterMap(); ConverterMapType GetTensorFlowNodeConverterMap();
ConverterMapType GetTensorFlowNodeConverterMapForFlex(); ConverterMapType GetTensorFlowNodeConverterMapForFlex();
Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
Model*, const ConverterMapType&); const ModelFlags& model_flags, Model*,
const ConverterMapType&);
} // namespace internal } // namespace internal
namespace { namespace {
Status ImportNode(const NodeDef& node, Model* model) { Status ImportNode(const NodeDef& node, Model* model) {
const auto converter = internal::GetTensorFlowNodeConverterMap(); const auto converter = internal::GetTensorFlowNodeConverterMap();
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model, return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
converter); ModelFlags(), model, converter);
} }
Status ImportFlexNode(const NodeDef& node, Model* model) { Status ImportFlexNode(const NodeDef& node, Model* model) {
// Empty converter => all nodes are flex nodes. // Empty converter => all nodes are flex nodes.
const auto converter = internal::ConverterMapType(); const auto converter = internal::ConverterMapType();
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model, return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
converter); ModelFlags(), model, converter);
} }
Status ImportNode(const NodeDef& node) { Status ImportNode(const NodeDef& node) {
@ -170,7 +171,7 @@ TEST(FlexImportTest, ConditionalConst) {
const auto converter = internal::GetTensorFlowNodeConverterMapForFlex(); const auto converter = internal::GetTensorFlowNodeConverterMapForFlex();
return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
&model, converter); ModelFlags(), &model, converter);
}; };
EXPECT_TRUE(build_and_import_node("Known", {1, 2, 3}, DT_INT32, 6).ok()); EXPECT_TRUE(build_and_import_node("Known", {1, 2, 3}, DT_INT32, 6).ok());

View File

@ -2335,6 +2335,14 @@ class Model {
int64 ArithmeticOpsCount() const { return ops_count; } int64 ArithmeticOpsCount() const { return ops_count; }
void AddInvalidInputArray(string invalid_input_array) {
invalid_input_arrays_.insert(invalid_input_array);
}
const std::unordered_set<string>& GetInvalidInputArrays() const {
return invalid_input_arrays_;
}
// Optional arrays are used for optional tensors, // Optional arrays are used for optional tensors,
// these tensors do not have data, but with reserved names as op inputs. // these tensors do not have data, but with reserved names as op inputs.
std::set<string> optional_arrays; std::set<string> optional_arrays;
@ -2361,6 +2369,9 @@ class Model {
// The Operator's refer to these Array's by their name strings, not by their // The Operator's refer to these Array's by their name strings, not by their
// addresses. See Operator::inputs, Operator::outputs. // addresses. See Operator::inputs, Operator::outputs.
std::unordered_map<string, std::unique_ptr<Array>> arrays; std::unordered_map<string, std::unique_ptr<Array>> arrays;
// Invalid input arrays.
std::unordered_set<string> invalid_input_arrays_;
}; };
// OperatorSignature contains the information required to making versioning // OperatorSignature contains the information required to making versioning

View File

@ -459,6 +459,13 @@ tensorflow::Status Export(
const Model& model, string* output_file_contents, const Model& model, string* output_file_contents,
const ExportParams& params, const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
for (const string& input_array : model.GetInvalidInputArrays()) {
if (model.HasArray(input_array)) {
return tensorflow::errors::InvalidArgument(absl::StrCat(
"Placeholder ", input_array, " should be specied by input_arrays."));
}
}
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
details::TensorsMap tensors_map; details::TensorsMap tensors_map;

View File

@ -133,6 +133,11 @@ TEST(TocoTest, TransientStringTensors) {
// input array must have a shape. // input array must have a shape.
toco_flags.set_output_format(TFLITE); toco_flags.set_output_format(TFLITE);
toco::InputArray* input_1 = model_flags.add_input_arrays();
input_1->set_name("input1");
toco::InputArray* indices_1 = model_flags.add_input_arrays();
indices_1->set_name("indices1");
model_flags.add_output_arrays("output1"); model_flags.add_output_arrays("output1");
string input = R"GraphDef( string input = R"GraphDef(
node { node {