TPU rewrite pass refactoring.
PiperOrigin-RevId: 321237265 Change-Id: I33c4966fc038816be9d9e9efd936e68a8400cc1a
This commit is contained in:
		
							parent
							
								
									4f3a8d84cf
								
							
						
					
					
						commit
						bdf8b75c14
					
				| @ -27,6 +27,7 @@ package_group( | ||||
|         "//tensorflow/compiler/mlir/...", | ||||
|         "//tensorflow/compiler/tests/...", | ||||
|         "//tensorflow/compiler/tf2xla/...", | ||||
|         "//tensorflow/core/tpu/...", | ||||
|         "//tensorflow/python/compiler/...", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
| @ -73,6 +73,17 @@ cc_library( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "tpu_compile_interface", | ||||
|     srcs = ["tpu_compile_interface.cc"], | ||||
|     hdrs = ["tpu_compile_interface.h"], | ||||
|     deps = [ | ||||
|         "//tensorflow/core/platform:fingerprint", | ||||
|         "//tensorflow/core/platform:logging", | ||||
|         "@com_google_absl//absl/strings", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "tpu_defs", | ||||
|     srcs = ["tpu_defs.cc"], | ||||
|  | ||||
| @ -18,6 +18,7 @@ cc_library( | ||||
|     srcs = ["tpu_rewrite_pass_registration.cc"], | ||||
|     deps = [ | ||||
|         ":distributed_tpu_configuration_rewrite_pass", | ||||
|         ":encapsulate_tpu_computations_pass", | ||||
|         ":variable_merger_pass", | ||||
|         "//tensorflow/core:core_cpu", | ||||
|     ], | ||||
| @ -93,3 +94,56 @@ tf_cc_test( | ||||
|         "//tensorflow/core:test_main", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "encapsulate_tpu_computations_pass", | ||||
|     srcs = [ | ||||
|         "encapsulate_tpu_computations_pass.cc", | ||||
|     ], | ||||
|     hdrs = [ | ||||
|         "encapsulate_tpu_computations_pass.h", | ||||
|     ], | ||||
|     deps = [ | ||||
|         "//tensorflow/compiler/jit:compilation_passes", | ||||
|         "//tensorflow/compiler/jit:encapsulate_util", | ||||
|         "//tensorflow/compiler/tf2xla:side_effect_util", | ||||
|         "//tensorflow/compiler/tf2xla:tf2xla_util", | ||||
|         "//tensorflow/compiler/xla:status_macros", | ||||
|         "//tensorflow/core:core_cpu_lib", | ||||
|         "//tensorflow/core:framework", | ||||
|         "//tensorflow/core:lib", | ||||
|         "//tensorflow/core:lib_internal", | ||||
|         "//tensorflow/core:protos_all_cc", | ||||
|         "//tensorflow/core:session_options", | ||||
|         "//tensorflow/core/tpu:tpu_compile_interface", | ||||
|         "//tensorflow/core/tpu:tpu_defs", | ||||
|         "@com_google_absl//absl/container:flat_hash_map", | ||||
|         "@com_google_absl//absl/container:flat_hash_set", | ||||
|         "@com_google_absl//absl/container:node_hash_map", | ||||
|         "@com_google_absl//absl/memory", | ||||
|         "@com_google_absl//absl/strings", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_cc_test( | ||||
|     name = "encapsulate_tpu_computations_pass_test", | ||||
|     srcs = ["encapsulate_tpu_computations_pass_test.cc"], | ||||
|     deps = [ | ||||
|         ":encapsulate_tpu_computations_pass", | ||||
|         "//tensorflow/cc:cc_ops", | ||||
|         "//tensorflow/cc:function_ops", | ||||
|         "//tensorflow/cc:resource_variable_ops", | ||||
|         "//tensorflow/cc:tpu_ops", | ||||
|         "//tensorflow/compiler/jit:compilation_passes", | ||||
|         "//tensorflow/compiler/tf2xla:test_util", | ||||
|         "//tensorflow/core:core_cpu", | ||||
|         "//tensorflow/core:core_cpu_impl", | ||||
|         "//tensorflow/core:framework", | ||||
|         "//tensorflow/core:framework_internal", | ||||
|         "//tensorflow/core:lib_internal", | ||||
|         "//tensorflow/core:ops", | ||||
|         "//tensorflow/core:test", | ||||
|         "//tensorflow/core:test_main", | ||||
|         "//tensorflow/core/tpu:tpu_defs", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -0,0 +1,73 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| // Rewrites computations generated by the tpu.replicate() Python code into
 | ||||
| // TPUReplicate operators.
 | ||||
| //
 | ||||
| // The tpu.replicate() does two main things:
 | ||||
| // a) marks operators that make up a TPU computation with the attribute
 | ||||
| //    _tpu_replicate=XYZ, where XYZ is a unique key.
 | ||||
| // b) adds TPUReplicatedInput and TPUReplicatedOutput nodes to represent
 | ||||
| //    replicated inputs. These nodes are not marked with the _tpu_replicate
 | ||||
| //    attribute.
 | ||||
| 
 | ||||
| #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITES_ENCAPSULATE_TPU_COMPUTATIONS_PASS_H_ | ||||
| #define TENSORFLOW_CORE_TPU_GRAPH_REWRITES_ENCAPSULATE_TPU_COMPUTATIONS_PASS_H_ | ||||
| 
 | ||||
| #include "tensorflow/compiler/jit/encapsulate_util.h" | ||||
| #include "tensorflow/core/common_runtime/optimization_registry.h" | ||||
| #include "tensorflow/core/graph/graph.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| 
 | ||||
| // Encapsulates nodes marked with the _tpu_replicate attribute into
 | ||||
| // TPUReplicate operators.
 | ||||
| class EncapsulateTPUComputationsPass : public GraphOptimizationPass { | ||||
|  public: | ||||
|   Status Run(const GraphOptimizationPassOptions& options) override; | ||||
| 
 | ||||
|   // The following methods are public only for unit tests.
 | ||||
| 
 | ||||
|   // This pass has two stages:
 | ||||
|   // a) first, we call the EncapsulateSubgraphsPass to encapsulate all nodes
 | ||||
|   //    marked with the same _tpu_replicate attribute into functions. These
 | ||||
|   //    functions contain the computations to be passed to TPUReplicate. During
 | ||||
|   //    encapsulation, we sort the arguments into the order expected by
 | ||||
|   //    TPUReplicate.
 | ||||
|   static Status Encapsulate(std::unique_ptr<Graph>* graph, | ||||
|                             FunctionLibraryDefinition* flib_def); | ||||
| 
 | ||||
|   // b) we rewrite the function calls generated in phase (a) into TPUReplicate
 | ||||
|   //    operators. We also flatten the TPUReplicatedInput and
 | ||||
|   //    TPUReplicatedOutput replicated input and output nodes of the function
 | ||||
|   //    call into the replicated input and outputs of the TPUReplicate operator.
 | ||||
|   static Status BuildTPUReplicateOps(Graph* graph); | ||||
| }; | ||||
| 
 | ||||
| // Graph optimization pass that calls `ExtractOutsideCompilation` for all XLA
 | ||||
| // computation nodes.
 | ||||
| class ExtractOutsideCompilationPass : public GraphOptimizationPass { | ||||
|  public: | ||||
|   Status Run(const GraphOptimizationPassOptions& options) override; | ||||
| 
 | ||||
|   static Status ProcessHeadTailOutsideCompilation( | ||||
|       const string& outside_compilation_attr_name, int* lifted_arg_count, | ||||
|       std::unordered_map<string, XlaClusterInfo>* clusters, Graph* g, | ||||
|       FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld); | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITES_ENCAPSULATE_TPU_COMPUTATIONS_PASS_H_
 | ||||
| @ -0,0 +1,810 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h" | ||||
| 
 | ||||
| #include "tensorflow/cc/ops/array_ops.h" | ||||
| #include "tensorflow/cc/ops/function_ops.h" | ||||
| #include "tensorflow/cc/ops/parsing_ops.h" | ||||
| #include "tensorflow/cc/ops/resource_variable_ops.h" | ||||
| #include "tensorflow/cc/ops/standard_ops.h" | ||||
| #include "tensorflow/cc/ops/tpu_replication_ops.h" | ||||
| #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" | ||||
| #include "tensorflow/compiler/tf2xla/test_util.h" | ||||
| #include "tensorflow/core/common_runtime/device_factory.h" | ||||
| #include "tensorflow/core/common_runtime/graph_constructor.h" | ||||
| #include "tensorflow/core/common_runtime/process_function_library_runtime.h" | ||||
| #include "tensorflow/core/framework/graph_to_functiondef.h" | ||||
| #include "tensorflow/core/lib/core/status_test_util.h" | ||||
| #include "tensorflow/core/lib/hash/hash.h" | ||||
| #include "tensorflow/core/lib/strings/proto_serialization.h" | ||||
| #include "tensorflow/core/platform/test.h" | ||||
| #include "tensorflow/core/public/session_options.h" | ||||
| #include "tensorflow/core/public/version.h" | ||||
| #include "tensorflow/core/tpu/tpu_defs.h" | ||||
| #include "tensorflow/core/util/equal_graph_def.h" | ||||
| #include "tensorflow/core/util/ptr_util.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| 
 | ||||
| static std::unique_ptr<Graph> MakeOuterGraph( | ||||
|     const FunctionLibraryDefinition& flib_def, const string& function) { | ||||
|   Scope scope = Scope::NewRootScope().ExitOnError(); | ||||
|   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); | ||||
| 
 | ||||
|   int num_replicas = 2; | ||||
| 
 | ||||
|   auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); | ||||
|   auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); | ||||
|   auto b0 = ops::Placeholder(scope.WithOpName("B0"), DT_FLOAT); | ||||
|   auto b1 = ops::Placeholder(scope.WithOpName("B1"), DT_FLOAT); | ||||
|   auto u0 = ops::Placeholder(scope.WithOpName("U0"), DT_RESOURCE); | ||||
|   auto u1 = ops::Placeholder(scope.WithOpName("U1"), DT_RESOURCE); | ||||
|   auto z = ops::Placeholder(scope.WithOpName("Z"), DT_RESOURCE); | ||||
|   auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); | ||||
|   auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); | ||||
|   auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); | ||||
|   auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); | ||||
|   auto x = ops::GuaranteeConst( | ||||
|       scope.WithOpName("X"), | ||||
|       ops::Placeholder(scope.WithOpName("X_Holder"), DT_DOUBLE)); | ||||
|   auto y = ops::GuaranteeConst( | ||||
|       scope.WithOpName("Y"), | ||||
|       ops::Placeholder(scope.WithOpName("Y_Holder"), DT_DOUBLE)); | ||||
| 
 | ||||
|   auto in0 = ops::TPUReplicatedInput(scope.WithOpName("In0"), | ||||
|                                      std::initializer_list<Input>{a0, a1}); | ||||
|   auto in1 = ops::TPUReplicatedInput(scope.WithOpName("In1"), | ||||
|                                      std::initializer_list<Input>{b0, b1}); | ||||
|   auto in2 = ops::TPUReplicatedInput(scope.WithOpName("In2"), | ||||
|                                      std::initializer_list<Input>{u0, u1}); | ||||
|   auto in3 = ops::TPUReplicatedInput(scope.WithOpName("In3"), | ||||
|                                      std::initializer_list<Input>{z}); | ||||
|   in3.node()->AddAttr("is_packed", true); | ||||
| 
 | ||||
|   NodeDef def; | ||||
|   TF_CHECK_OK(NodeDefBuilder("replicate0", function, &flib_def) | ||||
|                   .Input(in0.node()->name(), 0, DT_INT32) | ||||
|                   .Input(in1.node()->name(), 0, DT_FLOAT) | ||||
|                   .Input(in2.node()->name(), 0, DT_RESOURCE) | ||||
|                   .Input(in3.node()->name(), 0, DT_RESOURCE) | ||||
|                   .Input(c.node()->name(), 0, DT_INT32) | ||||
|                   .Input(d.node()->name(), 0, DT_FLOAT) | ||||
|                   .Input(v.node()->name(), 0, DT_RESOURCE) | ||||
|                   .Input(w.node()->name(), 0, DT_RESOURCE) | ||||
|                   .Input(x.node()->name(), 0, DT_DOUBLE) | ||||
|                   .Input(y.node()->name(), 0, DT_DOUBLE) | ||||
|                   .Attr(kTPUReplicateAttr, "replicate0") | ||||
|                   .Attr("num_replicas", num_replicas) | ||||
|                   .Attr("num_cores_per_replica", 6) | ||||
|                   .Attr("topology", "") | ||||
|                   .Attr("use_tpu", true) | ||||
|                   .Attr("device_assignment", std::vector<int>()) | ||||
|                   .Attr("host_compute_core", std::vector<string>()) | ||||
|                   .Attr("padding_map", std::vector<string>()) | ||||
|                   .Attr("_variable_start_index", 6) | ||||
|                   .Attr("_guaranteed_const_start_index", 8) | ||||
|                   .Attr("allow_soft_placement", false) | ||||
|                   .Attr("step_marker_location", "STEP_MARK_AT_ENTRY") | ||||
|                   .Attr("use_spmd_for_xla_partitioning", false) | ||||
|                   .Finalize(&def)); | ||||
| 
 | ||||
|   Status status; | ||||
|   Node* replicate = scope.graph()->AddNode(def, &status); | ||||
|   TF_CHECK_OK(status); | ||||
|   TF_CHECK_OK(scope.DoShapeInference(replicate)); | ||||
|   scope.graph()->AddEdge(in0.node(), 0, replicate, 0); | ||||
|   scope.graph()->AddEdge(in1.node(), 0, replicate, 1); | ||||
|   scope.graph()->AddEdge(in2.node(), 0, replicate, 2); | ||||
|   scope.graph()->AddEdge(in3.node(), 0, replicate, 3); | ||||
|   scope.graph()->AddEdge(c.node(), 0, replicate, 4); | ||||
|   scope.graph()->AddEdge(d.node(), 0, replicate, 5); | ||||
|   scope.graph()->AddEdge(v.node(), 0, replicate, 6); | ||||
|   scope.graph()->AddEdge(w.node(), 0, replicate, 7); | ||||
|   scope.graph()->AddEdge(x.node(), 0, replicate, 8); | ||||
|   scope.graph()->AddEdge(y.node(), 0, replicate, 9); | ||||
| 
 | ||||
|   auto out0 = ops::TPUReplicatedOutput(scope.WithOpName("Out0"), | ||||
|                                        Output(replicate, 0), num_replicas); | ||||
|   auto out1 = ops::TPUReplicatedOutput(scope.WithOpName("Out1"), | ||||
|                                        Output(replicate, 1), num_replicas); | ||||
|   auto out2 = ops::TPUReplicatedOutput(scope.WithOpName("Out2"), | ||||
|                                        Output(replicate, 2), num_replicas); | ||||
|   auto out3 = ops::TPUReplicatedOutput(scope.WithOpName("Out3"), | ||||
|                                        Output(replicate, 3), num_replicas); | ||||
|   auto out4 = ops::TPUReplicatedOutput(scope.WithOpName("Out4"), | ||||
|                                        Output(replicate, 4), num_replicas); | ||||
| 
 | ||||
|   auto consumer0_0a = ops::Identity(scope.WithOpName("consumer0_0a"), out0[0]); | ||||
|   auto consumer0_0b = ops::Identity(scope.WithOpName("consumer0_0b"), out0[0]); | ||||
|   auto consumer0_1 = ops::Identity(scope.WithOpName("consumer0_1"), out0[1]); | ||||
|   auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1[1]); | ||||
|   auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2[0]); | ||||
|   auto consumer3a = ops::Identity(scope.WithOpName("consumer3a"), out3[0]); | ||||
|   auto consumer3b = ops::Identity(scope.WithOpName("consumer3b"), out3[1]); | ||||
|   auto consumer4a = ops::Identity(scope.WithOpName("consumer4a"), out4[0]); | ||||
|   auto consumer4b = ops::Identity(scope.WithOpName("consumer4b"), out4[1]); | ||||
| 
 | ||||
|   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); | ||||
|   TF_CHECK_OK(scope.ToGraph(graph.get())); | ||||
|   return graph; | ||||
| } | ||||
| 
 | ||||
