Introduce additional TPU infeed and outfeed ops
PiperOrigin-RevId: 325542225 Change-Id: Ie972e60d6c5639b71719837c500ecc716eda2ebd
This commit is contained in:
parent
769155a21e
commit
3ba0deba91
@ -88,7 +88,13 @@ cc_library(
|
|||||||
name = "tpu_defs",
|
name = "tpu_defs",
|
||||||
srcs = ["tpu_defs.cc"],
|
srcs = ["tpu_defs.cc"],
|
||||||
hdrs = ["tpu_defs.h"],
|
hdrs = ["tpu_defs.h"],
|
||||||
deps = ["//tensorflow/core:protos_all_cc"],
|
deps = [
|
||||||
|
":tpu_api",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/stream_executor/tpu:c_api_conversions",
|
||||||
|
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -28,10 +28,16 @@ tf_kernel_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":cross_replica_ops",
|
":cross_replica_ops",
|
||||||
":host_compute_ops",
|
":host_compute_ops",
|
||||||
|
":image_resize_ops",
|
||||||
|
":infeed_ops",
|
||||||
|
":outfeed_ops",
|
||||||
|
":replication_ops",
|
||||||
":topk_ops",
|
":topk_ops",
|
||||||
":tpu_compile_op",
|
":tpu_compile_op",
|
||||||
":tpu_configuration_ops",
|
":tpu_configuration_ops",
|
||||||
":tpu_execute_op",
|
":tpu_execute_op",
|
||||||
|
":tpu_handle_to_key_op",
|
||||||
|
":transfer_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -684,3 +690,104 @@ cc_library(
|
|||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "infeed_ops",
|
||||||
|
srcs = ["infeed_ops.cc"],
|
||||||
|
hdrs = ["infeed_ops.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":transfer_ops",
|
||||||
|
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||||
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/common_runtime:dma_helper",
|
||||||
|
"//tensorflow/core/framework:protos_all_cc",
|
||||||
|
"//tensorflow/core/kernels:transpose_functor",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
|
"//tensorflow/stream_executor:multi_platform_manager",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_transfer_manager_base",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_transfer_manager_interface",
|
||||||
|
],
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "transfer_ops",
|
||||||
|
srcs = ["transfer_ops.cc"],
|
||||||
|
hdrs = ["transfer_ops.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core/kernels:ops_util",
|
||||||
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"//tensorflow/stream_executor:multi_platform_manager",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_transfer_manager_interface",
|
||||||
|
],
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "outfeed_ops",
|
||||||
|
srcs = ["outfeed_ops.cc"],
|
||||||
|
hdrs = ["outfeed_ops.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":transfer_ops",
|
||||||
|
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||||
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/framework:protos_all_cc",
|
||||||
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
|
"//tensorflow/stream_executor:multi_platform_manager",
|
||||||
|
],
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "image_resize_ops",
|
||||||
|
srcs = ["image_resize_ops.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:constants",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "replication_ops",
|
||||||
|
srcs = ["replication_ops.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
|
],
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_handle_to_key_op",
|
||||||
|
srcs = ["tpu_handle_to_key_op.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":tpu_compilation_cache_interface",
|
||||||
|
":tpu_op_consts",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/tpu:tpu_configuration",
|
||||||
|
],
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
155
tensorflow/core/tpu/kernels/image_resize_ops.cc
Normal file
155
tensorflow/core/tpu/kernels/image_resize_ops.cc
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "absl/strings/match.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class TpuCustomResizeOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit TpuCustomResizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Shape GetOutputShape(XlaOpKernelContext* ctx) const {
|
||||||
|
std::vector<int64> out_size;
|
||||||
|
auto status = ctx->ConstantInputAsIntVector(1, &out_size);
|
||||||
|
CHECK_EQ(out_size.size(), 2) << status.ToString();
|
||||||
|
xla::Shape output_shape =
|
||||||
|
TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
|
||||||
|
output_shape.mutable_dimensions()[1] = out_size[0];
|
||||||
|
output_shape.mutable_dimensions()[2] = out_size[1];
|
||||||
|
return output_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
string OpaqueField() const {
|
||||||
|
return absl::StrCat("\"", align_corners_, half_pixel_centers_, "\"");
|
||||||
|
}
|
||||||
|
|
||||||
|
void CompileGrad(XlaOpKernelContext* ctx, const char* target,
|
||||||
|
const xla::Shape& output_shape) {
|
||||||
|
auto input_shape =
|
||||||
|
TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
|
||||||
|
if (ctx->InputShape(1).dim_sizes() == ctx->InputShape(0).dim_sizes()) {
|
||||||
|
ctx->SetOutput(
|
||||||
|
0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// The gradient should be done in two phases for large resizes.
|
||||||
|
auto input = ctx->Input(0);
|
||||||
|
if (input_shape.dimensions(1) / output_shape.dimensions(1) > 3 &&
|
||||||
|
input_shape.dimensions(2) / output_shape.dimensions(2) > 3) {
|
||||||
|
auto intermediate_shape = output_shape;
|
||||||
|
intermediate_shape.mutable_dimensions()[1] = input_shape.dimensions(1);
|
||||||
|
input = xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
|
||||||
|
intermediate_shape, OpaqueField());
|
||||||
|
}
|
||||||
|
ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {input},
|
||||||
|
output_shape, OpaqueField()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CompileForward(XlaOpKernelContext* ctx, const char* target) {
|
||||||
|
auto output_shape = GetOutputShape(ctx);
|
||||||
|
if (ctx->InputShape(0).dim_size(1) == output_shape.dimensions(1) &&
|
||||||
|
ctx->InputShape(0).dim_size(2) == output_shape.dimensions(2)) {
|
||||||
|
ctx->SetOutput(
|
||||||
|
0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (ctx->InputShape(0).dim_size(1) == 1 &&
|
||||||
|
ctx->InputShape(0).dim_size(2) == 1) {
|
||||||
|
ctx->SetOutput(0,
|
||||||
|
ctx->Input(0) + xla::Zeros(ctx->builder(), output_shape));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
|
||||||
|
output_shape, OpaqueField()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool align_corners_;
|
||||||
|
bool half_pixel_centers_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TpuResizeNearestNeighborOp : public TpuCustomResizeOp {
|
||||||
|
public:
|
||||||
|
explicit TpuResizeNearestNeighborOp(OpKernelConstruction* ctx)
|
||||||
|
: TpuCustomResizeOp(ctx) {}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
CompileForward(ctx, "ResizeNearest");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TpuResizeBilinearOp : public TpuCustomResizeOp {
|
||||||
|
public:
|
||||||
|
explicit TpuResizeBilinearOp(OpKernelConstruction* ctx)
|
||||||
|
: TpuCustomResizeOp(ctx) {}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
CompileForward(ctx, "ResizeBilinear");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TpuResizeNearestNeighborGradOp : public TpuCustomResizeOp {
|
||||||
|
public:
|
||||||
|
explicit TpuResizeNearestNeighborGradOp(OpKernelConstruction* ctx)
|
||||||
|
: TpuCustomResizeOp(ctx) {}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
CompileGrad(ctx, "ResizeNearestGrad", GetOutputShape(ctx));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TpuResizeBilinearGradOp : public TpuCustomResizeOp {
|
||||||
|
public:
|
||||||
|
explicit TpuResizeBilinearGradOp(OpKernelConstruction* ctx)
|
||||||
|
: TpuCustomResizeOp(ctx) {}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
auto output_shape =
|
||||||
|
TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(1));
|
||||||
|
CompileGrad(ctx, "ResizeBilinearGrad", output_shape);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("ResizeNearestNeighbor")
|
||||||
|
.CompileTimeConstantInput("size")
|
||||||
|
.Device(DEVICE_TPU_XLA_JIT),
|
||||||
|
TpuResizeNearestNeighborOp);
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("ResizeNearestNeighborGrad")
|
||||||
|
.CompileTimeConstantInput("size")
|
||||||
|
.Device(DEVICE_TPU_XLA_JIT),
|
||||||
|
TpuResizeNearestNeighborGradOp);
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("ResizeBilinear")
|
||||||
|
.CompileTimeConstantInput("size")
|
||||||
|
.Device(DEVICE_TPU_XLA_JIT),
|
||||||
|
TpuResizeBilinearOp);
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("ResizeBilinearGrad").Device(DEVICE_TPU_XLA_JIT),
|
||||||
|
TpuResizeBilinearGradOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
529
tensorflow/core/tpu/kernels/infeed_ops.cc
Normal file
529
tensorflow/core/tpu/kernels/infeed_ops.cc
Normal file
@ -0,0 +1,529 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/kernels/infeed_ops.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/xla_device.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/function_handle_cache.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/variant.h"
|
||||||
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||||
|
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||||
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef tensorflow::tpu::NoncopyableBuffer LinearizerBuffer;
|
||||||
|
typedef std::deque<LinearizerBuffer> LinearizerBufferList;
|
||||||
|
|
||||||
|
// Transposes the given tensor using the tensorflow C++ transpose implementation
|
||||||
|
// to obtain a XLA literal for the host tensor laid out as the given layout. The
|
||||||
|
// returned tensor is normalized to the dim0major layout -- F32[10,20,30]{2,0,1}
|
||||||
|
// is returned as F32[20,10,30]{2,1,0}.
|
||||||
|
xla::StatusOr<Tensor> TransposeTensor(OpKernelContext* ctx,
|
||||||
|
const Tensor& input_tensor,
|
||||||
|
const xla::Shape& xla_shape) {
|
||||||
|
profiler::TraceMe trace_me("TransposeTensor", /*level=*/2);
|
||||||
|
const int64 rank = xla_shape.rank();
|
||||||
|
std::vector<int32> permutation(rank);
|
||||||
|
std::vector<int64> transposed_shapes(rank);
|
||||||
|
for (int64 i = 0; i < rank; ++i) {
|
||||||
|
permutation[i] = xla_shape.layout().minor_to_major(rank - 1 - i);
|
||||||
|
transposed_shapes[i] = xla_shape.dimensions(permutation[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor transposed_tensor;
|
||||||
|
|
||||||
|
// If this is a trivial transpose (i.e., bitcast), just create an aliased
|
||||||
|
// tensor with the transposed shape.
|
||||||
|
if (xla::LayoutUtil::IsMonotonicWithDim0Major(
|
||||||
|
xla::ShapeUtil::DropDegenerateDimensions(xla_shape).layout())) {
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(transposed_shapes, &shape));
|
||||||
|
TF_RETURN_IF_ERROR(transposed_tensor.BitcastFrom(
|
||||||
|
input_tensor, input_tensor.dtype(), shape));
|
||||||
|
return transposed_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
AllocatorAttributes alloc_attr;
|
||||||
|
alloc_attr.set_on_host(true);
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(input_tensor.dtype(),
|
||||||
|
TensorShape(transposed_shapes),
|
||||||
|
&transposed_tensor, alloc_attr));
|
||||||
|
// Eigen Transpose fails with SIGFPE if there is a dimension of size 0.
|
||||||
|
if (input_tensor.NumElements() > 0) {
|
||||||
|
TF_RETURN_IF_ERROR(DoTranspose<CPUDevice>(ctx->eigen_device<CPUDevice>(),
|
||||||
|
input_tensor, permutation,
|
||||||
|
&transposed_tensor));
|
||||||
|
}
|
||||||
|
return transposed_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<bool> GetLayoutOverride(OpKernelConstruction* ctx,
|
||||||
|
const char* attrn_name,
|
||||||
|
std::vector<int64>* minor_to_major) {
|
||||||
|
if (!ctx->HasAttr(attrn_name)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(ctx->GetAttr(attrn_name, minor_to_major));
|
||||||
|
return !minor_to_major->empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetInfeedShapeWithLayout(OpKernelConstruction* ctx,
|
||||||
|
const char* attrn_name,
|
||||||
|
const xla::Shape& input_shape,
|
||||||
|
xla::Shape* output_shape) {
|
||||||
|
std::vector<int64> minor_to_major;
|
||||||
|
TF_ASSIGN_OR_RETURN(bool has_override,
|
||||||
|
GetLayoutOverride(ctx, attrn_name, &minor_to_major));
|
||||||
|
if (!has_override) {
|
||||||
|
*output_shape = input_shape;
|
||||||
|
if (output_shape->IsTuple()) {
|
||||||
|
int64 tuple_elements = xla::ShapeUtil::TupleElementCount(*output_shape);
|
||||||
|
for (int64 i = 0; i < tuple_elements; ++i) {
|
||||||
|
xla::Shape* sub_shape =
|
||||||
|
xla::ShapeUtil::GetMutableSubshape(output_shape, {i});
|
||||||
|
*sub_shape->mutable_layout() = GetTPUInfeedLayout(*sub_shape).layout();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*output_shape->mutable_layout() =
|
||||||
|
GetTPUInfeedLayout(*output_shape).layout();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto layout_func = [](const xla::Shape& shape) -> xla::Layout {
|
||||||
|
return GetTPUInfeedLayout(shape).layout();
|
||||||
|
};
|
||||||
|
return GetShapeWithLayout(input_shape, minor_to_major, layout_func,
|
||||||
|
output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinearizedBuffersWrapper is an opaque C++ data structure for the outputs of
|
||||||
|
// PrelinearizeOp and PrelinearizeTupleOp. It holds the resultant linearized
|
||||||
|
// buffers and references to input tensors whose underlying storage are shared
|
||||||
|
// with linearized buffers.
|
||||||
|
// NOTE: This is not a feature-complete implementation of the DT_VARIANT
|
||||||
|
// specification. In particular, we cannot currently serialize an arbitrary
|
||||||
|
// `LinearizerBufferList` (aka `std::deque<LinearizerBuffer>`)
|
||||||
|
// object, so the `Encode()` and `Decode()` methods are not implemented.
|
||||||
|
struct LinearizedBuffersWrapper {
|
||||||
|
explicit LinearizedBuffersWrapper() {}
|
||||||
|
explicit LinearizedBuffersWrapper(LinearizerBufferList bufs,
|
||||||
|
std::vector<tensorflow::Tensor> ts)
|
||||||
|
: buffers(std::move(bufs)), tensors(std::move(ts)) {}
|
||||||
|
LinearizedBuffersWrapper(const LinearizedBuffersWrapper& wrapper) {
|
||||||
|
// tensorflow::Variant requires this copy constructor to compile.
|
||||||
|
LOG(FATAL) << "LinearizedBuffersWrapper should not copy.";
|
||||||
|
}
|
||||||
|
LinearizedBuffersWrapper& operator=(const LinearizedBuffersWrapper& wrapper) =
|
||||||
|
delete;
|
||||||
|
LinearizedBuffersWrapper(LinearizedBuffersWrapper&&) = default;
|
||||||
|
LinearizedBuffersWrapper& operator=(LinearizedBuffersWrapper&&) = default;
|
||||||
|
~LinearizedBuffersWrapper() = default;
|
||||||
|
|
||||||
|
// These functions are tensorflow::Variant requirements.
|
||||||
|
string TypeName() const { return "(anonymous)::LinearizedBuffersWrapper"; }
|
||||||
|
void Encode(tensorflow::VariantTensorData* data) const {
|
||||||
|
LOG(ERROR) << "Encode() is not implemented for LinearizedBuffersWrapper "
|
||||||
|
"objects.";
|
||||||
|
}
|
||||||
|
bool Decode(const tensorflow::VariantTensorData& data) {
|
||||||
|
LOG(ERROR) << "Decode() is not implemented for LinearizedBuffersWrapper "
|
||||||
|
"objects.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LinearizerBufferList buffers;
|
||||||
|
// Save references on tensors whose underlying storage are shared with
|
||||||
|
// LiteralLinearizer::Buffer in `buffers`.
|
||||||
|
std::vector<tensorflow::Tensor> tensors;
|
||||||
|
};
|
||||||
|
|
||||||
|
Status AutoTransposeAndLinearize(OpKernelContext* ctx,
|
||||||
|
const Tensor& input_tensor,
|
||||||
|
const xla::Shape& shape,
|
||||||
|
LinearizerBufferList* linearized_buffers,
|
||||||
|
std::vector<Tensor>* saved_input_tensors) {
|
||||||
|
const Tensor* tensor = &input_tensor;
|
||||||
|
// If the given layout is not in dim0major layout, tranposes the tensor.
|
||||||
|
bool has_transposed = false;
|
||||||
|
Tensor transposed_tensor;
|
||||||
|
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
|
||||||
|
// If the given layout is not in dim0major layout, transpose the tensor.
|
||||||
|
TF_ASSIGN_OR_RETURN(transposed_tensor,
|
||||||
|
TransposeTensor(ctx, input_tensor, shape));
|
||||||
|
tensor = &transposed_tensor;
|
||||||
|
has_transposed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::BorrowingLiteral literal;
|
||||||
|
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager()
|
||||||
|
->LinearizeToBuffers(literal, linearized_buffers));
|
||||||
|
|
||||||
|
// The input tensor is ref-counted. Save a handle on the input tensor if
|
||||||
|
// its underlying storage is shared with linearized buffers to prevent
|
||||||
|
// input tensor from getting freed.
|
||||||
|
for (const auto& buffer : *linearized_buffers) {
|
||||||
|
if (!buffer.owns_data() && !has_transposed) {
|
||||||
|
// `buffer` is created from zero-copy fast path from the un-transposed
|
||||||
|
// input tensor so its underlying data is shared with input tensor.
|
||||||
|
// Save a handle to input tensor to increment its ref-count and avoid
|
||||||
|
// it getting deallocated after PrelinearizeTupleOp completes.
|
||||||
|
saved_input_tensors->push_back(*tensor);
|
||||||
|
// A literal can be linearized to zero to two buffers. If any of the
|
||||||
|
// linearized buffer shares storage with input tensor. We save exactly
|
||||||
|
// one handle on the input tensor.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrelinearizeOp is used to linearize one tensor to the device format.
|
||||||
|
class PrelinearizeOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit PrelinearizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
xla::Shape shape;
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor& input_tensor = ctx->input(0);
|
||||||
|
// Validate input.
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_tensor.dtype() == dtype_,
|
||||||
|
errors::InvalidArgument("Prelinearize dtype mismatch; expected ",
|
||||||
|
DataType_Name(dtype_), ", got ",
|
||||||
|
DataType_Name(input_tensor.dtype())));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_tensor.shape() == shape_,
|
||||||
|
errors::InvalidArgument("Prelinearize shape mismatch; expected ",
|
||||||
|
shape_.DebugString(), ", got ",
|
||||||
|
input_tensor.shape().DebugString()));
|
||||||
|
|
||||||
|
// Auto-transpose and prelinearize.
|
||||||
|
LinearizerBufferList linearized_buffers;
|
||||||
|
std::vector<Tensor> saved_input_tensors;
|
||||||
|
auto status =
|
||||||
|
AutoTransposeAndLinearize(ctx, input_tensor, xla_shape_,
|
||||||
|
&linearized_buffers, &saved_input_tensors);
|
||||||
|
OP_REQUIRES_OK(ctx, status);
|
||||||
|
|
||||||
|
// Write to output.
|
||||||
|
tensorflow::Tensor* output;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
|
||||||
|
output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
|
||||||
|
std::move(linearized_buffers), std::move(saved_input_tensors)};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsExpensive() override { return true; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
TensorShape shape_;
|
||||||
|
DataType dtype_;
|
||||||
|
xla::Shape xla_shape_;
|
||||||
|
|
||||||
|
// PrelinearizeOp is neither copyable nor movable.
|
||||||
|
PrelinearizeOp(const PrelinearizeOp&) = delete;
|
||||||
|
PrelinearizeOp& operator=(const PrelinearizeOp&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
// PrelinearizeTupleOp is used to linearize multiple tensors to the device
|
||||||
|
// format.
|
||||||
|
class PrelinearizeTupleOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit PrelinearizeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, shapes_.size() == dtypes_.size(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"shapes and dtypes must be the same length. shapes length = ",
|
||||||
|
shapes_.size(), ", dtypes length = ", dtypes_.size()));
|
||||||
|
|
||||||
|
std::vector<xla::Shape> xla_shapes;
|
||||||
|
for (int i = 0; i < shapes_.size(); i++) {
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
|
||||||
|
xla_shapes.push_back(xla_shape);
|
||||||
|
}
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInfeedShapeWithLayout(
|
||||||
|
ctx, "layouts", xla::ShapeUtil::MakeTupleShape(xla_shapes),
|
||||||
|
&tuple_shape_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
OpInputList values;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values));
|
||||||
|
OP_REQUIRES(ctx, values.size() == shapes_.size(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Wrong number of inputs to PrelinearizeTuple."));
|
||||||
|
|
||||||
|
LinearizerBufferList all_linearized_buffers;
|
||||||
|
std::vector<Tensor> all_saved_input_tensors;
|
||||||
|
for (int i = 0; i < values.size(); i++) {
|
||||||
|
// Validate input.
|
||||||
|
const Tensor& input_tensor = values[i];
|
||||||
|
OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i],
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"PrelinearizeTuple dtype mismatch at tuple element ", i,
|
||||||
|
"; expected ", DataType_Name(dtypes_[i]), ", got ",
|
||||||
|
DataType_Name(input_tensor.dtype())));
|
||||||
|
OP_REQUIRES(ctx, input_tensor.shape() == shapes_[i],
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"PrelinearizeTuple shape mismatch at tuple element ", i,
|
||||||
|
"; expected ", shapes_[i].DebugString(), ", got ",
|
||||||
|
input_tensor.shape().DebugString()));
|
||||||
|
|
||||||
|
// Auto-transpose and prelinearize.
|
||||||
|
LinearizerBufferList linearized_buffers;
|
||||||
|
std::vector<Tensor> saved_input_tensors;
|
||||||
|
auto status = AutoTransposeAndLinearize(
|
||||||
|
ctx, input_tensor, tuple_shape_.tuple_shapes(i), &linearized_buffers,
|
||||||
|
&saved_input_tensors);
|
||||||
|
OP_REQUIRES_OK(ctx, status);
|
||||||
|
all_linearized_buffers.insert(
|
||||||
|
all_linearized_buffers.end(),
|
||||||
|
std::make_move_iterator(linearized_buffers.begin()),
|
||||||
|
std::make_move_iterator(linearized_buffers.end()));
|
||||||
|
all_saved_input_tensors.insert(
|
||||||
|
all_saved_input_tensors.end(),
|
||||||
|
std::make_move_iterator(saved_input_tensors.begin()),
|
||||||
|
std::make_move_iterator(saved_input_tensors.end()));
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Tensor* output;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
|
||||||
|
output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
|
||||||
|
std::move(all_linearized_buffers), std::move(all_saved_input_tensors)};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsExpensive() override { return true; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<TensorShape> shapes_;
|
||||||
|
DataTypeVector dtypes_;
|
||||||
|
xla::Shape tuple_shape_;
|
||||||
|
|
||||||
|
// PrelinearizeTupleOp is neither copyable nor movable.
|
||||||
|
PrelinearizeTupleOp(const PrelinearizeTupleOp&) = delete;
|
||||||
|
PrelinearizeTupleOp& operator=(const PrelinearizeTupleOp&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
// The InfeedEnqueuePrelinearizedBufferOp op is used to transfer prelinearized
|
||||||
|
// buffers to the device infeed queue.
|
||||||
|
class InfeedEnqueuePrelinearizedBufferOp : public TpuTransferAsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit InfeedEnqueuePrelinearizedBufferOp(OpKernelConstruction* ctx)
|
||||||
|
: TpuTransferAsyncOpKernel(ctx, "prelinearized_buffers_to_infeed", 8) {}
|
||||||
|
|
||||||
|
Status DoWork(OpKernelContext* ctx,
|
||||||
|
xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) override {
|
||||||
|
const Tensor& input_tensor = ctx->input(0);
|
||||||
|
const LinearizedBuffersWrapper* wrapper =
|
||||||
|
input_tensor.scalar<tensorflow::Variant>()()
|
||||||
|
.get<LinearizedBuffersWrapper>();
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->TransferBuffersToInfeed(
|
||||||
|
stream_executor, wrapper->buffers));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// InfeedEnqueuePrelinearizedBufferOp is neither copyable nor movable.
|
||||||
|
InfeedEnqueuePrelinearizedBufferOp(
|
||||||
|
const InfeedEnqueuePrelinearizedBufferOp&) = delete;
|
||||||
|
InfeedEnqueuePrelinearizedBufferOp& operator=(
|
||||||
|
const InfeedEnqueuePrelinearizedBufferOp&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TpuInfeedEnqueueOp::TpuInfeedEnqueueOp(OpKernelConstruction* ctx)
|
||||||
|
: TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
xla::Shape shape;
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuInfeedEnqueueOp::DoWork(
|
||||||
|
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) {
|
||||||
|
const Tensor& input_tensor = ctx->input(0);
|
||||||
|
|
||||||
|
// Validate runtime shape and fail if it doesn't match the contract.
|
||||||
|
if (input_tensor.dtype() != dtype_) {
|
||||||
|
return errors::InvalidArgument("Infeed dtype mismatch.");
|
||||||
|
}
|
||||||
|
if (input_tensor.shape() != shape_) {
|
||||||
|
return errors::InvalidArgument("Infeed shape mismatch; expected ",
|
||||||
|
shape_.DebugString(), ", got ",
|
||||||
|
input_tensor.shape().DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const Tensor* tensor = &input_tensor;
|
||||||
|
Tensor transposed_tensor;
|
||||||
|
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(xla_shape_.layout())) {
|
||||||
|
// If the given layout is not in dim0major layout, transpose the tensor.
|
||||||
|
TF_ASSIGN_OR_RETURN(transposed_tensor,
|
||||||
|
TransposeTensor(ctx, input_tensor, xla_shape_));
|
||||||
|
tensor = &transposed_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::BorrowingLiteral literal;
|
||||||
|
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
|
||||||
|
|
||||||
|
// Transfer the given literal to the Infeed interface of the device.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
transfer_manager->TransferLiteralToInfeed(stream_executor, literal));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
TpuInfeedEnqueueTupleOp::TpuInfeedEnqueueTupleOp(OpKernelConstruction* ctx)
|
||||||
|
: TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, shapes_.size() == dtypes_.size(),
|
||||||
|
errors::InvalidArgument("shapes and dtypes must be the same length."));
|
||||||
|
|
||||||
|
std::vector<xla::Shape> xla_shapes;
|
||||||
|
for (int i = 0; i < shapes_.size(); i++) {
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
|
||||||
|
xla_shapes.push_back(xla_shape);
|
||||||
|
}
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInfeedShapeWithLayout(ctx, "layouts",
|
||||||
|
xla::ShapeUtil::MakeTupleShape(xla_shapes),
|
||||||
|
&tuple_shape_));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuInfeedEnqueueTupleOp::DoWork(
|
||||||
|
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) {
|
||||||
|
OpInputList values;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->input_list("inputs", &values));
|
||||||
|
if (values.size() != shapes_.size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Wrong number of inputs to InfeedEnqueueTuple.");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& shapes : shapes_) {
|
||||||
|
VLOG(1) << "TransferLiteralToInfeed " << shapes.DebugString();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Tensor> maybe_transposed_tensors;
|
||||||
|
maybe_transposed_tensors.reserve(values.size());
|
||||||
|
for (int i = 0; i < values.size(); i++) {
|
||||||
|
// Validate runtime shapes and fail if it doesn't match the contract.
|
||||||
|
const Tensor* tensor = &values[i];
|
||||||
|
if (tensor->shape() != shapes_[i]) {
|
||||||
|
return errors::InvalidArgument("Infeed shape mismatch for tuple element ",
|
||||||
|
i, "; expected ", shapes_[i].DebugString(),
|
||||||
|
", got ", tensor->shape().DebugString());
|
||||||
|
}
|
||||||
|
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(
|
||||||
|
tuple_shape_.tuple_shapes(i).layout())) {
|
||||||
|
// If the given layout is not in dim0major layout, tranposes the given
|
||||||
|
// tensor.
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
Tensor transposed_tensor,
|
||||||
|
TransposeTensor(ctx, *tensor, tuple_shape_.tuple_shapes(i)));
|
||||||
|
maybe_transposed_tensors.emplace_back(transposed_tensor);
|
||||||
|
} else {
|
||||||
|
maybe_transposed_tensors.emplace_back(*tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::BorrowingLiteral tuple;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
HostTensorsToBorrowingLiteralTuple(maybe_transposed_tensors, &tuple));
|
||||||
|
|
||||||
|
// Transfer the given literal to the Infeed interface of the device.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
transfer_manager->TransferLiteralToInfeed(stream_executor, tuple));
|
||||||
|
|
||||||
|
VLOG(1) << "TransferLiteralToInfeed complete.";
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// These ops execute on either the TPU device or the CPU device. When running on
|
||||||
|
// CPU they must specify a non-negative value for device_ordinal to indicate
|
||||||
|
// which TPU to send infeed to.
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("InfeedEnqueue").Device(DEVICE_TPU_NODE).HostMemory("input"),
|
||||||
|
TpuInfeedEnqueueOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("InfeedEnqueue").Device(DEVICE_CPU),
|
||||||
|
TpuInfeedEnqueueOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("InfeedEnqueueTuple").Device(DEVICE_TPU_NODE).HostMemory("inputs"),
|
||||||
|
TpuInfeedEnqueueTupleOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("InfeedEnqueueTuple").Device(DEVICE_CPU),
|
||||||
|
TpuInfeedEnqueueTupleOp);
|
||||||
|
|
||||||
|
// Prelinearize ops run on CPU as part of tf.data input pipeline.
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("Prelinearize").Device(DEVICE_CPU),
|
||||||
|
PrelinearizeOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("PrelinearizeTuple").Device(DEVICE_CPU),
|
||||||
|
PrelinearizeTupleOp);
|
||||||
|
|
||||||
|
// InfeedEnqueuePrelinearizedBuffer op run on CPU and takes a device_ordinal to
|
||||||
|
// select the right device to infeed.
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("InfeedEnqueuePrelinearizedBuffer").Device(DEVICE_CPU),
|
||||||
|
InfeedEnqueuePrelinearizedBufferOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
69
tensorflow/core/tpu/kernels/infeed_ops.h
Normal file
69
tensorflow/core/tpu/kernels/infeed_ops.h
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
|
||||||
|
#define TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// TODO(b/65200690): Rework this when there is a callback based infeed API to
|
||||||
|
// StreamExecutor.
|
||||||
|
|
||||||
|
// The InfeedEnqueue op is used to deliver data to the device infeed queue.
|
||||||
|
class TpuInfeedEnqueueOp : public TpuTransferAsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit TpuInfeedEnqueueOp(OpKernelConstruction* ctx);
|
||||||
|
Status DoWork(OpKernelContext* ctx,
|
||||||
|
xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TensorShape shape_;
|
||||||
|
DataType dtype_;
|
||||||
|
xla::Shape xla_shape_;
|
||||||
|
|
||||||
|
// TpuInfeedEnqueueOp is neither copyable nor movable.
|
||||||
|
TpuInfeedEnqueueOp(const TpuInfeedEnqueueOp&) = delete;
|
||||||
|
TpuInfeedEnqueueOp& operator=(const TpuInfeedEnqueueOp&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
// The InfeedEnqueueTuple op is used on the host to deliver multiple tensors to
|
||||||
|
// the device infeed queue as an XLA tuple.
|
||||||
|
class TpuInfeedEnqueueTupleOp : public TpuTransferAsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit TpuInfeedEnqueueTupleOp(OpKernelConstruction* ctx);
|
||||||
|
Status DoWork(OpKernelContext* ctx,
|
||||||
|
xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<TensorShape> shapes_;
|
||||||
|
DataTypeVector dtypes_;
|
||||||
|
xla::Shape tuple_shape_;
|
||||||
|
|
||||||
|
// TpuInfeedEnqueueTupleOp is neither copyable nor movable.
|
||||||
|
TpuInfeedEnqueueTupleOp(const TpuInfeedEnqueueTupleOp&) = delete;
|
||||||
|
TpuInfeedEnqueueTupleOp& operator=(const TpuInfeedEnqueueTupleOp&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
|
116
tensorflow/core/tpu/kernels/outfeed_ops.cc
Normal file
116
tensorflow/core/tpu/kernels/outfeed_ops.cc
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/kernels/outfeed_ops.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/xla_device.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TpuOutfeedDequeueOp::TpuOutfeedDequeueOp(OpKernelConstruction* ctx)
|
||||||
|
: TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuOutfeedDequeueOp::DoWork(
|
||||||
|
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) {
|
||||||
|
Tensor* output;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_output(0, shape_, &output));
|
||||||
|
|
||||||
|
// Transfer from the outfeed interface of the device.
|
||||||
|
xla::MutableBorrowingLiteral literal;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
HostTensorToMutableBorrowingLiteral(xla_shape_, output, &literal));
|
||||||
|
|
||||||
|
VLOG(1) << "TransferLiteralFromOutfeed "
|
||||||
|
<< xla::ShapeUtil::HumanStringWithLayout(xla_shape_);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
|
||||||
|
stream_executor, xla_shape_, literal));
|
||||||
|
|
||||||
|
VLOG(1) << "TransferLiteralFromOutfeed complete.";
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
|
||||||
|
// device outfeed queue.
|
||||||
|
TpuOutfeedDequeueTupleOp::TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx)
|
||||||
|
: TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, shapes_.size() == dtypes_.size(),
|
||||||
|
errors::InvalidArgument("shapes and dtypes must be the same length."));
|
||||||
|
// The `dtypes` list is inferred from the supplied inputs, so it
|
||||||
|
// is always the correct length.
|
||||||
|
for (int i = 0; i < shapes_.size(); i++) {
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
|
||||||
|
xla_shapes_.push_back(xla_shape);
|
||||||
|
}
|
||||||
|
tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuOutfeedDequeueTupleOp::DoWork(
|
||||||
|
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) {
|
||||||
|
VLOG(1) << "TransferLiteralFromOutfeed "
|
||||||
|
<< xla::ShapeUtil::HumanStringWithLayout(tuple_shape_);
|
||||||
|
|
||||||
|
for (int i = 0; i < shapes_.size(); ++i) {
|
||||||
|
Tensor* output;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shapes_[i], &output));
|
||||||
|
|
||||||
|
xla::MutableBorrowingLiteral literal;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
HostTensorToMutableBorrowingLiteral(xla_shapes_[i], output, &literal));
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
|
||||||
|
stream_executor, xla_shapes_[i], literal));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// These ops execute on either the TPU device or the CPU device. When
|
||||||
|
// running on CPU they must specify a non-negative value for
|
||||||
|
// device_ordinal to indicate which TPU to receive outfeed from.
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("OutfeedDequeue").Device(DEVICE_TPU_NODE).HostMemory("output"),
|
||||||
|
TpuOutfeedDequeueOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("OutfeedDequeue").Device(DEVICE_CPU),
|
||||||
|
TpuOutfeedDequeueOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("OutfeedDequeueTuple").Device(DEVICE_TPU_NODE).HostMemory("outputs"),
|
||||||
|
TpuOutfeedDequeueTupleOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("OutfeedDequeueTuple").Device(DEVICE_CPU),
|
||||||
|
TpuOutfeedDequeueTupleOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
69
tensorflow/core/tpu/kernels/outfeed_ops.h
Normal file
69
tensorflow/core/tpu/kernels/outfeed_ops.h
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
|
||||||
|
#define TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// The OutfeedDequeue op is used to retrieve a single tensor from the device
|
||||||
|
// outfeed queue.
|
||||||
|
class TpuOutfeedDequeueOp : public TpuTransferAsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit TpuOutfeedDequeueOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
|
Status DoWork(OpKernelContext* ctx,
|
||||||
|
xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TensorShape shape_;
|
||||||
|
DataType dtype_;
|
||||||
|
xla::Shape xla_shape_;
|
||||||
|
|
||||||
|
// OutfeedDequeueOp is neither copyable nor movable.
|
||||||
|
TpuOutfeedDequeueOp(const TpuOutfeedDequeueOp&) = delete;
|
||||||
|
TpuOutfeedDequeueOp& operator=(const TpuOutfeedDequeueOp&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
|
||||||
|
// device outfeed queue.
|
||||||
|
class TpuOutfeedDequeueTupleOp : public TpuTransferAsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
|
Status DoWork(OpKernelContext* ctx,
|
||||||
|
xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<TensorShape> shapes_;
|
||||||
|
DataTypeVector dtypes_;
|
||||||
|
std::vector<xla::Shape> xla_shapes_;
|
||||||
|
xla::Shape tuple_shape_;
|
||||||
|
|
||||||
|
// OutfeedDequeueTupleOp is neither copyable nor movable.
|
||||||
|
TpuOutfeedDequeueTupleOp(const TpuOutfeedDequeueTupleOp&) = delete;
|
||||||
|
TpuOutfeedDequeueTupleOp& operator=(const TpuOutfeedDequeueTupleOp&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
|
27
tensorflow/core/tpu/kernels/replication_ops.cc
Normal file
27
tensorflow/core/tpu/kernels/replication_ops.cc
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_TPUReplicate").Device(DEVICE_TPU_SYSTEM),
|
||||||
|
XlaDeviceDummyOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
62
tensorflow/core/tpu/kernels/tpu_handle_to_key_op.cc
Normal file
62
tensorflow/core/tpu/kernels/tpu_handle_to_key_op.cc
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class TpuHandleToProtoKeyOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit TpuHandleToProtoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
~TpuHandleToProtoKeyOp() override = default;
|
||||||
|
TpuHandleToProtoKeyOp(const TpuHandleToProtoKeyOp&) = delete;
|
||||||
|
TpuHandleToProtoKeyOp& operator=(const TpuHandleToProtoKeyOp&) = delete;
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
VLOG(1) << "TpuHandleToProtoKeyOp::Compute " << ctx->op_kernel().name()
|
||||||
|
<< " on device " << ctx->op_kernel().requested_device();
|
||||||
|
const Tensor& uid = ctx->input(0);
|
||||||
|
|
||||||
|
ResourceMgr* rm = GetTPUConfigResourceMgr();
|
||||||
|
tpu::TpuCompilationCacheInterface* cache;
|
||||||
|
OP_REQUIRES_OK(ctx, rm->Lookup<tpu::TpuCompilationCacheInterface>(
|
||||||
|
rm->default_container(),
|
||||||
|
tpu::kCompilationCacheResourceName, &cache));
|
||||||
|
core::ScopedUnref cache_unref(cache);
|
||||||
|
|
||||||
|
std::vector<std::string> keys;
|
||||||
|
OP_REQUIRES_OK(ctx, cache->GetKeysFromUid(uid.scalar<int64>()(), &keys));
|
||||||
|
|
||||||
|
TensorShape output_shape;
|
||||||
|
output_shape.AddDim(keys.size());
|
||||||
|
Tensor* result = nullptr;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
|
||||||
|
for (int i = 0; i < keys.size(); ++i) {
|
||||||
|
result->vec<tstring>()(i) = keys[i];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TpuHandleToProtoKey").Device(DEVICE_CPU),
|
||||||
|
TpuHandleToProtoKeyOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
98
tensorflow/core/tpu/kernels/transfer_ops.cc
Normal file
98
tensorflow/core/tpu/kernels/transfer_ops.cc
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/platform/tracing.h"
|
||||||
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
|
||||||
|
const string& transfer_type,
|
||||||
|
int number_of_threads)
|
||||||
|
: AsyncOpKernel(ctx),
|
||||||
|
thread_pool_(new thread::ThreadPool(
|
||||||
|
ctx->env(),
|
||||||
|
strings::StrCat(transfer_type, "_thread_",
|
||||||
|
SanitizeThreadSuffix(def().name())),
|
||||||
|
/*num_threads=*/8)) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
|
||||||
|
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, device_ordinal_ >= 0,
|
||||||
|
errors::InvalidArgument(transfer_type,
|
||||||
|
" ops must specify a device_ordinal when "
|
||||||
|
"placed on CPU."));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TpuTransferAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
|
||||||
|
DoneCallback done) {
|
||||||
|
CancellationToken token =
|
||||||
|
ctx->cancellation_manager()->get_cancellation_token();
|
||||||
|
bool already_cancelled;
|
||||||
|
{
|
||||||
|
// Only protect registering the cancellation callback as mu_ cannot be held
|
||||||
|
// at a point where `done` could be called.
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
already_cancelled = !ctx->cancellation_manager()->RegisterCallback(
|
||||||
|
token, [this]() { Cancel(); });
|
||||||
|
}
|
||||||
|
OP_REQUIRES_ASYNC(ctx, !already_cancelled,
|
||||||
|
errors::Cancelled("Infeed was cancelled."), done);
|
||||||
|
thread_pool_->Schedule([this, ctx, done, token]() {
|
||||||
|
Status s = RunTransfer(ctx);
|
||||||
|
ctx->cancellation_manager()->DeregisterCallback(token);
|
||||||
|
OP_REQUIRES_OK_ASYNC(ctx, s, done);
|
||||||
|
done();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
|
||||||
|
auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
|
||||||
|
|
||||||
|
int real_device_ordinal = device_ordinal_;
|
||||||
|
if (real_device_ordinal < 0) {
|
||||||
|
const XlaDevice::Metadata* metadata;
|
||||||
|
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
|
||||||
|
real_device_ordinal = metadata->device_ordinal();
|
||||||
|
}
|
||||||
|
stream_executor::StreamExecutor* stream_executor =
|
||||||
|
tpu_platform->ExecutorForDevice(real_device_ordinal).ValueOrDie();
|
||||||
|
|
||||||
|
// When Xprof profiling is off (which is the default), constructing the
|
||||||
|
// activity is simple enough that its overhead is negligible.
|
||||||
|
profiler::TraceMe activity(
|
||||||
|
[this] { return profiler::TraceMeOp(name(), type_string()); },
|
||||||
|
profiler::TraceMeLevel::kInfo);
|
||||||
|
return DoWork(
|
||||||
|
ctx, xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager(),
|
||||||
|
stream_executor);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TpuTransferAsyncOpKernel::Cancel() {
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
TF_CHECK_OK(tpu::TpuNodeContext::CloseTpuHost());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
56
tensorflow/core/tpu/kernels/transfer_ops.h
Normal file
56
tensorflow/core/tpu/kernels/transfer_ops.h
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
|
||||||
|
#define TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/xla_device.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/util/stream_executor_util.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Base class providing common functionality for async ops that transfer from
|
||||||
|
// host to TPU.
|
||||||
|
class TpuTransferAsyncOpKernel : public AsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
|
||||||
|
const string& transfer_type,
|
||||||
|
int number_of_threads);
|
||||||
|
|
||||||
|
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual Status DoWork(OpKernelContext* context,
|
||||||
|
xla::TpuTransferManagerInterface* transfer_manager,
|
||||||
|
stream_executor::StreamExecutor* stream_executor) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status RunTransfer(OpKernelContext* ctx);
|
||||||
|
void Cancel();
|
||||||
|
|
||||||
|
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||||
|
int device_ordinal_;
|
||||||
|
mutex mu_;
|
||||||
|
|
||||||
|
// TpuTransferAsyncOpKernel is neither copyable nor movable.
|
||||||
|
TpuTransferAsyncOpKernel(const TpuTransferAsyncOpKernel&) = delete;
|
||||||
|
TpuTransferAsyncOpKernel& operator=(const TpuTransferAsyncOpKernel&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
|
@ -15,6 +15,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
const char* const DEVICE_TPU_NODE = "TPU";
|
const char* const DEVICE_TPU_NODE = "TPU";
|
||||||
@ -27,4 +31,18 @@ const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR =
|
|||||||
const char* const kTPUReplicateAttr = "_tpu_replicate";
|
const char* const kTPUReplicateAttr = "_tpu_replicate";
|
||||||
const char* const kOutsideCompilationAttr = "_xla_outside_compilation";
|
const char* const kOutsideCompilationAttr = "_xla_outside_compilation";
|
||||||
|
|
||||||
|
xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
|
||||||
|
XLA_Shape c_shape;
|
||||||
|
XLA_Shape c_infeed_shape;
|
||||||
|
|
||||||
|
ApiConverter::ToC(shape, &c_shape);
|
||||||
|
|
||||||
|
tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
|
||||||
|
&c_infeed_shape);
|
||||||
|
xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
|
||||||
|
ApiConverter::Free(&c_shape);
|
||||||
|
ApiConverter::Free(&c_infeed_shape);
|
||||||
|
return infeed_shape;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -56,6 +57,11 @@ static constexpr std::array<DataType, 16> kTpuAllTypes = {
|
|||||||
DT_COMPLEX64, DT_INT64, DT_UINT64, DT_QINT8, DT_QUINT8, DT_INT8, DT_UINT8,
|
DT_COMPLEX64, DT_INT64, DT_UINT64, DT_QINT8, DT_QUINT8, DT_INT8, DT_UINT8,
|
||||||
DT_INT16, DT_UINT16}};
|
DT_INT16, DT_UINT16}};
|
||||||
|
|
||||||
|
// For the given shape, chooses a layout for infeed on TPU. The returned shape
|
||||||
|
// has the same dimensions as the original shape, and only the layout is
|
||||||
|
// changed.
|
||||||
|
xla::Shape GetTPUInfeedLayout(const xla::Shape& shape);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_TPU_DEFS_H_
|
#endif // TENSORFLOW_CORE_TPU_TPU_DEFS_H_
|
||||||
|
@ -161,6 +161,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
|||||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralFromDevice);
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralFromDevice);
|
||||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
|
||||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetInfeedLayout);
|
||||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_LinearizeToBuffers);
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_LinearizeToBuffers);
|
||||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_FreeBuffers);
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_FreeBuffers);
|
||||||
|
|
||||||
|
@ -203,10 +203,12 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_transfer_manager_interface",
|
name = "tpu_transfer_manager_interface",
|
||||||
|
srcs = ["tpu_transfer_manager_interface.cc"],
|
||||||
hdrs = ["tpu_transfer_manager_interface.h"],
|
hdrs = ["tpu_transfer_manager_interface.h"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":noncopyable_buffer",
|
":noncopyable_buffer",
|
||||||
|
":tpu_platform_interface",
|
||||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -182,6 +182,8 @@ void TpuTransferManager_WriteSingleTupleIndexTable(
|
|||||||
XLA_TransferManager* manager, SE_Stream* stream,
|
XLA_TransferManager* manager, SE_Stream* stream,
|
||||||
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
|
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
|
||||||
SE_DeviceMemoryBase* region, SE_Status* status);
|
SE_DeviceMemoryBase* region, SE_Status* status);
|
||||||
|
void TpuTransferManager_GetInfeedLayout(XLA_Shape* shape,
|
||||||
|
XLA_Shape* infeed_shape);
|
||||||
void TpuTransferManager_LinearizeToBuffers(
|
void TpuTransferManager_LinearizeToBuffers(
|
||||||
XLA_TransferManager* manager, XLA_Literal* c_literal, char*** buffers_array,
|
XLA_TransferManager* manager, XLA_Literal* c_literal, char*** buffers_array,
|
||||||
int64_t** buffers_size, int64_t* buffers_array_size, SE_Status* status);
|
int64_t** buffers_size, int64_t* buffers_array_size, SE_Status* status);
|
||||||
@ -341,6 +343,7 @@ struct TfTpu_ExecutorApiFn {
|
|||||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralFromDevice);
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralFromDevice);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetInfeedLayout);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_LinearizeToBuffers);
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_LinearizeToBuffers);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_FreeBuffers);
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_FreeBuffers);
|
||||||
|
|
||||||
|
@ -81,6 +81,12 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface {
|
|||||||
const xla::Shape& shape,
|
const xla::Shape& shape,
|
||||||
stream_executor::DeviceMemoryBase* region) override;
|
stream_executor::DeviceMemoryBase* region) override;
|
||||||
|
|
||||||
|
Status LinearizeToBuffers(
|
||||||
|
const xla::LiteralSlice& literal,
|
||||||
|
std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) override {
|
||||||
|
LOG(FATAL) << "Not yet implemented.";
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
XLA_TransferManager* manager_;
|
XLA_TransferManager* manager_;
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
/*static*/ TpuTransferManagerInterface*
|
||||||
|
TpuTransferManagerInterface::GetRegisteredTpuTransferManager() {
|
||||||
|
auto* platform = tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform(
|
||||||
|
/*initialize_platform=*/false);
|
||||||
|
if (platform == nullptr) {
|
||||||
|
LOG(ERROR) << "Unable to retrieve registered TPU platform.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto tm = xla::TransferManager::GetForPlatform(platform);
|
||||||
|
if (!tm.ok()) {
|
||||||
|
LOG(ERROR) << "Unable to retrieve TpuTransferManager. No TPU platform is "
|
||||||
|
"registered for platform "
|
||||||
|
<< platform->Name() << " and ID " << platform->id();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return static_cast<TpuTransferManagerInterface*>(tm.ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
@ -24,9 +24,16 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class TpuTransferManagerInterface : public xla::TransferManager {
|
class TpuTransferManagerInterface : public xla::TransferManager {
|
||||||
|
public:
|
||||||
virtual Status TransferBuffersToInfeed(
|
virtual Status TransferBuffersToInfeed(
|
||||||
se::StreamExecutor* executor,
|
se::StreamExecutor* executor,
|
||||||
const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) = 0;
|
const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) = 0;
|
||||||
|
|
||||||
|
virtual Status LinearizeToBuffers(
|
||||||
|
const LiteralSlice& literal,
|
||||||
|
std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) = 0;
|
||||||
|
|
||||||
|
static TpuTransferManagerInterface* GetRegisteredTpuTransferManager();
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user