Handle IdentityN with 1 input in Toco

PiperOrigin-RevId: 259072753
This commit is contained in:
Yu-Cheng Ling 2019-07-19 18:13:25 -07:00 committed by TensorFlower Gardener
parent ba1654087a
commit fdc106e412
2 changed files with 189 additions and 177 deletions

View File

@ -870,7 +870,7 @@ def make_identity_tests(options):
# Chose a set of parameters
test_parameters = [{
"input_shape": [[], [1], [3, 3]],
"use_snapshot": [False, True],
"op_to_use": ["identity", "identity_n", "snapshot"],
}]
def build_graph(parameters):
@ -884,10 +884,13 @@ def make_identity_tests(options):
# shape, this conversion still fails.
# TODO(b/129197312), remove the walk-around code once the bug is fixed.
input_doubled = input_tensor * 2.0
if parameters["use_snapshot"]:
identity_output = array_ops.snapshot(input_doubled)
else:
if parameters["op_to_use"] == "identity":
identity_output = tf.identity(input_doubled)
elif parameters["op_to_use"] == "identity_n":
# Testing `IdentityN` with a single tensor.
identity_output = tf.identity_n([input_doubled])[0]
elif parameters["op_to_use"] == "snapshot":
identity_output = array_ops.snapshot(input_doubled)
return [input_tensor], [identity_output]
def build_inputs(parameters, sess, inputs, outputs):

View File

