Adjust TPU build dependencies.
PiperOrigin-RevId: 323453889 Change-Id: I6b33b57830d3414d45aaa066735c3b52f217e739
This commit is contained in:
parent
20cd718248
commit
9872ddd15a
@ -225,6 +225,8 @@ cc_library(
|
||||
"xla_device_context.h",
|
||||
"xla_device_ops.h",
|
||||
],
|
||||
# Public visibility is needed for external TF/XLA backends.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = XLA_DEVICE_DEPS,
|
||||
)
|
||||
|
||||
|
@ -295,19 +295,6 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
|
||||
<< diff << "\nActual: " << actual.DebugString(); \
|
||||
} while (false)
|
||||
|
||||
// These dummy Op registrations are here because the real Op registrations live
|
||||
// in contrib and there can't be a dependence from this test to contrib.
|
||||
REGISTER_OP("XlaHostCompute")
|
||||
.Input("inputs: Tinputs")
|
||||
.Output("outputs: Toutputs")
|
||||
.Attr("Tinputs: list(type) >= 0")
|
||||
.Attr("Toutputs: list(type) >= 0")
|
||||
.Attr("ancestors: list(string) >= 0")
|
||||
.Attr("key: string")
|
||||
.Attr("shape_inference_graph: func")
|
||||
.Attr("shapes: list(shape) >= 0")
|
||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("InputTest")
|
||||
.Output("o: float")
|
||||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||
@ -947,6 +934,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", shape_inference_graph},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -1114,6 +1103,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O2"},
|
||||
{"shape_inference_graph", shape_inference_graph2},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -1130,6 +1121,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", shape_inference_graph1},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -1266,6 +1259,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", NameAttrList()},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes",
|
||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
@ -1295,6 +1290,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F2_F2_O1"},
|
||||
{"shape_inference_graph", NameAttrList()},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes",
|
||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
@ -1428,6 +1425,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", NameAttrList()},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes",
|
||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
@ -1454,6 +1453,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F2_F2_O1"},
|
||||
{"shape_inference_graph", NameAttrList()},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes",
|
||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
@ -1566,6 +1567,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", NameAttrList()},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes",
|
||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
@ -1658,6 +1661,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", NameAttrList()},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes",
|
||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
@ -1765,6 +1770,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", shape_inference_graph},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -1875,6 +1882,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", shape_inference_graph},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -2009,6 +2018,8 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", shape_inference_graph1},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -2023,6 +2034,8 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O2"},
|
||||
{"shape_inference_graph", shape_inference_graph2},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -2153,6 +2166,8 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O2"},
|
||||
{"shape_inference_graph", NameAttrList()},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -2169,6 +2184,8 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", shape_inference_graph},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -2296,6 +2313,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", shape_inference_graph},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -2310,6 +2329,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O2"},
|
||||
{"shape_inference_graph", NameAttrList()},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -2325,6 +2346,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O3"},
|
||||
{"shape_inference_graph", NameAttrList()},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O3"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -2451,6 +2474,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", shape_inference_graph},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
{"_xla_token_input_nodes",
|
||||
@ -2567,6 +2592,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_F1_O1"},
|
||||
{"shape_inference_graph", shape_inference_graph},
|
||||
{"tpu_core", 0},
|
||||
{"cost_estimate_ns", 1000000},
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O1"},
|
||||
{"_xla_token_input_nodes",
|
||||
|
@ -2420,6 +2420,7 @@ Status ExtractOutsideCompilationForFunction(
|
||||
auto updated_fdef = absl::make_unique<FunctionDef>();
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
|
||||
updated_fdef->mutable_signature()->set_is_stateful(true);
|
||||
const FunctionDef* original_fdef = fld->Find(func_name);
|
||||
if (original_fdef) {
|
||||
for (const auto& attr : original_fdef->attr()) {
|
||||
|
@ -422,19 +422,6 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) {
|
||||
EXPECT_EQ(fld.Find("host_graph"), nullptr);
|
||||
}
|
||||
|
||||
REGISTER_OP("XlaSendToHost")
|
||||
.Input("input: Tinput")
|
||||
.Attr("Tinput: type")
|
||||
.Attr("key: string")
|
||||
.SetIsStateful();
|
||||
|
||||
REGISTER_OP("XlaRecvFromHost")
|
||||
.Output("output: Toutput")
|
||||
.Attr("Toutput: type")
|
||||
.Attr("shape: shape")
|
||||
.Attr("key: string")
|
||||
.SetIsStateful();
|
||||
|
||||
TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) {
|
||||
// Build the XLA computation func.
|
||||
// "const0" (bool)
|
||||
|
@ -880,6 +880,7 @@ cc_library(
|
||||
":tpu_outfeed_ops_op_lib",
|
||||
":tpu_ordinal_selector_ops_op_lib",
|
||||
":tpu_replication_ops_op_lib",
|
||||
"//tensorflow/core/tpu/ops",
|
||||
],
|
||||
) + if_mkl([
|
||||
":mkl_array_ops_op_lib",
|
||||
|
@ -0,0 +1,3 @@
|
||||
op {
|
||||
graph_op_name: "DataServiceDataset"
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
op {
|
||||
graph_op_name: "KthOrderStatistic"
|
||||
summary: "Computes the Kth order statistic of a data set. The current"
|
||||
description: <<END
|
||||
implementation uses a binary search requiring exactly 32 passes over
|
||||
the input data. The running time is linear with respect to input
|
||||
size. The median-of-medians algorithm is probably faster, but is
|
||||
difficult to implement efficiently in XLA. The implementation imposes
|
||||
a total ordering on floats. The ordering is consistent with the usual
|
||||
partial order. Positive NaNs are greater than positive
|
||||
infinity. Negative NaNs are less than negative infinity. NaNs with
|
||||
distinct payloads are treated as distinct. Subnormal numbers are
|
||||
preserved (not flushed to zero). Positive infinity is greater than all
|
||||
numbers. Negative infinity is less than all numbers. Positive is
|
||||
greater than negative zero. There are less than k values greater than
|
||||
the kth order statistic. There are at least k values greater than or
|
||||
equal to the Kth order statistic. The semantics are not the same as
|
||||
top_k_unique.
|
||||
END
|
||||
}
|
10
tensorflow/core/api_def/base_api/api_def_MakeUnique.pbtxt
Normal file
10
tensorflow/core/api_def/base_api/api_def_MakeUnique.pbtxt
Normal file
@ -0,0 +1,10 @@
|
||||
op {
|
||||
graph_op_name: "MakeUnique"
|
||||
summary: "Make all elements in the non-Batch dimension unique, but \\\"close\\\" to"
|
||||
description: <<END
|
||||
their initial value. Never returns a sub-normal number. Never returns
|
||||
zero. The sign of each input element is always identical to the sign
|
||||
of the corresponding output element. Behavior for infinite elements is
|
||||
undefined. Behavior for subnormal elements is undefined.
|
||||
END
|
||||
}
|
21
tensorflow/core/api_def/base_api/api_def_TPUCompile.pbtxt
Normal file
21
tensorflow/core/api_def/base_api/api_def_TPUCompile.pbtxt
Normal file
@ -0,0 +1,21 @@
|
||||
op {
|
||||
graph_op_name: "TPUCompile"
|
||||
summary: "Compiles a computations for execution on one or more TPU devices."
|
||||
description: <<END
|
||||
For the internal use of the distributed TPU compiler.
|
||||
|
||||
'num_computations' is the number of computations to be compiled.
|
||||
'function' is a function containing the computation to compile.
|
||||
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
|
||||
known statically at TPUReplication rewrite time.
|
||||
'guaranteed_constants' is a list of tensors which have been guaranteed to not
|
||||
change their values during the session lifetime. These contain tensors marked as
|
||||
constant using the GuaranteeConstOp.
|
||||
'metadata' is a serialized TPUCompileMetadataProto describing
|
||||
the shapes and types of the inputs to the computation, as well as a mapping onto
|
||||
the TPU pod topology.
|
||||
Each 'program' output is a string key that is passed to the _TPUExecute op and
|
||||
used to look up the program in the compilation cache.
|
||||
'may_modify_variables' indicates whether variables may be modified.
|
||||
END
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
op {
|
||||
graph_op_name: "TPUCompileSucceededAssert"
|
||||
summary: "Asserts that compilation succeeded. This op produces no output and closes the"
|
||||
description: <<END
|
||||
device during failure to ensure all pending device interactions fail.
|
||||
|
||||
'compilation_status' is a serialized CompilationResultProto.
|
||||
END
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
op {
|
||||
graph_op_name: "TPUExecute"
|
||||
summary: "Op that loads and executes a TPU program on a TPU device."
|
||||
description: <<END
|
||||
For the internal use of the distributed TPU compiler.
|
||||
END
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
op {
|
||||
graph_op_name: "TPUExecuteAndUpdateVariables"
|
||||
summary: "Op that executes a program with optional in-place variable updates."
|
||||
description: <<END
|
||||
It (optionally) reads device variables, loads and executes a TPU program on a
|
||||
TPU device, and then (optionally) in-place updates variables using the program
|
||||
outputs, as specified in attributes device_var_reads_indices (program input
|
||||
indices from directly reading variables) and device_var_updates_indices (program
|
||||
output indices used to update variables, -1 means no-update/read-only). Such
|
||||
program outputs are consumed by these variables will not appear in the op
|
||||
output. For the internal use of the distributed TPU compiler.
|
||||
END
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
op {
|
||||
graph_op_name: "TPUPartitionedInput"
|
||||
in_arg {
|
||||
name: "inputs"
|
||||
description: <<END
|
||||
A list of partitioned inputs which must have the same shape.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A handle which represents the full shape of partitioned tensors.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "partition_dim"
|
||||
description: <<END
|
||||
An integer describles which dimension is partitioned. -1 means
|
||||
those inputs are replicated.
|
||||
END
|
||||
}
|
||||
summary: "An op that groups a list of partitioned inputs together. This op"
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
op {
|
||||
graph_op_name: "TPUPartitionedOutput"
|
||||
in_arg {
|
||||
name: "inputs"
|
||||
description: <<END
|
||||
A tensor which represents the full shape of partitioned tensors.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A list of partitioned inputs which must have the same shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "partition_dim"
|
||||
description: <<END
|
||||
An integer describles which dimension is partitioned.
|
||||
END
|
||||
}
|
||||
summary: "An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned"
|
||||
description: <<END
|
||||
outputs outside the XLA computation.
|
||||
END
|
||||
}
|
18
tensorflow/core/api_def/base_api/api_def_TopKUnique.pbtxt
Normal file
18
tensorflow/core/api_def/base_api/api_def_TopKUnique.pbtxt
Normal file
@ -0,0 +1,18 @@
|
||||
op {
|
||||
graph_op_name: "TopKUnique"
|
||||
summary: "Returns the TopK unique values in the array in sorted order. The"
|
||||
description: <<END
|
||||
running time is proportional to the product of K and the input
|
||||
size. Sorting the whole array is more efficient for sufficiently large
|
||||
values of K. The median-of-medians algorithm is probably faster, but
|
||||
difficult to implement efficiently in XLA. If there are fewer than K
|
||||
unique numbers (not NANs), the results are padded with negative
|
||||
infinity. NaNs are never returned. Subnormal numbers are flushed to
|
||||
zero. If an element appears at multiple indices, the highest index is
|
||||
returned. If a TopK element never appears in the input due to padding
|
||||
values, the indices are padded with negative one. If a padding value
|
||||
appears in the input and padding is needed, the highest index of the
|
||||
padding value will be returned. The semantics are not the same as
|
||||
kth_order_statistic.
|
||||
END
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
op {
|
||||
graph_op_name: "TopKWithUnique"
|
||||
summary: "Returns the TopK values in the array in sorted order. This is a combination"
|
||||
description: <<END
|
||||
of MakeUnique and TopKUnique. The returned top-K will have its lower bits
|
||||
replaced by iota, thus it will be close to the original value but not exactly
|
||||
the same. The running time is proportional to the product of K and the input
|
||||
size. NaNs are never returned. Subnormal numbers are flushed to zero.
|
||||
END
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
op {
|
||||
graph_op_name: "XlaHostCompute"
|
||||
in_arg {
|
||||
name: "inputs"
|
||||
description: <<END
|
||||
A list of tensors that will be sent to the host.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "outputs"
|
||||
description: <<END
|
||||
A list of tensors that will be returned to the device.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "Tinputs"
|
||||
description: <<END
|
||||
The element types of each element in `inputs`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "Toutputs"
|
||||
description: <<END
|
||||
The element types of each element in `outputs`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "ancestors"
|
||||
description: <<END
|
||||
A list of names of HostCompute computations that must be
|
||||
sequenced before this computation.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "shapes"
|
||||
description: <<END
|
||||
If shape_inference_graph is empty, a list of the shapes of `outputs`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "shape_inference_graph"
|
||||
description: <<END
|
||||
If non-empty, a serialized GraphDef representing a graph
|
||||
that must be analyzed at compile time to determine the shapes of the outputs.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "key"
|
||||
description: <<END
|
||||
A unique identifier for this region used to match up host transfers.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "cost_estimate_ns"
|
||||
description: <<END
|
||||
Estimated duration of the host computation in nanoseconds.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "tpu_core"
|
||||
description: <<END
|
||||
Default core to use for host to device transfers.
|
||||
END
|
||||
}
|
||||
summary: "A pseudo-op to represent host-side computation in an XLA program."
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
op {
|
||||
graph_op_name: "XlaRecvFromHost"
|
||||
}
|
12
tensorflow/core/api_def/base_api/api_def_XlaSendToHost.pbtxt
Normal file
12
tensorflow/core/api_def/base_api/api_def_XlaSendToHost.pbtxt
Normal file
@ -0,0 +1,12 @@
|
||||
op {
|
||||
graph_op_name: "XlaSendToHost"
|
||||
in_arg: {
|
||||
name: "input"
|
||||
}
|
||||
attr {
|
||||
name: "Tinput"
|
||||
}
|
||||
attr {
|
||||
name: "key"
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"cc_header_only_library",
|
||||
"if_tpu",
|
||||
"tf_cc_test",
|
||||
"tf_cc_test_mkl",
|
||||
"tf_cc_tests",
|
||||
@ -91,7 +92,7 @@ cc_library(
|
||||
":core_cpu",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime",
|
||||
"//tensorflow/core/common_runtime/sycl:sycl_runtime",
|
||||
],
|
||||
] + if_tpu(["//tensorflow/core/tpu:tpu_runtime"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -78,8 +78,7 @@ cc_library(
|
||||
srcs = ["tpu_compile_interface.cc"],
|
||||
hdrs = ["tpu_compile_interface.h"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:fingerprint",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -146,8 +145,7 @@ cc_library(
|
||||
":tpu_api",
|
||||
":tpu_config_c_api",
|
||||
":tpu_library_init_fns",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
|
||||
@ -155,22 +153,7 @@ cc_library(
|
||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
|
||||
] + select({
|
||||
"//tensorflow:oss": [
|
||||
":tpu_compilation_device",
|
||||
":tpu_node_device",
|
||||
":tpu_system_device",
|
||||
"//tensorflow/core/tpu/ops:host_compute_ops",
|
||||
"//tensorflow/core/tpu/ops:topk_ops",
|
||||
"//tensorflow/core/tpu/ops:tpu_compile_op",
|
||||
"//tensorflow/core/tpu/ops:tpu_execute_op",
|
||||
"//tensorflow/core/tpu/ops:tpu_partitioned_ops",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||
"//tensorflow/stream_executor/tpu:tpu_transfer_manager",
|
||||
"//tensorflow/core/tpu:tpu_on_demand_compiler",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -193,17 +176,12 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/core/common_runtime:copy_tensor",
|
||||
"//tensorflow/core/common_runtime:device",
|
||||
"//tensorflow/core/common_runtime:device_factory",
|
||||
"//tensorflow/core/common_runtime:dma_helper",
|
||||
"//tensorflow/core/framework:kernel_def_proto_cc",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/tpu/kernels:tpu_configuration_ops",
|
||||
"//tensorflow/core/tpu/kernels:tpu_util",
|
||||
"//tensorflow/stream_executor/tpu:c_api_conversions",
|
||||
"//tensorflow/stream_executor/tpu:status_helper",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
@ -219,11 +197,10 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":virtual_device",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/core/common_runtime:device_factory",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_base",
|
||||
],
|
||||
)
|
||||
@ -245,7 +222,6 @@ cc_library(
|
||||
hdrs = ["tpu_execute.h"],
|
||||
deps = [
|
||||
":tpu_api",
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -303,3 +279,20 @@ cc_library(
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_runtime",
|
||||
srcs = [],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tpu_api_dlsym_initializer",
|
||||
":tpu_compilation_device",
|
||||
":tpu_node_device",
|
||||
":tpu_system_device",
|
||||
"//tensorflow/core/tpu:tpu_on_demand_compiler",
|
||||
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
||||
"//tensorflow/core/tpu/ops",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||
"//tensorflow/stream_executor/tpu:tpu_transfer_manager",
|
||||
],
|
||||
)
|
||||
|
@ -25,7 +25,14 @@ package(
|
||||
tf_kernel_library(
|
||||
name = "kernels",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":tpu_configuration_ops"],
|
||||
deps = [
|
||||
":cross_replica_ops",
|
||||
":host_compute_ops",
|
||||
":topk_ops",
|
||||
":tpu_compile_op",
|
||||
":tpu_configuration_ops",
|
||||
":tpu_execute_op",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -347,7 +354,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:casts", # buildcleaner: keep
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
@ -383,8 +389,10 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
@ -398,7 +406,7 @@ cc_library(
|
||||
name = "tpu_compilation_metrics_hdrs",
|
||||
hdrs = ["tpu_compilation_metrics.h"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -553,20 +561,20 @@ cc_library(
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_compile_op_common",
|
||||
":tpu_compile_op_support",
|
||||
":tpu_compile_proto_cc",
|
||||
":tpu_mesh_state_c_api_hdrs",
|
||||
":tpu_program_group",
|
||||
":tpu_program_group_interface",
|
||||
":tpu_util",
|
||||
"//tensorflow/compiler/jit:shape_inference",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_key",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_op_common",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_op_support",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_proto_cc",
|
||||
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_program_group",
|
||||
"//tensorflow/core/tpu/kernels:tpu_program_group_interface",
|
||||
"//tensorflow/core/tpu/kernels:tpu_util",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
@ -612,7 +620,7 @@ cc_library(
|
||||
":tpu_compilation_cache_lookup",
|
||||
":tpu_executable_info_proto_cc",
|
||||
":tpu_op_consts",
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/jit:xla_tensor",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
@ -628,12 +636,12 @@ cc_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/core/tpu:tpu_execute",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -654,7 +662,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -40,7 +40,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
@ -364,16 +363,16 @@ struct OutputBuffers {
|
||||
memory_allocator(allocator) {}
|
||||
|
||||
~OutputBuffers() {
|
||||
buffers.buffers().ForEachElement([&](const xla::ShapeIndex& index,
|
||||
const se::DeviceMemoryBase& buffer) {
|
||||
if (owned_buffers.element(index) && !buffer.is_null()) {
|
||||
Status status =
|
||||
memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Error deallocating buffer " << status;
|
||||
}
|
||||
}
|
||||
});
|
||||
buffers.buffers().ForEachElement(
|
||||
[&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
|
||||
if (owned_buffers.element(index) && !buffer.is_null()) {
|
||||
Status status =
|
||||
memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Error deallocating buffer " << status;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Which of the buffers do we own?
|
||||
|
@ -3,12 +3,26 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":host_compute_ops",
|
||||
":topk_ops",
|
||||
":tpu_compile_op",
|
||||
":tpu_execute_op",
|
||||
":tpu_partitioned_ops",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_partitioned_ops",
|
||||
srcs = [
|
||||
"tpu_partitioned_input_op.cc",
|
||||
"tpu_partitioned_output_op.cc",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
@ -22,6 +36,7 @@ cc_library(
|
||||
srcs = [
|
||||
"tpu_compile_op.cc",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
@ -35,6 +50,7 @@ cc_library(
|
||||
srcs = [
|
||||
"tpu_execute_op.cc",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
@ -48,6 +64,7 @@ cc_library(
|
||||
srcs = [
|
||||
"host_compute_ops.cc",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -61,6 +78,7 @@ cc_library(
|
||||
srcs = [
|
||||
"topk_ops.cc",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -30,17 +30,10 @@ REGISTER_OP("_HostComputeMlir")
|
||||
.Attr("Toutputs: list(type) >= 0")
|
||||
.Attr("key: string")
|
||||
.Attr("tpu_core: int = 0")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
A host-side computation called from a TPU device.
|
||||
|
||||
inputs: A list of tensors that will be sent to the host.
|
||||
outputs: A list of tensors that will be returned to the device.
|
||||
Tinputs: The element types of each element in `inputs`.
|
||||
Toutputs: The element types of each element in `outputs`.
|
||||
key: A unique identifier for this region used to match up host transfers.
|
||||
tpu_core: Default core to use for host to device transfers.
|
||||
)doc");
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
return ::tensorflow::shape_inference::UnknownShape(c);
|
||||
})
|
||||
.SetIsStateful();
|
||||
|
||||
REGISTER_OP("XlaHostCompute")
|
||||
.Input("inputs: Tinputs")
|
||||
@ -78,36 +71,16 @@ REGISTER_OP("XlaHostCompute")
|
||||
// statically known.
|
||||
return ::tensorflow::shape_inference::UnknownShape(c);
|
||||
}
|
||||
})
|
||||
.Doc(R"doc(
|
||||
A pseudo-op to represent host-side computation in an XLA program.
|
||||
|
||||
inputs: A list of tensors that will be sent to the host.
|
||||
outputs: A list of tensors that will be returned to the device.
|
||||
Tinputs: The element types of each element in `inputs`.
|
||||
Toutputs: The element types of each element in `outputs`.
|
||||
ancestors: A list of names of HostCompute computations that must be
|
||||
sequenced before this computation.
|
||||
shape_inference_graph: If non-empty, a serialized GraphDef representing a graph
|
||||
that must be analyzed at compile time to determine the shapes of the outputs.
|
||||
shapes: If shape_inference_graph is empty, a list of the shapes of `outputs`.
|
||||
key: A unique identifier for this region used to match up host transfers.
|
||||
cost_estimate_ns: Estimated duration of the host computation in nanoseconds.
|
||||
tpu_core: Default core to use for host to device transfers.
|
||||
)doc");
|
||||
});
|
||||
|
||||
REGISTER_OP("XlaSendToHost")
|
||||
.Input("input: Tinput")
|
||||
.Attr("Tinput: type")
|
||||
.Attr("key: string")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
An op to send a tensor to the host.
|
||||
|
||||
input: the tensor that will be sent to the host.
|
||||
Tinput: element type for input.
|
||||
key: A unique identifier for this region used to match up host transfers.
|
||||
)doc");
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
return ::tensorflow::shape_inference::UnknownShape(c);
|
||||
})
|
||||
.SetIsStateful();
|
||||
|
||||
REGISTER_OP("XlaRecvFromHost")
|
||||
.Output("output: Toutput")
|
||||
@ -127,14 +100,6 @@ REGISTER_OP("XlaRecvFromHost")
|
||||
c->MakeShapeFromShapeProto(shape_attr->shape(), &handle));
|
||||
c->set_output(0, handle);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
An op to receive a tensor from the host.
|
||||
|
||||
output: the tensor that will be received from the host.
|
||||
Toutput: element type for output.
|
||||
shape: shape for output.
|
||||
key: A unique identifier for this region used to match up host transfers.
|
||||
)doc");
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -33,24 +33,7 @@ REGISTER_OP("KthOrderStatistic")
|
||||
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
|
||||
c->set_output(0, s);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Computes the Kth order statistic of a data set. The current
|
||||
implementation uses a binary search requiring exactly 32 passes over
|
||||
the input data. The running time is linear with respect to input
|
||||
size. The median-of-medians algorithm is probably faster, but is
|
||||
difficult to implement efficiently in XLA. The implementation imposes
|
||||
a total ordering on floats. The ordering is consistent with the usual
|
||||
partial order. Positive NaNs are greater than positive
|
||||
infinity. Negative NaNs are less than negative infinity. NaNs with
|
||||
distinct payloads are treated as distinct. Subnormal numbers are
|
||||
preserved (not flushed to zero). Positive infinity is greater than all
|
||||
numbers. Negative infinity is less than all numbers. Positive is
|
||||
greater than negative zero. There are less than k values greater than
|
||||
the kth order statistic. There are at least k values greater than or
|
||||
equal to the Kth order statistic. The semantics are not the same as
|
||||
top_k_unique.
|
||||
)doc");
|
||||
});
|
||||
|
||||
REGISTER_OP("TopKUnique")
|
||||
.Input("input: float32")
|
||||
@ -69,22 +52,7 @@ REGISTER_OP("TopKUnique")
|
||||
c->set_output(0, s);
|
||||
c->set_output(1, s);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Returns the TopK unique values in the array in sorted order. The
|
||||
running time is proportional to the product of K and the input
|
||||
size. Sorting the whole array is more efficient for sufficiently large
|
||||
values of K. The median-of-medians algorithm is probably faster, but
|
||||
difficult to implement efficiently in XLA. If there are fewer than K
|
||||
unique numbers (not NANs), the results are padded with negative
|
||||
infinity. NaNs are never returned. Subnormal numbers are flushed to
|
||||
zero. If an element appears at multiple indices, the highest index is
|
||||
returned. If a TopK element never appears in the input due to padding
|
||||
values, the indices are padded with negative one. If a padding value
|
||||
appears in the input and padding is needed, the highest index of the
|
||||
padding value will be returned. The semantics are not the same as
|
||||
kth_order_statistic.
|
||||
)doc");
|
||||
});
|
||||
|
||||
REGISTER_OP("MakeUnique")
|
||||
.Input("input: float32")
|
||||
@ -94,14 +62,7 @@ REGISTER_OP("MakeUnique")
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
|
||||
c->set_output(0, input);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Make all elements in the non-Batch dimension unique, but \"close\" to
|
||||
their initial value. Never returns a sub-normal number. Never returns
|
||||
zero. The sign of each input element is always identical to the sign
|
||||
of the corresponding output element. Behavior for infinite elements is
|
||||
undefined. Behavior for subnormal elements is undefined.
|
||||
)doc");
|
||||
});
|
||||
|
||||
REGISTER_OP("TopKWithUnique")
|
||||
.Input("input: float32")
|
||||
@ -120,11 +81,5 @@ REGISTER_OP("TopKWithUnique")
|
||||
c->set_output(0, s);
|
||||
c->set_output(1, s);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Returns the TopK values in the array in sorted order. This is a combination
|
||||
of MakeUnique and TopKUnique. The returned top-K will have its lower bits
|
||||
replaced by iota, thus it will be close to the original value but not exactly
|
||||
the same. The running time is proportional to the product of K and the input
|
||||
size. NaNs are never returned. Subnormal numbers are flushed to zero.)doc");
|
||||
});
|
||||
} // namespace tensorflow
|
||||
|
@ -43,23 +43,7 @@ REGISTER_OP("_TPUCompileMlir")
|
||||
c->set_output(i + 1, c->Vector(2));
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(
|
||||
R"(
|
||||
Compiles a computations for execution on one or more TPU devices.
|
||||
For the internal use of the distributed TPU compiler. Note that currently only
|
||||
single TPU device is supported.
|
||||
|
||||
'mlir_module' is a serialized MLIR module with a `main` function that contains
|
||||
target computation.
|
||||
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
|
||||
known statically at TPUReplication rewrite time.
|
||||
'metadata' is a serialized TPUCompileMetadataProto describing
|
||||
the shapes and types of the inputs to the computation, as well as a mapping onto
|
||||
the TPU pod topology.
|
||||
'program' output is a string key that is passed to the _TPUExecute op and
|
||||
used to look up the program in the compilation cache.
|
||||
)");
|
||||
});
|
||||
|
||||
REGISTER_OP("TPUCompile")
|
||||
.Attr("num_computations: int >= 0")
|
||||
@ -91,39 +75,13 @@ REGISTER_OP("TPUCompile")
|
||||
c->set_output(num_computations + i + 1, c->Scalar());
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(
|
||||
R"(
|
||||
Compiles a computations for execution on one or more TPU devices.
|
||||
For the internal use of the distributed TPU compiler.
|
||||
|
||||
'num_computations' is the number of computations to be compiled.
|
||||
'function' is a function containing the computation to compile.
|
||||
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
|
||||
known statically at TPUReplication rewrite time.
|
||||
'guaranteed_constants' is a list of tensors which have been guaranteed to not
|
||||
change their values during the session lifetime. These contain tensors marked as
|
||||
constant using the GuaranteeConstOp.
|
||||
'metadata' is a serialized TPUCompileMetadataProto describing
|
||||
the shapes and types of the inputs to the computation, as well as a mapping onto
|
||||
the TPU pod topology.
|
||||
Each 'program' output is a string key that is passed to the _TPUExecute op and
|
||||
used to look up the program in the compilation cache.
|
||||
'may_modify_variables' indicates whether variables may be modified.
|
||||
)");
|
||||
});
|
||||
|
||||
REGISTER_OP("TPUCompileSucceededAssert")
|
||||
.Input("compilation_status: string")
|
||||
// Do not optimize me away. Read the comment on TPUCompileOp for more
|
||||
// details.
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs)
|
||||
.Doc(
|
||||
R"(
|
||||
Asserts that compilation succeeded. This op produces no output and closes the
|
||||
device during failure to ensure all pending device interactions fail.
|
||||
|
||||
'compilation_status' is a serialized CompilationResultProto.
|
||||
)");
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -35,10 +35,7 @@ REGISTER_OP("TPUExecute")
|
||||
c->set_output(i, c->UnknownShape());
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"(
|
||||
Op that loads and executes a TPU program on a TPU device.
|
||||
For the internal use of the distributed TPU compiler.)");
|
||||
});
|
||||
|
||||
REGISTER_OP("TPUExecuteAndUpdateVariables")
|
||||
.Input("args: Targs")
|
||||
@ -58,14 +55,6 @@ REGISTER_OP("TPUExecuteAndUpdateVariables")
|
||||
c->set_output(i, c->UnknownShape());
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"(Op that executes a program with optional in-place variable updates.
|
||||
It (optionally) reads device variables, loads and executes a TPU program on a
|
||||
TPU device, and then (optionally) in-place updates variables using the program
|
||||
outputs, as specified in attributes device_var_reads_indices (program input
|
||||
indices from directly reading variables) and device_var_updates_indices (program
|
||||
output indices used to update variables, -1 means no-update/read-only). Such
|
||||
program outputs are consumed by these variables will not appear in the op
|
||||
output. For the internal use of the distributed TPU compiler.)");
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -94,14 +94,6 @@ REGISTER_OP("TPUPartitionedInput")
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
An op that groups a list of partitioned inputs together. This op
|
||||
|
||||
inputs: A list of partitioned inputs which must have the same shape.
|
||||
output: A handle which represents the full shape of partitioned tensors.
|
||||
partition_dim: An integer describles which dimension is partitioned. -1 means
|
||||
those inputs are replicated.
|
||||
)doc");
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -23,7 +23,6 @@ namespace tensorflow {
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
|
||||
REGISTER_OP("TPUPartitionedOutput")
|
||||
.Input("inputs: T")
|
||||
.Output("output: num_splits * T")
|
||||
@ -53,14 +52,6 @@ REGISTER_OP("TPUPartitionedOutput")
|
||||
c->set_output(i, newoutput0);
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned
|
||||
outputs outside the XLA computation.
|
||||
|
||||
inputs: A tensor which represents the full shape of partitioned tensors.
|
||||
output: A list of partitioned inputs which must have the same shape.
|
||||
partition_dim: An integer describles which dimension is partitioned.
|
||||
)doc");
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
@ -419,10 +418,6 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
|
||||
|
||||
xla::Backend* backend = node_context->backend();
|
||||
|
||||
XlaDevice* device =
|
||||
tensorflow::down_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
|
||||
TF_RET_CHECK(device);
|
||||
|
||||
// Create a HostTransferManager to handle Send/Recv operations from the TPU.
|
||||
std::shared_ptr<HostTransferManager> host_transfer_manager =
|
||||
std::make_shared<HostTransferManager>(node_context, backend);
|
||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/graph/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/tpu/virtual_device.h"
|
||||
|
@ -3,7 +3,7 @@
|
||||
# ":platform" - Low-level and platform-specific Python code.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_strict_library")
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_tpu", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cc_test", "tf_cuda_library", "tf_enable_mlir_bridge", "tf_gen_op_wrapper_py")
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cc_test", "tf_cuda_library", "tf_enable_mlir_bridge", "tf_gen_op_wrapper_py")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_monitoring_python_deps")
|
||||
@ -6093,8 +6093,6 @@ pywrap_tensorflow_macro(
|
||||
"@ngraph_tf//:ngraph_tf",
|
||||
]) + if_xla_available([
|
||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||
]) + if_tpu([
|
||||
"//tensorflow/core/tpu:tpu_api_dlsym_initializer",
|
||||
]) + if_static(extra_deps = ["//tensorflow/core/platform:tf32_utils"]),
|
||||
)
|
||||
|
||||
|
@ -68,7 +68,7 @@ cc_library(
|
||||
deps = [
|
||||
":c_api_decl",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
],
|
||||
)
|
||||
@ -103,9 +103,7 @@ cc_library(
|
||||
":tpu_executor_interface",
|
||||
":tpu_platform_interface",
|
||||
":tpu_stream_interface",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
@ -131,8 +129,6 @@ cc_library(
|
||||
":status_helper",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
@ -164,8 +160,6 @@ cc_library(
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
@ -274,10 +268,8 @@ cc_library(
|
||||
hdrs = ["tpu_platform_interface.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:multi_platform_manager",
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor",
|
||||
],
|
||||
)
|
||||
|
||||
@ -334,7 +326,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_decl",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace ApiConverter {
|
||||
|
||||
xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) {
|
||||
xla::Shape xla_on_host_shape = ApiConverter::FromC(&c_buffer->on_host_shape);
|
||||
xla::Shape xla_on_device_shape =
|
||||
@ -114,6 +115,7 @@ SE_DeviceMemoryAllocator ToC(
|
||||
};
|
||||
return se_allocator;
|
||||
}
|
||||
|
||||
SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceMemory* mem) {
|
||||
SE_MaybeOwningDeviceMemory se_mem;
|
||||
se_mem.device_ordinal = mem->device_ordinal();
|
||||
|
@ -349,7 +349,6 @@ struct TfTpu_ExecutorApiFn {
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Free);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunHloPasses);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunBackend);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile);
|
||||
|
Loading…
Reference in New Issue
Block a user