| // Makes an encapsulate body graph for use in tests.
 | ||||
| static std::unique_ptr<Graph> MakeBodyGraph() { | ||||
|   Scope scope = Scope::NewRootScope().ExitOnError(); | ||||
| 
 | ||||
|   auto arg0 = ops::_Arg(scope.WithOpName("in0_0_arg"), DT_INT32, 0); | ||||
|   auto arg1 = ops::_Arg(scope.WithOpName("in1_0_arg"), DT_FLOAT, 1); | ||||
|   auto arg2 = ops::_Arg(scope.WithOpName("in2_0_arg"), DT_RESOURCE, 2); | ||||
|   auto arg3 = ops::_Arg(scope.WithOpName("in3_0_arg"), DT_RESOURCE, 3); | ||||
|   auto arg4 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 4); | ||||
|   auto arg5 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 5); | ||||
|   auto arg6 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 6); | ||||
|   auto arg7 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 7); | ||||
| 
 | ||||
|   auto add_attrs = [](Node* node) { | ||||
|     node->AddAttr(kTPUReplicateAttr, "replicate0"); | ||||
|   }; | ||||
| 
 | ||||
|   string device = | ||||
|       tensorflow::strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE); | ||||
| 
 | ||||
|   auto in1_identity = | ||||
|       ops::Identity(scope.WithOpName("In1_identity").WithDevice(device), arg1); | ||||
| 
 | ||||
