[TF/XLA Bridge] [NFC] Reduce the amount of boilerplate required to create GraphOptimizationPassOptions

PiperOrigin-RevId: 272090359
This commit is contained in:
George Karpenkov 2019-09-30 16:21:56 -07:00 committed by TensorFlower Gardener
parent b81191bc81
commit 19156f6e40
6 changed files with 49 additions and 39 deletions

View File

@ -747,6 +747,7 @@ tf_cc_test(
":encapsulate_util", ":encapsulate_util",
":flags", ":flags",
":node_matchers", ":node_matchers",
":test_util",
":xla_cluster_util", ":xla_cluster_util",
":xla_cpu_device", ":xla_cpu_device",
":xla_gpu_device", ":xla_gpu_device",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/compiler/jit/node_matchers.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
@ -72,11 +73,11 @@ Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib,
FixupSourceAndSinkEdges(graph.get()); FixupSourceAndSinkEdges(graph.get());
SessionOptions session_options; GraphOptimizationPassWrapper wrapper;
GraphOptimizationPassOptions opt_options; GraphOptimizationPassOptions opt_options =
opt_options.session_options = &session_options; wrapper.CreateGraphOptimizationPassOptions(&graph);
opt_options.flib_def = &flib_def; opt_options.flib_def = &flib_def;
opt_options.graph = &graph;
BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true); BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true);
TF_RETURN_IF_ERROR(pass.Run(opt_options)); TF_RETURN_IF_ERROR(pass.Run(opt_options));
VLOG(3) << graph->ToGraphDefDebug().DebugString(); VLOG(3) << graph->ToGraphDefDebug().DebugString();

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
@ -33,17 +34,12 @@ namespace {
Status ClusterScoping(std::unique_ptr<Graph>* graph) { Status ClusterScoping(std::unique_ptr<Graph>* graph) {
FixupSourceAndSinkEdges(graph->get()); FixupSourceAndSinkEdges(graph->get());
GraphOptimizationPassOptions opt_options; GraphOptimizationPassWrapper wrapper;
opt_options.graph = graph; wrapper.session_options.config.mutable_graph_options()
FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
opt_options.flib_def = &flib_def;
SessionOptions session_options;
session_options.env = Env::Default();
session_options.config.mutable_graph_options()
->mutable_optimizer_options() ->mutable_optimizer_options()
->set_global_jit_level(OptimizerOptions::ON_2); ->set_global_jit_level(OptimizerOptions::ON_2);
opt_options.session_options = &session_options; GraphOptimizationPassOptions opt_options =
wrapper.CreateGraphOptimizationPassOptions(graph);
ClusterScopingPass pass; ClusterScopingPass pass;
return pass.Run(opt_options); return pass.Run(opt_options);

View File

@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
@ -2600,13 +2601,9 @@ TEST(EncapsulateSubgraphsTest, RefVariablesMarked) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global()); auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(root.ToGraph(graph.get()));
SessionOptions session_options; GraphOptimizationPassWrapper wrapper;
session_options.env = Env::Default(); GraphOptimizationPassOptions options =
GraphOptimizationPassOptions options; wrapper.CreateGraphOptimizationPassOptions(&graph);
options.session_options = &session_options;
FunctionLibraryDefinition library(OpRegistry::Global(), {});
options.flib_def = &library;
options.graph = &graph;
EncapsulateSubgraphsPass pass; EncapsulateSubgraphsPass pass;
TF_ASSERT_OK(pass.Run(options)); TF_ASSERT_OK(pass.Run(options));
@ -2634,15 +2631,9 @@ TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global()); auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(root.ToGraph(graph.get()));
// TODO(cheshire): reduce boilerplate for creating GraphOptimizationPassWrapper wrapper;
// GraphOptimizationPassOptions here and elsewhere, probably using a macro. GraphOptimizationPassOptions options =
SessionOptions session_options; wrapper.CreateGraphOptimizationPassOptions(&graph);
session_options.env = Env::Default();
GraphOptimizationPassOptions options;
options.session_options = &session_options;
FunctionLibraryDefinition library(OpRegistry::Global(), {});
options.flib_def = &library;
options.graph = &graph;
EncapsulateSubgraphsPass pass; EncapsulateSubgraphsPass pass;
TF_ASSERT_OK(pass.Run(options)); TF_ASSERT_OK(pass.Run(options));

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -90,14 +91,10 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
} }
} }
GraphOptimizationPassOptions opt_options; GraphOptimizationPassWrapper wrapper;
opt_options.graph = graph; GraphOptimizationPassOptions opt_options =
FunctionDefLibrary fdef_lib; wrapper.CreateGraphOptimizationPassOptions(graph);
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
opt_options.flib_def = &flib_def;
SessionOptions session_options;
session_options.env = Env::Default();
opt_options.session_options = &session_options;
PartiallyDeclusterPass pass; PartiallyDeclusterPass pass;
return pass.Run(opt_options); return pass.Run(opt_options);
} }

View File

@ -23,10 +23,13 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/jit/shape_inference.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow { namespace tensorflow {
@ -38,6 +41,27 @@ Status ShapeAnnotationsMatch(
const Graph& graph, const GraphShapeInfo& shape_info, const Graph& graph, const GraphShapeInfo& shape_info,
std::map<string, std::vector<PartialTensorShape>> expected_shapes); std::map<string, std::vector<PartialTensorShape>> expected_shapes);
// A helper object to create GraphOptimizationPassOptions.
struct GraphOptimizationPassWrapper {
explicit GraphOptimizationPassWrapper() : library(OpRegistry::Global(), {}) {
session_options.env = Env::Default();
}
// Create GraphOptimizationPassOptions with a graph passed in constructor and
// sensible options.
GraphOptimizationPassOptions CreateGraphOptimizationPassOptions(
std::unique_ptr<Graph>* graph) {
GraphOptimizationPassOptions options;
options.session_options = &session_options;
options.flib_def = &library;
options.graph = graph;
return options;
}
FunctionLibraryDefinition library;
SessionOptions session_options;
};
} // namespace tensorflow } // namespace tensorflow