Handle IdentityN with 1 input in Toco
PiperOrigin-RevId: 259072753
This commit is contained in:
parent
ba1654087a
commit
fdc106e412
@ -870,7 +870,7 @@ def make_identity_tests(options):
|
||||
# Chose a set of parameters
|
||||
test_parameters = [{
|
||||
"input_shape": [[], [1], [3, 3]],
|
||||
"use_snapshot": [False, True],
|
||||
"op_to_use": ["identity", "identity_n", "snapshot"],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
@ -884,10 +884,13 @@ def make_identity_tests(options):
|
||||
# shape, this conversion still fails.
|
||||
# TODO(b/129197312), remove the walk-around code once the bug is fixed.
|
||||
input_doubled = input_tensor * 2.0
|
||||
if parameters["use_snapshot"]:
|
||||
identity_output = array_ops.snapshot(input_doubled)
|
||||
else:
|
||||
if parameters["op_to_use"] == "identity":
|
||||
identity_output = tf.identity(input_doubled)
|
||||
elif parameters["op_to_use"] == "identity_n":
|
||||
# Testing `IdentityN` with a single tensor.
|
||||
identity_output = tf.identity_n([input_doubled])[0]
|
||||
elif parameters["op_to_use"] == "snapshot":
|
||||
identity_output = array_ops.snapshot(input_doubled)
|
||||
return [input_tensor], [identity_output]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
|
@ -562,6 +562,178 @@ void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
|
||||
node.SerializeToString(&op->tensorflow_node_def);
|
||||
}
|
||||
|
||||
void GetOutputNamesFromNodeDef(const NodeDef& node,
|
||||
const tensorflow::OpDef& op_def,
|
||||
TensorFlowUnsupportedOperator* op) {
|
||||
int next_output = 0;
|
||||
auto add_output = [&node, &next_output, op]() {
|
||||
if (next_output == 0) {
|
||||
op->outputs.push_back(node.name()); // Implicit :0.
|
||||
} else {
|
||||
op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
|
||||
}
|
||||
++next_output;
|
||||
};
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
string multiples = op_def.output_arg(i).number_attr();
|
||||
if (!multiples.empty()) {
|
||||
CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
|
||||
int num_outputs = GetIntAttr(node, multiples);
|
||||
for (int j = 0; j < num_outputs; ++j) {
|
||||
add_output();
|
||||
}
|
||||
} else {
|
||||
string list = op_def.output_arg(i).type_list_attr();
|
||||
if (!list.empty()) {
|
||||
CHECK(HasAttr(node, list)) << "No attr named " << list;
|
||||
const AttrValue::ListValue& list_value = GetListAttr(node, list);
|
||||
for (int j = 0; j < list_value.type_size(); ++j) {
|
||||
add_output();
|
||||
}
|
||||
} else {
|
||||
add_output();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetOutputTypesFromNodeDef(const NodeDef& node,
|
||||
const tensorflow::OpDef& op_def,
|
||||
TensorFlowUnsupportedOperator* op) {
|
||||
// The given type to the op, or clear the types if invalid.
|
||||
auto add_type = [&node, op](tensorflow::DataType type) {
|
||||
if (type == tensorflow::DT_INVALID) {
|
||||
LOG(WARNING) << "Op node missing output type attribute: " << node.name();
|
||||
op->output_data_types.clear();
|
||||
} else {
|
||||
op->output_data_types.push_back(ConvertDataType(type));
|
||||
}
|
||||
};
|
||||
|
||||
// Retrieve the data type according to the OpDef definition: either the
|
||||
// "type" or "type_attr" field will be set.
|
||||
auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
|
||||
if (a.type() != tensorflow::DT_INVALID) {
|
||||
return a.type();
|
||||
} else if (HasAttr(node, a.type_attr())) {
|
||||
return GetDataTypeAttr(node, a.type_attr());
|
||||
} else {
|
||||
return tensorflow::DT_INVALID;
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
string multiples = op_def.output_arg(i).number_attr();
|
||||
if (!multiples.empty()) {
|
||||
CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
|
||||
int num_outputs = GetIntAttr(node, multiples);
|
||||
auto type = get_type(op_def.output_arg(i));
|
||||
for (int j = 0; j < num_outputs; ++j) {
|
||||
add_type(type);
|
||||
}
|
||||
} else {
|
||||
string list = op_def.output_arg(i).type_list_attr();
|
||||
if (!list.empty()) {
|
||||
CHECK(HasAttr(node, list)) << "No attr named " << list;
|
||||
const AttrValue::ListValue& list_value = GetListAttr(node, list);
|
||||
for (int j = 0; j < list_value.type_size(); ++j) {
|
||||
add_type(list_value.type(j));
|
||||
}
|
||||
} else {
|
||||
add_type(get_type(op_def.output_arg(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertUnsupportedOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
// Names of special attributes in TF graph that are used by Toco.
|
||||
static constexpr char kAttrOutputQuantized[] = "_output_quantized";
|
||||
static constexpr char kAttrOutputTypes[] = "_output_types";
|
||||
static constexpr char kAttrOutputShapes[] = "_output_shapes";
|
||||
static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
|
||||
"_support_output_type_float_in_quantized_op";
|
||||
|
||||
LOG(INFO) << "Converting unsupported operation: " << node.op();
|
||||
|
||||
auto* op = new TensorFlowUnsupportedOperator;
|
||||
op->tensorflow_op = node.op();
|
||||
|
||||
// For Flex mode. Please read the comments of the function.
|
||||
RetainTensorFlowNodeDef(node, op);
|
||||
|
||||
model->operators.emplace_back(op);
|
||||
|
||||
// Parse inputs.
|
||||
const int num_inputs = GetInputsCount(node, tf_import_flags);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
op->inputs.push_back(node.input(i));
|
||||
}
|
||||
|
||||
// Parse outputs. Name them after the node's name, plus an ordinal suffix.
|
||||
// Note that some outputs are to be multiplied by a named attribute.
|
||||
const tensorflow::OpDef* op_def = nullptr;
|
||||
if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
|
||||
GetOutputNamesFromNodeDef(node, *op_def, op);
|
||||
} else {
|
||||
op->outputs.push_back(node.name()); // Implicit :0.
|
||||
}
|
||||
|
||||
// Parse if the op supports quantization
|
||||
if (HasAttr(node, kAttrOutputQuantized)) {
|
||||
op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
|
||||
}
|
||||
// Parse if the quantized op allows output arrays of type float
|
||||
if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
|
||||
op->support_output_type_float_in_quantized_op =
|
||||
GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
|
||||
}
|
||||
|
||||
// Parse output type(s).
|
||||
if (HasAttr(node, kAttrOutputTypes)) {
|
||||
const auto& output_types = GetListAttr(node, kAttrOutputTypes);
|
||||
for (int i = 0; i < output_types.type_size(); ++i) {
|
||||
op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
|
||||
}
|
||||
} else if (HasAttr(node, "Tout")) {
|
||||
const auto& output_type = GetDataTypeAttr(node, "Tout");
|
||||
op->output_data_types.push_back(ConvertDataType(output_type));
|
||||
} else if (op_def != nullptr) {
|
||||
GetOutputTypesFromNodeDef(node, *op_def, op);
|
||||
} else {
|
||||
// TODO(b/113613439): Figure out how to propagate types for custom ops
|
||||
// that have no OpDef.
|
||||
LOG(INFO) << "Unable to determine output type for op: " << node.op();
|
||||
}
|
||||
|
||||
// Parse output shape(s).
|
||||
if (HasAttr(node, kAttrOutputShapes)) {
|
||||
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
|
||||
Shape output_shape;
|
||||
for (int i = 0; i < output_shapes.shape_size(); ++i) {
|
||||
const auto& shape = output_shapes.shape(i);
|
||||
// TOCO doesn't yet properly handle shapes with wildcard dimensions.
|
||||
// TODO(b/113613439): Handle shape inference for unsupported ops that have
|
||||
// shapes with wildcard dimensions.
|
||||
if (HasWildcardDimension(shape)) {
|
||||
LOG(INFO) << "Skipping wildcard output shape(s) for node: "
|
||||
<< node.name();
|
||||
op->output_shapes.clear();
|
||||
break;
|
||||
}
|
||||
const auto status =
|
||||
ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
op->output_shapes.push_back(output_shape);
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertConstOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
@ -839,7 +1011,15 @@ tensorflow::Status ConvertIdentityOperator(
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
|
||||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
|
||||
node.op() == "Snapshot");
|
||||
node.op() == "Snapshot" || node.op() == "IdentityN");
|
||||
|
||||
if (node.op() == "IdentityN" && node.input_size() != 1) {
|
||||
// When IdentityN doesn't have exactly 1 input, convert it as an unsupported
|
||||
// op so it's still possible to run with Flex runtime.
|
||||
return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
|
||||
model);
|
||||
}
|
||||
|
||||
auto* op = new TensorFlowIdentityOperator;
|
||||
// Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
|
||||
// identity nodes with multiple inputs, but the other inputs seem
|
||||
@ -1239,178 +1419,6 @@ tensorflow::Status ConvertSimpleOperatorFlexOk(
|
||||
node, tf_import_flags, model_flags, model);
|
||||
}
|
||||
|
||||
void GetOutputNamesFromNodeDef(const NodeDef& node,
|
||||
const tensorflow::OpDef& op_def,
|
||||
TensorFlowUnsupportedOperator* op) {
|
||||
int next_output = 0;
|
||||
auto add_output = [&node, &next_output, op]() {
|
||||
if (next_output == 0) {
|
||||
op->outputs.push_back(node.name()); // Implicit :0.
|
||||
} else {
|
||||
op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
|
||||
}
|
||||
++next_output;
|
||||
};
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
string multiples = op_def.output_arg(i).number_attr();
|
||||
if (!multiples.empty()) {
|
||||
CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
|
||||
int num_outputs = GetIntAttr(node, multiples);
|
||||
for (int j = 0; j < num_outputs; ++j) {
|
||||
add_output();
|
||||
}
|
||||
} else {
|
||||
string list = op_def.output_arg(i).type_list_attr();
|
||||
if (!list.empty()) {
|
||||
CHECK(HasAttr(node, list)) << "No attr named " << list;
|
||||
const AttrValue::ListValue& list_value = GetListAttr(node, list);
|
||||
for (int j = 0; j < list_value.type_size(); ++j) {
|
||||
add_output();
|
||||
}
|
||||
} else {
|
||||
add_output();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetOutputTypesFromNodeDef(const NodeDef& node,
|
||||
const tensorflow::OpDef& op_def,
|
||||
TensorFlowUnsupportedOperator* op) {
|
||||
// The given type to the op, or clear the types if invalid.
|
||||
auto add_type = [&node, op](tensorflow::DataType type) {
|
||||
if (type == tensorflow::DT_INVALID) {
|
||||
LOG(WARNING) << "Op node missing output type attribute: " << node.name();
|
||||
op->output_data_types.clear();
|
||||
} else {
|
||||
op->output_data_types.push_back(ConvertDataType(type));
|
||||
}
|
||||
};
|
||||
|
||||
// Retrieve the data type according to the OpDef definition: either the
|
||||
// "type" or "type_attr" field will be set.
|
||||
auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
|
||||
if (a.type() != tensorflow::DT_INVALID) {
|
||||
return a.type();
|
||||
} else if (HasAttr(node, a.type_attr())) {
|
||||
return GetDataTypeAttr(node, a.type_attr());
|
||||
} else {
|
||||
return tensorflow::DT_INVALID;
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
string multiples = op_def.output_arg(i).number_attr();
|
||||
if (!multiples.empty()) {
|
||||
CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
|
||||
int num_outputs = GetIntAttr(node, multiples);
|
||||
auto type = get_type(op_def.output_arg(i));
|
||||
for (int j = 0; j < num_outputs; ++j) {
|
||||
add_type(type);
|
||||
}
|
||||
} else {
|
||||
string list = op_def.output_arg(i).type_list_attr();
|
||||
if (!list.empty()) {
|
||||
CHECK(HasAttr(node, list)) << "No attr named " << list;
|
||||
const AttrValue::ListValue& list_value = GetListAttr(node, list);
|
||||
for (int j = 0; j < list_value.type_size(); ++j) {
|
||||
add_type(list_value.type(j));
|
||||
}
|
||||
} else {
|
||||
add_type(get_type(op_def.output_arg(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertUnsupportedOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
const ModelFlags& model_flags, Model* model) {
|
||||
// Names of special attributes in TF graph that are used by Toco.
|
||||
static constexpr char kAttrOutputQuantized[] = "_output_quantized";
|
||||
static constexpr char kAttrOutputTypes[] = "_output_types";
|
||||
static constexpr char kAttrOutputShapes[] = "_output_shapes";
|
||||
static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
|
||||
"_support_output_type_float_in_quantized_op";
|
||||
|
||||
LOG(INFO) << "Converting unsupported operation: " << node.op();
|
||||
|
||||
auto* op = new TensorFlowUnsupportedOperator;
|
||||
op->tensorflow_op = node.op();
|
||||
|
||||
// For Flex mode. Please read the comments of the function.
|
||||
RetainTensorFlowNodeDef(node, op);
|
||||
|
||||
model->operators.emplace_back(op);
|
||||
|
||||
// Parse inputs.
|
||||
const int num_inputs = GetInputsCount(node, tf_import_flags);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
op->inputs.push_back(node.input(i));
|
||||
}
|
||||
|
||||
// Parse outputs. Name them after the node's name, plus an ordinal suffix.
|
||||
// Note that some outputs are to be multiplied by a named attribute.
|
||||
const tensorflow::OpDef* op_def = nullptr;
|
||||
if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
|
||||
GetOutputNamesFromNodeDef(node, *op_def, op);
|
||||
} else {
|
||||
op->outputs.push_back(node.name()); // Implicit :0.
|
||||
}
|
||||
|
||||
// Parse if the op supports quantization
|
||||
if (HasAttr(node, kAttrOutputQuantized)) {
|
||||
op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
|
||||
}
|
||||
// Parse if the quantized op allows output arrays of type float
|
||||
if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
|
||||
op->support_output_type_float_in_quantized_op =
|
||||
GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
|
||||
}
|
||||
|
||||
// Parse output type(s).
|
||||
if (HasAttr(node, kAttrOutputTypes)) {
|
||||
const auto& output_types = GetListAttr(node, kAttrOutputTypes);
|
||||
for (int i = 0; i < output_types.type_size(); ++i) {
|
||||
op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
|
||||
}
|
||||
} else if (HasAttr(node, "Tout")) {
|
||||
const auto& output_type = GetDataTypeAttr(node, "Tout");
|
||||
op->output_data_types.push_back(ConvertDataType(output_type));
|
||||
} else if (op_def != nullptr) {
|
||||
GetOutputTypesFromNodeDef(node, *op_def, op);
|
||||
} else {
|
||||
// TODO(b/113613439): Figure out how to propagate types for custom ops
|
||||
// that have no OpDef.
|
||||
LOG(INFO) << "Unable to determine output type for op: " << node.op();
|
||||
}
|
||||
|
||||
// Parse output shape(s).
|
||||
if (HasAttr(node, kAttrOutputShapes)) {
|
||||
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
|
||||
Shape output_shape;
|
||||
for (int i = 0; i < output_shapes.shape_size(); ++i) {
|
||||
const auto& shape = output_shapes.shape(i);
|
||||
// TOCO doesn't yet properly handle shapes with wildcard dimensions.
|
||||
// TODO(b/113613439): Handle shape inference for unsupported ops that have
|
||||
// shapes with wildcard dimensions.
|
||||
if (HasWildcardDimension(shape)) {
|
||||
LOG(INFO) << "Skipping wildcard output shape(s) for node: "
|
||||
<< node.name();
|
||||
op->output_shapes.clear();
|
||||
break;
|
||||
}
|
||||
const auto status =
|
||||
ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
op->output_shapes.push_back(output_shape);
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if
|
||||
// the types are not supported. Converting Const operators here avoids
|
||||
// expensive copies of the protocol buffers downstream in the flex delegate.
|
||||
@ -2504,6 +2512,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"GreaterEqual",
|
||||
ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
|
||||
{"Identity", ConvertIdentityOperator},
|
||||
{"IdentityN", ConvertIdentityOperator},
|
||||
{"LRN", ConvertLRNOperator},
|
||||
{"LeakyRelu", ConvertLeakyReluOperator},
|
||||
{"LegacyFedInput", ConvertPlaceholderOperator},
|
||||
|
Loading…
Reference in New Issue
Block a user