|   auto read_u = ops::ReadVariableOp( | ||||
|       scope.WithOpName("ReadU").WithDevice(device), arg2, DT_FLOAT); | ||||
|   add_attrs(read_u.node()); | ||||
|   auto read_z = ops::ReadVariableOp( | ||||
|       scope.WithOpName("ReadZ").WithDevice(device), arg3, DT_FLOAT); | ||||
|   add_attrs(read_z.node()); | ||||
|   auto read_v = ops::ReadVariableOp( | ||||
|       scope.WithOpName("ReadV").WithDevice(device), arg6, DT_FLOAT); | ||||
|   add_attrs(read_v.node()); | ||||
|   auto read_w = ops::ReadVariableOp( | ||||
|       scope.WithOpName("ReadW").WithDevice(device), arg7, DT_FLOAT); | ||||
|   add_attrs(read_w.node()); | ||||
| 
 | ||||
|   auto e = ops::Add(scope.WithOpName("E").WithDevice(device), arg0, arg4); | ||||
|   add_attrs(e.node()); | ||||
|   auto f = ops::Add(scope.WithOpName("F").WithDevice(device), read_v, read_w); | ||||
|   add_attrs(f.node()); | ||||
|   auto g = ops::Add(scope.WithOpName("G").WithDevice(device), f, arg5); | ||||
|   add_attrs(g.node()); | ||||
| 
 | ||||
|   auto arg8 = ops::_Arg(scope.WithOpName("x_0_arg"), DT_DOUBLE, 8); | ||||
|   auto arg9 = ops::_Arg(scope.WithOpName("y_0_arg"), DT_DOUBLE, 9); | ||||
|   arg8.node()->AddAttr("_is_guaranteed_constant", true); | ||||
|   arg9.node()->AddAttr("_is_guaranteed_constant", true); | ||||
|   auto h = ops::Add(scope.WithOpName("H").WithDevice(device), arg8, arg9); | ||||
|   add_attrs(h.node()); | ||||
| 
 | ||||
|   auto out0 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 0); | ||||
|   auto out1 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 1); | ||||
|   auto out2 = ops::_Retval(scope.WithOpName("in1_identity_0_retval_RetVal"), | ||||
|                            in1_identity, 2); | ||||
|   auto out3 = | ||||
|       ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); | ||||
|   auto out4 = | ||||
|       ops::_Retval(scope.WithOpName("readz_0_retval_RetVal"), read_z, 4); | ||||
| 
 | ||||
|   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); | ||||
|   TF_CHECK_OK(scope.ToGraph(graph.get())); | ||||
|   return graph; | ||||
| } | ||||
| 
 | ||||
| TEST(EncapsulateTPUComputations, DeterministicEncapsulate) { | ||||
|   // Test that control edge insertion order doesn't affect the cache key
 | ||||
|   // (cluster name) generated by TPU encapsulate pass.
 | ||||
|   auto get_serialized_graph = [](bool control_input_reversed, | ||||
|                                  bool operand_reversed) -> string { | ||||
|     FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); | ||||
|     std::unique_ptr<Graph> graph(new Graph(&flib_def)); | ||||
|     { | ||||
|       Scope scope = Scope::NewRootScope().ExitOnError(); | ||||
|       auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); | ||||
|       auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); | ||||
| 
 | ||||
|       ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1) | ||||
|                                     : ops::Add(scope.WithOpName("E"), a1, a0); | ||||
| 
 | ||||
|       auto metadata = ops::TPUReplicateMetadata(scope, /*num_replicas=*/2); | ||||
|       auto add_attrs = [](Node* node) { | ||||
|         node->AddAttr(kTPUReplicateAttr, "replicate0"); | ||||
|       }; | ||||
|       add_attrs(metadata.operation.node()); | ||||
|       add_attrs(e.node()); | ||||
| 
 | ||||
|       TF_CHECK_OK(scope.ToGraph(graph.get())); | ||||
|       auto get_node_in_graph = [&graph](Node* node) { | ||||
|         return graph->FindNodeId(node->id()); | ||||
|       }; | ||||
|       // Insert control edge in different order. The order should not affect
 | ||||
|       // the encapsulated or serialized graph.
 | ||||
|       if (!control_input_reversed) { | ||||
|         graph->AddControlEdge(get_node_in_graph(a0.node()), | ||||
|                               get_node_in_graph(e.node()), true); | ||||
|         graph->AddControlEdge(get_node_in_graph(a1.node()), | ||||
|                               get_node_in_graph(e.node()), true); | ||||
|       } else { | ||||
|         graph->AddControlEdge(get_node_in_graph(a1.node()), | ||||
|                               get_node_in_graph(e.node()), true); | ||||
|         graph->AddControlEdge(get_node_in_graph(a0.node()), | ||||
|                               get_node_in_graph(e.node()), true); | ||||
|       } | ||||
|     } | ||||
|     TF_CHECK_OK(EncapsulateTPUComputationsPass::Encapsulate(&graph, &flib_def)); | ||||
|     GraphDef gdef; | ||||
|     graph->ToGraphDef(&gdef); | ||||
|     // Before serialization, sort control inputs first to remove
 | ||||
|     // nondeterminism.
 | ||||
|     SortControlInputs(&gdef); | ||||
|     string serialized; | ||||
|     SerializeToStringDeterministic(gdef, &serialized); | ||||
|     return serialized; | ||||
|   }; | ||||
| 
 | ||||
|   // Changing the order of control input shouldn't affect the graph generated.
 | ||||
|   EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true, | ||||
|                                  /*operand_reversed=*/false), | ||||
|             get_serialized_graph(/*control_input_reversed=*/false, | ||||
|                                  /*operand_reversed=*/false)); | ||||
| 
 | ||||
|   // Changing the order of data input should affect the graph generated.
 | ||||
|   EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false, | ||||
|                                  /*operand_reversed=*/true), | ||||
|             get_serialized_graph(/*control_input_reversed=*/false, | ||||
|                                  /*operand_reversed=*/false)); | ||||
| } | ||||
| 
 | ||||