@ -562,6 +562,178 @@ void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
node.SerializeToString(&op->tensorflow_node_def);
}
void GetOutputNamesFromNodeDef(const NodeDef& node,
const tensorflow::OpDef& op_def,
TensorFlowUnsupportedOperator* op) {
int next_output = 0;
auto add_output = [&node, &next_output, op]() {
if (next_output == 0) {
op->outputs.push_back(node.name()); // Implicit :0.
} else {
op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
}
++next_output;
};
for (int i = 0; i < op_def.output_arg_size(); ++i) {
string multiples = op_def.output_arg(i).number_attr();
if (!multiples.empty()) {
CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
int num_outputs = GetIntAttr(node, multiples);
for (int j = 0; j < num_outputs; ++j) {
add_output();
}
} else {
string list = op_def.output_arg(i).type_list_attr();
if (!list.empty()) {
CHECK(HasAttr(node, list)) << "No attr named " << list;
const AttrValue::ListValue& list_value = GetListAttr(node, list);
for (int j = 0; j < list_value.type_size(); ++j) {
add_output();
}
} else {
add_output();
}
}
}
}
void GetOutputTypesFromNodeDef(const NodeDef& node,
const tensorflow::OpDef& op_def,
TensorFlowUnsupportedOperator* op) {
// The given type to the op, or clear the types if invalid.
auto add_type = [&node, op](tensorflow::DataType type) {
if (type == tensorflow::DT_INVALID) {
LOG(WARNING) << "Op node missing output type attribute: " << node.name();
op->output_data_types.clear();
} else {
op->output_data_types.push_back(ConvertDataType(type));
}
};
// Retrieve the data type according to the OpDef definition: either the
// "type" or "type_attr" field will be set.
auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
if (a.type() != tensorflow::DT_INVALID) {
return a.type();
} else if (HasAttr(node, a.type_attr())) {
return GetDataTypeAttr(node, a.type_attr());
} else {
return tensorflow::DT_INVALID;
}
};
for (int i = 0; i < op_def.output_arg_size(); ++i) {
string multiples = op_def.output_arg(i).number_attr();
if (!multiples.empty()) {
CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
int num_outputs = GetIntAttr(node, multiples);
auto type = get_type(op_def.output_arg(i));
for (int j = 0; j < num_outputs; ++j) {
add_type(type);
}
} else {
string list = op_def.output_arg(i).type_list_attr();
if (!list.empty()) {
CHECK(HasAttr(node, list)) << "No attr named " << list;
const AttrValue::ListValue& list_value = GetListAttr(node, list);
for (int j = 0; j < list_value.type_size(); ++j) {
add_type(list_value.type(j));
}
} else {
add_type(get_type(op_def.output_arg(i)));
}
}
}
}
tensorflow::Status ConvertUnsupportedOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
const ModelFlags& model_flags, Model* model) {
// Names of special attributes in TF graph that are used by Toco.
static constexpr char kAttrOutputQuantized[] = "_output_quantized";
static constexpr char kAttrOutputTypes[] = "_output_types";
static constexpr char kAttrOutputShapes[] = "_output_shapes";
static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
"_support_output_type_float_in_quantized_op";
LOG(INFO) << "Converting unsupported operation: " << node.op();
auto* op = new TensorFlowUnsupportedOperator;
op->tensorflow_op = node.op();
// For Flex mode. Please read the comments of the function.
RetainTensorFlowNodeDef(node, op);
model->operators.emplace_back(op);
// Parse inputs.
const int num_inputs = GetInputsCount(node, tf_import_flags);
for (int i = 0; i < num_inputs; ++i) {
op->inputs.push_back(node.input(i));
}
// Parse outputs. Name them after the node's name, plus an ordinal suffix.
// Note that some outputs are to be multiplied by a named attribute.
const tensorflow::OpDef* op_def = nullptr;
if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
GetOutputNamesFromNodeDef(node, *op_def, op);
} else {
op->outputs.push_back(node.name()); // Implicit :0.
}
// Parse if the op supports quantization
if (HasAttr(node, kAttrOutputQuantized)) {
op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
}
// Parse if the quantized op allows output arrays of type float
if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
op->support_output_type_float_in_quantized_op =
GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
}
// Parse output type(s).
if (HasAttr(node, kAttrOutputTypes)) {
const auto& output_types = GetListAttr(node, kAttrOutputTypes);
for (int i = 0; i < output_types.type_size(); ++i) {
op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
}
} else if (HasAttr(node, "Tout")) {
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
} else if (op_def != nullptr) {
GetOutputTypesFromNodeDef(node, *op_def, op);
} else {
// TODO(b/113613439): Figure out how to propagate types for custom ops
// that have no OpDef.
LOG(INFO) << "Unable to determine output type for op: " << node.op();
}
// Parse output shape(s).
if (HasAttr(node, kAttrOutputShapes)) {
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
Shape output_shape;
for (int i = 0; i < output_shapes.shape_size(); ++i) {
const auto& shape = output_shapes.shape(i);
// TOCO doesn't yet properly handle shapes with wildcard dimensions.
// TODO(b/113613439): Handle shape inference for unsupported ops that have
// shapes with wildcard dimensions.
if (HasWildcardDimension(shape)) {
LOG(INFO) << "Skipping wildcard output shape(s) for node: "
<< node.name();
op->output_shapes.clear();
break;
}
const auto status =
ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
if (!status.ok()) {
return status;
}
op->output_shapes.push_back(output_shape);
}
}
return tensorflow::Status::OK();
}
tensorflow::Status ConvertConstOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
const ModelFlags& model_flags, Model* model) {
@ -839,7 +1011,15 @@ tensorflow::Status ConvertIdentityOperator(
const ModelFlags& model_flags, Model* model) {
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
node.op() == "Snapshot");
node.op() == "Snapshot" || node.op() == "IdentityN");
if (node.op() == "IdentityN" && node.input_size() != 1) {
// When IdentityN doesn't have exactly 1 input, convert it as an unsupported
// op so it's still possible to run with Flex runtime.
return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
model);
}
auto* op = new TensorFlowIdentityOperator;
// Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
// identity nodes with multiple inputs, but the other inputs seem
@ -1239,178 +1419,6 @@ tensorflow::Status ConvertSimpleOperatorFlexOk(
node, tf_import_flags, model_flags, model);
}
void GetOutputNamesFromNodeDef(const NodeDef& node,
const tensorflow::OpDef& op_def,
TensorFlowUnsupportedOperator* op) {
int next_output = 0;
auto add_output = [&node, &next_output, op]() {
if (next_output == 0) {
op->outputs.push_back(node.name()); // Implicit :0.
} else {
op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
}
++next_output;
};
for (int i = 0; i < op_def.output_arg_size(); ++i) {
string multiples = op_def.output_arg(i).number_attr();
if (!multiples.empty()) {
CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
int num_outputs = GetIntAttr(node, multiples);
for (int j = 0; j < num_outputs; ++j) {
add_output();
}
} else {
string list = op_def.output_arg(i).type_list_attr();
if (!list.empty()) {
CHECK(HasAttr(node, list)) << "No attr named " << list;
const AttrValue::ListValue& list_value = GetListAttr(node, list);
for (int j = 0; j < list_value.type_size(); ++j) {
add_output();
}
} else {
add_output();
}
}
}
}
void GetOutputTypesFromNodeDef(const NodeDef& node,
const tensorflow::OpDef& op_def,
TensorFlowUnsupportedOperator* op) {
// The given type to the op, or clear the types if invalid.
auto add_type = [&node, op](tensorflow::DataType type) {
if (type == tensorflow::DT_INVALID) {
LOG(WARNING) << "Op node missing output type attribute: " << node.name();
op->output_data_types.clear();
} else {
op->output_data_types.push_back(ConvertDataType(type));
}
};
// Retrieve the data type according to the OpDef definition: either the
// "type" or "type_attr" field will be set.
auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
if (a.type() != tensorflow::DT_INVALID) {
return a.type();
} else if (HasAttr(node, a.type_attr())) {
return GetDataTypeAttr(node, a.type_attr());
} else {
return tensorflow::DT_INVALID;
}
};
for (int i = 0; i < op_def.output_arg_size(); ++i) {
string multiples = op_def.output_arg(i).number_attr();
if (!multiples.empty()) {
CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
int num_outputs = GetIntAttr(node, multiples);
auto type = get_type(op_def.output_arg(i));
for (int j = 0; j < num_outputs; ++j) {
add_type(type);
}
} else {
string list = op_def.output_arg(i).type_list_attr();
if (!list.empty()) {
CHECK(HasAttr(node, list)) << "No attr named " << list;
const AttrValue::ListValue& list_value = GetListAttr(node, list);
for (int j = 0; j < list_value.type_size(); ++j) {
add_type(list_value.type(j));
}
} else {
add_type(get_type(op_def.output_arg(i)));
}
}
}
}
tensorflow::Status ConvertUnsupportedOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
const ModelFlags& model_flags, Model* model) {
// Names of special attributes in TF graph that are used by Toco.
static constexpr char kAttrOutputQuantized[] = "_output_quantized";
static constexpr char kAttrOutputTypes[] = "_output_types";
static constexpr char kAttrOutputShapes[] = "_output_shapes";
static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
"_support_output_type_float_in_quantized_op";
LOG(INFO) << "Converting unsupported operation: " << node.op();
auto* op = new TensorFlowUnsupportedOperator;
op->tensorflow_op = node.op();
// For Flex mode. Please read the comments of the function.
RetainTensorFlowNodeDef(node, op);
model->operators.emplace_back(op);
// Parse inputs.
const int num_inputs = GetInputsCount(node, tf_import_flags);
for (int i = 0; i < num_inputs; ++i) {
op->inputs.push_back(node.input(i));
}
// Parse outputs. Name them after the node's name, plus an ordinal suffix.
// Note that some outputs are to be multiplied by a named attribute.
const tensorflow::OpDef* op_def = nullptr;
if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
GetOutputNamesFromNodeDef(node, *op_def, op);
} else {
op->outputs.push_back(node.name()); // Implicit :0.
}
// Parse if the op supports quantization
if (HasAttr(node, kAttrOutputQuantized)) {
op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
}
// Parse if the quantized op allows output arrays of type float
if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
op->support_output_type_float_in_quantized_op =
GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
}
// Parse output type(s).
if (HasAttr(node, kAttrOutputTypes)) {
const auto& output_types = GetListAttr(node, kAttrOutputTypes);
for (int i = 0; i < output_types.type_size(); ++i) {
op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
}
} else if (HasAttr(node, "Tout")) {
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
} else if (op_def != nullptr) {
GetOutputTypesFromNodeDef(node, *op_def, op);
} else {
// TODO(b/113613439): Figure out how to propagate types for custom ops
// that have no OpDef.
LOG(INFO) << "Unable to determine output type for op: " << node.op();
}
// Parse output shape(s).
if (HasAttr(node, kAttrOutputShapes)) {
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
Shape output_shape;
for (int i = 0; i < output_shapes.shape_size(); ++i) {
const auto& shape = output_shapes.shape(i);
// TOCO doesn't yet properly handle shapes with wildcard dimensions.
// TODO(b/113613439): Handle shape inference for unsupported ops that have
// shapes with wildcard dimensions.
if (HasWildcardDimension(shape)) {
LOG(INFO) << "Skipping wildcard output shape(s) for node: "
<< node.name();
op->output_shapes.clear();
break;
}
const auto status =
ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
if (!status.ok()) {
return status;
}
op->output_shapes.push_back(output_shape);
}
}
return tensorflow::Status::OK();
}
// Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if
// the types are not supported. Converting Const operators here avoids
// expensive copies of the protocol buffers downstream in the flex delegate.
@ -2504,6 +2512,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"GreaterEqual",
ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
{"Identity", ConvertIdentityOperator},
{"IdentityN", ConvertIdentityOperator},
{"LRN", ConvertLRNOperator},
{"LeakyRelu", ConvertLeakyReluOperator},
{"LegacyFedInput", ConvertPlaceholderOperator},