[TF/XLA Bridge] [NFC] Reduce the amount of boilerplate required to create GraphOptimizationPassOptions
PiperOrigin-RevId: 272090359
This commit is contained in:
parent
b81191bc81
commit
19156f6e40
@ -747,6 +747,7 @@ tf_cc_test(
|
||||
":encapsulate_util",
|
||||
":flags",
|
||||
":node_matchers",
|
||||
":test_util",
|
||||
":xla_cluster_util",
|
||||
":xla_cpu_device",
|
||||
":xla_gpu_device",
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/node_matchers.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -72,11 +73,11 @@ Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib,
|
||||
|
||||
FixupSourceAndSinkEdges(graph.get());
|
||||
|
||||
SessionOptions session_options;
|
||||
GraphOptimizationPassOptions opt_options;
|
||||
opt_options.session_options = &session_options;
|
||||
GraphOptimizationPassWrapper wrapper;
|
||||
GraphOptimizationPassOptions opt_options =
|
||||
wrapper.CreateGraphOptimizationPassOptions(&graph);
|
||||
opt_options.flib_def = &flib_def;
|
||||
opt_options.graph = &graph;
|
||||
|
||||
BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true);
|
||||
TF_RETURN_IF_ERROR(pass.Run(opt_options));
|
||||
VLOG(3) << graph->ToGraphDefDebug().DebugString();
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
@ -33,17 +34,12 @@ namespace {
|
||||
Status ClusterScoping(std::unique_ptr<Graph>* graph) {
|
||||
FixupSourceAndSinkEdges(graph->get());
|
||||
|
||||
GraphOptimizationPassOptions opt_options;
|
||||
opt_options.graph = graph;
|
||||
FunctionDefLibrary fdef_lib;
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
opt_options.flib_def = &flib_def;
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
session_options.config.mutable_graph_options()
|
||||
GraphOptimizationPassWrapper wrapper;
|
||||
wrapper.session_options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_global_jit_level(OptimizerOptions::ON_2);
|
||||
opt_options.session_options = &session_options;
|
||||
GraphOptimizationPassOptions opt_options =
|
||||
wrapper.CreateGraphOptimizationPassOptions(graph);
|
||||
|
||||
ClusterScopingPass pass;
|
||||
return pass.Run(opt_options);
|
||||
|
@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_util.h"
|
||||
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -2600,13 +2601,9 @@ TEST(EncapsulateSubgraphsTest, RefVariablesMarked) {
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
GraphOptimizationPassOptions options;
|
||||
options.session_options = &session_options;
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
options.flib_def = &library;
|
||||
options.graph = &graph;
|
||||
GraphOptimizationPassWrapper wrapper;
|
||||
GraphOptimizationPassOptions options =
|
||||
wrapper.CreateGraphOptimizationPassOptions(&graph);
|
||||
|
||||
EncapsulateSubgraphsPass pass;
|
||||
TF_ASSERT_OK(pass.Run(options));
|
||||
@ -2634,15 +2631,9 @@ TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) {
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
// TODO(cheshire): reduce boilerplate for creating
|
||||
// GraphOptimizationPassOptions here and elsewhere, probably using a macro.
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
GraphOptimizationPassOptions options;
|
||||
options.session_options = &session_options;
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
options.flib_def = &library;
|
||||
options.graph = &graph;
|
||||
GraphOptimizationPassWrapper wrapper;
|
||||
GraphOptimizationPassOptions options =
|
||||
wrapper.CreateGraphOptimizationPassOptions(&graph);
|
||||
|
||||
EncapsulateSubgraphsPass pass;
|
||||
TF_ASSERT_OK(pass.Run(options));
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
@ -90,14 +91,10 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
|
||||
}
|
||||
}
|
||||
|
||||
GraphOptimizationPassOptions opt_options;
|
||||
opt_options.graph = graph;
|
||||
FunctionDefLibrary fdef_lib;
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
opt_options.flib_def = &flib_def;
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
opt_options.session_options = &session_options;
|
||||
GraphOptimizationPassWrapper wrapper;
|
||||
GraphOptimizationPassOptions opt_options =
|
||||
wrapper.CreateGraphOptimizationPassOptions(graph);
|
||||
|
||||
PartiallyDeclusterPass pass;
|
||||
return pass.Run(opt_options);
|
||||
}
|
||||
|
@ -23,10 +23,13 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -38,6 +41,27 @@ Status ShapeAnnotationsMatch(
|
||||
const Graph& graph, const GraphShapeInfo& shape_info,
|
||||
std::map<string, std::vector<PartialTensorShape>> expected_shapes);
|
||||
|
||||
// A helper object to create GraphOptimizationPassOptions.
|
||||
struct GraphOptimizationPassWrapper {
|
||||
explicit GraphOptimizationPassWrapper() : library(OpRegistry::Global(), {}) {
|
||||
session_options.env = Env::Default();
|
||||
}
|
||||
|
||||
// Create GraphOptimizationPassOptions with a graph passed in constructor and
|
||||
// sensible options.
|
||||
GraphOptimizationPassOptions CreateGraphOptimizationPassOptions(
|
||||
std::unique_ptr<Graph>* graph) {
|
||||
GraphOptimizationPassOptions options;
|
||||
options.session_options = &session_options;
|
||||
options.flib_def = &library;
|
||||
options.graph = graph;
|
||||
return options;
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition library;
|
||||
SessionOptions session_options;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user