| TEST(EncapsulateTPUComputations, Encapsulate) { | ||||
|   FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); | ||||
|   std::unique_ptr<Graph> graph(new Graph(&flib_def)); | ||||
|   { | ||||
|     Scope scope = Scope::NewRootScope().ExitOnError(); | ||||
|     auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); | ||||
|     auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); | ||||
|     auto b0 = ops::Placeholder(scope.WithOpName("B0"), DT_FLOAT); | ||||
|     auto b1 = ops::Placeholder(scope.WithOpName("B1"), DT_FLOAT); | ||||
|     auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); | ||||
|     auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); | ||||
|     auto u0 = ops::Placeholder(scope.WithOpName("U0"), DT_RESOURCE); | ||||
|     auto u1 = ops::Placeholder(scope.WithOpName("U1"), DT_RESOURCE); | ||||
|     auto z = ops::Placeholder(scope.WithOpName("Z"), DT_RESOURCE); | ||||
|     auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); | ||||
|     auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); | ||||
|     auto x = ops::GuaranteeConst( | ||||
|         scope.WithOpName("X"), | ||||
|         ops::Placeholder(scope.WithOpName("X_Holder"), DT_DOUBLE)); | ||||
|     auto y = ops::GuaranteeConst( | ||||
|         scope.WithOpName("Y"), | ||||
|         ops::Placeholder(scope.WithOpName("Y_Holder"), DT_DOUBLE)); | ||||
| 
 | ||||
|     auto in0 = ops::TPUReplicatedInput(scope.WithOpName("In0"), | ||||
|                                        std::initializer_list<Input>{a0, a1}); | ||||
|     auto in1 = ops::TPUReplicatedInput(scope.WithOpName("In1"), | ||||
|                                        std::initializer_list<Input>{b0, b1}); | ||||
|     auto in2 = ops::TPUReplicatedInput(scope.WithOpName("In2"), | ||||
|                                        std::initializer_list<Input>{u0, u1}); | ||||
|     auto in3 = ops::TPUReplicatedInput(scope.WithOpName("In3"), | ||||
|                                        std::initializer_list<Input>{z}); | ||||
|     in3.node()->AddAttr("is_packed", true); | ||||
| 
 | ||||
|     auto add_attrs = [](Node* node) { | ||||
|       node->AddAttr(kTPUReplicateAttr, "replicate0"); | ||||
|     }; | ||||
|     auto metadata = ops::TPUReplicateMetadata( | ||||
|         scope, /*num_replicas=*/2, | ||||
|         ops::TPUReplicateMetadata::ComputationShape({2, 3})); | ||||
|     add_attrs(metadata.operation.node()); | ||||
| 
 | ||||
|     auto in1_identity = ops::Identity(scope.WithOpName("In1_identity"), in1); | ||||
|     add_attrs(in1_identity.node()); | ||||
| 
 | ||||
|     auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), in2, DT_FLOAT); | ||||
|     add_attrs(read_u.node()); | ||||
|     auto read_z = ops::ReadVariableOp(scope.WithOpName("ReadZ"), in3, DT_FLOAT); | ||||
|     add_attrs(read_z.node()); | ||||
|     auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT); | ||||
|     add_attrs(read_v.node()); | ||||
|     auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT); | ||||
|     add_attrs(read_w.node()); | ||||
| 
 | ||||
|     auto e = ops::Add(scope.WithOpName("E"), in0, c); | ||||
|     add_attrs(e.node()); | ||||
|     auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); | ||||
|     add_attrs(f.node()); | ||||
|     auto g = ops::Add(scope.WithOpName("G"), f, d); | ||||
|     add_attrs(g.node()); | ||||
|     auto h = ops::Add(scope.WithOpName("H"), x, y); | ||||
|     add_attrs(h.node()); | ||||
| 
 | ||||
|     auto out0 = ops::TPUReplicatedOutput(scope.WithOpName("Out0"), e, 2); | ||||
|     auto out1 = ops::TPUReplicatedOutput(scope.WithOpName("Out1"), g, 2); | ||||
|     auto out2 = | ||||
|         ops::TPUReplicatedOutput(scope.WithOpName("Out2"), in1_identity, 2); | ||||
|     auto out3 = ops::TPUReplicatedOutput(scope.WithOpName("Out3"), read_u, 2); | ||||
|     auto out4 = ops::TPUReplicatedOutput(scope.WithOpName("Out4"), read_z, 2); | ||||
| 
 | ||||
|     auto consumer0_0a = | ||||
|         ops::Identity(scope.WithOpName("consumer0_0a"), out0[0]); | ||||
|     auto consumer0_0b = | ||||
|         ops::Identity(scope.WithOpName("consumer0_0b"), out0[0]); | ||||
|     auto consumer0_1 = ops::Identity(scope.WithOpName("consumer0_1"), out0[1]); | ||||
|     auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1[1]); | ||||
|     auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2[0]); | ||||
|     auto consumer3a = ops::Identity(scope.WithOpName("consumer3a"), out3[0]); | ||||
|     auto consumer3b = ops::Identity(scope.WithOpName("consumer3b"), out3[1]); | ||||
|     auto consumer4a = ops::Identity(scope.WithOpName("consumer4a"), out4[0]); | ||||
|     auto consumer4b = ops::Identity(scope.WithOpName("consumer4b"), out4[1]); | ||||
|     TF_ASSERT_OK(scope.ToGraph(graph.get())); | ||||
|   } | ||||
| 
 | ||||
|   std::unique_ptr<Graph> graph_copy(new Graph(&flib_def)); | ||||
|   CopyGraph(*graph, graph_copy.get()); | ||||
| 
 | ||||
|   TF_ASSERT_OK(EncapsulateTPUComputationsPass::Encapsulate(&graph, &flib_def)); | ||||
|   // Remove _xla_inferred_shapes attribute.
 | ||||
|   for (Node* n : graph->nodes()) { | ||||
|     n->ClearAttr("_xla_inferred_shapes"); | ||||
|   } | ||||
| 
 | ||||
|   std::unordered_map<string, Node*> index = graph->BuildNodeNameIndex(); | ||||
|   string function = index.at("replicate0")->type_string(); | ||||
| 
 | ||||
|   // Tests the outer graph is as expected.
 | ||||
|   { | ||||
|     std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function); | ||||
|     GraphDef expected_def; | ||||
|     outer->ToGraphDef(&expected_def); | ||||
| 
 | ||||
|     GraphDef actual_def; | ||||
|     graph->ToGraphDef(&actual_def); | ||||
|     TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def); | ||||
|   } | ||||
| 
 | ||||
|   // Tests the encapsulated body graph is as expected.
 | ||||
|   { | ||||
|     std::unique_ptr<Graph> body = MakeBodyGraph(); | ||||
|     GraphDef expected_body_def; | ||||
|     body->ToGraphDef(&expected_body_def); | ||||
| 
 | ||||
|     InstantiationResultForTest result; | ||||
|     TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result)); | ||||
| 
 | ||||
|     EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_RESOURCE, DT_RESOURCE, | ||||
|                               DT_INT32, DT_FLOAT, DT_RESOURCE, DT_RESOURCE, | ||||
|                               DT_DOUBLE, DT_DOUBLE}), | ||||
|               result.arg_types); | ||||
|     EXPECT_EQ( | ||||
|         (DataTypeVector{DT_INT32, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}), | ||||
|         result.ret_types); | ||||
|     TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef); | ||||
|   } | ||||
| 
 | ||||
|   // Encapsulates the same computation again, verifies we reuse the same
 | ||||
|   // function. Encapsulation should be deterministic to avoid recompilation.
 | ||||
|   TF_ASSERT_OK( | ||||
|       EncapsulateTPUComputationsPass::Encapsulate(&graph_copy, &flib_def)); | ||||
|   std::unordered_map<string, Node*> index_copy = | ||||
|       graph_copy->BuildNodeNameIndex(); | ||||
|   string function_copy = index_copy.at("replicate0")->type_string(); | ||||
|   EXPECT_EQ(function, function_copy); | ||||
| } | ||||
| 
 | ||||
