diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc index 323fec6cf86..3a7611a6683 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc @@ -41,6 +41,97 @@ void PrintModelStats(const string& label, const Model& model) { << " quantized)"; } +// Some graphs have RNN back-edges that are discardable, having been +// created typically by TensorFlow import rather than specified by the user. +// Such graphs might have cycles (closed by RNN back-edges) that may be pruned. +// Local graph transformations can't identify such global features, +// so this function performs this global transformation. +// +// The other (and related) thing that is peculiar about RNN back-edges +// is that they do not prevent the arrays that they touch, from being +// pruned. Thus, they may refer to array names which no longer exist. +// The intent is for that to result in the eventual pruning of such +// 'dangling' RNN back-edges. We perform this pruning at the end of this +// function, as the pruning of connected components done here may leave +// more RNN back-edges dangling. +void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) { + // Identify the set of arrays that are in 'useful' connected components + // of the graph, which means connected to output arrays. + std::unordered_set useful_arrays; + for (const string& output_array : model->flags.output_arrays()) { + useful_arrays.insert(output_array); + } + bool found_new_useful_arrays; + do { + found_new_useful_arrays = false; + for (const auto& op : model->operators) { + bool op_touches_useful_arrays = false; + for (const string& output : op->outputs) { + op_touches_useful_arrays |= useful_arrays.count(output); + } + if (op_touches_useful_arrays) { + for (const string& input : op->inputs) { + found_new_useful_arrays |= !useful_arrays.count(input); + useful_arrays.insert(input); + } + for (const string& output : op->outputs) { + found_new_useful_arrays |= !useful_arrays.count(output); + useful_arrays.insert(output); + } + } + } + for (const auto& rnn_state : model->flags.rnn_states()) { + bool rnn_back_edge_touches_useful_arrays = + useful_arrays.count(rnn_state.state_array()); + if (rnn_back_edge_touches_useful_arrays) { + found_new_useful_arrays |= + !useful_arrays.count(rnn_state.back_edge_source_array()); + useful_arrays.insert(rnn_state.back_edge_source_array()); + } + } + } while (found_new_useful_arrays); + // Erase arrays that aren't useful, and that are discardable. + for (auto it = model->arrays.begin(); it != model->arrays.end();) { + if (useful_arrays.count(it->first) || + !IsDiscardableArray(*model, it->first)) { + ++it; + } else { + it = model->arrays.erase(it); + } + } + // Erase operators that do not produce a useful output array. + for (auto it = model->operators.begin(); it != model->operators.end();) { + // Only need to test the first output, as we simultaneously added all of + // an operator's outputs to the list of output arrays. + if (useful_arrays.count((*it)->outputs[0])) { + ++it; + } else { + for (const string& output : (*it)->outputs) { + CHECK(!useful_arrays.count(output)); + } + it = model->operators.erase(it); + } + } + // Erase RNN back-edges that are 'dangling' i.e. that touch an array + // that no longer exists. This should only happen for discardable RNN + // back-edges. + std::vector rnn_states_to_keep; + for (const auto& rnn_state : model->flags.rnn_states()) { + const bool dangling = + !model->arrays.count(rnn_state.back_edge_source_array()) || + !model->arrays.count(rnn_state.state_array()); + if (dangling) { + CHECK(rnn_state.discardable()); + } else { + rnn_states_to_keep.push_back(rnn_state); + } + } + model->flags.clear_rnn_states(); + for (const auto& rnn_state : rnn_states_to_keep) { + *model->flags.add_rnn_states() = rnn_state; + } +} + bool GraphTransformationsPass(int increment, Model* model, const GraphTransformationsSet& transformations) { CHECK(increment == 1 || increment == -1); @@ -86,6 +177,7 @@ bool GraphTransformationsPass(int increment, Model* model, op_index += increment; } } + DiscardUselessConnectedComponentsAndRNNBackEdges(model); return changed; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc index b6037357047..23a5c857e8b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc @@ -57,7 +57,8 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { // Drop trivial inputs. for (const string& input : trivial_inputs) { - if (CountOpsWithInput(*model, input) == 1) { + if (IsDiscardableArray(*model, input) && + CountOpsWithInput(*model, input) == 1) { model->arrays.erase(input); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc index 0ab301552ff..674a46815b1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -65,7 +65,12 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { } for (const auto& rnn_state : model->flags.rnn_states()) { if (output == rnn_state.back_edge_source_array()) { - return false; + // The output is consumed by a RNN back-edge.. + if (!IsDiscardableArray(*model, rnn_state.back_edge_source_array()) || + !IsDiscardableArray(*model, rnn_state.state_array()) || + CountOpsWithInput(*model, rnn_state.state_array())) { + return false; + } } } if (CountOpsWithInput(*model, output)) { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index cde5a936afd..9f72f9a1d35 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1380,6 +1380,31 @@ void ConvertSvdfOperator(const NodeDef& node, model->operators.emplace_back(op); } +// Some TensorFlow ops only occur in graph cycles, representing +// control flow. We do not currently support control flow, so we wouldn't +// be able to fully support such graphs, including performing inference, +// anyway. However, rather than erroring out early on graphs being cyclic, +// it helps to at least support these just enough to allow getting a +// graph visualization. This is not trivial, as we require graphs to be +// acyclic aside from RNN back-edges. The solution is to special-case +// such ops as RNN back-edges, which is technically incorrect (does not +// allow representing the op's semantics) but good enough to get a +// graph visualization. +void ConvertOperatorSpecialCasedAsRNNBackEdge( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + // At the moment, the only type of operator special-cased in this way is + // NextIteration, occuring only in control-flow cycles. + CHECK_EQ(node.op(), "NextIteration"); + CHECK_EQ(node.input_size(), 1); + auto* rnn_state = model->flags.add_rnn_states(); + // This RNN state is not explicitly created by the user, so it's + // OK for some later graph transformation to discard it. + rnn_state->set_discardable(true); + rnn_state->set_state_array(node.name()); + rnn_state->set_back_edge_source_array(node.input(0)); +} + void StripCaretFromArrayNames(Model* model) { for (auto& op : model->operators) { for (auto& input : op->inputs) { @@ -1402,26 +1427,61 @@ void StripZeroOutputIndexFromInputs(NodeDef* node) { } } -void AddExtraOutputsFedIntoOtherOps(Model* model) { +// In TensorFlow GraphDef, when a node has multiple outputs, they are named +// name:0, name:1, ... +// where 'name' is the node's name(). Just 'name' is an equivalent shorthand +// form for name:0. +// A TensorFlow GraphDef does not explicitly list all the outputs of each node +// (unlike inputs), it being implied by the node's name and operator type +// (the latter implies the number of outputs). +// This makes it non-trivial for us to reconstruct the list of all arrays +// present in the graph and, for each operator, the list of its outputs. +// We do that by taking advantage of the fact that +// at least each node lists explicitly its inputs, so after we've loaded +// all nodes, we can use that information. +void AddExtraOutputs(Model* model) { + // Construct the list of all arrays consumed by anything in the graph. + std::vector consumed_arrays; + // Add arrays consumed by an op. for (const auto& consumer_op : model->operators) { for (const string& input : consumer_op->inputs) { - const std::vector& split = absl::StrSplit(input, ':'); - if (split.size() != 2) { - continue; - } - int output_index = 0; - if (!absl::SimpleAtoi(split[1], &output_index)) { - continue; - } - auto* producer_op = GetOpWithOutput(*model, split[0]); - if (!producer_op) { - continue; - } - while (producer_op->outputs.size() <= output_index) { - using toco::port::StringF; - producer_op->outputs.push_back( - StringF("%s:%d", split[0], producer_op->outputs.size())); - } + consumed_arrays.push_back(input); + } + } + // Add global outputs of the model. + for (const string& output_array : model->flags.output_arrays()) { + consumed_arrays.push_back(output_array); + } + // Add arrays consumed by a RNN back-edge. + for (const auto& rnn_state : model->flags.rnn_states()) { + consumed_arrays.push_back(rnn_state.back_edge_source_array()); + } + // Now add operator outputs so that all arrays that are consumed, + // are produced. + for (const string& consumed_array : consumed_arrays) { + // Split the consumed array name into the form name:output_index. + const std::vector& split = absl::StrSplit(consumed_array, ':'); + // If not of the form name:output_index, then this is not an additional + // output of a node with multiple outputs, so nothing to do here. + if (split.size() != 2) { + continue; + } + int output_index = 0; + if (!absl::SimpleAtoi(split[1], &output_index)) { + continue; + } + // Each op is initially recorded as producing at least the array that + // has its name. We use that to identify the producer node. + auto* producer_op = GetOpWithOutput(*model, split[0]); + if (!producer_op) { + continue; + } + // Add extra outputs to that producer node, all the way to the + // output_index. + while (producer_op->outputs.size() <= output_index) { + using toco::port::StringF; + producer_op->outputs.push_back( + StringF("%s:%d", split[0], producer_op->outputs.size())); } } } @@ -1633,6 +1693,8 @@ std::unique_ptr ImportTensorFlowGraphDef( ConvertMeanOperator(node, tf_import_flags, model); } else if (node.op() == "Svdf") { ConvertSvdfOperator(node, tf_import_flags, model); + } else if (node.op() == "NextIteration") { + ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } @@ -1641,7 +1703,7 @@ std::unique_ptr ImportTensorFlowGraphDef( ResolveModelFlags(model_flags, model); StripCaretFromArrayNames(model); - AddExtraOutputsFedIntoOtherOps(model); + AddExtraOutputs(model); FixNoMissingArray(model); FixNoOrphanedArray(model); FixOperatorOrdering(model); diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto index d818a3632d1..05c48bc369b 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -77,6 +77,25 @@ message InputArray { optional IODataType data_type = 5; } +message RnnState { + optional string state_array = 1; + optional string back_edge_source_array = 2; + optional bool discardable = 5; + // TODO(benoitjacob): drop the 'size' field. Should be redundant with + // --input_shapes and shapes propagation. + optional int32 size = 3; + // TODO(benoitjacob): manually_create is a temporary hack: + // due to discrepancies between the current toco dims tracking and + // TensorFlow shapes, for some models we need to manually create RNN state + // arrays with a specified shape. + // Maybe we should actually implement back-edges as operators of their own, + // which would remove the need for much special-casing, including here, + // we could probably consistently let PropagateFixedSizes handle state + // arrays. + // TODO(benoitjacob): should really drop manually_create now. + optional bool manually_create = 4; +} + // ModelFlags encodes properties of a model that, depending on the file // format, may or may not be recorded in the model file. The purpose of // representing these properties in ModelFlags is to allow passing them @@ -112,20 +131,6 @@ message ModelFlags { // the 'batch' field: at most one of these two fields can be set. optional bool variable_batch = 10; - message RnnState { - optional string state_array = 1; - optional string back_edge_source_array = 2; - optional int32 size = 3; - // TODO(benoitjacob): manually_create is a temporary hack: - // due to discrepancies between the current toco dims tracking and - // TensorFlow shapes, for some models we need to manually create RNN state - // arrays with a specified shape. - // Maybe we should actually implement back-edges as operators of their own, - // which would remove the need for much special-casing, including here, - // we could probably consistently let PropagateFixedSizes handle state - // arrays. - optional bool manually_create = 4; - } repeated RnnState rnn_states = 12; // Checks applied to the model, typically after toco's comprehensive diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 637287a9472..078afe79d01 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -573,8 +573,10 @@ void CheckNoMissingArray(const Model& model) { << "Output array not found: " << output_array; } for (const auto& rnn_state : model.flags.rnn_states()) { - CHECK(model.arrays.count(rnn_state.state_array())); - CHECK(model.arrays.count(rnn_state.back_edge_source_array())); + if (!rnn_state.discardable()) { + CHECK(model.arrays.count(rnn_state.state_array())); + CHECK(model.arrays.count(rnn_state.back_edge_source_array())); + } } } @@ -596,12 +598,18 @@ void FixNoMissingArray(Model* model) { model->GetOrCreateArray(output_array); } } + for (const auto& rnn_state : model->flags.rnn_states()) { + model->GetOrCreateArray(rnn_state.state_array()); + model->GetOrCreateArray(rnn_state.back_edge_source_array()); + } } void CheckNoOrphanedArray(const Model& model) { std::unordered_set arrays_without_known_use; for (const auto& array : model.arrays) { - arrays_without_known_use.insert(array.first); + if (IsDiscardableArray(model, array.first)) { + arrays_without_known_use.insert(array.first); + } } for (const auto& op : model.operators) { for (const auto& input : op->inputs) { @@ -611,6 +619,10 @@ void CheckNoOrphanedArray(const Model& model) { arrays_without_known_use.erase(output); } } + for (const auto& rnn_state : model.flags.rnn_states()) { + arrays_without_known_use.erase(rnn_state.state_array()); + arrays_without_known_use.erase(rnn_state.back_edge_source_array()); + } if (!arrays_without_known_use.empty()) { for (const auto& array : arrays_without_known_use) { LOG(INFO) << "Error: Orphaned array: " << array; @@ -632,8 +644,14 @@ void FixNoOrphanedArray(Model* model) { arrays_without_known_use.erase(output); } } + for (const auto& rnn_state : model->flags.rnn_states()) { + arrays_without_known_use.erase(rnn_state.state_array()); + arrays_without_known_use.erase(rnn_state.back_edge_source_array()); + } for (const auto& array : arrays_without_known_use) { - model->arrays.erase(array); + if (IsDiscardableArray(*model, array)) { + model->arrays.erase(array); + } } } @@ -1042,16 +1060,8 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { #undef RESOLVE_MODEL_FLAG - if (model->flags.rnn_states_size() == 0) { + if (!model_flags.rnn_states().empty()) { model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states()); - } else { - CHECK_EQ(model->flags.rnn_states_size(), model_flags.rnn_states_size()); - for (int i = 0; i < model->flags.rnn_states_size(); i++) { - CHECK_EQ(model->flags.rnn_states(i).state_array(), - model_flags.rnn_states(i).state_array()); - CHECK_EQ(model->flags.rnn_states(i).back_edge_source_array(), - model_flags.rnn_states(i).back_edge_source_array()); - } } if (model->flags.model_checks_size() == 0) { @@ -1571,11 +1581,13 @@ bool IsDiscardableArray(const Model& model, const string& array_name) { } } for (const auto& rnn_state : model.flags.rnn_states()) { - if (array_name == rnn_state.state_array()) { - return false; - } - if (array_name == rnn_state.back_edge_source_array()) { - return false; + if (!rnn_state.discardable()) { + if (array_name == rnn_state.state_array()) { + return false; + } + if (array_name == rnn_state.back_edge_source_array()) { + return false; + } } } return true;