Allow partial handling of some cyclic graphs, at least to
get graph visualizations. TensorFlow graphs may be cyclic, generally corresponding to the encoding of control flow (e.g. TensorFlow while loops) in the graph. TensorFlow Lite does not currently support such control flow, and toco currently relies extensively on the assumption that graphs are acyclic. The case of RNNs is handled by special-casing some edges as 'RNN back-edges' so as to keep the graph itself (i.e. without these back-edges) acyclic. This assumption is guarded by CheckInvariants, so cyclic graphs result in early failures. The issue with that is that at the moment, given a cyclic TensorFlow graph, toco is not even useful to get a mere graph-visualization of it. Leaving alone actually supporting control flow, it would be nice to at least support getting a graphviz. Indeed, a good graphviz is often the first step toward reasoning about what a graph really is doing and how that could be modified to avoid involving control flow. This change achieves that as follows. In import_tensorflow.cc, NextIteration nodes are special-cased: instead of being imported as (unsupported) operators, they are imported as RNN back-edges. As NextIteration nodes are characteristic of control flow graphs which we do not support anyway, special-casing them is acceptable. This alone results in imported graphs that are no longer cyclic, the cycles being only closed by RNN back edges (maintained separately from the graph). So that alone already removes the CheckInvariants failures. However, another problem appears at this point: the resulting graph visualizations are too large, as the graphs are not correctly pruned. This is because the cycles (involving RNN back-edges) keep themselves alive from the point of view of graph-pruning transformations (remove_unused_op). Our graph transformations, which are local, cannot see that sometimes a whole connected component of the graph is disconnected from --input_arrays and --output_arrays, thus should be dropped. That can only be done by a global tranformation. So we add such a global transformation, running once at the end of each graph-transformations pass (i.e. infrequently): DiscardUselessConnectedComponents This, however, raises another question. Discarding unused cycles involving RNN back-edges implies, in particular, discarding RNN back-edges. So far, RNN back-edges were always explicitily specified by the user on the command-line, and we never discard things explicitly specified by the user (that would in particular make TF->TF transformations not idempotent). What changes here is that now a RNN back-edge needs not be explicitly specified by the user anymore, it may instead be internally constructed by import_tensorflow encountering a NextIteration node. So we need to distinguish between these two cases. We add a 'discardable' bool flag on RNN back-edges. It is important to allow discarding an array if it only occurs as a vertex touching a discardable RNN back-edge, otherwise graph pruning stops prematurely. That implies that RNN back-edges may be dangling, i.e. may point to an array name that doesn't actually exist (anymore). PiperOrigin-RevId: 178296359
This commit is contained in:
parent
2ea11416c9
commit
daab44f7af
@ -41,6 +41,97 @@ void PrintModelStats(const string& label, const Model& model) {
|
|||||||
<< " quantized)";
|
<< " quantized)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Some graphs have RNN back-edges that are discardable, having been
|
||||||
|
// created typically by TensorFlow import rather than specified by the user.
|
||||||
|
// Such graphs might have cycles (closed by RNN back-edges) that may be pruned.
|
||||||
|
// Local graph transformations can't identify such global features,
|
||||||
|
// so this function performs this global transformation.
|
||||||
|
//
|
||||||
|
// The other (and related) thing that is peculiar about RNN back-edges
|
||||||
|
// is that they do not prevent the arrays that they touch, from being
|
||||||
|
// pruned. Thus, they may refer to array names which no longer exist.
|
||||||
|
// The intent is for that to result in the eventual pruning of such
|
||||||
|
// 'dangling' RNN back-edges. We perform this pruning at the end of this
|
||||||
|
// function, as the pruning of connected components done here may leave
|
||||||
|
// more RNN back-edges dangling.
|
||||||
|
void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
|
||||||
|
// Identify the set of arrays that are in 'useful' connected components
|
||||||
|
// of the graph, which means connected to output arrays.
|
||||||
|
std::unordered_set<string> useful_arrays;
|
||||||
|
for (const string& output_array : model->flags.output_arrays()) {
|
||||||
|
useful_arrays.insert(output_array);
|
||||||
|
}
|
||||||
|
bool found_new_useful_arrays;
|
||||||
|
do {
|
||||||
|
found_new_useful_arrays = false;
|
||||||
|
for (const auto& op : model->operators) {
|
||||||
|
bool op_touches_useful_arrays = false;
|
||||||
|
for (const string& output : op->outputs) {
|
||||||
|
op_touches_useful_arrays |= useful_arrays.count(output);
|
||||||
|
}
|
||||||
|
if (op_touches_useful_arrays) {
|
||||||
|
for (const string& input : op->inputs) {
|
||||||
|
found_new_useful_arrays |= !useful_arrays.count(input);
|
||||||
|
useful_arrays.insert(input);
|
||||||
|
}
|
||||||
|
for (const string& output : op->outputs) {
|
||||||
|
found_new_useful_arrays |= !useful_arrays.count(output);
|
||||||
|
useful_arrays.insert(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||||
|
bool rnn_back_edge_touches_useful_arrays =
|
||||||
|
useful_arrays.count(rnn_state.state_array());
|
||||||
|
if (rnn_back_edge_touches_useful_arrays) {
|
||||||
|
found_new_useful_arrays |=
|
||||||
|
!useful_arrays.count(rnn_state.back_edge_source_array());
|
||||||
|
useful_arrays.insert(rnn_state.back_edge_source_array());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (found_new_useful_arrays);
|
||||||
|
// Erase arrays that aren't useful, and that are discardable.
|
||||||
|
for (auto it = model->arrays.begin(); it != model->arrays.end();) {
|
||||||
|
if (useful_arrays.count(it->first) ||
|
||||||
|
!IsDiscardableArray(*model, it->first)) {
|
||||||
|
++it;
|
||||||
|
} else {
|
||||||
|
it = model->arrays.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Erase operators that do not produce a useful output array.
|
||||||
|
for (auto it = model->operators.begin(); it != model->operators.end();) {
|
||||||
|
// Only need to test the first output, as we simultaneously added all of
|
||||||
|
// an operator's outputs to the list of output arrays.
|
||||||
|
if (useful_arrays.count((*it)->outputs[0])) {
|
||||||
|
++it;
|
||||||
|
} else {
|
||||||
|
for (const string& output : (*it)->outputs) {
|
||||||
|
CHECK(!useful_arrays.count(output));
|
||||||
|
}
|
||||||
|
it = model->operators.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Erase RNN back-edges that are 'dangling' i.e. that touch an array
|
||||||
|
// that no longer exists. This should only happen for discardable RNN
|
||||||
|
// back-edges.
|
||||||
|
std::vector<RnnState> rnn_states_to_keep;
|
||||||
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||||
|
const bool dangling =
|
||||||
|
!model->arrays.count(rnn_state.back_edge_source_array()) ||
|
||||||
|
!model->arrays.count(rnn_state.state_array());
|
||||||
|
if (dangling) {
|
||||||
|
CHECK(rnn_state.discardable());
|
||||||
|
} else {
|
||||||
|
rnn_states_to_keep.push_back(rnn_state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model->flags.clear_rnn_states();
|
||||||
|
for (const auto& rnn_state : rnn_states_to_keep) {
|
||||||
|
*model->flags.add_rnn_states() = rnn_state;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool GraphTransformationsPass(int increment, Model* model,
|
bool GraphTransformationsPass(int increment, Model* model,
|
||||||
const GraphTransformationsSet& transformations) {
|
const GraphTransformationsSet& transformations) {
|
||||||
CHECK(increment == 1 || increment == -1);
|
CHECK(increment == 1 || increment == -1);
|
||||||
@ -86,6 +177,7 @@ bool GraphTransformationsPass(int increment, Model* model,
|
|||||||
op_index += increment;
|
op_index += increment;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
DiscardUselessConnectedComponentsAndRNNBackEdges(model);
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,7 +57,8 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
|
|||||||
|
|
||||||
// Drop trivial inputs.
|
// Drop trivial inputs.
|
||||||
for (const string& input : trivial_inputs) {
|
for (const string& input : trivial_inputs) {
|
||||||
if (CountOpsWithInput(*model, input) == 1) {
|
if (IsDiscardableArray(*model, input) &&
|
||||||
|
CountOpsWithInput(*model, input) == 1) {
|
||||||
model->arrays.erase(input);
|
model->arrays.erase(input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -65,9 +65,14 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
|
|||||||
}
|
}
|
||||||
for (const auto& rnn_state : model->flags.rnn_states()) {
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||||
if (output == rnn_state.back_edge_source_array()) {
|
if (output == rnn_state.back_edge_source_array()) {
|
||||||
|
// The output is consumed by a RNN back-edge..
|
||||||
|
if (!IsDiscardableArray(*model, rnn_state.back_edge_source_array()) ||
|
||||||
|
!IsDiscardableArray(*model, rnn_state.state_array()) ||
|
||||||
|
CountOpsWithInput(*model, rnn_state.state_array())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if (CountOpsWithInput(*model, output)) {
|
if (CountOpsWithInput(*model, output)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -1380,6 +1380,31 @@ void ConvertSvdfOperator(const NodeDef& node,
|
|||||||
model->operators.emplace_back(op);
|
model->operators.emplace_back(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Some TensorFlow ops only occur in graph cycles, representing
|
||||||
|
// control flow. We do not currently support control flow, so we wouldn't
|
||||||
|
// be able to fully support such graphs, including performing inference,
|
||||||
|
// anyway. However, rather than erroring out early on graphs being cyclic,
|
||||||
|
// it helps to at least support these just enough to allow getting a
|
||||||
|
// graph visualization. This is not trivial, as we require graphs to be
|
||||||
|
// acyclic aside from RNN back-edges. The solution is to special-case
|
||||||
|
// such ops as RNN back-edges, which is technically incorrect (does not
|
||||||
|
// allow representing the op's semantics) but good enough to get a
|
||||||
|
// graph visualization.
|
||||||
|
void ConvertOperatorSpecialCasedAsRNNBackEdge(
|
||||||
|
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||||
|
Model* model) {
|
||||||
|
// At the moment, the only type of operator special-cased in this way is
|
||||||
|
// NextIteration, occuring only in control-flow cycles.
|
||||||
|
CHECK_EQ(node.op(), "NextIteration");
|
||||||
|
CHECK_EQ(node.input_size(), 1);
|
||||||
|
auto* rnn_state = model->flags.add_rnn_states();
|
||||||
|
// This RNN state is not explicitly created by the user, so it's
|
||||||
|
// OK for some later graph transformation to discard it.
|
||||||
|
rnn_state->set_discardable(true);
|
||||||
|
rnn_state->set_state_array(node.name());
|
||||||
|
rnn_state->set_back_edge_source_array(node.input(0));
|
||||||
|
}
|
||||||
|
|
||||||
void StripCaretFromArrayNames(Model* model) {
|
void StripCaretFromArrayNames(Model* model) {
|
||||||
for (auto& op : model->operators) {
|
for (auto& op : model->operators) {
|
||||||
for (auto& input : op->inputs) {
|
for (auto& input : op->inputs) {
|
||||||
@ -1402,10 +1427,42 @@ void StripZeroOutputIndexFromInputs(NodeDef* node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddExtraOutputsFedIntoOtherOps(Model* model) {
|
// In TensorFlow GraphDef, when a node has multiple outputs, they are named
|
||||||
|
// name:0, name:1, ...
|
||||||
|
// where 'name' is the node's name(). Just 'name' is an equivalent shorthand
|
||||||
|
// form for name:0.
|
||||||
|
// A TensorFlow GraphDef does not explicitly list all the outputs of each node
|
||||||
|
// (unlike inputs), it being implied by the node's name and operator type
|
||||||
|
// (the latter implies the number of outputs).
|
||||||
|
// This makes it non-trivial for us to reconstruct the list of all arrays
|
||||||
|
// present in the graph and, for each operator, the list of its outputs.
|
||||||
|
// We do that by taking advantage of the fact that
|
||||||
|
// at least each node lists explicitly its inputs, so after we've loaded
|
||||||
|
// all nodes, we can use that information.
|
||||||
|
void AddExtraOutputs(Model* model) {
|
||||||
|
// Construct the list of all arrays consumed by anything in the graph.
|
||||||
|
std::vector<string> consumed_arrays;
|
||||||
|
// Add arrays consumed by an op.
|
||||||
for (const auto& consumer_op : model->operators) {
|
for (const auto& consumer_op : model->operators) {
|
||||||
for (const string& input : consumer_op->inputs) {
|
for (const string& input : consumer_op->inputs) {
|
||||||
const std::vector<string>& split = absl::StrSplit(input, ':');
|
consumed_arrays.push_back(input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Add global outputs of the model.
|
||||||
|
for (const string& output_array : model->flags.output_arrays()) {
|
||||||
|
consumed_arrays.push_back(output_array);
|
||||||
|
}
|
||||||
|
// Add arrays consumed by a RNN back-edge.
|
||||||
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||||
|
consumed_arrays.push_back(rnn_state.back_edge_source_array());
|
||||||
|
}
|
||||||
|
// Now add operator outputs so that all arrays that are consumed,
|
||||||
|
// are produced.
|
||||||
|
for (const string& consumed_array : consumed_arrays) {
|
||||||
|
// Split the consumed array name into the form name:output_index.
|
||||||
|
const std::vector<string>& split = absl::StrSplit(consumed_array, ':');
|
||||||
|
// If not of the form name:output_index, then this is not an additional
|
||||||
|
// output of a node with multiple outputs, so nothing to do here.
|
||||||
if (split.size() != 2) {
|
if (split.size() != 2) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -1413,17 +1470,20 @@ void AddExtraOutputsFedIntoOtherOps(Model* model) {
|
|||||||
if (!absl::SimpleAtoi(split[1], &output_index)) {
|
if (!absl::SimpleAtoi(split[1], &output_index)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
// Each op is initially recorded as producing at least the array that
|
||||||
|
// has its name. We use that to identify the producer node.
|
||||||
auto* producer_op = GetOpWithOutput(*model, split[0]);
|
auto* producer_op = GetOpWithOutput(*model, split[0]);
|
||||||
if (!producer_op) {
|
if (!producer_op) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
// Add extra outputs to that producer node, all the way to the
|
||||||
|
// output_index.
|
||||||
while (producer_op->outputs.size() <= output_index) {
|
while (producer_op->outputs.size() <= output_index) {
|
||||||
using toco::port::StringF;
|
using toco::port::StringF;
|
||||||
producer_op->outputs.push_back(
|
producer_op->outputs.push_back(
|
||||||
StringF("%s:%d", split[0], producer_op->outputs.size()));
|
StringF("%s:%d", split[0], producer_op->outputs.size()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool InlineAllFunctions(GraphDef* graphdef) {
|
bool InlineAllFunctions(GraphDef* graphdef) {
|
||||||
@ -1633,6 +1693,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
|||||||
ConvertMeanOperator(node, tf_import_flags, model);
|
ConvertMeanOperator(node, tf_import_flags, model);
|
||||||
} else if (node.op() == "Svdf") {
|
} else if (node.op() == "Svdf") {
|
||||||
ConvertSvdfOperator(node, tf_import_flags, model);
|
ConvertSvdfOperator(node, tf_import_flags, model);
|
||||||
|
} else if (node.op() == "NextIteration") {
|
||||||
|
ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model);
|
||||||
} else {
|
} else {
|
||||||
ConvertUnsupportedOperator(node, tf_import_flags, model);
|
ConvertUnsupportedOperator(node, tf_import_flags, model);
|
||||||
}
|
}
|
||||||
@ -1641,7 +1703,7 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
|||||||
ResolveModelFlags(model_flags, model);
|
ResolveModelFlags(model_flags, model);
|
||||||
|
|
||||||
StripCaretFromArrayNames(model);
|
StripCaretFromArrayNames(model);
|
||||||
AddExtraOutputsFedIntoOtherOps(model);
|
AddExtraOutputs(model);
|
||||||
FixNoMissingArray(model);
|
FixNoMissingArray(model);
|
||||||
FixNoOrphanedArray(model);
|
FixNoOrphanedArray(model);
|
||||||
FixOperatorOrdering(model);
|
FixOperatorOrdering(model);
|
||||||
|
@ -77,6 +77,25 @@ message InputArray {
|
|||||||
optional IODataType data_type = 5;
|
optional IODataType data_type = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message RnnState {
|
||||||
|
optional string state_array = 1;
|
||||||
|
optional string back_edge_source_array = 2;
|
||||||
|
optional bool discardable = 5;
|
||||||
|
// TODO(benoitjacob): drop the 'size' field. Should be redundant with
|
||||||
|
// --input_shapes and shapes propagation.
|
||||||
|
optional int32 size = 3;
|
||||||
|
// TODO(benoitjacob): manually_create is a temporary hack:
|
||||||
|
// due to discrepancies between the current toco dims tracking and
|
||||||
|
// TensorFlow shapes, for some models we need to manually create RNN state
|
||||||
|
// arrays with a specified shape.
|
||||||
|
// Maybe we should actually implement back-edges as operators of their own,
|
||||||
|
// which would remove the need for much special-casing, including here,
|
||||||
|
// we could probably consistently let PropagateFixedSizes handle state
|
||||||
|
// arrays.
|
||||||
|
// TODO(benoitjacob): should really drop manually_create now.
|
||||||
|
optional bool manually_create = 4;
|
||||||
|
}
|
||||||
|
|
||||||
// ModelFlags encodes properties of a model that, depending on the file
|
// ModelFlags encodes properties of a model that, depending on the file
|
||||||
// format, may or may not be recorded in the model file. The purpose of
|
// format, may or may not be recorded in the model file. The purpose of
|
||||||
// representing these properties in ModelFlags is to allow passing them
|
// representing these properties in ModelFlags is to allow passing them
|
||||||
@ -112,20 +131,6 @@ message ModelFlags {
|
|||||||
// the 'batch' field: at most one of these two fields can be set.
|
// the 'batch' field: at most one of these two fields can be set.
|
||||||
optional bool variable_batch = 10;
|
optional bool variable_batch = 10;
|
||||||
|
|
||||||
message RnnState {
|
|
||||||
optional string state_array = 1;
|
|
||||||
optional string back_edge_source_array = 2;
|
|
||||||
optional int32 size = 3;
|
|
||||||
// TODO(benoitjacob): manually_create is a temporary hack:
|
|
||||||
// due to discrepancies between the current toco dims tracking and
|
|
||||||
// TensorFlow shapes, for some models we need to manually create RNN state
|
|
||||||
// arrays with a specified shape.
|
|
||||||
// Maybe we should actually implement back-edges as operators of their own,
|
|
||||||
// which would remove the need for much special-casing, including here,
|
|
||||||
// we could probably consistently let PropagateFixedSizes handle state
|
|
||||||
// arrays.
|
|
||||||
optional bool manually_create = 4;
|
|
||||||
}
|
|
||||||
repeated RnnState rnn_states = 12;
|
repeated RnnState rnn_states = 12;
|
||||||
|
|
||||||
// Checks applied to the model, typically after toco's comprehensive
|
// Checks applied to the model, typically after toco's comprehensive
|
||||||
|
@ -573,9 +573,11 @@ void CheckNoMissingArray(const Model& model) {
|
|||||||
<< "Output array not found: " << output_array;
|
<< "Output array not found: " << output_array;
|
||||||
}
|
}
|
||||||
for (const auto& rnn_state : model.flags.rnn_states()) {
|
for (const auto& rnn_state : model.flags.rnn_states()) {
|
||||||
|
if (!rnn_state.discardable()) {
|
||||||
CHECK(model.arrays.count(rnn_state.state_array()));
|
CHECK(model.arrays.count(rnn_state.state_array()));
|
||||||
CHECK(model.arrays.count(rnn_state.back_edge_source_array()));
|
CHECK(model.arrays.count(rnn_state.back_edge_source_array()));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FixNoMissingArray(Model* model) {
|
void FixNoMissingArray(Model* model) {
|
||||||
@ -596,13 +598,19 @@ void FixNoMissingArray(Model* model) {
|
|||||||
model->GetOrCreateArray(output_array);
|
model->GetOrCreateArray(output_array);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||||
|
model->GetOrCreateArray(rnn_state.state_array());
|
||||||
|
model->GetOrCreateArray(rnn_state.back_edge_source_array());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckNoOrphanedArray(const Model& model) {
|
void CheckNoOrphanedArray(const Model& model) {
|
||||||
std::unordered_set<string> arrays_without_known_use;
|
std::unordered_set<string> arrays_without_known_use;
|
||||||
for (const auto& array : model.arrays) {
|
for (const auto& array : model.arrays) {
|
||||||
|
if (IsDiscardableArray(model, array.first)) {
|
||||||
arrays_without_known_use.insert(array.first);
|
arrays_without_known_use.insert(array.first);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
for (const auto& op : model.operators) {
|
for (const auto& op : model.operators) {
|
||||||
for (const auto& input : op->inputs) {
|
for (const auto& input : op->inputs) {
|
||||||
arrays_without_known_use.erase(input);
|
arrays_without_known_use.erase(input);
|
||||||
@ -611,6 +619,10 @@ void CheckNoOrphanedArray(const Model& model) {
|
|||||||
arrays_without_known_use.erase(output);
|
arrays_without_known_use.erase(output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (const auto& rnn_state : model.flags.rnn_states()) {
|
||||||
|
arrays_without_known_use.erase(rnn_state.state_array());
|
||||||
|
arrays_without_known_use.erase(rnn_state.back_edge_source_array());
|
||||||
|
}
|
||||||
if (!arrays_without_known_use.empty()) {
|
if (!arrays_without_known_use.empty()) {
|
||||||
for (const auto& array : arrays_without_known_use) {
|
for (const auto& array : arrays_without_known_use) {
|
||||||
LOG(INFO) << "Error: Orphaned array: " << array;
|
LOG(INFO) << "Error: Orphaned array: " << array;
|
||||||
@ -632,9 +644,15 @@ void FixNoOrphanedArray(Model* model) {
|
|||||||
arrays_without_known_use.erase(output);
|
arrays_without_known_use.erase(output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||||
|
arrays_without_known_use.erase(rnn_state.state_array());
|
||||||
|
arrays_without_known_use.erase(rnn_state.back_edge_source_array());
|
||||||
|
}
|
||||||
for (const auto& array : arrays_without_known_use) {
|
for (const auto& array : arrays_without_known_use) {
|
||||||
|
if (IsDiscardableArray(*model, array)) {
|
||||||
model->arrays.erase(array);
|
model->arrays.erase(array);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckArrayFieldsConsistent(const Model& model) {
|
void CheckArrayFieldsConsistent(const Model& model) {
|
||||||
@ -1042,16 +1060,8 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
|
|||||||
|
|
||||||
#undef RESOLVE_MODEL_FLAG
|
#undef RESOLVE_MODEL_FLAG
|
||||||
|
|
||||||
if (model->flags.rnn_states_size() == 0) {
|
if (!model_flags.rnn_states().empty()) {
|
||||||
model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
|
model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
|
||||||
} else {
|
|
||||||
CHECK_EQ(model->flags.rnn_states_size(), model_flags.rnn_states_size());
|
|
||||||
for (int i = 0; i < model->flags.rnn_states_size(); i++) {
|
|
||||||
CHECK_EQ(model->flags.rnn_states(i).state_array(),
|
|
||||||
model_flags.rnn_states(i).state_array());
|
|
||||||
CHECK_EQ(model->flags.rnn_states(i).back_edge_source_array(),
|
|
||||||
model_flags.rnn_states(i).back_edge_source_array());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model->flags.model_checks_size() == 0) {
|
if (model->flags.model_checks_size() == 0) {
|
||||||
@ -1571,6 +1581,7 @@ bool IsDiscardableArray(const Model& model, const string& array_name) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const auto& rnn_state : model.flags.rnn_states()) {
|
for (const auto& rnn_state : model.flags.rnn_states()) {
|
||||||
|
if (!rnn_state.discardable()) {
|
||||||
if (array_name == rnn_state.state_array()) {
|
if (array_name == rnn_state.state_array()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1578,6 +1589,7 @@ bool IsDiscardableArray(const Model& model, const string& array_name) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user