| TEST(EncapsulateTPUComputations, BuildTPUReplicateOps) { | ||||
|   std::unique_ptr<Graph> body_graph = MakeBodyGraph(); | ||||
|   FunctionDefLibrary flib; | ||||
|   TF_ASSERT_OK( | ||||
|       GraphToFunctionDef(*body_graph, "replicate0", flib.add_function())); | ||||
| 
 | ||||
|   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); | ||||
| 
 | ||||
|   std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "replicate0"); | ||||
|   TF_ASSERT_OK( | ||||
|       EncapsulateTPUComputationsPass::BuildTPUReplicateOps(graph.get())); | ||||
| 
 | ||||
|   Scope scope = Scope::NewRootScope().ExitOnError(); | ||||
|   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); | ||||
| 
 | ||||
|   auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); | ||||
|   auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); | ||||
|   auto b0 = ops::Placeholder(scope.WithOpName("B0"), DT_FLOAT); | ||||
|   auto b1 = ops::Placeholder(scope.WithOpName("B1"), DT_FLOAT); | ||||
|   auto u0 = ops::Placeholder(scope.WithOpName("U0"), DT_RESOURCE); | ||||
|   auto u1 = ops::Placeholder(scope.WithOpName("U1"), DT_RESOURCE); | ||||
|   auto z = ops::Placeholder(scope.WithOpName("Z"), DT_RESOURCE); | ||||
|   auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); | ||||
|   auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); | ||||
|   auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); | ||||
|   auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); | ||||
|   auto x = | ||||
|       ops::Identity(scope.WithOpName("X"), | ||||
|                     ops::Placeholder(scope.WithOpName("X_Holder"), DT_DOUBLE)); | ||||
|   auto y = | ||||
|       ops::Identity(scope.WithOpName("Y"), | ||||
|                     ops::Placeholder(scope.WithOpName("Y_Holder"), DT_DOUBLE)); | ||||
| 
 | ||||
|   NameAttrList function; | ||||
|   function.set_name("replicate0"); | ||||
|   auto replicate = ops::_TPUReplicate( | ||||
|       scope.WithOpName("replicate0"), | ||||
|       std::initializer_list<Input>{a0, b0, u0, a1, b1, u1, z}, | ||||
|       std::initializer_list<Input>{c, d}, std::initializer_list<Input>{v, w}, | ||||
|       std::initializer_list<Input>{x, y}, function, | ||||
|       /*num_replicas=*/2, | ||||
|       {DT_INT32, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32, DT_FLOAT, | ||||
|        DT_FLOAT, DT_FLOAT, DT_FLOAT}, | ||||
|       ops::_TPUReplicate::NumCoresPerReplica(6).NumDistributedVariables(1)); | ||||
| 
 | ||||
|   auto consumer0_0a = | ||||
|       ops::Identity(scope.WithOpName("consumer0_0a"), replicate.outputs[0]); | ||||
|   auto consumer0_0b = | ||||
|       ops::Identity(scope.WithOpName("consumer0_0b"), replicate.outputs[0]); | ||||
|   auto consumer0_1 = | ||||
|       ops::Identity(scope.WithOpName("consumer0_1"), replicate.outputs[5]); | ||||
|   auto consumer1 = | ||||
|       ops::Identity(scope.WithOpName("consumer1"), replicate.outputs[6]); | ||||
|   auto consumer2 = | ||||
|       ops::Identity(scope.WithOpName("consumer2"), replicate.outputs[2]); | ||||
|   auto consumer3a = | ||||
|       ops::Identity(scope.WithOpName("consumer3a"), replicate.outputs[3]); | ||||
|   auto consumer3b = | ||||
|       ops::Identity(scope.WithOpName("consumer3b"), replicate.outputs[8]); | ||||
|   auto consumer4a = | ||||
|       ops::Identity(scope.WithOpName("consumer4a"), replicate.outputs[4]); | ||||
|   auto consumer4b = | ||||
|       ops::Identity(scope.WithOpName("consumer4b"), replicate.outputs[9]); | ||||
| 
 | ||||
|   GraphDef expected_def; | ||||
|   TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); | ||||
| 
 | ||||
|   GraphDef actual_def; | ||||
|   graph->ToGraphDef(&actual_def); | ||||
|   TF_EXPECT_GRAPH_EQ(expected_def, actual_def); | ||||
| } | ||||
| 
 | ||||
| class ExtractOutsideCompilationByScope : public ::testing::TestWithParam<bool> { | ||||
| }; | ||||
| 
 | ||||
| Status PivotControlExists(const Node* node, const Node* pivot) { | ||||
|   for (const Edge* edge : node->in_edges()) { | ||||
|     if (edge->IsControlEdge() && (edge->src() == pivot)) { | ||||
|       return Status::OK(); | ||||
|     } | ||||
|   } | ||||
|   return errors::NotFound("Control edge with pivot not found."); | ||||
| } | ||||
| 
 | ||||
| TEST_P(ExtractOutsideCompilationByScope, | ||||
|        MoveHeadAndTailOutsideCompilationToHost) { | ||||
|   FunctionLibraryDefinition fld(OpRegistry::Global(), FunctionDefLibrary()); | ||||
| 
 | ||||
|   // Create FunctionLibraryRuntime.
 | ||||
|   SessionOptions session_options; | ||||
|   std::vector<std::unique_ptr<Device>> devices; | ||||
|   TF_CHECK_OK(DeviceFactory::AddDevices( | ||||
|       session_options, "/job:localhost/replica:0/task:0", &devices)); | ||||
|   OptimizerOptions opts; | ||||
|   auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices)); | ||||
|   auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( | ||||
|       device_mgr.get(), Env::Default(), /*config=*/nullptr, | ||||
|       TF_GRAPH_DEF_VERSION, &fld, opts, | ||||
|       /*default_thread_pool=*/nullptr); | ||||
|   auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); | ||||
| 
 | ||||
|   { | ||||
|     // Build TPU replicate function.
 | ||||
|     // arg0 = _Arg[index = 0, T = DT_STRING]
 | ||||
|     // arg1 = _Arg[index = 1, T = DT_INT32]
 | ||||
|     // arg2 = _Arg[index = 2, T = DT_RESOURCE]
 | ||||
|     // as_int = StringToNumber[out_type = DT_INT32](arg0)     (oc node)
 | ||||
|     // add = Add(as_int, arg1)
 | ||||
|     // as_string = AsString(add)                              (oc node)
 | ||||
|     // read_var = ops::ReadVariableOp(arg2)
 | ||||
|     // ret0 = _RetVal[index = 0, T = DT_STRING](as_string)
 | ||||
|     // ret1 = _RetVal[index = 1, T = DT_INT32](add)
 | ||||
|     // ret2 = _RetVal[index = 1, T = DT_FLOAT](read_var)
 | ||||
|     Scope s = Scope::NewRootScope().ExitOnError(); | ||||
|     auto arg0 = ops::_Arg(s.WithOpName("arg0"), DT_STRING, 0); | ||||
|     auto arg1 = ops::_Arg(s.WithOpName("arg1"), DT_INT32, 1); | ||||
|     auto arg2 = ops::_Arg(s.WithOpName("arg2"), DT_RESOURCE, 2); | ||||
|     auto as_int = ops::StringToNumber(s.WithOpName("as_int"), arg0, | ||||
|                                       ops::StringToNumber::OutType(DT_INT32)); | ||||
|     auto add = ops::Add(s.WithOpName("add"), as_int, arg1); | ||||
|     auto as_string = ops::AsString(s.WithOpName("as_string"), add); | ||||
|     auto read_var = | ||||
|         ops::ReadVariableOp(s.WithOpName("ReadVar"), arg2, DT_FLOAT); | ||||
|     auto ret0 = ops::_Retval(s.WithOpName("ret0"), as_string, 0); | ||||
|     auto ret1 = ops::_Retval(s.WithOpName("ret1"), add, 1); | ||||
|     auto ret2 = ops::_Retval(s.WithOpName("ret2"), read_var, 2); | ||||
|     Graph g(OpRegistry::Global()); | ||||
|     TF_ASSERT_OK(s.ToGraph(&g)); | ||||
|     auto node_name_index = g.BuildNodeNameIndex(); | ||||
|     node_name_index["as_int"]->AddAttr("oc", "0"); | ||||
|     node_name_index["as_string"]->AddAttr("oc", "0"); | ||||
|     FunctionDef fdef; | ||||
|     TF_ASSERT_OK(GraphToFunctionDef(g, "cluster", &fdef)); | ||||
|     TF_ASSERT_OK(fld.AddFunctionDef(fdef)); | ||||
|   } | ||||
| 
 | ||||
