TPU rewrite pass refactoring.

PiperOrigin-RevId: 321237265
Change-Id: I33c4966fc038816be9d9e9efd936e68a8400cc1a
This commit is contained in:
Russell Power 2020-07-14 14:35:39 -07:00 committed by TensorFlower Gardener
parent 4f3a8d84cf
commit bdf8b75c14
11 changed files with 4010 additions and 2 deletions

View File

@ -27,6 +27,7 @@ package_group(
"//tensorflow/compiler/mlir/...",
"//tensorflow/compiler/tests/...",
"//tensorflow/compiler/tf2xla/...",
"//tensorflow/core/tpu/...",
"//tensorflow/python/compiler/...",
],
)

View File

@ -73,6 +73,17 @@ cc_library(
],
)
cc_library(
name = "tpu_compile_interface",
srcs = ["tpu_compile_interface.cc"],
hdrs = ["tpu_compile_interface.h"],
deps = [
"//tensorflow/core/platform:fingerprint",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "tpu_defs",
srcs = ["tpu_defs.cc"],

View File

@ -18,6 +18,7 @@ cc_library(
srcs = ["tpu_rewrite_pass_registration.cc"],
deps = [
":distributed_tpu_configuration_rewrite_pass",
":encapsulate_tpu_computations_pass",
":variable_merger_pass",
"//tensorflow/core:core_cpu",
],
@ -93,3 +94,56 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
cc_library(
name = "encapsulate_tpu_computations_pass",
srcs = [
"encapsulate_tpu_computations_pass.cc",
],
hdrs = [
"encapsulate_tpu_computations_pass.h",
],
deps = [
"//tensorflow/compiler/jit:compilation_passes",
"//tensorflow/compiler/jit:encapsulate_util",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/core/tpu:tpu_compile_interface",
"//tensorflow/core/tpu:tpu_defs",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "encapsulate_tpu_computations_pass_test",
srcs = ["encapsulate_tpu_computations_pass_test.cc"],
deps = [
":encapsulate_tpu_computations_pass",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:tpu_ops",
"//tensorflow/compiler/jit:compilation_passes",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_impl",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib_internal",
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/tpu:tpu_defs",
],
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Rewrites computations generated by the tpu.replicate() Python code into
// TPUReplicate operators.
//
// The tpu.replicate() does two main things:
// a) marks operators that make up a TPU computation with the attribute
// _tpu_replicate=XYZ, where XYZ is a unique key.
// b) adds TPUReplicatedInput and TPUReplicatedOutput nodes to represent
// replicated inputs. These nodes are not marked with the _tpu_replicate
// attribute.
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITES_ENCAPSULATE_TPU_COMPUTATIONS_PASS_H_
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITES_ENCAPSULATE_TPU_COMPUTATIONS_PASS_H_
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// Encapsulates nodes marked with the _tpu_replicate attribute into
// TPUReplicate operators.
class EncapsulateTPUComputationsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
// The following methods are public only for unit tests.
// This pass has two stages:
// a) first, we call the EncapsulateSubgraphsPass to encapsulate all nodes
// marked with the same _tpu_replicate attribute into functions. These
// functions contain the computations to be passed to TPUReplicate. During
// encapsulation, we sort the arguments into the order expected by
// TPUReplicate.
static Status Encapsulate(std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def);
// b) we rewrite the function calls generated in phase (a) into TPUReplicate
// operators. We also flatten the TPUReplicatedInput and
// TPUReplicatedOutput replicated input and output nodes of the function
// call into the replicated input and outputs of the TPUReplicate operator.
static Status BuildTPUReplicateOps(Graph* graph);
};
// Graph optimization pass that calls `ExtractOutsideCompilation` for all XLA
// computation nodes.
class ExtractOutsideCompilationPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
static Status ProcessHeadTailOutsideCompilation(
const string& outside_compilation_attr_name, int* lifted_arg_count,
std::unordered_map<string, XlaClusterInfo>* clusters, Graph* g,
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITES_ENCAPSULATE_TPU_COMPUTATIONS_PASS_H_

View File

@ -0,0 +1,810 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/ops/tpu_replication_ops.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/util/equal_graph_def.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
static std::unique_ptr<Graph> MakeOuterGraph(
const FunctionLibraryDefinition& flib_def, const string& function) {
Scope scope = Scope::NewRootScope().ExitOnError();
TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
int num_replicas = 2;
auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
auto b0 = ops::Placeholder(scope.WithOpName("B0"), DT_FLOAT);
auto b1 = ops::Placeholder(scope.WithOpName("B1"), DT_FLOAT);
auto u0 = ops::Placeholder(scope.WithOpName("U0"), DT_RESOURCE);
auto u1 = ops::Placeholder(scope.WithOpName("U1"), DT_RESOURCE);
auto z = ops::Placeholder(scope.WithOpName("Z"), DT_RESOURCE);
auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
auto x = ops::GuaranteeConst(
scope.WithOpName("X"),
ops::Placeholder(scope.WithOpName("X_Holder"), DT_DOUBLE));
auto y = ops::GuaranteeConst(
scope.WithOpName("Y"),
ops::Placeholder(scope.WithOpName("Y_Holder"), DT_DOUBLE));
auto in0 = ops::TPUReplicatedInput(scope.WithOpName("In0"),
std::initializer_list<Input>{a0, a1});
auto in1 = ops::TPUReplicatedInput(scope.WithOpName("In1"),
std::initializer_list<Input>{b0, b1});
auto in2 = ops::TPUReplicatedInput(scope.WithOpName("In2"),
std::initializer_list<Input>{u0, u1});
auto in3 = ops::TPUReplicatedInput(scope.WithOpName("In3"),
std::initializer_list<Input>{z});
in3.node()->AddAttr("is_packed", true);
NodeDef def;
TF_CHECK_OK(NodeDefBuilder("replicate0", function, &flib_def)
.Input(in0.node()->name(), 0, DT_INT32)
.Input(in1.node()->name(), 0, DT_FLOAT)
.Input(in2.node()->name(), 0, DT_RESOURCE)
.Input(in3.node()->name(), 0, DT_RESOURCE)
.Input(c.node()->name(), 0, DT_INT32)
.Input(d.node()->name(), 0, DT_FLOAT)
.Input(v.node()->name(), 0, DT_RESOURCE)
.Input(w.node()->name(), 0, DT_RESOURCE)
.Input(x.node()->name(), 0, DT_DOUBLE)
.Input(y.node()->name(), 0, DT_DOUBLE)
.Attr(kTPUReplicateAttr, "replicate0")
.Attr("num_replicas", num_replicas)
.Attr("num_cores_per_replica", 6)
.Attr("topology", "")
.Attr("use_tpu", true)
.Attr("device_assignment", std::vector<int>())
.Attr("host_compute_core", std::vector<string>())
.Attr("padding_map", std::vector<string>())
.Attr("_variable_start_index", 6)
.Attr("_guaranteed_const_start_index", 8)
.Attr("allow_soft_placement", false)
.Attr("step_marker_location", "STEP_MARK_AT_ENTRY")
.Attr("use_spmd_for_xla_partitioning", false)
.Finalize(&def));
Status status;
Node* replicate = scope.graph()->AddNode(def, &status);
TF_CHECK_OK(status);
TF_CHECK_OK(scope.DoShapeInference(replicate));
scope.graph()->AddEdge(in0.node(), 0, replicate, 0);
scope.graph()->AddEdge(in1.node(), 0, replicate, 1);
scope.graph()->AddEdge(in2.node(), 0, replicate, 2);
scope.graph()->AddEdge(in3.node(), 0, replicate, 3);
scope.graph()->AddEdge(c.node(), 0, replicate, 4);
scope.graph()->AddEdge(d.node(), 0, replicate, 5);
scope.graph()->AddEdge(v.node(), 0, replicate, 6);
scope.graph()->AddEdge(w.node(), 0, replicate, 7);
scope.graph()->AddEdge(x.node(), 0, replicate, 8);
scope.graph()->AddEdge(y.node(), 0, replicate, 9);
auto out0 = ops::TPUReplicatedOutput(scope.WithOpName("Out0"),
Output(replicate, 0), num_replicas);
auto out1 = ops::TPUReplicatedOutput(scope.WithOpName("Out1"),
Output(replicate, 1), num_replicas);
auto out2 = ops::TPUReplicatedOutput(scope.WithOpName("Out2"),
Output(replicate, 2), num_replicas);
auto out3 = ops::TPUReplicatedOutput(scope.WithOpName("Out3"),
Output(replicate, 3), num_replicas);
auto out4 = ops::TPUReplicatedOutput(scope.WithOpName("Out4"),
Output(replicate, 4), num_replicas);
auto consumer0_0a = ops::Identity(scope.WithOpName("consumer0_0a"), out0[0]);
auto consumer0_0b = ops::Identity(scope.WithOpName("consumer0_0b"), out0[0]);
auto consumer0_1 = ops::Identity(scope.WithOpName("consumer0_1"), out0[1]);
auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1[1]);
auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2[0]);
auto consumer3a = ops::Identity(scope.WithOpName("consumer3a"), out3[0]);
auto consumer3b = ops::Identity(scope.WithOpName("consumer3b"), out3[1]);
auto consumer4a = ops::Identity(scope.WithOpName("consumer4a"), out4[0]);
auto consumer4b = ops::Identity(scope.WithOpName("consumer4b"), out4[1]);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_CHECK_OK(scope.ToGraph(graph.get()));
return graph;
}
// Makes an encapsulate body graph for use in tests.
static std::unique_ptr<Graph> MakeBodyGraph() {
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("in0_0_arg"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("in1_0_arg"), DT_FLOAT, 1);
auto arg2 = ops::_Arg(scope.WithOpName("in2_0_arg"), DT_RESOURCE, 2);
auto arg3 = ops::_Arg(scope.WithOpName("in3_0_arg"), DT_RESOURCE, 3);
auto arg4 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 4);
auto arg5 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 5);
auto arg6 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 6);
auto arg7 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 7);
auto add_attrs = [](Node* node) {
node->AddAttr(kTPUReplicateAttr, "replicate0");
};
string device =
tensorflow::strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE);
auto in1_identity =
ops::Identity(scope.WithOpName("In1_identity").WithDevice(device), arg1);
auto read_u = ops::ReadVariableOp(
scope.WithOpName("ReadU").WithDevice(device), arg2, DT_FLOAT);
add_attrs(read_u.node());
auto read_z = ops::ReadVariableOp(
scope.WithOpName("ReadZ").WithDevice(device), arg3, DT_FLOAT);
add_attrs(read_z.node());
auto read_v = ops::ReadVariableOp(
scope.WithOpName("ReadV").WithDevice(device), arg6, DT_FLOAT);
add_attrs(read_v.node());
auto read_w = ops::ReadVariableOp(
scope.WithOpName("ReadW").WithDevice(device), arg7, DT_FLOAT);
add_attrs(read_w.node());
auto e = ops::Add(scope.WithOpName("E").WithDevice(device), arg0, arg4);
add_attrs(e.node());
auto f = ops::Add(scope.WithOpName("F").WithDevice(device), read_v, read_w);
add_attrs(f.node());
auto g = ops::Add(scope.WithOpName("G").WithDevice(device), f, arg5);
add_attrs(g.node());
auto arg8 = ops::_Arg(scope.WithOpName("x_0_arg"), DT_DOUBLE, 8);
auto arg9 = ops::_Arg(scope.WithOpName("y_0_arg"), DT_DOUBLE, 9);
arg8.node()->AddAttr("_is_guaranteed_constant", true);
arg9.node()->AddAttr("_is_guaranteed_constant", true);
auto h = ops::Add(scope.WithOpName("H").WithDevice(device), arg8, arg9);
add_attrs(h.node());
auto out0 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 0);
auto out1 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 1);
auto out2 = ops::_Retval(scope.WithOpName("in1_identity_0_retval_RetVal"),
in1_identity, 2);
auto out3 =
ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
auto out4 =
ops::_Retval(scope.WithOpName("readz_0_retval_RetVal"), read_z, 4);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_CHECK_OK(scope.ToGraph(graph.get()));
return graph;
}
TEST(EncapsulateTPUComputations, DeterministicEncapsulate) {
// Test that control edge insertion order doesn't affect the cache key
// (cluster name) generated by TPU encapsulate pass.
auto get_serialized_graph = [](bool control_input_reversed,
bool operand_reversed) -> string {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph(new Graph(&flib_def));
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
: ops::Add(scope.WithOpName("E"), a1, a0);
auto metadata = ops::TPUReplicateMetadata(scope, /*num_replicas=*/2);
auto add_attrs = [](Node* node) {
node->AddAttr(kTPUReplicateAttr, "replicate0");
};
add_attrs(metadata.operation.node());
add_attrs(e.node());
TF_CHECK_OK(scope.ToGraph(graph.get()));
auto get_node_in_graph = [&graph](Node* node) {
return graph->FindNodeId(node->id());
};
// Insert control edge in different order. The order should not affect
// the encapsulated or serialized graph.
if (!control_input_reversed) {
graph->AddControlEdge(get_node_in_graph(a0.node()),
get_node_in_graph(e.node()), true);
graph->AddControlEdge(get_node_in_graph(a1.node()),
get_node_in_graph(e.node()), true);
} else {
graph->AddControlEdge(get_node_in_graph(a1.node()),
get_node_in_graph(e.node()), true);
graph->AddControlEdge(get_node_in_graph(a0.node()),
get_node_in_graph(e.node()), true);
}
}
TF_CHECK_OK(EncapsulateTPUComputationsPass::Encapsulate(&graph, &flib_def));
GraphDef gdef;
graph->ToGraphDef(&gdef);
// Before serialization, sort control inputs first to remove
// nondeterminism.
SortControlInputs(&gdef);
string serialized;
SerializeToStringDeterministic(gdef, &serialized);
return serialized;
};
// Changing the order of control input shouldn't affect the graph generated.
EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
/*operand_reversed=*/false),
get_serialized_graph(/*control_input_reversed=*/false,
/*operand_reversed=*/false));
// Changing the order of data input should affect the graph generated.
EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
/*operand_reversed=*/true),
get_serialized_graph(/*control_input_reversed=*/false,
/*operand_reversed=*/false));
}
TEST(EncapsulateTPUComputations, Encapsulate) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph(new Graph(&flib_def));
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
auto b0 = ops::Placeholder(scope.WithOpName("B0"), DT_FLOAT);
auto b1 = ops::Placeholder(scope.WithOpName("B1"), DT_FLOAT);
auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
auto u0 = ops::Placeholder(scope.WithOpName("U0"), DT_RESOURCE);
auto u1 = ops::Placeholder(scope.WithOpName("U1"), DT_RESOURCE);
auto z = ops::Placeholder(scope.WithOpName("Z"), DT_RESOURCE);
auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
auto x = ops::GuaranteeConst(
scope.WithOpName("X"),
ops::Placeholder(scope.WithOpName("X_Holder"), DT_DOUBLE));
auto y = ops::GuaranteeConst(
scope.WithOpName("Y"),
ops::Placeholder(scope.WithOpName("Y_Holder"), DT_DOUBLE));
auto in0 = ops::TPUReplicatedInput(scope.WithOpName("In0"),
std::initializer_list<Input>{a0, a1});
auto in1 = ops::TPUReplicatedInput(scope.WithOpName("In1"),
std::initializer_list<Input>{b0, b1});
auto in2 = ops::TPUReplicatedInput(scope.WithOpName("In2"),
std::initializer_list<Input>{u0, u1});
auto in3 = ops::TPUReplicatedInput(scope.WithOpName("In3"),
std::initializer_list<Input>{z});
in3.node()->AddAttr("is_packed", true);
auto add_attrs = [](Node* node) {
node->AddAttr(kTPUReplicateAttr, "replicate0");
};
auto metadata = ops::TPUReplicateMetadata(
scope, /*num_replicas=*/2,
ops::TPUReplicateMetadata::ComputationShape({2, 3}));
add_attrs(metadata.operation.node());
auto in1_identity = ops::Identity(scope.WithOpName("In1_identity"), in1);
add_attrs(in1_identity.node());
auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), in2, DT_FLOAT);
add_attrs(read_u.node());
auto read_z = ops::ReadVariableOp(scope.WithOpName("ReadZ"), in3, DT_FLOAT);
add_attrs(read_z.node());
auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
add_attrs(read_v.node());
auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
add_attrs(read_w.node());
auto e = ops::Add(scope.WithOpName("E"), in0, c);
add_attrs(e.node());
auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
add_attrs(f.node());
auto g = ops::Add(scope.WithOpName("G"), f, d);
add_attrs(g.node());
auto h = ops::Add(scope.WithOpName("H"), x, y);
add_attrs(h.node());
auto out0 = ops::TPUReplicatedOutput(scope.WithOpName("Out0"), e, 2);
auto out1 = ops::TPUReplicatedOutput(scope.WithOpName("Out1"), g, 2);
auto out2 =
ops::TPUReplicatedOutput(scope.WithOpName("Out2"), in1_identity, 2);
auto out3 = ops::TPUReplicatedOutput(scope.WithOpName("Out3"), read_u, 2);
auto out4 = ops::TPUReplicatedOutput(scope.WithOpName("Out4"), read_z, 2);
auto consumer0_0a =
ops::Identity(scope.WithOpName("consumer0_0a"), out0[0]);
auto consumer0_0b =
ops::Identity(scope.WithOpName("consumer0_0b"), out0[0]);
auto consumer0_1 = ops::Identity(scope.WithOpName("consumer0_1"), out0[1]);
auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1[1]);
auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2[0]);
auto consumer3a = ops::Identity(scope.WithOpName("consumer3a"), out3[0]);
auto consumer3b = ops::Identity(scope.WithOpName("consumer3b"), out3[1]);
auto consumer4a = ops::Identity(scope.WithOpName("consumer4a"), out4[0]);
auto consumer4b = ops::Identity(scope.WithOpName("consumer4b"), out4[1]);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
}
std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
CopyGraph(*graph, graph_copy.get());
TF_ASSERT_OK(EncapsulateTPUComputationsPass::Encapsulate(&graph, &flib_def));
// Remove _xla_inferred_shapes attribute.
for (Node* n : graph->nodes()) {
n->ClearAttr("_xla_inferred_shapes");
}
std::unordered_map<string, Node*> index = graph->BuildNodeNameIndex();
string function = index.at("replicate0")->type_string();
// Tests the outer graph is as expected.
{
std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
GraphDef expected_def;
outer->ToGraphDef(&expected_def);
GraphDef actual_def;
graph->ToGraphDef(&actual_def);
TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
}
// Tests the encapsulated body graph is as expected.
{
std::unique_ptr<Graph> body = MakeBodyGraph();
GraphDef expected_body_def;
body->ToGraphDef(&expected_body_def);
InstantiationResultForTest result;
TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_RESOURCE, DT_RESOURCE,
DT_INT32, DT_FLOAT, DT_RESOURCE, DT_RESOURCE,
DT_DOUBLE, DT_DOUBLE}),
result.arg_types);
EXPECT_EQ(
(DataTypeVector{DT_INT32, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}),
result.ret_types);
TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
}
// Encapsulates the same computation again, verifies we reuse the same
// function. Encapsulation should be deterministic to avoid recompilation.
TF_ASSERT_OK(
EncapsulateTPUComputationsPass::Encapsulate(&graph_copy, &flib_def));
std::unordered_map<string, Node*> index_copy =
graph_copy->BuildNodeNameIndex();
string function_copy = index_copy.at("replicate0")->type_string();
EXPECT_EQ(function, function_copy);
}
TEST(EncapsulateTPUComputations, BuildTPUReplicateOps) {
std::unique_ptr<Graph> body_graph = MakeBodyGraph();
FunctionDefLibrary flib;
TF_ASSERT_OK(
GraphToFunctionDef(*body_graph, "replicate0", flib.add_function()));
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "replicate0");
TF_ASSERT_OK(
EncapsulateTPUComputationsPass::BuildTPUReplicateOps(graph.get()));
Scope scope = Scope::NewRootScope().ExitOnError();
TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
auto b0 = ops::Placeholder(scope.WithOpName("B0"), DT_FLOAT);
auto b1 = ops::Placeholder(scope.WithOpName("B1"), DT_FLOAT);
auto u0 = ops::Placeholder(scope.WithOpName("U0"), DT_RESOURCE);
auto u1 = ops::Placeholder(scope.WithOpName("U1"), DT_RESOURCE);
auto z = ops::Placeholder(scope.WithOpName("Z"), DT_RESOURCE);
auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
auto x =
ops::Identity(scope.WithOpName("X"),
ops::Placeholder(scope.WithOpName("X_Holder"), DT_DOUBLE));
auto y =
ops::Identity(scope.WithOpName("Y"),
ops::Placeholder(scope.WithOpName("Y_Holder"), DT_DOUBLE));
NameAttrList function;
function.set_name("replicate0");
auto replicate = ops::_TPUReplicate(
scope.WithOpName("replicate0"),
std::initializer_list<Input>{a0, b0, u0, a1, b1, u1, z},
std::initializer_list<Input>{c, d}, std::initializer_list<Input>{v, w},
std::initializer_list<Input>{x, y}, function,
/*num_replicas=*/2,
{DT_INT32, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32, DT_FLOAT,
DT_FLOAT, DT_FLOAT, DT_FLOAT},
ops::_TPUReplicate::NumCoresPerReplica(6).NumDistributedVariables(1));
auto consumer0_0a =
ops::Identity(scope.WithOpName("consumer0_0a"), replicate.outputs[0]);
auto consumer0_0b =
ops::Identity(scope.WithOpName("consumer0_0b"), replicate.outputs[0]);
auto consumer0_1 =
ops::Identity(scope.WithOpName("consumer0_1"), replicate.outputs[5]);
auto consumer1 =
ops::Identity(scope.WithOpName("consumer1"), replicate.outputs[6]);
auto consumer2 =
ops::Identity(scope.WithOpName("consumer2"), replicate.outputs[2]);
auto consumer3a =
ops::Identity(scope.WithOpName("consumer3a"), replicate.outputs[3]);
auto consumer3b =
ops::Identity(scope.WithOpName("consumer3b"), replicate.outputs[8]);
auto consumer4a =
ops::Identity(scope.WithOpName("consumer4a"), replicate.outputs[4]);
auto consumer4b =
ops::Identity(scope.WithOpName("consumer4b"), replicate.outputs[9]);
GraphDef expected_def;
TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
GraphDef actual_def;
graph->ToGraphDef(&actual_def);
TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
}
class ExtractOutsideCompilationByScope : public ::testing::TestWithParam<bool> {
};
Status PivotControlExists(const Node* node, const Node* pivot) {
for (const Edge* edge : node->in_edges()) {
if (edge->IsControlEdge() && (edge->src() == pivot)) {
return Status::OK();
}
}
return errors::NotFound("Control edge with pivot not found.");
}
TEST_P(ExtractOutsideCompilationByScope,
MoveHeadAndTailOutsideCompilationToHost) {
FunctionLibraryDefinition fld(OpRegistry::Global(), FunctionDefLibrary());
// Create FunctionLibraryRuntime.
SessionOptions session_options;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
OptimizerOptions opts;
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, &fld, opts,
/*default_thread_pool=*/nullptr);
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
{
// Build TPU replicate function.
// arg0 = _Arg[index = 0, T = DT_STRING]
// arg1 = _Arg[index = 1, T = DT_INT32]
// arg2 = _Arg[index = 2, T = DT_RESOURCE]
// as_int = StringToNumber[out_type = DT_INT32](arg0) (oc node)
// add = Add(as_int, arg1)
// as_string = AsString(add) (oc node)
// read_var = ops::ReadVariableOp(arg2)
// ret0 = _RetVal[index = 0, T = DT_STRING](as_string)
// ret1 = _RetVal[index = 1, T = DT_INT32](add)
// ret2 = _RetVal[index = 1, T = DT_FLOAT](read_var)
Scope s = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(s.WithOpName("arg0"), DT_STRING, 0);
auto arg1 = ops::_Arg(s.WithOpName("arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(s.WithOpName("arg2"), DT_RESOURCE, 2);
auto as_int = ops::StringToNumber(s.WithOpName("as_int"), arg0,
ops::StringToNumber::OutType(DT_INT32));
auto add = ops::Add(s.WithOpName("add"), as_int, arg1);
auto as_string = ops::AsString(s.WithOpName("as_string"), add);
auto read_var =
ops::ReadVariableOp(s.WithOpName("ReadVar"), arg2, DT_FLOAT);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), as_string, 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), add, 1);
auto ret2 = ops::_Retval(s.WithOpName("ret2"), read_var, 2);
Graph g(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(&g));
auto node_name_index = g.BuildNodeNameIndex();
node_name_index["as_int"]->AddAttr("oc", "0");
node_name_index["as_string"]->AddAttr("oc", "0");
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(g, "cluster", &fdef));
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
}
string control_flow_scope = GetParam() ? "scope/" : "";
string pivot_name = absl::StrCat(control_flow_scope, "tpu_replicate/pivot");
Graph host_graph(OpRegistry::Global());
NameAttrList function;
function.set_name("cluster");
{
// Build host graph.
// input00 = Placeholder[T = DT_STRING]
// input01 = Placeholder[T = DT_INT32]
// input10 = Placeholder[T = DT_STRING]
// input11 = Placeholder[T = DT_INT32]
// input2 = Placeholder[T = DT_RESOURCE]
// tpu_replicate = _TPUReplicate(input00, input01, input10, input11)
// output = IdentityN(tpu_replicate, tpu_replicate:1, tpu_replicate:2,
// tpu_replicate:3, tpu_replicate:4, tpu_replicate:5)
Scope s = Scope::NewRootScope().ExitOnError();
auto pivot = ops::NoOp(s.WithOpName(pivot_name));
pivot.operation.node()->AddAttr("_pivot_for_cluster", "cluster");
auto input00 = ops::Placeholder(s.WithOpName("input00"), DT_STRING);
auto input01 = ops::Placeholder(s.WithOpName("input01"), DT_INT32);
auto input10 = ops::Placeholder(s.WithOpName("input10"), DT_STRING);
auto input11 = ops::Placeholder(s.WithOpName("input11"), DT_INT32);
auto input2 = ops::Placeholder(s.WithOpName("input2"), DT_RESOURCE);
auto control_scope = s.WithControlDependencies({pivot});
auto replicate = ops::_TPUReplicate(
control_scope.WithOpName("tpu_replicate"),
std::initializer_list<Input>{input00, input01, input10, input11,
input2},
std::initializer_list<Input>{}, std::initializer_list<Input>{},
std::initializer_list<Input>{}, function,
/*num_replicas=*/2,
{DT_STRING, DT_INT32, DT_FLOAT, DT_STRING, DT_INT32, DT_FLOAT},
ops::_TPUReplicate::NumCoresPerReplica(1).NumDistributedVariables(1));
auto output = ops::IdentityN(
s.WithOpName("output"),
std::initializer_list<Input>{
replicate.outputs[0], replicate.outputs[1], replicate.outputs[2],
replicate.outputs[3], replicate.outputs[4], replicate.outputs[5]});
TF_ASSERT_OK(s.ToGraph(&host_graph));
}
auto node_name_index = host_graph.BuildNodeNameIndex();
Node* replicate_node = node_name_index["tpu_replicate"];
std::unordered_map<string, XlaClusterInfo> clusters;
clusters.emplace("cluster",
XlaClusterInfo{"cluster", function, replicate_node,
std::map<string, int>{}});
int lifted_arg_count = 0;
TF_ASSERT_OK(ExtractOutsideCompilationPass::ProcessHeadTailOutsideCompilation(
"oc", &lifted_arg_count, &clusters, &host_graph, flr, &fld));
node_name_index = host_graph.BuildNodeNameIndex();
replicate_node = node_name_index["tpu_replicate"];
{
// Check host graph.
const Edge* e;
Node* pivot = node_name_index[pivot_name];
// Check that we have input00 -> as_int/R0 -> tpu_replicate.
Node* as_int_R0 = node_name_index["as_int_head_oc/R0"];
EXPECT_NE(as_int_R0, nullptr);
TF_ASSERT_OK(as_int_R0->input_edge(0, &e));
EXPECT_EQ(e->src(), node_name_index["input00"]);
TF_ASSERT_OK(replicate_node->input_edge(1, &e));
EXPECT_EQ(e->src(), as_int_R0);
// Check that as_int/R0 has pivot as control input
TF_EXPECT_OK(PivotControlExists(as_int_R0, pivot));
// Check that we have input10 -> as_int/R1 -> tpu_replicate.
Node* as_int_R1 = node_name_index["as_int_head_oc/R1"];
EXPECT_NE(as_int_R1, nullptr);
TF_ASSERT_OK(as_int_R1->input_edge(0, &e));
EXPECT_EQ(e->src(), node_name_index["input10"]);
TF_ASSERT_OK(replicate_node->input_edge(3, &e));
EXPECT_EQ(e->src(), as_int_R1);
// Check that as_int/R0 has pivot as control input
TF_EXPECT_OK(PivotControlExists(as_int_R1, pivot));
// Check that we have tpu_replicate -> as_string/R0 -> output.
Node* as_string_R0 = node_name_index["as_string_tail_oc/R0"];
EXPECT_NE(as_string_R0, nullptr);
TF_ASSERT_OK(as_string_R0->input_edge(0, &e));
EXPECT_EQ(e->src(), replicate_node);
TF_ASSERT_OK(node_name_index["output"]->input_edge(0, &e));
EXPECT_EQ(e->src(), as_string_R0);
// Check that as_string/R0 has pivot as control input
TF_EXPECT_OK(PivotControlExists(as_string_R0, pivot));
// Check that we have tpu_replicate -> as_string/R1 -> output.
Node* as_string_R1 = node_name_index["as_string_tail_oc/R1"];
EXPECT_NE(as_string_R1, nullptr);
TF_ASSERT_OK(as_string_R1->input_edge(0, &e));
EXPECT_EQ(e->src(), replicate_node);
TF_ASSERT_OK(node_name_index["output"]->input_edge(3, &e));
EXPECT_EQ(e->src(), as_string_R1);
// Check that as_string/R1 has pivot as control input
TF_EXPECT_OK(PivotControlExists(as_string_R1, pivot));
}
{
// Check TPU graph.
const FunctionDef* fdef = fld.Find("cluster");
EXPECT_NE(fdef, nullptr);
// Check its signature, should have 2 DT_INT32 inputs, 1 DT_RESOURCE input,
// 2 DT_INT32 outputs and 1 DT_FLOAT output.
EXPECT_EQ(fdef->signature().input_arg_size(), 3);
EXPECT_EQ(fdef->signature().input_arg(0).type(), DT_INT32);
EXPECT_EQ(fdef->signature().input_arg(1).type(), DT_INT32);
EXPECT_EQ(fdef->signature().input_arg(2).type(), DT_RESOURCE);
EXPECT_EQ(fdef->signature().output_arg_size(), 3);
EXPECT_EQ(fdef->signature().output_arg(0).type(), DT_INT32);
EXPECT_EQ(fdef->signature().output_arg(1).type(), DT_FLOAT);
EXPECT_EQ(fdef->signature().output_arg(2).type(), DT_INT32);
// Check that it has no StringToNumber/AsString op any more.
for (const NodeDef& node_def : fdef->node_def()) {
EXPECT_NE(node_def.op(), "StringToNumber");
EXPECT_NE(node_def.op(), "AsString");
}
}
}
INSTANTIATE_TEST_SUITE_P(All, ExtractOutsideCompilationByScope,
::testing::ValuesIn({true, false}));
TEST(ExtractOutsideCompilation, RemoveArgRetvalPair) {
FunctionLibraryDefinition fld(OpRegistry::Global(), FunctionDefLibrary());
// Create FunctionLibraryRuntime.
SessionOptions session_options;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
OptimizerOptions opts;
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, &fld, opts,
/*default_thread_pool=*/nullptr);
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
{
// Build TPU replicate function.
// arg0 = _Arg[index = 0, T = DT_STRING]
// arg1 = _Arg[index = 1, T = DT_FLOAT]
// arg2 = _Arg[index = 2, T = DT_INT32]
// arg3 = _Arg[index = 3, T = DT_RESOURCE]
// arg4 = _Arg[index = 4, T = DT_RESOURCE]
// add = Add(arg2, arg2)
// read = ReadVariableOp(arg4)
// ret0 = _RetVal[index = 0, T = DT_STRING](arg0)
// ret1 = _RetVal[index = 1, T = DT_INT32](add)
// ret2 = _RetVal[index = 2, T = DT_FLOAT](read)
// ret3 = _RetVal[index = 3, T = DT_RESOURCE](arg3)
Scope s = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(s.WithOpName("arg0"), DT_STRING, 0);
auto arg1 = ops::_Arg(s.WithOpName("arg1"), DT_FLOAT, 1);
auto arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(s.WithOpName("arg3"), DT_RESOURCE, 3);
auto arg4 = ops::_Arg(s.WithOpName("arg4"), DT_RESOURCE, 4);
auto add = ops::Add(s.WithOpName("add"), arg2, arg2);
auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg0, 0);
auto ret1 = ops::_Retval(s.WithOpName("ret1"), add, 1);
auto read = ops::ReadVariableOp(s.WithOpName("read"), arg4, DT_FLOAT);
auto ret2 = ops::_Retval(s.WithOpName("ret2"), read, 2);
auto ret3 = ops::_Retval(s.WithOpName("ret3"), arg3, 3);
Graph g(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(&g));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(g, "cluster", &fdef));
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
}
Graph host_graph(OpRegistry::Global());
NameAttrList function;
function.set_name("cluster");
{
// Build host graph.
// input00 = Placeholder[T = DT_STRING]
// input01 = Placeholder[T = DT_FLOAT]
// input02 = Placeholder[T = DT_INT32]
// input10 = Placeholder[T = DT_STRING]
// input11 = Placeholder[T = DT_FLOAT]
// input12 = Placeholder[T = DT_INT32]
// input3 = Placeholder[T = DT_RESOURCE], distributed variable
// input4 = Placeholder[T = DT_RESOURCE], distributed variable
// tpu_replicate = _TPUReplicate(input00, input01, input02, input10,
// input11, input12, input3, input4)
// output = IdentityN(tpu_replicate, tpu_replicate:1, tpu_replicate:2,
// tpu_replicate:3, tpu_replicate:4, tpu_replicate:5,
// tpu_replicate:6, tpu_replicate:7)
Scope s = Scope::NewRootScope().ExitOnError();
auto input00 = ops::Placeholder(s.WithOpName("input00"), DT_STRING);
auto input01 = ops::Placeholder(s.WithOpName("input01"), DT_FLOAT);
auto input02 = ops::Placeholder(s.WithOpName("input02"), DT_INT32);
auto input10 = ops::Placeholder(s.WithOpName("input10"), DT_STRING);
auto input11 = ops::Placeholder(s.WithOpName("input11"), DT_FLOAT);
auto input12 = ops::Placeholder(s.WithOpName("input12"), DT_INT32);
auto input3 = ops::Placeholder(s.WithOpName("input3"), DT_RESOURCE);
auto input4 = ops::Placeholder(s.WithOpName("input3"), DT_RESOURCE);
auto replicate = ops::_TPUReplicate(
s.WithOpName("tpu_replicate"),
std::initializer_list<Input>{input00, input01, input02, input10,
input11, input12, input3, input4},
std::initializer_list<Input>{}, std::initializer_list<Input>{},
std::initializer_list<Input>{}, function,
/*num_replicas=*/2,
{DT_STRING, DT_INT32, DT_FLOAT, DT_RESOURCE, DT_STRING, DT_INT32,
DT_FLOAT, DT_RESOURCE},
ops::_TPUReplicate::NumCoresPerReplica(1).NumDistributedVariables(2));
auto output = ops::IdentityN(
s.WithOpName("output"),
std::initializer_list<Input>{
replicate.outputs[0], replicate.outputs[1], replicate.outputs[2],
replicate.outputs[3], replicate.outputs[4], replicate.outputs[5],
replicate.outputs[6], replicate.outputs[7]});
TF_ASSERT_OK(s.ToGraph(&host_graph));
}
auto node_name_index = host_graph.BuildNodeNameIndex();
Node* replicate_node = node_name_index["tpu_replicate"];
std::unordered_map<string, XlaClusterInfo> clusters;
clusters.emplace("cluster",
XlaClusterInfo{"cluster", function, replicate_node,
std::map<string, int>{}});
int lifted_arg_count = 0;
TF_ASSERT_OK(ExtractOutsideCompilationPass::ProcessHeadTailOutsideCompilation(
"oc", &lifted_arg_count, &clusters, &host_graph, flr, &fld));
node_name_index = host_graph.BuildNodeNameIndex();
replicate_node = node_name_index["tpu_replicate"];
Node* output = node_name_index["output"];
EXPECT_EQ(replicate_node->num_inputs(), 3);
const DataTypeVector expected_input_types = {DT_INT32, DT_INT32, DT_RESOURCE};
EXPECT_EQ(replicate_node->input_types(), expected_input_types);
EXPECT_EQ(replicate_node->num_outputs(), 4);
const DataTypeVector expected_output_types = {DT_INT32, DT_FLOAT, DT_INT32,
DT_FLOAT};
EXPECT_EQ(replicate_node->output_types(), expected_output_types);
{
// Check host graph.
Node* input_node;
// Check that we have input00 -> output:1.
TF_ASSERT_OK(output->input_node(0, &input_node));
EXPECT_EQ(input_node->name(), "input00");
// Check that we have input10 -> output:4.
TF_ASSERT_OK(output->input_node(4, &input_node));
EXPECT_EQ(input_node->name(), "input10");
// Check that we have input3 -> output:3, output:7.
TF_ASSERT_OK(output->input_node(3, &input_node));
EXPECT_EQ(input_node->name(), "input3");
TF_ASSERT_OK(output->input_node(7, &input_node));
EXPECT_EQ(input_node->name(), "input3");
}
{
// Check TPU graph.
const FunctionDef* fdef = fld.Find("cluster");
EXPECT_NE(fdef, nullptr);
// Check its signature, should have 1 DT_INT32 input, 1 DT_RESOURCE input,
// 1 DT_INT32 output and 1 DT_FLOAT output
EXPECT_EQ(fdef->signature().input_arg_size(), 2);
EXPECT_EQ(fdef->signature().input_arg(0).type(), DT_INT32);
EXPECT_EQ(fdef->signature().input_arg(1).type(), DT_RESOURCE);
EXPECT_EQ(fdef->signature().output_arg_size(), 2);
EXPECT_EQ(fdef->signature().output_arg(0).type(), DT_INT32);
EXPECT_EQ(fdef->signature().output_arg(1).type(), DT_FLOAT);
}
}
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
#include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
#include "tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h"
namespace tensorflow {
@ -25,8 +26,10 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
DistributedTPUConfigurationRewritePass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
DistributedTPUShutdownRewritePass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 0,
VariableMergerPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 34,
EncapsulateTPUComputationsPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 39,
ExtractOutsideCompilationPass);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/tpu_compile_interface.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/logging.h"
class TpuCompileInterfaceExternal : public TpuCompileInterface {
public:
uint64_t FingerprintString(absl::string_view str) override {
return ::tensorflow::Fingerprint64(str);
}
};
static TpuCompileInterface* impl_ = new TpuCompileInterfaceExternal;
TpuCompileInterface* TpuCompileInterface::Get() { return impl_; }
bool TpuCompileInterface::RegisterImplementation(TpuCompileInterface* impl) {
VLOG(1) << "Updating TpuCompileInterface.";
if (impl_ != nullptr) {
delete impl_;
}
impl_ = impl;
return true;
}

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_
#define TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_
#include "absl/strings/string_view.h"
// Some legacy code requires different implementations for operations like
// fingerprint/hashing during compilation and/or graph rewriting. These
// alternate implementations can be registered (via a module initializer) to
// change the default behavior.
class TpuCompileInterface {
public:
virtual ~TpuCompileInterface() {}
static TpuCompileInterface* Get();
static bool RegisterImplementation(TpuCompileInterface* impl);
virtual uint64_t FingerprintString(absl::string_view str) = 0;
};
#endif // TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_

View File

@ -24,4 +24,7 @@ const char* const DEVICE_TPU_XLA_JIT = "XLA_TPU_JIT";
const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR =
"_mirrored_variable_indices";
const char* const kTPUReplicateAttr = "_tpu_replicate";
const char* const kOutsideCompilationAttr = "_xla_outside_compilation";
} // namespace tensorflow

View File

@ -47,6 +47,9 @@ extern const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR;
// variable.
extern const char* const TPU_FAST_MEM_ATTR; // "_TPU_FAST_MEM"
extern const char* const kTPUReplicateAttr;
extern const char* const kOutsideCompilationAttr;
// Supported types for TPUs.
static constexpr std::array<DataType, 11> kTpuAllTypes = {
{DT_INT32, DT_UINT32, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL,