1) Previously, UpgradeLegacyGraph only functionalized control flow in the graph, not control flow in functions that are called from the graph, which caused problems because subsequent steps didn't expect and correctly handle remaining v1 control flow. Now such control flow is functionalized, too. 2) Refactored existing functionalization code. 3) Fixed bug in existing functionalization code: In certain cases the definition of a modified function (after its control flow was functionalized) could not be found when the calling graph node was rewritten. PiperOrigin-RevId: 326755095 Change-Id: I7163f7f0359f5de978ec5d764949978a2341cbaa
77 lines
3.4 KiB
C++
77 lines
3.4 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
|
|
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
|
|
|
|
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
|
#include "tensorflow/compiler/xla/status_macros.h"
|
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
|
#include "tensorflow/core/framework/function.h"
|
|
#include "tensorflow/core/graph/graph.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
// Transformation that converts tf.while_loop() loops into functional While
|
|
// operators and tf.cond() conditionals into function If operators, suitable for
|
|
// XLA compilation.
|
|
//
|
|
// If `node_filter` is defined, then only loops and conditions for whose
|
|
// nodes `node_filter` returns true are functionalized.
|
|
|
|
// If `include_functions` is true, then loops and conditions inside of functions
|
|
// that are associated with nodes in `graph` (e.g., a function called from a
|
|
// node in `graph`) are also functionalized, otherwise they are not.
|
|
// This also handles transitive cases, e.g., a function body will be
|
|
// functionalized when it is called in another function that is called by some
|
|
// node in `graph` (and so on). The node filter also applies here.
|
|
//
|
|
// Precondition:
|
|
// For any node in a loop or condition for which `node_filter` returns true,
|
|
// all nodes inside of the same loop or condition must also return true
|
|
// (including nodes in other nested loops and conditions inside of that loop or
|
|
// condition).
|
|
// This means that a "not to be functionalized" loop or condition is not allowed
|
|
// inside a "to be functionalized" loop or condition.
|
|
//
|
|
// The user of this function is responsible for using a node filter that
|
|
// satisfies the above conditions.
|
|
Status FunctionalizeControlFlow(Graph* graph,
|
|
FunctionLibraryDefinition* library,
|
|
const NodeFilter& node_filter = {},
|
|
bool include_functions = false);
|
|
|
|
Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
|
|
FunctionLibraryDefinition* library,
|
|
const NodeFilter& node_filter = {},
|
|
bool include_functions = false);
|
|
|
|
// This pass looks at the graph, and turns V1 control flow structure
|
|
// (Switch/Merge/etc.) into V2 control flow structure (If/While).
|
|
class FunctionalizeControlFlowPass : public GraphOptimizationPass {
|
|
public:
|
|
Status Run(const GraphOptimizationPassOptions& options) override;
|
|
};
|
|
|
|
// Same as the above but only modifies functions that will be executed by XLA.
|
|
class FunctionalizeControlFlowForXlaPass : public GraphOptimizationPass {
|
|
public:
|
|
Status Run(const GraphOptimizationPassOptions& options) override;
|
|
};
|
|
|
|
} // namespace tensorflow
|
|
|
|
#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
|