|   string control_flow_scope = GetParam() ? "scope/" : ""; | ||||
|   string pivot_name = absl::StrCat(control_flow_scope, "tpu_replicate/pivot"); | ||||
|   Graph host_graph(OpRegistry::Global()); | ||||
|   NameAttrList function; | ||||
|   function.set_name("cluster"); | ||||
|   { | ||||
|     // Build host graph.
 | ||||
|     // input00 = Placeholder[T = DT_STRING]
 | ||||
|     // input01 = Placeholder[T = DT_INT32]
 | ||||
|     // input10 = Placeholder[T = DT_STRING]
 | ||||
|     // input11 = Placeholder[T = DT_INT32]
 | ||||
|     // input2 = Placeholder[T = DT_RESOURCE]
 | ||||
|     // tpu_replicate = _TPUReplicate(input00, input01, input10, input11)
 | ||||
|     // output = IdentityN(tpu_replicate, tpu_replicate:1, tpu_replicate:2,
 | ||||
|     //                    tpu_replicate:3, tpu_replicate:4, tpu_replicate:5)
 | ||||
|     Scope s = Scope::NewRootScope().ExitOnError(); | ||||
|     auto pivot = ops::NoOp(s.WithOpName(pivot_name)); | ||||
|     pivot.operation.node()->AddAttr("_pivot_for_cluster", "cluster"); | ||||
|     auto input00 = ops::Placeholder(s.WithOpName("input00"), DT_STRING); | ||||
|     auto input01 = ops::Placeholder(s.WithOpName("input01"), DT_INT32); | ||||
|     auto input10 = ops::Placeholder(s.WithOpName("input10"), DT_STRING); | ||||
|     auto input11 = ops::Placeholder(s.WithOpName("input11"), DT_INT32); | ||||
|     auto input2 = ops::Placeholder(s.WithOpName("input2"), DT_RESOURCE); | ||||
|     auto control_scope = s.WithControlDependencies({pivot}); | ||||
|     auto replicate = ops::_TPUReplicate( | ||||
|         control_scope.WithOpName("tpu_replicate"), | ||||
|         std::initializer_list<Input>{input00, input01, input10, input11, | ||||
|                                      input2}, | ||||
|         std::initializer_list<Input>{}, std::initializer_list<Input>{}, | ||||
|         std::initializer_list<Input>{}, function, | ||||
|         /*num_replicas=*/2, | ||||
|         {DT_STRING, DT_INT32, DT_FLOAT, DT_STRING, DT_INT32, DT_FLOAT}, | ||||
|         ops::_TPUReplicate::NumCoresPerReplica(1).NumDistributedVariables(1)); | ||||
|     auto output = ops::IdentityN( | ||||
|         s.WithOpName("output"), | ||||
|         std::initializer_list<Input>{ | ||||
|             replicate.outputs[0], replicate.outputs[1], replicate.outputs[2], | ||||
|             replicate.outputs[3], replicate.outputs[4], replicate.outputs[5]}); | ||||
|     TF_ASSERT_OK(s.ToGraph(&host_graph)); | ||||
|   } | ||||
|   auto node_name_index = host_graph.BuildNodeNameIndex(); | ||||
|   Node* replicate_node = node_name_index["tpu_replicate"]; | ||||
| 
 | ||||
|   std::unordered_map<string, XlaClusterInfo> clusters; | ||||
|   clusters.emplace("cluster", | ||||
|                    XlaClusterInfo{"cluster", function, replicate_node, | ||||
|                                   std::map<string, int>{}}); | ||||
|   int lifted_arg_count = 0; | ||||
|   TF_ASSERT_OK(ExtractOutsideCompilationPass::ProcessHeadTailOutsideCompilation( | ||||
|       "oc", &lifted_arg_count, &clusters, &host_graph, flr, &fld)); | ||||
|   node_name_index = host_graph.BuildNodeNameIndex(); | ||||
|   replicate_node = node_name_index["tpu_replicate"]; | ||||
| 
 | ||||
|   { | ||||
|     // Check host graph.
 | ||||
|     const Edge* e; | ||||
|     Node* pivot = node_name_index[pivot_name]; | ||||
|     // Check that we have input00 -> as_int/R0 -> tpu_replicate.
 | ||||
|     Node* as_int_R0 = node_name_index["as_int_head_oc/R0"]; | ||||
|     EXPECT_NE(as_int_R0, nullptr); | ||||
|     TF_ASSERT_OK(as_int_R0->input_edge(0, &e)); | ||||
|     EXPECT_EQ(e->src(), node_name_index["input00"]); | ||||
|     TF_ASSERT_OK(replicate_node->input_edge(1, &e)); | ||||
|     EXPECT_EQ(e->src(), as_int_R0); | ||||
|     // Check that as_int/R0 has pivot as control input
 | ||||
|     TF_EXPECT_OK(PivotControlExists(as_int_R0, pivot)); | ||||
|     // Check that we have input10 -> as_int/R1 -> tpu_replicate.
 | ||||
|     Node* as_int_R1 = node_name_index["as_int_head_oc/R1"]; | ||||
|     EXPECT_NE(as_int_R1, nullptr); | ||||
|     TF_ASSERT_OK(as_int_R1->input_edge(0, &e)); | ||||
|     EXPECT_EQ(e->src(), node_name_index["input10"]); | ||||
|     TF_ASSERT_OK(replicate_node->input_edge(3, &e)); | ||||
|     EXPECT_EQ(e->src(), as_int_R1); | ||||
|     // Check that as_int/R0 has pivot as control input
 | ||||
|     TF_EXPECT_OK(PivotControlExists(as_int_R1, pivot)); | ||||
|     // Check that we have tpu_replicate -> as_string/R0 -> output.
 | ||||
|     Node* as_string_R0 = node_name_index["as_string_tail_oc/R0"]; | ||||
|     EXPECT_NE(as_string_R0, nullptr); | ||||
|     TF_ASSERT_OK(as_string_R0->input_edge(0, &e)); | ||||
|     EXPECT_EQ(e->src(), replicate_node); | ||||
|     TF_ASSERT_OK(node_name_index["output"]->input_edge(0, &e)); | ||||
|     EXPECT_EQ(e->src(), as_string_R0); | ||||
|     // Check that as_string/R0 has pivot as control input
 | ||||
|     TF_EXPECT_OK(PivotControlExists(as_string_R0, pivot)); | ||||
|     // Check that we have tpu_replicate -> as_string/R1 -> output.
 | ||||
|     Node* as_string_R1 = node_name_index["as_string_tail_oc/R1"]; | ||||
|     EXPECT_NE(as_string_R1, nullptr); | ||||
|     TF_ASSERT_OK(as_string_R1->input_edge(0, &e)); | ||||
|     EXPECT_EQ(e->src(), replicate_node); | ||||
|     TF_ASSERT_OK(node_name_index["output"]->input_edge(3, &e)); | ||||
|     EXPECT_EQ(e->src(), as_string_R1); | ||||
|     // Check that as_string/R1 has pivot as control input
 | ||||
|     TF_EXPECT_OK(PivotControlExists(as_string_R1, pivot)); | ||||
|   } | ||||
| 
 | ||||
