[TF/XLA Bridge] [NFC] Reduce the amount of boilerplate required to create GraphOptimizationPassOptions
PiperOrigin-RevId: 272090359
This commit is contained in:
parent
b81191bc81
commit
19156f6e40
@ -747,6 +747,7 @@ tf_cc_test(
|
|||||||
":encapsulate_util",
|
":encapsulate_util",
|
||||||
":flags",
|
":flags",
|
||||||
":node_matchers",
|
":node_matchers",
|
||||||
|
":test_util",
|
||||||
":xla_cluster_util",
|
":xla_cluster_util",
|
||||||
":xla_cpu_device",
|
":xla_cpu_device",
|
||||||
":xla_gpu_device",
|
":xla_gpu_device",
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||||
#include "tensorflow/compiler/jit/node_matchers.h"
|
#include "tensorflow/compiler/jit/node_matchers.h"
|
||||||
|
#include "tensorflow/compiler/jit/test_util.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -72,11 +73,11 @@ Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib,
|
|||||||
|
|
||||||
FixupSourceAndSinkEdges(graph.get());
|
FixupSourceAndSinkEdges(graph.get());
|
||||||
|
|
||||||
SessionOptions session_options;
|
GraphOptimizationPassWrapper wrapper;
|
||||||
GraphOptimizationPassOptions opt_options;
|
GraphOptimizationPassOptions opt_options =
|
||||||
opt_options.session_options = &session_options;
|
wrapper.CreateGraphOptimizationPassOptions(&graph);
|
||||||
opt_options.flib_def = &flib_def;
|
opt_options.flib_def = &flib_def;
|
||||||
opt_options.graph = &graph;
|
|
||||||
BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true);
|
BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true);
|
||||||
TF_RETURN_IF_ERROR(pass.Run(opt_options));
|
TF_RETURN_IF_ERROR(pass.Run(opt_options));
|
||||||
VLOG(3) << graph->ToGraphDefDebug().DebugString();
|
VLOG(3) << graph->ToGraphDefDebug().DebugString();
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
|
#include "tensorflow/compiler/jit/test_util.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
@ -33,17 +34,12 @@ namespace {
|
|||||||
Status ClusterScoping(std::unique_ptr<Graph>* graph) {
|
Status ClusterScoping(std::unique_ptr<Graph>* graph) {
|
||||||
FixupSourceAndSinkEdges(graph->get());
|
FixupSourceAndSinkEdges(graph->get());
|
||||||
|
|
||||||
GraphOptimizationPassOptions opt_options;
|
GraphOptimizationPassWrapper wrapper;
|
||||||
opt_options.graph = graph;
|
wrapper.session_options.config.mutable_graph_options()
|
||||||
FunctionDefLibrary fdef_lib;
|
|
||||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
|
||||||
opt_options.flib_def = &flib_def;
|
|
||||||
SessionOptions session_options;
|
|
||||||
session_options.env = Env::Default();
|
|
||||||
session_options.config.mutable_graph_options()
|
|
||||||
->mutable_optimizer_options()
|
->mutable_optimizer_options()
|
||||||
->set_global_jit_level(OptimizerOptions::ON_2);
|
->set_global_jit_level(OptimizerOptions::ON_2);
|
||||||
opt_options.session_options = &session_options;
|
GraphOptimizationPassOptions opt_options =
|
||||||
|
wrapper.CreateGraphOptimizationPassOptions(graph);
|
||||||
|
|
||||||
ClusterScopingPass pass;
|
ClusterScopingPass pass;
|
||||||
return pass.Run(opt_options);
|
return pass.Run(opt_options);
|
||||||
|
@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
|
||||||
|
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/cc/framework/ops.h"
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/compiler/jit/encapsulate_util.h"
|
#include "tensorflow/compiler/jit/encapsulate_util.h"
|
||||||
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
|
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
|
||||||
|
#include "tensorflow/compiler/jit/test_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
@ -2600,13 +2601,9 @@ TEST(EncapsulateSubgraphsTest, RefVariablesMarked) {
|
|||||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
SessionOptions session_options;
|
GraphOptimizationPassWrapper wrapper;
|
||||||
session_options.env = Env::Default();
|
GraphOptimizationPassOptions options =
|
||||||
GraphOptimizationPassOptions options;
|
wrapper.CreateGraphOptimizationPassOptions(&graph);
|
||||||
options.session_options = &session_options;
|
|
||||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
|
||||||
options.flib_def = &library;
|
|
||||||
options.graph = &graph;
|
|
||||||
|
|
||||||
EncapsulateSubgraphsPass pass;
|
EncapsulateSubgraphsPass pass;
|
||||||
TF_ASSERT_OK(pass.Run(options));
|
TF_ASSERT_OK(pass.Run(options));
|
||||||
@ -2634,15 +2631,9 @@ TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) {
|
|||||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
// TODO(cheshire): reduce boilerplate for creating
|
GraphOptimizationPassWrapper wrapper;
|
||||||
// GraphOptimizationPassOptions here and elsewhere, probably using a macro.
|
GraphOptimizationPassOptions options =
|
||||||
SessionOptions session_options;
|
wrapper.CreateGraphOptimizationPassOptions(&graph);
|
||||||
session_options.env = Env::Default();
|
|
||||||
GraphOptimizationPassOptions options;
|
|
||||||
options.session_options = &session_options;
|
|
||||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
|
||||||
options.flib_def = &library;
|
|
||||||
options.graph = &graph;
|
|
||||||
|
|
||||||
EncapsulateSubgraphsPass pass;
|
EncapsulateSubgraphsPass pass;
|
||||||
TF_ASSERT_OK(pass.Run(options));
|
TF_ASSERT_OK(pass.Run(options));
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
|
#include "tensorflow/compiler/jit/test_util.h"
|
||||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
|
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
@ -90,14 +91,10 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphOptimizationPassOptions opt_options;
|
GraphOptimizationPassWrapper wrapper;
|
||||||
opt_options.graph = graph;
|
GraphOptimizationPassOptions opt_options =
|
||||||
FunctionDefLibrary fdef_lib;
|
wrapper.CreateGraphOptimizationPassOptions(graph);
|
||||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
|
||||||
opt_options.flib_def = &flib_def;
|
|
||||||
SessionOptions session_options;
|
|
||||||
session_options.env = Env::Default();
|
|
||||||
opt_options.session_options = &session_options;
|
|
||||||
PartiallyDeclusterPass pass;
|
PartiallyDeclusterPass pass;
|
||||||
return pass.Run(opt_options);
|
return pass.Run(opt_options);
|
||||||
}
|
}
|
||||||
|
@ -23,10 +23,13 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||||
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -38,6 +41,27 @@ Status ShapeAnnotationsMatch(
|
|||||||
const Graph& graph, const GraphShapeInfo& shape_info,
|
const Graph& graph, const GraphShapeInfo& shape_info,
|
||||||
std::map<string, std::vector<PartialTensorShape>> expected_shapes);
|
std::map<string, std::vector<PartialTensorShape>> expected_shapes);
|
||||||
|
|
||||||
|
// A helper object to create GraphOptimizationPassOptions.
|
||||||
|
struct GraphOptimizationPassWrapper {
|
||||||
|
explicit GraphOptimizationPassWrapper() : library(OpRegistry::Global(), {}) {
|
||||||
|
session_options.env = Env::Default();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create GraphOptimizationPassOptions with a graph passed in constructor and
|
||||||
|
// sensible options.
|
||||||
|
GraphOptimizationPassOptions CreateGraphOptimizationPassOptions(
|
||||||
|
std::unique_ptr<Graph>* graph) {
|
||||||
|
GraphOptimizationPassOptions options;
|
||||||
|
options.session_options = &session_options;
|
||||||
|
options.flib_def = &library;
|
||||||
|
options.graph = graph;
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionLibraryDefinition library;
|
||||||
|
SessionOptions session_options;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user