[TF/XLA Bridge] [NFC] Reduce the amount of boilerplate required to create GraphOptimizationPassOptions

PiperOrigin-RevId: 272090359
This commit is contained in:
George Karpenkov 2019-09-30 16:21:56 -07:00 committed by TensorFlower Gardener
parent b81191bc81
commit 19156f6e40
6 changed files with 49 additions and 39 deletions

View File

@ -747,6 +747,7 @@ tf_cc_test(
":encapsulate_util",
":flags",
":node_matchers",
":test_util",
":xla_cluster_util",
":xla_cpu_device",
":xla_gpu_device",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/node_matchers.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -72,11 +73,11 @@ Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib,
FixupSourceAndSinkEdges(graph.get());
SessionOptions session_options;
GraphOptimizationPassOptions opt_options;
opt_options.session_options = &session_options;
GraphOptimizationPassWrapper wrapper;
GraphOptimizationPassOptions opt_options =
wrapper.CreateGraphOptimizationPassOptions(&graph);
opt_options.flib_def = &flib_def;
opt_options.graph = &graph;
BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true);
TF_RETURN_IF_ERROR(pass.Run(opt_options));
VLOG(3) << graph->ToGraphDefDebug().DebugString();

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/algorithm.h"
@ -33,17 +34,12 @@ namespace {
Status ClusterScoping(std::unique_ptr<Graph>* graph) {
FixupSourceAndSinkEdges(graph->get());
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
opt_options.flib_def = &flib_def;
SessionOptions session_options;
session_options.env = Env::Default();
session_options.config.mutable_graph_options()
GraphOptimizationPassWrapper wrapper;
wrapper.session_options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_global_jit_level(OptimizerOptions::ON_2);
opt_options.session_options = &session_options;
GraphOptimizationPassOptions opt_options =
wrapper.CreateGraphOptimizationPassOptions(graph);
ClusterScopingPass pass;
return pass.Run(opt_options);

View File

@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include <memory>
#include <utility>
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
@ -2600,13 +2601,9 @@ TEST(EncapsulateSubgraphsTest, RefVariablesMarked) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
SessionOptions session_options;
session_options.env = Env::Default();
GraphOptimizationPassOptions options;
options.session_options = &session_options;
FunctionLibraryDefinition library(OpRegistry::Global(), {});
options.flib_def = &library;
options.graph = &graph;
GraphOptimizationPassWrapper wrapper;
GraphOptimizationPassOptions options =
wrapper.CreateGraphOptimizationPassOptions(&graph);
EncapsulateSubgraphsPass pass;
TF_ASSERT_OK(pass.Run(options));
@ -2634,15 +2631,9 @@ TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
// TODO(cheshire): reduce boilerplate for creating
// GraphOptimizationPassOptions here and elsewhere, probably using a macro.
SessionOptions session_options;
session_options.env = Env::Default();
GraphOptimizationPassOptions options;
options.session_options = &session_options;
FunctionLibraryDefinition library(OpRegistry::Global(), {});
options.flib_def = &library;
options.graph = &graph;
GraphOptimizationPassWrapper wrapper;
GraphOptimizationPassOptions options =
wrapper.CreateGraphOptimizationPassOptions(&graph);
EncapsulateSubgraphsPass pass;
TF_ASSERT_OK(pass.Run(options));

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -90,14 +91,10 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
}
}
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
opt_options.flib_def = &flib_def;
SessionOptions session_options;
session_options.env = Env::Default();
opt_options.session_options = &session_options;
GraphOptimizationPassWrapper wrapper;
GraphOptimizationPassOptions opt_options =
wrapper.CreateGraphOptimizationPassOptions(graph);
PartiallyDeclusterPass pass;
return pass.Run(opt_options);
}

View File

@ -23,10 +23,13 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/jit/shape_inference.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
@ -38,6 +41,27 @@ Status ShapeAnnotationsMatch(
const Graph& graph, const GraphShapeInfo& shape_info,
std::map<string, std::vector<PartialTensorShape>> expected_shapes);
// A helper object to create GraphOptimizationPassOptions.
struct GraphOptimizationPassWrapper {
explicit GraphOptimizationPassWrapper() : library(OpRegistry::Global(), {}) {
session_options.env = Env::Default();
}
// Create GraphOptimizationPassOptions with a graph passed in constructor and
// sensible options.
GraphOptimizationPassOptions CreateGraphOptimizationPassOptions(
std::unique_ptr<Graph>* graph) {
GraphOptimizationPassOptions options;
options.session_options = &session_options;
options.flib_def = &library;
options.graph = graph;
return options;
}
FunctionLibraryDefinition library;
SessionOptions session_options;
};
} // namespace tensorflow