|   { | ||||
|     // Check TPU graph.
 | ||||
|     const FunctionDef* fdef = fld.Find("cluster"); | ||||
|     EXPECT_NE(fdef, nullptr); | ||||
|     // Check its signature, should have 2 DT_INT32 inputs, 1 DT_RESOURCE input,
 | ||||
|     // 2 DT_INT32 outputs and 1 DT_FLOAT output.
 | ||||
|     EXPECT_EQ(fdef->signature().input_arg_size(), 3); | ||||
|     EXPECT_EQ(fdef->signature().input_arg(0).type(), DT_INT32); | ||||
|     EXPECT_EQ(fdef->signature().input_arg(1).type(), DT_INT32); | ||||
|     EXPECT_EQ(fdef->signature().input_arg(2).type(), DT_RESOURCE); | ||||
|     EXPECT_EQ(fdef->signature().output_arg_size(), 3); | ||||
|     EXPECT_EQ(fdef->signature().output_arg(0).type(), DT_INT32); | ||||
|     EXPECT_EQ(fdef->signature().output_arg(1).type(), DT_FLOAT); | ||||
|     EXPECT_EQ(fdef->signature().output_arg(2).type(), DT_INT32); | ||||
|     // Check that it has no StringToNumber/AsString op any more.
 | ||||
|     for (const NodeDef& node_def : fdef->node_def()) { | ||||
|       EXPECT_NE(node_def.op(), "StringToNumber"); | ||||
|       EXPECT_NE(node_def.op(), "AsString"); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| INSTANTIATE_TEST_SUITE_P(All, ExtractOutsideCompilationByScope, | ||||
|                          ::testing::ValuesIn({true, false})); | ||||
| 
 | ||||
| TEST(ExtractOutsideCompilation, RemoveArgRetvalPair) { | ||||
|   FunctionLibraryDefinition fld(OpRegistry::Global(), FunctionDefLibrary()); | ||||
| 
 | ||||
|   // Create FunctionLibraryRuntime.
 | ||||
|   SessionOptions session_options; | ||||
|   std::vector<std::unique_ptr<Device>> devices; | ||||
|   TF_CHECK_OK(DeviceFactory::AddDevices( | ||||
|       session_options, "/job:localhost/replica:0/task:0", &devices)); | ||||
|   OptimizerOptions opts; | ||||
|   auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices)); | ||||
|   auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( | ||||
|       device_mgr.get(), Env::Default(), /*config=*/nullptr, | ||||
|       TF_GRAPH_DEF_VERSION, &fld, opts, | ||||
|       /*default_thread_pool=*/nullptr); | ||||
|   auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); | ||||
| 
 | ||||
|   { | ||||
|     // Build TPU replicate function.
 | ||||
|     // arg0 = _Arg[index = 0, T = DT_STRING]
 | ||||
|     // arg1 = _Arg[index = 1, T = DT_FLOAT]
 | ||||
|     // arg2 = _Arg[index = 2, T = DT_INT32]
 | ||||
|     // arg3 = _Arg[index = 3, T = DT_RESOURCE]
 | ||||
|     // arg4 = _Arg[index = 4, T = DT_RESOURCE]
 | ||||
|     // add = Add(arg2, arg2)
 | ||||
|     // read = ReadVariableOp(arg4)
 | ||||
|     // ret0 = _RetVal[index = 0, T = DT_STRING](arg0)
 | ||||
|     // ret1 = _RetVal[index = 1, T = DT_INT32](add)
 | ||||
|     // ret2 = _RetVal[index = 2, T = DT_FLOAT](read)
 | ||||
|     // ret3 = _RetVal[index = 3, T = DT_RESOURCE](arg3)
 | ||||
|     Scope s = Scope::NewRootScope().ExitOnError(); | ||||
|     auto arg0 = ops::_Arg(s.WithOpName("arg0"), DT_STRING, 0); | ||||
|     auto arg1 = ops::_Arg(s.WithOpName("arg1"), DT_FLOAT, 1); | ||||
|     auto arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2); | ||||
|     auto arg3 = ops::_Arg(s.WithOpName("arg3"), DT_RESOURCE, 3); | ||||
|     auto arg4 = ops::_Arg(s.WithOpName("arg4"), DT_RESOURCE, 4); | ||||
|     auto add = ops::Add(s.WithOpName("add"), arg2, arg2); | ||||
|     auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg0, 0); | ||||
|     auto ret1 = ops::_Retval(s.WithOpName("ret1"), add, 1); | ||||
|     auto read = ops::ReadVariableOp(s.WithOpName("read"), arg4, DT_FLOAT); | ||||
|     auto ret2 = ops::_Retval(s.WithOpName("ret2"), read, 2); | ||||
|     auto ret3 = ops::_Retval(s.WithOpName("ret3"), arg3, 3); | ||||
|     Graph g(OpRegistry::Global()); | ||||
|     TF_ASSERT_OK(s.ToGraph(&g)); | ||||
|     FunctionDef fdef; | ||||
|     TF_ASSERT_OK(GraphToFunctionDef(g, "cluster", &fdef)); | ||||
|     TF_ASSERT_OK(fld.AddFunctionDef(fdef)); | ||||
|   } | ||||
| 
 | ||||
|   Graph host_graph(OpRegistry::Global()); | ||||
|   NameAttrList function; | ||||
|   function.set_name("cluster"); | ||||
|   { | ||||
|     // Build host graph.
 | ||||
|     // input00 = Placeholder[T = DT_STRING]
 | ||||
|     // input01 = Placeholder[T = DT_FLOAT]
 | ||||
|     // input02 = Placeholder[T = DT_INT32]
 | ||||
|     // input10 = Placeholder[T = DT_STRING]
 | ||||
|     // input11 = Placeholder[T = DT_FLOAT]
 | ||||
|     // input12 = Placeholder[T = DT_INT32]
 | ||||
|     // input3 = Placeholder[T = DT_RESOURCE], distributed variable
 | ||||
|     // input4 = Placeholder[T = DT_RESOURCE], distributed variable
 | ||||
|     // tpu_replicate = _TPUReplicate(input00, input01, input02, input10,
 | ||||
|     //                               input11, input12, input3, input4)
 | ||||
|     // output = IdentityN(tpu_replicate, tpu_replicate:1, tpu_replicate:2,
 | ||||
|     //                    tpu_replicate:3, tpu_replicate:4, tpu_replicate:5,
 | ||||
|     //                    tpu_replicate:6, tpu_replicate:7)
 | ||||
|     Scope s = Scope::NewRootScope().ExitOnError(); | ||||
|     auto input00 = ops::Placeholder(s.WithOpName("input00"), DT_STRING); | ||||
|     auto input01 = ops::Placeholder(s.WithOpName("input01"), DT_FLOAT); | ||||
|     auto input02 = ops::Placeholder(s.WithOpName("input02"), DT_INT32); | ||||
|     auto input10 = ops::Placeholder(s.WithOpName("input10"), DT_STRING); | ||||
|     auto input11 = ops::Placeholder(s.WithOpName("input11"), DT_FLOAT); | ||||
|     auto input12 = ops::Placeholder(s.WithOpName("input12"), DT_INT32); | ||||
|     auto input3 = ops::Placeholder(s.WithOpName("input3"), DT_RESOURCE); | ||||
|     auto input4 = ops::Placeholder(s.WithOpName("input3"), DT_RESOURCE); | ||||
|     auto replicate = ops::_TPUReplicate( | ||||
|         s.WithOpName("tpu_replicate"), | ||||
|         std::initializer_list<Input>{input00, input01, input02, input10, | ||||
|                                      input11, input12, input3, input4}, | ||||
|         std::initializer_list<Input>{}, std::initializer_list<Input>{}, | ||||
|         std::initializer_list<Input>{}, function, | ||||
|         /*num_replicas=*/2, | ||||
|         {DT_STRING, DT_INT32, DT_FLOAT, DT_RESOURCE, DT_STRING, DT_INT32, | ||||
|          DT_FLOAT, DT_RESOURCE}, | ||||
|         ops::_TPUReplicate::NumCoresPerReplica(1).NumDistributedVariables(2)); | ||||
|     auto output = ops::IdentityN( | ||||
|         s.WithOpName("output"), | ||||
|         std::initializer_list<Input>{ | ||||
|             replicate.outputs[0], replicate.outputs[1], replicate.outputs[2], | ||||
|             replicate.outputs[3], replicate.outputs[4], replicate.outputs[5], | ||||
|             replicate.outputs[6], replicate.outputs[7]}); | ||||
|     TF_ASSERT_OK(s.ToGraph(&host_graph)); | ||||
|   } | ||||
|   auto node_name_index = host_graph.BuildNodeNameIndex(); | ||||
|   Node* replicate_node = node_name_index["tpu_replicate"]; | ||||
| 
 | ||||
