From 19156f6e40ad7f84b871903c5aea5bcfc03675a4 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Mon, 30 Sep 2019 16:21:56 -0700 Subject: [PATCH] [TF/XLA Bridge] [NFC] Reduce the amount of boilerplate required to create GraphOptimizationPassOptions PiperOrigin-RevId: 272090359 --- tensorflow/compiler/jit/BUILD | 1 + .../compiler/jit/build_xla_ops_pass_test.cc | 9 ++++--- .../compiler/jit/cluster_scoping_pass_test.cc | 14 ++++------ .../jit/encapsulate_subgraphs_pass_test.cc | 27 +++++++------------ .../jit/partially_decluster_pass_test.cc | 13 ++++----- tensorflow/compiler/jit/test_util.h | 24 +++++++++++++++++ 6 files changed, 49 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 9bb87a77340..b08588f7332 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -747,6 +747,7 @@ tf_cc_test( ":encapsulate_util", ":flags", ":node_matchers", + ":test_util", ":xla_cluster_util", ":xla_cpu_device", ":xla_gpu_device", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index f434feb18a4..9a3863a7615 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/compiler/jit/test_util.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -72,11 +73,11 @@ Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib, FixupSourceAndSinkEdges(graph.get()); - SessionOptions session_options; - GraphOptimizationPassOptions opt_options; - opt_options.session_options = &session_options; + GraphOptimizationPassWrapper wrapper; + GraphOptimizationPassOptions opt_options = + wrapper.CreateGraphOptimizationPassOptions(&graph); opt_options.flib_def = &flib_def; - opt_options.graph = &graph; + BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true); TF_RETURN_IF_ERROR(pass.Run(opt_options)); VLOG(3) << graph->ToGraphDefDebug().DebugString(); diff --git a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc index b3e63b8c298..5798d519bd7 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/test_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/algorithm.h" @@ -33,17 +34,12 @@ namespace { Status ClusterScoping(std::unique_ptr* graph) { FixupSourceAndSinkEdges(graph->get()); - GraphOptimizationPassOptions opt_options; - opt_options.graph = graph; - FunctionDefLibrary fdef_lib; - FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); - opt_options.flib_def = &flib_def; - SessionOptions session_options; - session_options.env = Env::Default(); - session_options.config.mutable_graph_options() + GraphOptimizationPassWrapper wrapper; + wrapper.session_options.config.mutable_graph_options() ->mutable_optimizer_options() ->set_global_jit_level(OptimizerOptions::ON_2); - opt_options.session_options = &session_options; + GraphOptimizationPassOptions opt_options = + wrapper.CreateGraphOptimizationPassOptions(graph); ClusterScopingPass pass; return pass.Run(opt_options); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index d3d6cd96f97..0f7cee518d4 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -13,17 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" + #include #include -#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" - #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" +#include "tensorflow/compiler/jit/test_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" @@ -2600,13 +2601,9 @@ TEST(EncapsulateSubgraphsTest, RefVariablesMarked) { auto graph = absl::make_unique(OpRegistry::Global()); TF_ASSERT_OK(root.ToGraph(graph.get())); - SessionOptions session_options; - session_options.env = Env::Default(); - GraphOptimizationPassOptions options; - options.session_options = &session_options; - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - options.flib_def = &library; - options.graph = &graph; + GraphOptimizationPassWrapper wrapper; + GraphOptimizationPassOptions options = + wrapper.CreateGraphOptimizationPassOptions(&graph); EncapsulateSubgraphsPass pass; TF_ASSERT_OK(pass.Run(options)); @@ -2634,15 +2631,9 @@ TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) { auto graph = absl::make_unique(OpRegistry::Global()); TF_ASSERT_OK(root.ToGraph(graph.get())); - // TODO(cheshire): reduce boilerplate for creating - // GraphOptimizationPassOptions here and elsewhere, probably using a macro. - SessionOptions session_options; - session_options.env = Env::Default(); - GraphOptimizationPassOptions options; - options.session_options = &session_options; - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - options.flib_def = &library; - options.graph = &graph; + GraphOptimizationPassWrapper wrapper; + GraphOptimizationPassOptions options = + wrapper.CreateGraphOptimizationPassOptions(&graph); EncapsulateSubgraphsPass pass; TF_ASSERT_OK(pass.Run(options)); diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index a9c44fb1cb7..d352ec8977b 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/test_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -90,14 +91,10 @@ Status PartiallyDecluster(std::unique_ptr* graph) { } } - GraphOptimizationPassOptions opt_options; - opt_options.graph = graph; - FunctionDefLibrary fdef_lib; - FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); - opt_options.flib_def = &flib_def; - SessionOptions session_options; - session_options.env = Env::Default(); - opt_options.session_options = &session_options; + GraphOptimizationPassWrapper wrapper; + GraphOptimizationPassOptions opt_options = + wrapper.CreateGraphOptimizationPassOptions(graph); + PartiallyDeclusterPass pass; return pass.Run(opt_options); } diff --git a/tensorflow/compiler/jit/test_util.h b/tensorflow/compiler/jit/test_util.h index 0c9fee8f244..b5982c490df 100644 --- a/tensorflow/compiler/jit/test_util.h +++ b/tensorflow/compiler/jit/test_util.h @@ -23,10 +23,13 @@ limitations under the License. #include #include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -38,6 +41,27 @@ Status ShapeAnnotationsMatch( const Graph& graph, const GraphShapeInfo& shape_info, std::map> expected_shapes); +// A helper object to create GraphOptimizationPassOptions. +struct GraphOptimizationPassWrapper { + explicit GraphOptimizationPassWrapper() : library(OpRegistry::Global(), {}) { + session_options.env = Env::Default(); + } + + // Create GraphOptimizationPassOptions with a graph passed in constructor and + // sensible options. + GraphOptimizationPassOptions CreateGraphOptimizationPassOptions( + std::unique_ptr* graph) { + GraphOptimizationPassOptions options; + options.session_options = &session_options; + options.flib_def = &library; + options.graph = graph; + return options; + } + + FunctionLibraryDefinition library; + SessionOptions session_options; +}; + } // namespace tensorflow