Adjust TPU build dependencies.
PiperOrigin-RevId: 323453889 Change-Id: I6b33b57830d3414d45aaa066735c3b52f217e739
This commit is contained in:
parent
20cd718248
commit
9872ddd15a
@ -225,6 +225,8 @@ cc_library(
|
|||||||
"xla_device_context.h",
|
"xla_device_context.h",
|
||||||
"xla_device_ops.h",
|
"xla_device_ops.h",
|
||||||
],
|
],
|
||||||
|
# Public visibility is needed for external TF/XLA backends.
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
deps = XLA_DEVICE_DEPS,
|
deps = XLA_DEVICE_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -295,19 +295,6 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
|
|||||||
<< diff << "\nActual: " << actual.DebugString(); \
|
<< diff << "\nActual: " << actual.DebugString(); \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
// These dummy Op registrations are here because the real Op registrations live
|
|
||||||
// in contrib and there can't be a dependence from this test to contrib.
|
|
||||||
REGISTER_OP("XlaHostCompute")
|
|
||||||
.Input("inputs: Tinputs")
|
|
||||||
.Output("outputs: Toutputs")
|
|
||||||
.Attr("Tinputs: list(type) >= 0")
|
|
||||||
.Attr("Toutputs: list(type) >= 0")
|
|
||||||
.Attr("ancestors: list(string) >= 0")
|
|
||||||
.Attr("key: string")
|
|
||||||
.Attr("shape_inference_graph: func")
|
|
||||||
.Attr("shapes: list(shape) >= 0")
|
|
||||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
|
||||||
|
|
||||||
REGISTER_OP("InputTest")
|
REGISTER_OP("InputTest")
|
||||||
.Output("o: float")
|
.Output("o: float")
|
||||||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||||
@ -947,6 +934,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", shape_inference_graph},
|
{"shape_inference_graph", shape_inference_graph},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const DataType>({})},
|
{"shapes", absl::Span<const DataType>({})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -1114,6 +1103,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O2"},
|
{"key", "host_compute_channel_F1_F1_O2"},
|
||||||
{"shape_inference_graph", shape_inference_graph2},
|
{"shape_inference_graph", shape_inference_graph2},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const DataType>({})},
|
{"shapes", absl::Span<const DataType>({})},
|
||||||
{"_outside_compilation_subgraph", "O2"},
|
{"_outside_compilation_subgraph", "O2"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -1130,6 +1121,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", shape_inference_graph1},
|
{"shape_inference_graph", shape_inference_graph1},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const DataType>({})},
|
{"shapes", absl::Span<const DataType>({})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -1266,6 +1259,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", NameAttrList()},
|
{"shape_inference_graph", NameAttrList()},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes",
|
{"shapes",
|
||||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
@ -1295,6 +1290,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F2_F2_O1"},
|
{"key", "host_compute_channel_F2_F2_O1"},
|
||||||
{"shape_inference_graph", NameAttrList()},
|
{"shape_inference_graph", NameAttrList()},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes",
|
{"shapes",
|
||||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
@ -1428,6 +1425,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", NameAttrList()},
|
{"shape_inference_graph", NameAttrList()},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes",
|
{"shapes",
|
||||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
@ -1454,6 +1453,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F2_F2_O1"},
|
{"key", "host_compute_channel_F2_F2_O1"},
|
||||||
{"shape_inference_graph", NameAttrList()},
|
{"shape_inference_graph", NameAttrList()},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes",
|
{"shapes",
|
||||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
@ -1566,6 +1567,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", NameAttrList()},
|
{"shape_inference_graph", NameAttrList()},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes",
|
{"shapes",
|
||||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
@ -1658,6 +1661,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", NameAttrList()},
|
{"shape_inference_graph", NameAttrList()},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes",
|
{"shapes",
|
||||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
@ -1765,6 +1770,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", shape_inference_graph},
|
{"shape_inference_graph", shape_inference_graph},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -1875,6 +1882,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", shape_inference_graph},
|
{"shape_inference_graph", shape_inference_graph},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -2009,6 +2018,8 @@ TEST(EncapsulateSubgraphsTest,
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", shape_inference_graph1},
|
{"shape_inference_graph", shape_inference_graph1},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -2023,6 +2034,8 @@ TEST(EncapsulateSubgraphsTest,
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O2"},
|
{"key", "host_compute_channel_F1_F1_O2"},
|
||||||
{"shape_inference_graph", shape_inference_graph2},
|
{"shape_inference_graph", shape_inference_graph2},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O2"},
|
{"_outside_compilation_subgraph", "O2"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -2153,6 +2166,8 @@ TEST(EncapsulateSubgraphsTest,
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O2"},
|
{"key", "host_compute_channel_F1_F1_O2"},
|
||||||
{"shape_inference_graph", NameAttrList()},
|
{"shape_inference_graph", NameAttrList()},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O2"},
|
{"_outside_compilation_subgraph", "O2"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -2169,6 +2184,8 @@ TEST(EncapsulateSubgraphsTest,
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", shape_inference_graph},
|
{"shape_inference_graph", shape_inference_graph},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -2296,6 +2313,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", shape_inference_graph},
|
{"shape_inference_graph", shape_inference_graph},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -2310,6 +2329,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O2"},
|
{"key", "host_compute_channel_F1_F1_O2"},
|
||||||
{"shape_inference_graph", NameAttrList()},
|
{"shape_inference_graph", NameAttrList()},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O2"},
|
{"_outside_compilation_subgraph", "O2"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -2325,6 +2346,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O3"},
|
{"key", "host_compute_channel_F1_F1_O3"},
|
||||||
{"shape_inference_graph", NameAttrList()},
|
{"shape_inference_graph", NameAttrList()},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O3"},
|
{"_outside_compilation_subgraph", "O3"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -2451,6 +2474,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", shape_inference_graph},
|
{"shape_inference_graph", shape_inference_graph},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
@ -2567,6 +2592,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
|
|||||||
{"ancestors", absl::Span<const string>({})},
|
{"ancestors", absl::Span<const string>({})},
|
||||||
{"key", "host_compute_channel_F1_F1_O1"},
|
{"key", "host_compute_channel_F1_F1_O1"},
|
||||||
{"shape_inference_graph", shape_inference_graph},
|
{"shape_inference_graph", shape_inference_graph},
|
||||||
|
{"tpu_core", 0},
|
||||||
|
{"cost_estimate_ns", 1000000},
|
||||||
{"shapes", absl::Span<const DataType>({})},
|
{"shapes", absl::Span<const DataType>({})},
|
||||||
{"_outside_compilation_subgraph", "O1"},
|
{"_outside_compilation_subgraph", "O1"},
|
||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
|
@ -2420,6 +2420,7 @@ Status ExtractOutsideCompilationForFunction(
|
|||||||
auto updated_fdef = absl::make_unique<FunctionDef>();
|
auto updated_fdef = absl::make_unique<FunctionDef>();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
|
GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
|
||||||
|
updated_fdef->mutable_signature()->set_is_stateful(true);
|
||||||
const FunctionDef* original_fdef = fld->Find(func_name);
|
const FunctionDef* original_fdef = fld->Find(func_name);
|
||||||
if (original_fdef) {
|
if (original_fdef) {
|
||||||
for (const auto& attr : original_fdef->attr()) {
|
for (const auto& attr : original_fdef->attr()) {
|
||||||
|
@ -422,19 +422,6 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) {
|
|||||||
EXPECT_EQ(fld.Find("host_graph"), nullptr);
|
EXPECT_EQ(fld.Find("host_graph"), nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_OP("XlaSendToHost")
|
|
||||||
.Input("input: Tinput")
|
|
||||||
.Attr("Tinput: type")
|
|
||||||
.Attr("key: string")
|
|
||||||
.SetIsStateful();
|
|
||||||
|
|
||||||
REGISTER_OP("XlaRecvFromHost")
|
|
||||||
.Output("output: Toutput")
|
|
||||||
.Attr("Toutput: type")
|
|
||||||
.Attr("shape: shape")
|
|
||||||
.Attr("key: string")
|
|
||||||
.SetIsStateful();
|
|
||||||
|
|
||||||
TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) {
|
TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) {
|
||||||
// Build the XLA computation func.
|
// Build the XLA computation func.
|
||||||
// "const0" (bool)
|
// "const0" (bool)
|
||||||
|
@ -880,6 +880,7 @@ cc_library(
|
|||||||
":tpu_outfeed_ops_op_lib",
|
":tpu_outfeed_ops_op_lib",
|
||||||
":tpu_ordinal_selector_ops_op_lib",
|
":tpu_ordinal_selector_ops_op_lib",
|
||||||
":tpu_replication_ops_op_lib",
|
":tpu_replication_ops_op_lib",
|
||||||
|
"//tensorflow/core/tpu/ops",
|
||||||
],
|
],
|
||||||
) + if_mkl([
|
) + if_mkl([
|
||||||
":mkl_array_ops_op_lib",
|
":mkl_array_ops_op_lib",
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "DataServiceDataset"
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "KthOrderStatistic"
|
||||||
|
summary: "Computes the Kth order statistic of a data set. The current"
|
||||||
|
description: <<END
|
||||||
|
implementation uses a binary search requiring exactly 32 passes over
|
||||||
|
the input data. The running time is linear with respect to input
|
||||||
|
size. The median-of-medians algorithm is probably faster, but is
|
||||||
|
difficult to implement efficiently in XLA. The implementation imposes
|
||||||
|
a total ordering on floats. The ordering is consistent with the usual
|
||||||
|
partial order. Positive NaNs are greater than positive
|
||||||
|
infinity. Negative NaNs are less than negative infinity. NaNs with
|
||||||
|
distinct payloads are treated as distinct. Subnormal numbers are
|
||||||
|
preserved (not flushed to zero). Positive infinity is greater than all
|
||||||
|
numbers. Negative infinity is less than all numbers. Positive is
|
||||||
|
greater than negative zero. There are less than k values greater than
|
||||||
|
the kth order statistic. There are at least k values greater than or
|
||||||
|
equal to the Kth order statistic. The semantics are not the same as
|
||||||
|
top_k_unique.
|
||||||
|
END
|
||||||
|
}
|
10
tensorflow/core/api_def/base_api/api_def_MakeUnique.pbtxt
Normal file
10
tensorflow/core/api_def/base_api/api_def_MakeUnique.pbtxt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "MakeUnique"
|
||||||
|
summary: "Make all elements in the non-Batch dimension unique, but \\\"close\\\" to"
|
||||||
|
description: <<END
|
||||||
|
their initial value. Never returns a sub-normal number. Never returns
|
||||||
|
zero. The sign of each input element is always identical to the sign
|
||||||
|
of the corresponding output element. Behavior for infinite elements is
|
||||||
|
undefined. Behavior for subnormal elements is undefined.
|
||||||
|
END
|
||||||
|
}
|
21
tensorflow/core/api_def/base_api/api_def_TPUCompile.pbtxt
Normal file
21
tensorflow/core/api_def/base_api/api_def_TPUCompile.pbtxt
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TPUCompile"
|
||||||
|
summary: "Compiles a computations for execution on one or more TPU devices."
|
||||||
|
description: <<END
|
||||||
|
For the internal use of the distributed TPU compiler.
|
||||||
|
|
||||||
|
'num_computations' is the number of computations to be compiled.
|
||||||
|
'function' is a function containing the computation to compile.
|
||||||
|
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
|
||||||
|
known statically at TPUReplication rewrite time.
|
||||||
|
'guaranteed_constants' is a list of tensors which have been guaranteed to not
|
||||||
|
change their values during the session lifetime. These contain tensors marked as
|
||||||
|
constant using the GuaranteeConstOp.
|
||||||
|
'metadata' is a serialized TPUCompileMetadataProto describing
|
||||||
|
the shapes and types of the inputs to the computation, as well as a mapping onto
|
||||||
|
the TPU pod topology.
|
||||||
|
Each 'program' output is a string key that is passed to the _TPUExecute op and
|
||||||
|
used to look up the program in the compilation cache.
|
||||||
|
'may_modify_variables' indicates whether variables may be modified.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,9 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TPUCompileSucceededAssert"
|
||||||
|
summary: "Asserts that compilation succeeded. This op produces no output and closes the"
|
||||||
|
description: <<END
|
||||||
|
device during failure to ensure all pending device interactions fail.
|
||||||
|
|
||||||
|
'compilation_status' is a serialized CompilationResultProto.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TPUExecute"
|
||||||
|
summary: "Op that loads and executes a TPU program on a TPU device."
|
||||||
|
description: <<END
|
||||||
|
For the internal use of the distributed TPU compiler.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,13 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TPUExecuteAndUpdateVariables"
|
||||||
|
summary: "Op that executes a program with optional in-place variable updates."
|
||||||
|
description: <<END
|
||||||
|
It (optionally) reads device variables, loads and executes a TPU program on a
|
||||||
|
TPU device, and then (optionally) in-place updates variables using the program
|
||||||
|
outputs, as specified in attributes device_var_reads_indices (program input
|
||||||
|
indices from directly reading variables) and device_var_updates_indices (program
|
||||||
|
output indices used to update variables, -1 means no-update/read-only). Such
|
||||||
|
program outputs are consumed by these variables will not appear in the op
|
||||||
|
output. For the internal use of the distributed TPU compiler.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TPUPartitionedInput"
|
||||||
|
in_arg {
|
||||||
|
name: "inputs"
|
||||||
|
description: <<END
|
||||||
|
A list of partitioned inputs which must have the same shape.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
A handle which represents the full shape of partitioned tensors.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "partition_dim"
|
||||||
|
description: <<END
|
||||||
|
An integer describles which dimension is partitioned. -1 means
|
||||||
|
those inputs are replicated.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "An op that groups a list of partitioned inputs together. This op"
|
||||||
|
}
|
@ -0,0 +1,25 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TPUPartitionedOutput"
|
||||||
|
in_arg {
|
||||||
|
name: "inputs"
|
||||||
|
description: <<END
|
||||||
|
A tensor which represents the full shape of partitioned tensors.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
A list of partitioned inputs which must have the same shape.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "partition_dim"
|
||||||
|
description: <<END
|
||||||
|
An integer describles which dimension is partitioned.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned"
|
||||||
|
description: <<END
|
||||||
|
outputs outside the XLA computation.
|
||||||
|
END
|
||||||
|
}
|
18
tensorflow/core/api_def/base_api/api_def_TopKUnique.pbtxt
Normal file
18
tensorflow/core/api_def/base_api/api_def_TopKUnique.pbtxt
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TopKUnique"
|
||||||
|
summary: "Returns the TopK unique values in the array in sorted order. The"
|
||||||
|
description: <<END
|
||||||
|
running time is proportional to the product of K and the input
|
||||||
|
size. Sorting the whole array is more efficient for sufficiently large
|
||||||
|
values of K. The median-of-medians algorithm is probably faster, but
|
||||||
|
difficult to implement efficiently in XLA. If there are fewer than K
|
||||||
|
unique numbers (not NANs), the results are padded with negative
|
||||||
|
infinity. NaNs are never returned. Subnormal numbers are flushed to
|
||||||
|
zero. If an element appears at multiple indices, the highest index is
|
||||||
|
returned. If a TopK element never appears in the input due to padding
|
||||||
|
values, the indices are padded with negative one. If a padding value
|
||||||
|
appears in the input and padding is needed, the highest index of the
|
||||||
|
padding value will be returned. The semantics are not the same as
|
||||||
|
kth_order_statistic.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,10 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TopKWithUnique"
|
||||||
|
summary: "Returns the TopK values in the array in sorted order. This is a combination"
|
||||||
|
description: <<END
|
||||||
|
of MakeUnique and TopKUnique. The returned top-K will have its lower bits
|
||||||
|
replaced by iota, thus it will be close to the original value but not exactly
|
||||||
|
the same. The running time is proportional to the product of K and the input
|
||||||
|
size. NaNs are never returned. Subnormal numbers are flushed to zero.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,66 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "XlaHostCompute"
|
||||||
|
in_arg {
|
||||||
|
name: "inputs"
|
||||||
|
description: <<END
|
||||||
|
A list of tensors that will be sent to the host.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "outputs"
|
||||||
|
description: <<END
|
||||||
|
A list of tensors that will be returned to the device.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tinputs"
|
||||||
|
description: <<END
|
||||||
|
The element types of each element in `inputs`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Toutputs"
|
||||||
|
description: <<END
|
||||||
|
The element types of each element in `outputs`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "ancestors"
|
||||||
|
description: <<END
|
||||||
|
A list of names of HostCompute computations that must be
|
||||||
|
sequenced before this computation.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "shapes"
|
||||||
|
description: <<END
|
||||||
|
If shape_inference_graph is empty, a list of the shapes of `outputs`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "shape_inference_graph"
|
||||||
|
description: <<END
|
||||||
|
If non-empty, a serialized GraphDef representing a graph
|
||||||
|
that must be analyzed at compile time to determine the shapes of the outputs.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "key"
|
||||||
|
description: <<END
|
||||||
|
A unique identifier for this region used to match up host transfers.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "cost_estimate_ns"
|
||||||
|
description: <<END
|
||||||
|
Estimated duration of the host computation in nanoseconds.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "tpu_core"
|
||||||
|
description: <<END
|
||||||
|
Default core to use for host to device transfers.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "A pseudo-op to represent host-side computation in an XLA program."
|
||||||
|
}
|
@ -0,0 +1,3 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "XlaRecvFromHost"
|
||||||
|
}
|
12
tensorflow/core/api_def/base_api/api_def_XlaSendToHost.pbtxt
Normal file
12
tensorflow/core/api_def/base_api/api_def_XlaSendToHost.pbtxt
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "XlaSendToHost"
|
||||||
|
in_arg: {
|
||||||
|
name: "input"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tinput"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "key"
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"cc_header_only_library",
|
"cc_header_only_library",
|
||||||
|
"if_tpu",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_cc_test_mkl",
|
"tf_cc_test_mkl",
|
||||||
"tf_cc_tests",
|
"tf_cc_tests",
|
||||||
@ -91,7 +92,7 @@ cc_library(
|
|||||||
":core_cpu",
|
":core_cpu",
|
||||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime",
|
"//tensorflow/core/common_runtime/gpu:gpu_runtime",
|
||||||
"//tensorflow/core/common_runtime/sycl:sycl_runtime",
|
"//tensorflow/core/common_runtime/sycl:sycl_runtime",
|
||||||
],
|
] + if_tpu(["//tensorflow/core/tpu:tpu_runtime"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
@ -78,8 +78,7 @@ cc_library(
|
|||||||
srcs = ["tpu_compile_interface.cc"],
|
srcs = ["tpu_compile_interface.cc"],
|
||||||
hdrs = ["tpu_compile_interface.h"],
|
hdrs = ["tpu_compile_interface.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core/platform:fingerprint",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:logging",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -146,8 +145,7 @@ cc_library(
|
|||||||
":tpu_api",
|
":tpu_api",
|
||||||
":tpu_config_c_api",
|
":tpu_config_c_api",
|
||||||
":tpu_library_init_fns",
|
":tpu_library_init_fns",
|
||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:status",
|
|
||||||
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
|
||||||
@ -155,22 +153,7 @@ cc_library(
|
|||||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
|
||||||
] + select({
|
],
|
||||||
"//tensorflow:oss": [
|
|
||||||
":tpu_compilation_device",
|
|
||||||
":tpu_node_device",
|
|
||||||
":tpu_system_device",
|
|
||||||
"//tensorflow/core/tpu/ops:host_compute_ops",
|
|
||||||
"//tensorflow/core/tpu/ops:topk_ops",
|
|
||||||
"//tensorflow/core/tpu/ops:tpu_compile_op",
|
|
||||||
"//tensorflow/core/tpu/ops:tpu_execute_op",
|
|
||||||
"//tensorflow/core/tpu/ops:tpu_partitioned_ops",
|
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
|
||||||
"//tensorflow/stream_executor/tpu:tpu_transfer_manager",
|
|
||||||
"//tensorflow/core/tpu:tpu_on_demand_compiler",
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -193,17 +176,12 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:session_options",
|
"//tensorflow/core:session_options",
|
||||||
"//tensorflow/core/common_runtime:copy_tensor",
|
|
||||||
"//tensorflow/core/common_runtime:device",
|
|
||||||
"//tensorflow/core/common_runtime:device_factory",
|
|
||||||
"//tensorflow/core/common_runtime:dma_helper",
|
|
||||||
"//tensorflow/core/framework:kernel_def_proto_cc",
|
|
||||||
"//tensorflow/core/lib/core:status",
|
|
||||||
"//tensorflow/core/platform:status",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_configuration_ops",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_util",
|
|
||||||
"//tensorflow/stream_executor/tpu:c_api_conversions",
|
"//tensorflow/stream_executor/tpu:c_api_conversions",
|
||||||
"//tensorflow/stream_executor/tpu:status_helper",
|
"//tensorflow/stream_executor/tpu:status_helper",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||||
@ -219,11 +197,10 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":virtual_device",
|
":virtual_device",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:session_options",
|
"//tensorflow/core:session_options",
|
||||||
"//tensorflow/core/common_runtime:device_factory",
|
|
||||||
"//tensorflow/core/lib/core:status",
|
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor_base",
|
"//tensorflow/stream_executor/tpu:tpu_executor_base",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -245,7 +222,6 @@ cc_library(
|
|||||||
hdrs = ["tpu_execute.h"],
|
hdrs = ["tpu_execute.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_api",
|
":tpu_api",
|
||||||
"//tensorflow/compiler/jit:xla_device",
|
|
||||||
"//tensorflow/compiler/xla:executable_run_options",
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//tensorflow/compiler/xla:shape_layout",
|
"//tensorflow/compiler/xla:shape_layout",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
@ -303,3 +279,20 @@ cc_library(
|
|||||||
],
|
],
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_runtime",
|
||||||
|
srcs = [],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":tpu_api_dlsym_initializer",
|
||||||
|
":tpu_compilation_device",
|
||||||
|
":tpu_node_device",
|
||||||
|
":tpu_system_device",
|
||||||
|
"//tensorflow/core/tpu:tpu_on_demand_compiler",
|
||||||
|
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
||||||
|
"//tensorflow/core/tpu/ops",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_transfer_manager",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -25,7 +25,14 @@ package(
|
|||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "kernels",
|
name = "kernels",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [":tpu_configuration_ops"],
|
deps = [
|
||||||
|
":cross_replica_ops",
|
||||||
|
":host_compute_ops",
|
||||||
|
":topk_ops",
|
||||||
|
":tpu_compile_op",
|
||||||
|
":tpu_configuration_ops",
|
||||||
|
":tpu_execute_op",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -347,7 +354,6 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:casts", # buildcleaner: keep
|
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"//tensorflow/core/tpu:tpu_api",
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
@ -383,8 +389,10 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service",
|
"//tensorflow/compiler/xla/service",
|
||||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:refcount",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||||
"@com_google_absl//absl/container:node_hash_map",
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
@ -398,7 +406,7 @@ cc_library(
|
|||||||
name = "tpu_compilation_metrics_hdrs",
|
name = "tpu_compilation_metrics_hdrs",
|
||||||
hdrs = ["tpu_compilation_metrics.h"],
|
hdrs = ["tpu_compilation_metrics.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -553,20 +561,20 @@ cc_library(
|
|||||||
DEFAULT: [],
|
DEFAULT: [],
|
||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
|
":tpu_compilation_cache_key",
|
||||||
|
":tpu_compile_c_api_hdrs",
|
||||||
|
":tpu_compile_op_common",
|
||||||
|
":tpu_compile_op_support",
|
||||||
|
":tpu_compile_proto_cc",
|
||||||
|
":tpu_mesh_state_c_api_hdrs",
|
||||||
|
":tpu_program_group",
|
||||||
|
":tpu_program_group_interface",
|
||||||
|
":tpu_util",
|
||||||
"//tensorflow/compiler/jit:shape_inference",
|
"//tensorflow/compiler/jit:shape_inference",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_key",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compile_op_common",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compile_op_support",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compile_proto_cc",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_program_group",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_program_group_interface",
|
|
||||||
"//tensorflow/core/tpu/kernels:tpu_util",
|
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
"@com_google_absl//absl/types:variant",
|
"@com_google_absl//absl/types:variant",
|
||||||
@ -612,7 +620,7 @@ cc_library(
|
|||||||
":tpu_compilation_cache_lookup",
|
":tpu_compilation_cache_lookup",
|
||||||
":tpu_executable_info_proto_cc",
|
":tpu_executable_info_proto_cc",
|
||||||
":tpu_op_consts",
|
":tpu_op_consts",
|
||||||
"//tensorflow/compiler/jit:xla_device",
|
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||||
"//tensorflow/compiler/jit:xla_launch_util",
|
"//tensorflow/compiler/jit:xla_launch_util",
|
||||||
"//tensorflow/compiler/jit:xla_tensor",
|
"//tensorflow/compiler/jit:xla_tensor",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
@ -628,12 +636,12 @@ cc_library(
|
|||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"//tensorflow/core/tpu:tpu_configuration",
|
"//tensorflow/core/tpu:tpu_configuration",
|
||||||
"//tensorflow/core/tpu:tpu_defs",
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
"//tensorflow/core/tpu:tpu_execute",
|
"//tensorflow/core/tpu:tpu_execute",
|
||||||
"//tensorflow/stream_executor:device_memory_allocator",
|
"//tensorflow/stream_executor",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
@ -654,7 +662,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:graph",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -40,7 +40,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
|
||||||
#include "tensorflow/core/platform/tracing.h"
|
#include "tensorflow/core/platform/tracing.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||||
@ -364,16 +363,16 @@ struct OutputBuffers {
|
|||||||
memory_allocator(allocator) {}
|
memory_allocator(allocator) {}
|
||||||
|
|
||||||
~OutputBuffers() {
|
~OutputBuffers() {
|
||||||
buffers.buffers().ForEachElement([&](const xla::ShapeIndex& index,
|
buffers.buffers().ForEachElement(
|
||||||
const se::DeviceMemoryBase& buffer) {
|
[&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
|
||||||
if (owned_buffers.element(index) && !buffer.is_null()) {
|
if (owned_buffers.element(index) && !buffer.is_null()) {
|
||||||
Status status =
|
Status status =
|
||||||
memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
|
memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Error deallocating buffer " << status;
|
LOG(ERROR) << "Error deallocating buffer " << status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Which of the buffers do we own?
|
// Which of the buffers do we own?
|
||||||
|
@ -3,12 +3,26 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "ops",
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
":host_compute_ops",
|
||||||
|
":topk_ops",
|
||||||
|
":tpu_compile_op",
|
||||||
|
":tpu_execute_op",
|
||||||
|
":tpu_partitioned_ops",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_partitioned_ops",
|
name = "tpu_partitioned_ops",
|
||||||
srcs = [
|
srcs = [
|
||||||
"tpu_partitioned_input_op.cc",
|
"tpu_partitioned_input_op.cc",
|
||||||
"tpu_partitioned_output_op.cc",
|
"tpu_partitioned_output_op.cc",
|
||||||
],
|
],
|
||||||
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
@ -22,6 +36,7 @@ cc_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"tpu_compile_op.cc",
|
"tpu_compile_op.cc",
|
||||||
],
|
],
|
||||||
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
@ -35,6 +50,7 @@ cc_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"tpu_execute_op.cc",
|
"tpu_execute_op.cc",
|
||||||
],
|
],
|
||||||
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
@ -48,6 +64,7 @@ cc_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"host_compute_ops.cc",
|
"host_compute_ops.cc",
|
||||||
],
|
],
|
||||||
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -61,6 +78,7 @@ cc_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"topk_ops.cc",
|
"topk_ops.cc",
|
||||||
],
|
],
|
||||||
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -30,17 +30,10 @@ REGISTER_OP("_HostComputeMlir")
|
|||||||
.Attr("Toutputs: list(type) >= 0")
|
.Attr("Toutputs: list(type) >= 0")
|
||||||
.Attr("key: string")
|
.Attr("key: string")
|
||||||
.Attr("tpu_core: int = 0")
|
.Attr("tpu_core: int = 0")
|
||||||
.SetIsStateful()
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
.Doc(R"doc(
|
return ::tensorflow::shape_inference::UnknownShape(c);
|
||||||
A host-side computation called from a TPU device.
|
})
|
||||||
|
.SetIsStateful();
|
||||||
inputs: A list of tensors that will be sent to the host.
|
|
||||||
outputs: A list of tensors that will be returned to the device.
|
|
||||||
Tinputs: The element types of each element in `inputs`.
|
|
||||||
Toutputs: The element types of each element in `outputs`.
|
|
||||||
key: A unique identifier for this region used to match up host transfers.
|
|
||||||
tpu_core: Default core to use for host to device transfers.
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
REGISTER_OP("XlaHostCompute")
|
REGISTER_OP("XlaHostCompute")
|
||||||
.Input("inputs: Tinputs")
|
.Input("inputs: Tinputs")
|
||||||
@ -78,36 +71,16 @@ REGISTER_OP("XlaHostCompute")
|
|||||||
// statically known.
|
// statically known.
|
||||||
return ::tensorflow::shape_inference::UnknownShape(c);
|
return ::tensorflow::shape_inference::UnknownShape(c);
|
||||||
}
|
}
|
||||||
})
|
});
|
||||||
.Doc(R"doc(
|
|
||||||
A pseudo-op to represent host-side computation in an XLA program.
|
|
||||||
|
|
||||||
inputs: A list of tensors that will be sent to the host.
|
|
||||||
outputs: A list of tensors that will be returned to the device.
|
|
||||||
Tinputs: The element types of each element in `inputs`.
|
|
||||||
Toutputs: The element types of each element in `outputs`.
|
|
||||||
ancestors: A list of names of HostCompute computations that must be
|
|
||||||
sequenced before this computation.
|
|
||||||
shape_inference_graph: If non-empty, a serialized GraphDef representing a graph
|
|
||||||
that must be analyzed at compile time to determine the shapes of the outputs.
|
|
||||||
shapes: If shape_inference_graph is empty, a list of the shapes of `outputs`.
|
|
||||||
key: A unique identifier for this region used to match up host transfers.
|
|
||||||
cost_estimate_ns: Estimated duration of the host computation in nanoseconds.
|
|
||||||
tpu_core: Default core to use for host to device transfers.
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
REGISTER_OP("XlaSendToHost")
|
REGISTER_OP("XlaSendToHost")
|
||||||
.Input("input: Tinput")
|
.Input("input: Tinput")
|
||||||
.Attr("Tinput: type")
|
.Attr("Tinput: type")
|
||||||
.Attr("key: string")
|
.Attr("key: string")
|
||||||
.SetIsStateful()
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
.Doc(R"doc(
|
return ::tensorflow::shape_inference::UnknownShape(c);
|
||||||
An op to send a tensor to the host.
|
})
|
||||||
|
.SetIsStateful();
|
||||||
input: the tensor that will be sent to the host.
|
|
||||||
Tinput: element type for input.
|
|
||||||
key: A unique identifier for this region used to match up host transfers.
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
REGISTER_OP("XlaRecvFromHost")
|
REGISTER_OP("XlaRecvFromHost")
|
||||||
.Output("output: Toutput")
|
.Output("output: Toutput")
|
||||||
@ -127,14 +100,6 @@ REGISTER_OP("XlaRecvFromHost")
|
|||||||
c->MakeShapeFromShapeProto(shape_attr->shape(), &handle));
|
c->MakeShapeFromShapeProto(shape_attr->shape(), &handle));
|
||||||
c->set_output(0, handle);
|
c->set_output(0, handle);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(R"doc(
|
|
||||||
An op to receive a tensor from the host.
|
|
||||||
|
|
||||||
output: the tensor that will be received from the host.
|
|
||||||
Toutput: element type for output.
|
|
||||||
shape: shape for output.
|
|
||||||
key: A unique identifier for this region used to match up host transfers.
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -33,24 +33,7 @@ REGISTER_OP("KthOrderStatistic")
|
|||||||
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
|
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
|
||||||
c->set_output(0, s);
|
c->set_output(0, s);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(R"doc(
|
|
||||||
Computes the Kth order statistic of a data set. The current
|
|
||||||
implementation uses a binary search requiring exactly 32 passes over
|
|
||||||
the input data. The running time is linear with respect to input
|
|
||||||
size. The median-of-medians algorithm is probably faster, but is
|
|
||||||
difficult to implement efficiently in XLA. The implementation imposes
|
|
||||||
a total ordering on floats. The ordering is consistent with the usual
|
|
||||||
partial order. Positive NaNs are greater than positive
|
|
||||||
infinity. Negative NaNs are less than negative infinity. NaNs with
|
|
||||||
distinct payloads are treated as distinct. Subnormal numbers are
|
|
||||||
preserved (not flushed to zero). Positive infinity is greater than all
|
|
||||||
numbers. Negative infinity is less than all numbers. Positive is
|
|
||||||
greater than negative zero. There are less than k values greater than
|
|
||||||
the kth order statistic. There are at least k values greater than or
|
|
||||||
equal to the Kth order statistic. The semantics are not the same as
|
|
||||||
top_k_unique.
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
REGISTER_OP("TopKUnique")
|
REGISTER_OP("TopKUnique")
|
||||||
.Input("input: float32")
|
.Input("input: float32")
|
||||||
@ -69,22 +52,7 @@ REGISTER_OP("TopKUnique")
|
|||||||
c->set_output(0, s);
|
c->set_output(0, s);
|
||||||
c->set_output(1, s);
|
c->set_output(1, s);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(R"doc(
|
|
||||||
Returns the TopK unique values in the array in sorted order. The
|
|
||||||
running time is proportional to the product of K and the input
|
|
||||||
size. Sorting the whole array is more efficient for sufficiently large
|
|
||||||
values of K. The median-of-medians algorithm is probably faster, but
|
|
||||||
difficult to implement efficiently in XLA. If there are fewer than K
|
|
||||||
unique numbers (not NANs), the results are padded with negative
|
|
||||||
infinity. NaNs are never returned. Subnormal numbers are flushed to
|
|
||||||
zero. If an element appears at multiple indices, the highest index is
|
|
||||||
returned. If a TopK element never appears in the input due to padding
|
|
||||||
values, the indices are padded with negative one. If a padding value
|
|
||||||
appears in the input and padding is needed, the highest index of the
|
|
||||||
padding value will be returned. The semantics are not the same as
|
|
||||||
kth_order_statistic.
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
REGISTER_OP("MakeUnique")
|
REGISTER_OP("MakeUnique")
|
||||||
.Input("input: float32")
|
.Input("input: float32")
|
||||||
@ -94,14 +62,7 @@ REGISTER_OP("MakeUnique")
|
|||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
|
||||||
c->set_output(0, input);
|
c->set_output(0, input);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(R"doc(
|
|
||||||
Make all elements in the non-Batch dimension unique, but \"close\" to
|
|
||||||
their initial value. Never returns a sub-normal number. Never returns
|
|
||||||
zero. The sign of each input element is always identical to the sign
|
|
||||||
of the corresponding output element. Behavior for infinite elements is
|
|
||||||
undefined. Behavior for subnormal elements is undefined.
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
REGISTER_OP("TopKWithUnique")
|
REGISTER_OP("TopKWithUnique")
|
||||||
.Input("input: float32")
|
.Input("input: float32")
|
||||||
@ -120,11 +81,5 @@ REGISTER_OP("TopKWithUnique")
|
|||||||
c->set_output(0, s);
|
c->set_output(0, s);
|
||||||
c->set_output(1, s);
|
c->set_output(1, s);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(R"doc(
|
|
||||||
Returns the TopK values in the array in sorted order. This is a combination
|
|
||||||
of MakeUnique and TopKUnique. The returned top-K will have its lower bits
|
|
||||||
replaced by iota, thus it will be close to the original value but not exactly
|
|
||||||
the same. The running time is proportional to the product of K and the input
|
|
||||||
size. NaNs are never returned. Subnormal numbers are flushed to zero.)doc");
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -43,23 +43,7 @@ REGISTER_OP("_TPUCompileMlir")
|
|||||||
c->set_output(i + 1, c->Vector(2));
|
c->set_output(i + 1, c->Vector(2));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(
|
|
||||||
R"(
|
|
||||||
Compiles a computations for execution on one or more TPU devices.
|
|
||||||
For the internal use of the distributed TPU compiler. Note that currently only
|
|
||||||
single TPU device is supported.
|
|
||||||
|
|
||||||
'mlir_module' is a serialized MLIR module with a `main` function that contains
|
|
||||||
target computation.
|
|
||||||
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
|
|
||||||
known statically at TPUReplication rewrite time.
|
|
||||||
'metadata' is a serialized TPUCompileMetadataProto describing
|
|
||||||
the shapes and types of the inputs to the computation, as well as a mapping onto
|
|
||||||
the TPU pod topology.
|
|
||||||
'program' output is a string key that is passed to the _TPUExecute op and
|
|
||||||
used to look up the program in the compilation cache.
|
|
||||||
)");
|
|
||||||
|
|
||||||
REGISTER_OP("TPUCompile")
|
REGISTER_OP("TPUCompile")
|
||||||
.Attr("num_computations: int >= 0")
|
.Attr("num_computations: int >= 0")
|
||||||
@ -91,39 +75,13 @@ REGISTER_OP("TPUCompile")
|
|||||||
c->set_output(num_computations + i + 1, c->Scalar());
|
c->set_output(num_computations + i + 1, c->Scalar());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(
|
|
||||||
R"(
|
|
||||||
Compiles a computations for execution on one or more TPU devices.
|
|
||||||
For the internal use of the distributed TPU compiler.
|
|
||||||
|
|
||||||
'num_computations' is the number of computations to be compiled.
|
|
||||||
'function' is a function containing the computation to compile.
|
|
||||||
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
|
|
||||||
known statically at TPUReplication rewrite time.
|
|
||||||
'guaranteed_constants' is a list of tensors which have been guaranteed to not
|
|
||||||
change their values during the session lifetime. These contain tensors marked as
|
|
||||||
constant using the GuaranteeConstOp.
|
|
||||||
'metadata' is a serialized TPUCompileMetadataProto describing
|
|
||||||
the shapes and types of the inputs to the computation, as well as a mapping onto
|
|
||||||
the TPU pod topology.
|
|
||||||
Each 'program' output is a string key that is passed to the _TPUExecute op and
|
|
||||||
used to look up the program in the compilation cache.
|
|
||||||
'may_modify_variables' indicates whether variables may be modified.
|
|
||||||
)");
|
|
||||||
|
|
||||||
REGISTER_OP("TPUCompileSucceededAssert")
|
REGISTER_OP("TPUCompileSucceededAssert")
|
||||||
.Input("compilation_status: string")
|
.Input("compilation_status: string")
|
||||||
// Do not optimize me away. Read the comment on TPUCompileOp for more
|
// Do not optimize me away. Read the comment on TPUCompileOp for more
|
||||||
// details.
|
// details.
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(shape_inference::NoOutputs)
|
.SetShapeFn(shape_inference::NoOutputs);
|
||||||
.Doc(
|
|
||||||
R"(
|
|
||||||
Asserts that compilation succeeded. This op produces no output and closes the
|
|
||||||
device during failure to ensure all pending device interactions fail.
|
|
||||||
|
|
||||||
'compilation_status' is a serialized CompilationResultProto.
|
|
||||||
)");
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -35,10 +35,7 @@ REGISTER_OP("TPUExecute")
|
|||||||
c->set_output(i, c->UnknownShape());
|
c->set_output(i, c->UnknownShape());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(R"(
|
|
||||||
Op that loads and executes a TPU program on a TPU device.
|
|
||||||
For the internal use of the distributed TPU compiler.)");
|
|
||||||
|
|
||||||
REGISTER_OP("TPUExecuteAndUpdateVariables")
|
REGISTER_OP("TPUExecuteAndUpdateVariables")
|
||||||
.Input("args: Targs")
|
.Input("args: Targs")
|
||||||
@ -58,14 +55,6 @@ REGISTER_OP("TPUExecuteAndUpdateVariables")
|
|||||||
c->set_output(i, c->UnknownShape());
|
c->set_output(i, c->UnknownShape());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(R"(Op that executes a program with optional in-place variable updates.
|
|
||||||
It (optionally) reads device variables, loads and executes a TPU program on a
|
|
||||||
TPU device, and then (optionally) in-place updates variables using the program
|
|
||||||
outputs, as specified in attributes device_var_reads_indices (program input
|
|
||||||
indices from directly reading variables) and device_var_updates_indices (program
|
|
||||||
output indices used to update variables, -1 means no-update/read-only). Such
|
|
||||||
program outputs are consumed by these variables will not appear in the op
|
|
||||||
output. For the internal use of the distributed TPU compiler.)");
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -94,14 +94,6 @@ REGISTER_OP("TPUPartitionedInput")
|
|||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(R"doc(
|
|
||||||
An op that groups a list of partitioned inputs together. This op
|
|
||||||
|
|
||||||
inputs: A list of partitioned inputs which must have the same shape.
|
|
||||||
output: A handle which represents the full shape of partitioned tensors.
|
|
||||||
partition_dim: An integer describles which dimension is partitioned. -1 means
|
|
||||||
those inputs are replicated.
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -23,7 +23,6 @@ namespace tensorflow {
|
|||||||
using shape_inference::InferenceContext;
|
using shape_inference::InferenceContext;
|
||||||
using shape_inference::ShapeHandle;
|
using shape_inference::ShapeHandle;
|
||||||
|
|
||||||
|
|
||||||
REGISTER_OP("TPUPartitionedOutput")
|
REGISTER_OP("TPUPartitionedOutput")
|
||||||
.Input("inputs: T")
|
.Input("inputs: T")
|
||||||
.Output("output: num_splits * T")
|
.Output("output: num_splits * T")
|
||||||
@ -53,14 +52,6 @@ REGISTER_OP("TPUPartitionedOutput")
|
|||||||
c->set_output(i, newoutput0);
|
c->set_output(i, newoutput0);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
});
|
||||||
.Doc(R"doc(
|
|
||||||
An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned
|
|
||||||
outputs outside the XLA computation.
|
|
||||||
|
|
||||||
inputs: A tensor which represents the full shape of partitioned tensors.
|
|
||||||
output: A list of partitioned inputs which must have the same shape.
|
|
||||||
partition_dim: An integer describles which dimension is partitioned.
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/base/casts.h"
|
#include "absl/base/casts.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/jit/xla_device.h"
|
|
||||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
@ -419,10 +418,6 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
|
|||||||
|
|
||||||
xla::Backend* backend = node_context->backend();
|
xla::Backend* backend = node_context->backend();
|
||||||
|
|
||||||
XlaDevice* device =
|
|
||||||
tensorflow::down_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
|
|
||||||
TF_RET_CHECK(device);
|
|
||||||
|
|
||||||
// Create a HostTransferManager to handle Send/Recv operations from the TPU.
|
// Create a HostTransferManager to handle Send/Recv operations from the TPU.
|
||||||
std::shared_ptr<HostTransferManager> host_transfer_manager =
|
std::shared_ptr<HostTransferManager> host_transfer_manager =
|
||||||
std::make_shared<HostTransferManager>(node_context, backend);
|
std::make_shared<HostTransferManager>(node_context, backend);
|
||||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/graph/types.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
#include "tensorflow/core/tpu/virtual_device.h"
|
#include "tensorflow/core/tpu/virtual_device.h"
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# ":platform" - Low-level and platform-specific Python code.
|
# ":platform" - Low-level and platform-specific Python code.
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "py_strict_library")
|
load("//tensorflow:tensorflow.bzl", "py_strict_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_tpu", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cc_test", "tf_cuda_library", "tf_enable_mlir_bridge", "tf_gen_op_wrapper_py")
|
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cc_test", "tf_cuda_library", "tf_enable_mlir_bridge", "tf_gen_op_wrapper_py")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_monitoring_python_deps")
|
load("//tensorflow:tensorflow.bzl", "tf_monitoring_python_deps")
|
||||||
@ -6093,8 +6093,6 @@ pywrap_tensorflow_macro(
|
|||||||
"@ngraph_tf//:ngraph_tf",
|
"@ngraph_tf//:ngraph_tf",
|
||||||
]) + if_xla_available([
|
]) + if_xla_available([
|
||||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||||
]) + if_tpu([
|
|
||||||
"//tensorflow/core/tpu:tpu_api_dlsym_initializer",
|
|
||||||
]) + if_static(extra_deps = ["//tensorflow/core/platform:tf32_utils"]),
|
]) + if_static(extra_deps = ["//tensorflow/core/platform:tf32_utils"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":c_api_decl",
|
":c_api_decl",
|
||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
"//tensorflow/core/platform:status",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/tpu:tpu_api",
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -103,9 +103,7 @@ cc_library(
|
|||||||
":tpu_executor_interface",
|
":tpu_executor_interface",
|
||||||
":tpu_platform_interface",
|
":tpu_platform_interface",
|
||||||
":tpu_stream_interface",
|
":tpu_stream_interface",
|
||||||
"//tensorflow/core/platform:casts",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:mutex",
|
|
||||||
"//tensorflow/core/platform:types",
|
|
||||||
"//tensorflow/core/tpu:tpu_api",
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
"//tensorflow/stream_executor",
|
"//tensorflow/stream_executor",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
@ -131,8 +129,6 @@ cc_library(
|
|||||||
":status_helper",
|
":status_helper",
|
||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:mutex",
|
|
||||||
"//tensorflow/core/platform:types",
|
|
||||||
"//tensorflow/core/tpu:tpu_api",
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
"//tensorflow/stream_executor",
|
"//tensorflow/stream_executor",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
@ -164,8 +160,6 @@ cc_library(
|
|||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:mutex",
|
|
||||||
"//tensorflow/core/platform:types",
|
|
||||||
"//tensorflow/core/tpu:tpu_api",
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
"//tensorflow/stream_executor",
|
"//tensorflow/stream_executor",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
@ -274,10 +268,8 @@ cc_library(
|
|||||||
hdrs = ["tpu_platform_interface.h"],
|
hdrs = ["tpu_platform_interface.h"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core/platform:mutex",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/stream_executor",
|
||||||
"//tensorflow/stream_executor:multi_platform_manager",
|
|
||||||
"//tensorflow/stream_executor:stream_executor_headers",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -334,7 +326,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":c_api_decl",
|
":c_api_decl",
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/tpu:tpu_api",
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||||
|
|
||||||
namespace ApiConverter {
|
namespace ApiConverter {
|
||||||
|
|
||||||
xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) {
|
xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) {
|
||||||
xla::Shape xla_on_host_shape = ApiConverter::FromC(&c_buffer->on_host_shape);
|
xla::Shape xla_on_host_shape = ApiConverter::FromC(&c_buffer->on_host_shape);
|
||||||
xla::Shape xla_on_device_shape =
|
xla::Shape xla_on_device_shape =
|
||||||
@ -114,6 +115,7 @@ SE_DeviceMemoryAllocator ToC(
|
|||||||
};
|
};
|
||||||
return se_allocator;
|
return se_allocator;
|
||||||
}
|
}
|
||||||
|
|
||||||
SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceMemory* mem) {
|
SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceMemory* mem) {
|
||||||
SE_MaybeOwningDeviceMemory se_mem;
|
SE_MaybeOwningDeviceMemory se_mem;
|
||||||
se_mem.device_ordinal = mem->device_ordinal();
|
se_mem.device_ordinal = mem->device_ordinal();
|
||||||
|
@ -349,7 +349,6 @@ struct TfTpu_ExecutorApiFn {
|
|||||||
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_New);
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_New);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Free);
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Free);
|
||||||
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunHloPasses);
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunHloPasses);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunBackend);
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunBackend);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile);
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile);
|
||||||
|
Loading…
Reference in New Issue
Block a user