|   std::unordered_map<string, XlaClusterInfo> clusters; | ||||
|   clusters.emplace("cluster", | ||||
|                    XlaClusterInfo{"cluster", function, replicate_node, | ||||
|                                   std::map<string, int>{}}); | ||||
|   int lifted_arg_count = 0; | ||||
|   TF_ASSERT_OK(ExtractOutsideCompilationPass::ProcessHeadTailOutsideCompilation( | ||||
|       "oc", &lifted_arg_count, &clusters, &host_graph, flr, &fld)); | ||||
|   node_name_index = host_graph.BuildNodeNameIndex(); | ||||
|   replicate_node = node_name_index["tpu_replicate"]; | ||||
|   Node* output = node_name_index["output"]; | ||||
| 
 | ||||
|   EXPECT_EQ(replicate_node->num_inputs(), 3); | ||||
|   const DataTypeVector expected_input_types = {DT_INT32, DT_INT32, DT_RESOURCE}; | ||||
|   EXPECT_EQ(replicate_node->input_types(), expected_input_types); | ||||
|   EXPECT_EQ(replicate_node->num_outputs(), 4); | ||||
|   const DataTypeVector expected_output_types = {DT_INT32, DT_FLOAT, DT_INT32, | ||||
|                                                 DT_FLOAT}; | ||||
|   EXPECT_EQ(replicate_node->output_types(), expected_output_types); | ||||
| 
 | ||||
|   { | ||||
|     // Check host graph.
 | ||||
|     Node* input_node; | ||||
|     // Check that we have input00 -> output:1.
 | ||||
|     TF_ASSERT_OK(output->input_node(0, &input_node)); | ||||
|     EXPECT_EQ(input_node->name(), "input00"); | ||||
|     // Check that we have input10 -> output:4.
 | ||||
|     TF_ASSERT_OK(output->input_node(4, &input_node)); | ||||
|     EXPECT_EQ(input_node->name(), "input10"); | ||||
|     // Check that we have input3 -> output:3, output:7.
 | ||||
|     TF_ASSERT_OK(output->input_node(3, &input_node)); | ||||
|     EXPECT_EQ(input_node->name(), "input3"); | ||||
|     TF_ASSERT_OK(output->input_node(7, &input_node)); | ||||
|     EXPECT_EQ(input_node->name(), "input3"); | ||||
|   } | ||||
| 
 | ||||
|   { | ||||
|     // Check TPU graph.
 | ||||
|     const FunctionDef* fdef = fld.Find("cluster"); | ||||
|     EXPECT_NE(fdef, nullptr); | ||||
|     // Check its signature, should have 1 DT_INT32 input, 1 DT_RESOURCE input,
 | ||||
|     // 1 DT_INT32 output and 1 DT_FLOAT output
 | ||||
|     EXPECT_EQ(fdef->signature().input_arg_size(), 2); | ||||
|     EXPECT_EQ(fdef->signature().input_arg(0).type(), DT_INT32); | ||||
|     EXPECT_EQ(fdef->signature().input_arg(1).type(), DT_RESOURCE); | ||||
|     EXPECT_EQ(fdef->signature().output_arg_size(), 2); | ||||
|     EXPECT_EQ(fdef->signature().output_arg(0).type(), DT_INT32); | ||||
|     EXPECT_EQ(fdef->signature().output_arg(1).type(), DT_FLOAT); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
| @ -15,6 +15,7 @@ limitations under the License. | ||||
| 
 | ||||
| #include "tensorflow/core/common_runtime/optimization_registry.h" | ||||
| #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h" | ||||
| #include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h" | ||||
| #include "tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| @ -25,8 +26,10 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20, | ||||
|                       DistributedTPUConfigurationRewritePass); | ||||
| REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20, | ||||
|                       DistributedTPUShutdownRewritePass); | ||||
| REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 0, | ||||
|                       VariableMergerPass); | ||||
| REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 34, | ||||
|                       EncapsulateTPUComputationsPass); | ||||
| REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 39, | ||||
|                       ExtractOutsideCompilationPass); | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace tensorflow
 | ||||
|  | ||||
							
								
								
									
										38
									
								
								tensorflow/core/tpu/tpu_compile_interface.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								tensorflow/core/tpu/tpu_compile_interface.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,38 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/core/tpu/tpu_compile_interface.h" | ||||
| 
 | ||||
| #include "tensorflow/core/platform/fingerprint.h" | ||||
| #include "tensorflow/core/platform/logging.h" | ||||
| 
 | ||||
| class TpuCompileInterfaceExternal : public TpuCompileInterface { | ||||
|  public: | ||||
|   uint64_t FingerprintString(absl::string_view str) override { | ||||
|     return ::tensorflow::Fingerprint64(str); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| static TpuCompileInterface* impl_ = new TpuCompileInterfaceExternal; | ||||
| TpuCompileInterface* TpuCompileInterface::Get() { return impl_; } | ||||
| 
 | ||||
| bool TpuCompileInterface::RegisterImplementation(TpuCompileInterface* impl) { | ||||
|   VLOG(1) << "Updating TpuCompileInterface."; | ||||
|   if (impl_ != nullptr) { | ||||
|     delete impl_; | ||||
|   } | ||||
|   impl_ = impl; | ||||
|   return true; | ||||
| } | ||||
							
								
								
									
										33
									
								
								tensorflow/core/tpu/tpu_compile_interface.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								tensorflow/core/tpu/tpu_compile_interface.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| #ifndef TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_ | ||||
| #define TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_ | ||||
| 
 | ||||
| #include "absl/strings/string_view.h" | ||||
| 
 | ||||
| // Some legacy code requires different implementations for operations like
 | ||||
| // fingerprint/hashing during compilation and/or graph rewriting. These
 | ||||
| // alternate implementations can be registered (via a module initializer) to
 | ||||
| // change the default behavior.
 | ||||
| class TpuCompileInterface { | ||||
|  public: | ||||
|   virtual ~TpuCompileInterface() {} | ||||
|   static TpuCompileInterface* Get(); | ||||
|   static bool RegisterImplementation(TpuCompileInterface* impl); | ||||
| 
 | ||||
|   virtual uint64_t FingerprintString(absl::string_view str) = 0; | ||||
| }; | ||||
| 
 | ||||
| #endif  // TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_
 | ||||
| @ -24,4 +24,7 @@ const char* const DEVICE_TPU_XLA_JIT = "XLA_TPU_JIT"; | ||||
| const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR = | ||||
|     "_mirrored_variable_indices"; | ||||
| 
 | ||||
| const char* const kTPUReplicateAttr = "_tpu_replicate"; | ||||
| const char* const kOutsideCompilationAttr = "_xla_outside_compilation"; | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
|  | ||||
| @ -47,6 +47,9 @@ extern const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR; | ||||
| // variable.
 | ||||
| extern const char* const TPU_FAST_MEM_ATTR;  // "_TPU_FAST_MEM"
 | ||||
| 
 | ||||
| extern const char* const kTPUReplicateAttr; | ||||
| extern const char* const kOutsideCompilationAttr; | ||||
| 
 | ||||
| // Supported types for TPUs.
 | ||||
| static constexpr std::array<DataType, 11> kTpuAllTypes = { | ||||
|     {DT_INT32, DT_UINT32, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL, | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user