Introduce additional TPU infeed and outfeed ops
PiperOrigin-RevId: 325542225 Change-Id: Ie972e60d6c5639b71719837c500ecc716eda2ebd
This commit is contained in:
parent
769155a21e
commit
3ba0deba91
@ -88,7 +88,13 @@ cc_library(
|
||||
name = "tpu_defs",
|
||||
srcs = ["tpu_defs.cc"],
|
||||
hdrs = ["tpu_defs.h"],
|
||||
deps = ["//tensorflow/core:protos_all_cc"],
|
||||
deps = [
|
||||
":tpu_api",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/tpu:c_api_conversions",
|
||||
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -28,10 +28,16 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":cross_replica_ops",
|
||||
":host_compute_ops",
|
||||
":image_resize_ops",
|
||||
":infeed_ops",
|
||||
":outfeed_ops",
|
||||
":replication_ops",
|
||||
":topk_ops",
|
||||
":tpu_compile_op",
|
||||
":tpu_configuration_ops",
|
||||
":tpu_execute_op",
|
||||
":tpu_handle_to_key_op",
|
||||
":transfer_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -684,3 +690,104 @@ cc_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "infeed_ops",
|
||||
srcs = ["infeed_ops.cc"],
|
||||
hdrs = ["infeed_ops.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":transfer_ops",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/common_runtime:dma_helper",
|
||||
"//tensorflow/core/framework:protos_all_cc",
|
||||
"//tensorflow/core/kernels:transpose_functor",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/stream_executor:multi_platform_manager",
|
||||
"//tensorflow/stream_executor/tpu:tpu_transfer_manager_base",
|
||||
"//tensorflow/stream_executor/tpu:tpu_transfer_manager_interface",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "transfer_ops",
|
||||
srcs = ["transfer_ops.cc"],
|
||||
hdrs = ["transfer_ops.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:multi_platform_manager",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_transfer_manager_interface",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "outfeed_ops",
|
||||
srcs = ["outfeed_ops.cc"],
|
||||
hdrs = ["outfeed_ops.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":transfer_ops",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/framework:protos_all_cc",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/stream_executor:multi_platform_manager",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image_resize_ops",
|
||||
srcs = ["image_resize_ops.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "replication_ops",
|
||||
srcs = ["replication_ops.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_handle_to_key_op",
|
||||
srcs = ["tpu_handle_to_key_op.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_op_consts",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
155
tensorflow/core/tpu/kernels/image_resize_ops.cc
Normal file
155
tensorflow/core/tpu/kernels/image_resize_ops.cc
Normal file
@ -0,0 +1,155 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TpuCustomResizeOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TpuCustomResizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
|
||||
}
|
||||
|
||||
xla::Shape GetOutputShape(XlaOpKernelContext* ctx) const {
|
||||
std::vector<int64> out_size;
|
||||
auto status = ctx->ConstantInputAsIntVector(1, &out_size);
|
||||
CHECK_EQ(out_size.size(), 2) << status.ToString();
|
||||
xla::Shape output_shape =
|
||||
TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
|
||||
output_shape.mutable_dimensions()[1] = out_size[0];
|
||||
output_shape.mutable_dimensions()[2] = out_size[1];
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
string OpaqueField() const {
|
||||
return absl::StrCat("\"", align_corners_, half_pixel_centers_, "\"");
|
||||
}
|
||||
|
||||
void CompileGrad(XlaOpKernelContext* ctx, const char* target,
|
||||
const xla::Shape& output_shape) {
|
||||
auto input_shape =
|
||||
TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
|
||||
if (ctx->InputShape(1).dim_sizes() == ctx->InputShape(0).dim_sizes()) {
|
||||
ctx->SetOutput(
|
||||
0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
|
||||
return;
|
||||
}
|
||||
// The gradient should be done in two phases for large resizes.
|
||||
auto input = ctx->Input(0);
|
||||
if (input_shape.dimensions(1) / output_shape.dimensions(1) > 3 &&
|
||||
input_shape.dimensions(2) / output_shape.dimensions(2) > 3) {
|
||||
auto intermediate_shape = output_shape;
|
||||
intermediate_shape.mutable_dimensions()[1] = input_shape.dimensions(1);
|
||||
input = xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
|
||||
intermediate_shape, OpaqueField());
|
||||
}
|
||||
ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {input},
|
||||
output_shape, OpaqueField()));
|
||||
}
|
||||
|
||||
void CompileForward(XlaOpKernelContext* ctx, const char* target) {
|
||||
auto output_shape = GetOutputShape(ctx);
|
||||
if (ctx->InputShape(0).dim_size(1) == output_shape.dimensions(1) &&
|
||||
ctx->InputShape(0).dim_size(2) == output_shape.dimensions(2)) {
|
||||
ctx->SetOutput(
|
||||
0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
|
||||
return;
|
||||
}
|
||||
if (ctx->InputShape(0).dim_size(1) == 1 &&
|
||||
ctx->InputShape(0).dim_size(2) == 1) {
|
||||
ctx->SetOutput(0,
|
||||
ctx->Input(0) + xla::Zeros(ctx->builder(), output_shape));
|
||||
return;
|
||||
}
|
||||
ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
|
||||
output_shape, OpaqueField()));
|
||||
}
|
||||
|
||||
private:
|
||||
bool align_corners_;
|
||||
bool half_pixel_centers_;
|
||||
};
|
||||
|
||||
class TpuResizeNearestNeighborOp : public TpuCustomResizeOp {
|
||||
public:
|
||||
explicit TpuResizeNearestNeighborOp(OpKernelConstruction* ctx)
|
||||
: TpuCustomResizeOp(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
CompileForward(ctx, "ResizeNearest");
|
||||
}
|
||||
};
|
||||
|
||||
class TpuResizeBilinearOp : public TpuCustomResizeOp {
|
||||
public:
|
||||
explicit TpuResizeBilinearOp(OpKernelConstruction* ctx)
|
||||
: TpuCustomResizeOp(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
CompileForward(ctx, "ResizeBilinear");
|
||||
}
|
||||
};
|
||||
|
||||
class TpuResizeNearestNeighborGradOp : public TpuCustomResizeOp {
|
||||
public:
|
||||
explicit TpuResizeNearestNeighborGradOp(OpKernelConstruction* ctx)
|
||||
: TpuCustomResizeOp(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
CompileGrad(ctx, "ResizeNearestGrad", GetOutputShape(ctx));
|
||||
}
|
||||
};
|
||||
|
||||
class TpuResizeBilinearGradOp : public TpuCustomResizeOp {
|
||||
public:
|
||||
explicit TpuResizeBilinearGradOp(OpKernelConstruction* ctx)
|
||||
: TpuCustomResizeOp(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto output_shape =
|
||||
TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(1));
|
||||
CompileGrad(ctx, "ResizeBilinearGrad", output_shape);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ResizeNearestNeighbor")
|
||||
.CompileTimeConstantInput("size")
|
||||
.Device(DEVICE_TPU_XLA_JIT),
|
||||
TpuResizeNearestNeighborOp);
|
||||
|
||||
REGISTER_XLA_OP(Name("ResizeNearestNeighborGrad")
|
||||
.CompileTimeConstantInput("size")
|
||||
.Device(DEVICE_TPU_XLA_JIT),
|
||||
TpuResizeNearestNeighborGradOp);
|
||||
|
||||
REGISTER_XLA_OP(Name("ResizeBilinear")
|
||||
.CompileTimeConstantInput("size")
|
||||
.Device(DEVICE_TPU_XLA_JIT),
|
||||
TpuResizeBilinearOp);
|
||||
|
||||
REGISTER_XLA_OP(Name("ResizeBilinearGrad").Device(DEVICE_TPU_XLA_JIT),
|
||||
TpuResizeBilinearGradOp);
|
||||
|
||||
} // namespace tensorflow
|
529
tensorflow/core/tpu/kernels/infeed_ops.cc
Normal file
529
tensorflow/core/tpu/kernels/infeed_ops.cc
Normal file
@ -0,0 +1,529 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/infeed_ops.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function_handle_cache.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef tensorflow::tpu::NoncopyableBuffer LinearizerBuffer;
|
||||
typedef std::deque<LinearizerBuffer> LinearizerBufferList;
|
||||
|
||||
// Transposes the given tensor using the tensorflow C++ transpose implementation
|
||||
// to obtain a XLA literal for the host tensor laid out as the given layout. The
|
||||
// returned tensor is normalized to the dim0major layout -- F32[10,20,30]{2,0,1}
|
||||
// is returned as F32[20,10,30]{2,1,0}.
|
||||
xla::StatusOr<Tensor> TransposeTensor(OpKernelContext* ctx,
|
||||
const Tensor& input_tensor,
|
||||
const xla::Shape& xla_shape) {
|
||||
profiler::TraceMe trace_me("TransposeTensor", /*level=*/2);
|
||||
const int64 rank = xla_shape.rank();
|
||||
std::vector<int32> permutation(rank);
|
||||
std::vector<int64> transposed_shapes(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
permutation[i] = xla_shape.layout().minor_to_major(rank - 1 - i);
|
||||
transposed_shapes[i] = xla_shape.dimensions(permutation[i]);
|
||||
}
|
||||
|
||||
Tensor transposed_tensor;
|
||||
|
||||
// If this is a trivial transpose (i.e., bitcast), just create an aliased
|
||||
// tensor with the transposed shape.
|
||||
if (xla::LayoutUtil::IsMonotonicWithDim0Major(
|
||||
xla::ShapeUtil::DropDegenerateDimensions(xla_shape).layout())) {
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(transposed_shapes, &shape));
|
||||
TF_RETURN_IF_ERROR(transposed_tensor.BitcastFrom(
|
||||
input_tensor, input_tensor.dtype(), shape));
|
||||
return transposed_tensor;
|
||||
}
|
||||
|
||||
AllocatorAttributes alloc_attr;
|
||||
alloc_attr.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(input_tensor.dtype(),
|
||||
TensorShape(transposed_shapes),
|
||||
&transposed_tensor, alloc_attr));
|
||||
// Eigen Transpose fails with SIGFPE if there is a dimension of size 0.
|
||||
if (input_tensor.NumElements() > 0) {
|
||||
TF_RETURN_IF_ERROR(DoTranspose<CPUDevice>(ctx->eigen_device<CPUDevice>(),
|
||||
input_tensor, permutation,
|
||||
&transposed_tensor));
|
||||
}
|
||||
return transposed_tensor;
|
||||
}
|
||||
|
||||
xla::StatusOr<bool> GetLayoutOverride(OpKernelConstruction* ctx,
|
||||
const char* attrn_name,
|
||||
std::vector<int64>* minor_to_major) {
|
||||
if (!ctx->HasAttr(attrn_name)) {
|
||||
return false;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr(attrn_name, minor_to_major));
|
||||
return !minor_to_major->empty();
|
||||
}
|
||||
|
||||
Status GetInfeedShapeWithLayout(OpKernelConstruction* ctx,
|
||||
const char* attrn_name,
|
||||
const xla::Shape& input_shape,
|
||||
xla::Shape* output_shape) {
|
||||
std::vector<int64> minor_to_major;
|
||||
TF_ASSIGN_OR_RETURN(bool has_override,
|
||||
GetLayoutOverride(ctx, attrn_name, &minor_to_major));
|
||||
if (!has_override) {
|
||||
*output_shape = input_shape;
|
||||
if (output_shape->IsTuple()) {
|
||||
int64 tuple_elements = xla::ShapeUtil::TupleElementCount(*output_shape);
|
||||
for (int64 i = 0; i < tuple_elements; ++i) {
|
||||
xla::Shape* sub_shape =
|
||||
xla::ShapeUtil::GetMutableSubshape(output_shape, {i});
|
||||
*sub_shape->mutable_layout() = GetTPUInfeedLayout(*sub_shape).layout();
|
||||
}
|
||||
} else {
|
||||
*output_shape->mutable_layout() =
|
||||
GetTPUInfeedLayout(*output_shape).layout();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto layout_func = [](const xla::Shape& shape) -> xla::Layout {
|
||||
return GetTPUInfeedLayout(shape).layout();
|
||||
};
|
||||
return GetShapeWithLayout(input_shape, minor_to_major, layout_func,
|
||||
output_shape);
|
||||
}
|
||||
|
||||
// LinearizedBuffersWrapper is an opaque C++ data structure for the outputs of
|
||||
// PrelinearizeOp and PrelinearizeTupleOp. It holds the resultant linearized
|
||||
// buffers and references to input tensors whose underlying storage are shared
|
||||
// with linearized buffers.
|
||||
// NOTE: This is not a feature-complete implementation of the DT_VARIANT
|
||||
// specification. In particular, we cannot currently serialize an arbitrary
|
||||
// `LinearizerBufferList` (aka `std::deque<LinearizerBuffer>`)
|
||||
// object, so the `Encode()` and `Decode()` methods are not implemented.
|
||||
struct LinearizedBuffersWrapper {
|
||||
explicit LinearizedBuffersWrapper() {}
|
||||
explicit LinearizedBuffersWrapper(LinearizerBufferList bufs,
|
||||
std::vector<tensorflow::Tensor> ts)
|
||||
: buffers(std::move(bufs)), tensors(std::move(ts)) {}
|
||||
LinearizedBuffersWrapper(const LinearizedBuffersWrapper& wrapper) {
|
||||
// tensorflow::Variant requires this copy constructor to compile.
|
||||
LOG(FATAL) << "LinearizedBuffersWrapper should not copy.";
|
||||
}
|
||||
LinearizedBuffersWrapper& operator=(const LinearizedBuffersWrapper& wrapper) =
|
||||
delete;
|
||||
LinearizedBuffersWrapper(LinearizedBuffersWrapper&&) = default;
|
||||
LinearizedBuffersWrapper& operator=(LinearizedBuffersWrapper&&) = default;
|
||||
~LinearizedBuffersWrapper() = default;
|
||||
|
||||
// These functions are tensorflow::Variant requirements.
|
||||
string TypeName() const { return "(anonymous)::LinearizedBuffersWrapper"; }
|
||||
void Encode(tensorflow::VariantTensorData* data) const {
|
||||
LOG(ERROR) << "Encode() is not implemented for LinearizedBuffersWrapper "
|
||||
"objects.";
|
||||
}
|
||||
bool Decode(const tensorflow::VariantTensorData& data) {
|
||||
LOG(ERROR) << "Decode() is not implemented for LinearizedBuffersWrapper "
|
||||
"objects.";
|
||||
return false;
|
||||
}
|
||||
|
||||
LinearizerBufferList buffers;
|
||||
// Save references on tensors whose underlying storage are shared with
|
||||
// LiteralLinearizer::Buffer in `buffers`.
|
||||
std::vector<tensorflow::Tensor> tensors;
|
||||
};
|
||||
|
||||
Status AutoTransposeAndLinearize(OpKernelContext* ctx,
|
||||
const Tensor& input_tensor,
|
||||
const xla::Shape& shape,
|
||||
LinearizerBufferList* linearized_buffers,
|
||||
std::vector<Tensor>* saved_input_tensors) {
|
||||
const Tensor* tensor = &input_tensor;
|
||||
// If the given layout is not in dim0major layout, tranposes the tensor.
|
||||
bool has_transposed = false;
|
||||
Tensor transposed_tensor;
|
||||
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
|
||||
// If the given layout is not in dim0major layout, transpose the tensor.
|
||||
TF_ASSIGN_OR_RETURN(transposed_tensor,
|
||||
TransposeTensor(ctx, input_tensor, shape));
|
||||
tensor = &transposed_tensor;
|
||||
has_transposed = true;
|
||||
}
|
||||
|
||||
xla::BorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager()
|
||||
->LinearizeToBuffers(literal, linearized_buffers));
|
||||
|
||||
// The input tensor is ref-counted. Save a handle on the input tensor if
|
||||
// its underlying storage is shared with linearized buffers to prevent
|
||||
// input tensor from getting freed.
|
||||
for (const auto& buffer : *linearized_buffers) {
|
||||
if (!buffer.owns_data() && !has_transposed) {
|
||||
// `buffer` is created from zero-copy fast path from the un-transposed
|
||||
// input tensor so its underlying data is shared with input tensor.
|
||||
// Save a handle to input tensor to increment its ref-count and avoid
|
||||
// it getting deallocated after PrelinearizeTupleOp completes.
|
||||
saved_input_tensors->push_back(*tensor);
|
||||
// A literal can be linearized to zero to two buffers. If any of the
|
||||
// linearized buffer shares storage with input tensor. We save exactly
|
||||
// one handle on the input tensor.
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// PrelinearizeOp is used to linearize one tensor to the device format.
|
||||
class PrelinearizeOp : public OpKernel {
|
||||
public:
|
||||
explicit PrelinearizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
xla::Shape shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& input_tensor = ctx->input(0);
|
||||
// Validate input.
|
||||
OP_REQUIRES(
|
||||
ctx, input_tensor.dtype() == dtype_,
|
||||
errors::InvalidArgument("Prelinearize dtype mismatch; expected ",
|
||||
DataType_Name(dtype_), ", got ",
|
||||
DataType_Name(input_tensor.dtype())));
|
||||
OP_REQUIRES(
|
||||
ctx, input_tensor.shape() == shape_,
|
||||
errors::InvalidArgument("Prelinearize shape mismatch; expected ",
|
||||
shape_.DebugString(), ", got ",
|
||||
input_tensor.shape().DebugString()));
|
||||
|
||||
// Auto-transpose and prelinearize.
|
||||
LinearizerBufferList linearized_buffers;
|
||||
std::vector<Tensor> saved_input_tensors;
|
||||
auto status =
|
||||
AutoTransposeAndLinearize(ctx, input_tensor, xla_shape_,
|
||||
&linearized_buffers, &saved_input_tensors);
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
|
||||
// Write to output.
|
||||
tensorflow::Tensor* output;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
|
||||
output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
|
||||
std::move(linearized_buffers), std::move(saved_input_tensors)};
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return true; }
|
||||
|
||||
private:
|
||||
TensorShape shape_;
|
||||
DataType dtype_;
|
||||
xla::Shape xla_shape_;
|
||||
|
||||
// PrelinearizeOp is neither copyable nor movable.
|
||||
PrelinearizeOp(const PrelinearizeOp&) = delete;
|
||||
PrelinearizeOp& operator=(const PrelinearizeOp&) = delete;
|
||||
};
|
||||
|
||||
// PrelinearizeTupleOp is used to linearize multiple tensors to the device
|
||||
// format.
|
||||
class PrelinearizeTupleOp : public OpKernel {
|
||||
public:
|
||||
explicit PrelinearizeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||
OP_REQUIRES(
|
||||
ctx, shapes_.size() == dtypes_.size(),
|
||||
errors::InvalidArgument(
|
||||
"shapes and dtypes must be the same length. shapes length = ",
|
||||
shapes_.size(), ", dtypes length = ", dtypes_.size()));
|
||||
|
||||
std::vector<xla::Shape> xla_shapes;
|
||||
for (int i = 0; i < shapes_.size(); i++) {
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
|
||||
xla_shapes.push_back(xla_shape);
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInfeedShapeWithLayout(
|
||||
ctx, "layouts", xla::ShapeUtil::MakeTupleShape(xla_shapes),
|
||||
&tuple_shape_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
OpInputList values;
|
||||
OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values));
|
||||
OP_REQUIRES(ctx, values.size() == shapes_.size(),
|
||||
errors::InvalidArgument(
|
||||
"Wrong number of inputs to PrelinearizeTuple."));
|
||||
|
||||
LinearizerBufferList all_linearized_buffers;
|
||||
std::vector<Tensor> all_saved_input_tensors;
|
||||
for (int i = 0; i < values.size(); i++) {
|
||||
// Validate input.
|
||||
const Tensor& input_tensor = values[i];
|
||||
OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i],
|
||||
errors::InvalidArgument(
|
||||
"PrelinearizeTuple dtype mismatch at tuple element ", i,
|
||||
"; expected ", DataType_Name(dtypes_[i]), ", got ",
|
||||
DataType_Name(input_tensor.dtype())));
|
||||
OP_REQUIRES(ctx, input_tensor.shape() == shapes_[i],
|
||||
errors::InvalidArgument(
|
||||
"PrelinearizeTuple shape mismatch at tuple element ", i,
|
||||
"; expected ", shapes_[i].DebugString(), ", got ",
|
||||
input_tensor.shape().DebugString()));
|
||||
|
||||
// Auto-transpose and prelinearize.
|
||||
LinearizerBufferList linearized_buffers;
|
||||
std::vector<Tensor> saved_input_tensors;
|
||||
auto status = AutoTransposeAndLinearize(
|
||||
ctx, input_tensor, tuple_shape_.tuple_shapes(i), &linearized_buffers,
|
||||
&saved_input_tensors);
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
all_linearized_buffers.insert(
|
||||
all_linearized_buffers.end(),
|
||||
std::make_move_iterator(linearized_buffers.begin()),
|
||||
std::make_move_iterator(linearized_buffers.end()));
|
||||
all_saved_input_tensors.insert(
|
||||
all_saved_input_tensors.end(),
|
||||
std::make_move_iterator(saved_input_tensors.begin()),
|
||||
std::make_move_iterator(saved_input_tensors.end()));
|
||||
}
|
||||
|
||||
tensorflow::Tensor* output;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
|
||||
output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
|
||||
std::move(all_linearized_buffers), std::move(all_saved_input_tensors)};
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return true; }
|
||||
|
||||
private:
|
||||
std::vector<TensorShape> shapes_;
|
||||
DataTypeVector dtypes_;
|
||||
xla::Shape tuple_shape_;
|
||||
|
||||
// PrelinearizeTupleOp is neither copyable nor movable.
|
||||
PrelinearizeTupleOp(const PrelinearizeTupleOp&) = delete;
|
||||
PrelinearizeTupleOp& operator=(const PrelinearizeTupleOp&) = delete;
|
||||
};
|
||||
|
||||
// The InfeedEnqueuePrelinearizedBufferOp op is used to transfer prelinearized
|
||||
// buffers to the device infeed queue.
|
||||
class InfeedEnqueuePrelinearizedBufferOp : public TpuTransferAsyncOpKernel {
|
||||
public:
|
||||
explicit InfeedEnqueuePrelinearizedBufferOp(OpKernelConstruction* ctx)
|
||||
: TpuTransferAsyncOpKernel(ctx, "prelinearized_buffers_to_infeed", 8) {}
|
||||
|
||||
Status DoWork(OpKernelContext* ctx,
|
||||
xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) override {
|
||||
const Tensor& input_tensor = ctx->input(0);
|
||||
const LinearizedBuffersWrapper* wrapper =
|
||||
input_tensor.scalar<tensorflow::Variant>()()
|
||||
.get<LinearizedBuffersWrapper>();
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferBuffersToInfeed(
|
||||
stream_executor, wrapper->buffers));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// InfeedEnqueuePrelinearizedBufferOp is neither copyable nor movable.
|
||||
InfeedEnqueuePrelinearizedBufferOp(
|
||||
const InfeedEnqueuePrelinearizedBufferOp&) = delete;
|
||||
InfeedEnqueuePrelinearizedBufferOp& operator=(
|
||||
const InfeedEnqueuePrelinearizedBufferOp&) = delete;
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TpuInfeedEnqueueOp::TpuInfeedEnqueueOp(OpKernelConstruction* ctx)
|
||||
: TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
xla::Shape shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
|
||||
}
|
||||
|
||||
Status TpuInfeedEnqueueOp::DoWork(
|
||||
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) {
|
||||
const Tensor& input_tensor = ctx->input(0);
|
||||
|
||||
// Validate runtime shape and fail if it doesn't match the contract.
|
||||
if (input_tensor.dtype() != dtype_) {
|
||||
return errors::InvalidArgument("Infeed dtype mismatch.");
|
||||
}
|
||||
if (input_tensor.shape() != shape_) {
|
||||
return errors::InvalidArgument("Infeed shape mismatch; expected ",
|
||||
shape_.DebugString(), ", got ",
|
||||
input_tensor.shape().DebugString());
|
||||
}
|
||||
|
||||
const Tensor* tensor = &input_tensor;
|
||||
Tensor transposed_tensor;
|
||||
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(xla_shape_.layout())) {
|
||||
// If the given layout is not in dim0major layout, transpose the tensor.
|
||||
TF_ASSIGN_OR_RETURN(transposed_tensor,
|
||||
TransposeTensor(ctx, input_tensor, xla_shape_));
|
||||
tensor = &transposed_tensor;
|
||||
}
|
||||
|
||||
xla::BorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
|
||||
|
||||
// Transfer the given literal to the Infeed interface of the device.
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->TransferLiteralToInfeed(stream_executor, literal));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TpuInfeedEnqueueTupleOp::TpuInfeedEnqueueTupleOp(OpKernelConstruction* ctx)
|
||||
: TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||
OP_REQUIRES(
|
||||
ctx, shapes_.size() == dtypes_.size(),
|
||||
errors::InvalidArgument("shapes and dtypes must be the same length."));
|
||||
|
||||
std::vector<xla::Shape> xla_shapes;
|
||||
for (int i = 0; i < shapes_.size(); i++) {
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
|
||||
xla_shapes.push_back(xla_shape);
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInfeedShapeWithLayout(ctx, "layouts",
|
||||
xla::ShapeUtil::MakeTupleShape(xla_shapes),
|
||||
&tuple_shape_));
|
||||
}
|
||||
|
||||
Status TpuInfeedEnqueueTupleOp::DoWork(
|
||||
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) {
|
||||
OpInputList values;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("inputs", &values));
|
||||
if (values.size() != shapes_.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Wrong number of inputs to InfeedEnqueueTuple.");
|
||||
}
|
||||
|
||||
for (const auto& shapes : shapes_) {
|
||||
VLOG(1) << "TransferLiteralToInfeed " << shapes.DebugString();
|
||||
}
|
||||
|
||||
std::vector<Tensor> maybe_transposed_tensors;
|
||||
maybe_transposed_tensors.reserve(values.size());
|
||||
for (int i = 0; i < values.size(); i++) {
|
||||
// Validate runtime shapes and fail if it doesn't match the contract.
|
||||
const Tensor* tensor = &values[i];
|
||||
if (tensor->shape() != shapes_[i]) {
|
||||
return errors::InvalidArgument("Infeed shape mismatch for tuple element ",
|
||||
i, "; expected ", shapes_[i].DebugString(),
|
||||
", got ", tensor->shape().DebugString());
|
||||
}
|
||||
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(
|
||||
tuple_shape_.tuple_shapes(i).layout())) {
|
||||
// If the given layout is not in dim0major layout, tranposes the given
|
||||
// tensor.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Tensor transposed_tensor,
|
||||
TransposeTensor(ctx, *tensor, tuple_shape_.tuple_shapes(i)));
|
||||
maybe_transposed_tensors.emplace_back(transposed_tensor);
|
||||
} else {
|
||||
maybe_transposed_tensors.emplace_back(*tensor);
|
||||
}
|
||||
}
|
||||
|
||||
xla::BorrowingLiteral tuple;
|
||||
TF_RETURN_IF_ERROR(
|
||||
HostTensorsToBorrowingLiteralTuple(maybe_transposed_tensors, &tuple));
|
||||
|
||||
// Transfer the given literal to the Infeed interface of the device.
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->TransferLiteralToInfeed(stream_executor, tuple));
|
||||
|
||||
VLOG(1) << "TransferLiteralToInfeed complete.";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// These ops execute on either the TPU device or the CPU device. When running on
|
||||
// CPU they must specify a non-negative value for device_ordinal to indicate
|
||||
// which TPU to send infeed to.
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("InfeedEnqueue").Device(DEVICE_TPU_NODE).HostMemory("input"),
|
||||
TpuInfeedEnqueueOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("InfeedEnqueue").Device(DEVICE_CPU),
|
||||
TpuInfeedEnqueueOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("InfeedEnqueueTuple").Device(DEVICE_TPU_NODE).HostMemory("inputs"),
|
||||
TpuInfeedEnqueueTupleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("InfeedEnqueueTuple").Device(DEVICE_CPU),
|
||||
TpuInfeedEnqueueTupleOp);
|
||||
|
||||
// Prelinearize ops run on CPU as part of tf.data input pipeline.
|
||||
REGISTER_KERNEL_BUILDER(Name("Prelinearize").Device(DEVICE_CPU),
|
||||
PrelinearizeOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("PrelinearizeTuple").Device(DEVICE_CPU),
|
||||
PrelinearizeTupleOp);
|
||||
|
||||
// InfeedEnqueuePrelinearizedBuffer op run on CPU and takes a device_ordinal to
|
||||
// select the right device to infeed.
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("InfeedEnqueuePrelinearizedBuffer").Device(DEVICE_CPU),
|
||||
InfeedEnqueuePrelinearizedBufferOp);
|
||||
|
||||
} // namespace tensorflow
|
69
tensorflow/core/tpu/kernels/infeed_ops.h
Normal file
69
tensorflow/core/tpu/kernels/infeed_ops.h
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(b/65200690): Rework this when there is a callback based infeed API to
|
||||
// StreamExecutor.
|
||||
|
||||
// The InfeedEnqueue op is used to deliver data to the device infeed queue.
|
||||
class TpuInfeedEnqueueOp : public TpuTransferAsyncOpKernel {
|
||||
public:
|
||||
explicit TpuInfeedEnqueueOp(OpKernelConstruction* ctx);
|
||||
Status DoWork(OpKernelContext* ctx,
|
||||
xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) override;
|
||||
|
||||
private:
|
||||
TensorShape shape_;
|
||||
DataType dtype_;
|
||||
xla::Shape xla_shape_;
|
||||
|
||||
// TpuInfeedEnqueueOp is neither copyable nor movable.
|
||||
TpuInfeedEnqueueOp(const TpuInfeedEnqueueOp&) = delete;
|
||||
TpuInfeedEnqueueOp& operator=(const TpuInfeedEnqueueOp&) = delete;
|
||||
};
|
||||
|
||||
// The InfeedEnqueueTuple op is used on the host to deliver multiple tensors to
|
||||
// the device infeed queue as an XLA tuple.
|
||||
class TpuInfeedEnqueueTupleOp : public TpuTransferAsyncOpKernel {
|
||||
public:
|
||||
explicit TpuInfeedEnqueueTupleOp(OpKernelConstruction* ctx);
|
||||
Status DoWork(OpKernelContext* ctx,
|
||||
xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) override;
|
||||
|
||||
private:
|
||||
std::vector<TensorShape> shapes_;
|
||||
DataTypeVector dtypes_;
|
||||
xla::Shape tuple_shape_;
|
||||
|
||||
// TpuInfeedEnqueueTupleOp is neither copyable nor movable.
|
||||
TpuInfeedEnqueueTupleOp(const TpuInfeedEnqueueTupleOp&) = delete;
|
||||
TpuInfeedEnqueueTupleOp& operator=(const TpuInfeedEnqueueTupleOp&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
|
116
tensorflow/core/tpu/kernels/outfeed_ops.cc
Normal file
116
tensorflow/core/tpu/kernels/outfeed_ops.cc
Normal file
@ -0,0 +1,116 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/outfeed_ops.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TpuOutfeedDequeueOp::TpuOutfeedDequeueOp(OpKernelConstruction* ctx)
|
||||
: TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
|
||||
}
|
||||
|
||||
Status TpuOutfeedDequeueOp::DoWork(
|
||||
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) {
|
||||
Tensor* output;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(0, shape_, &output));
|
||||
|
||||
// Transfer from the outfeed interface of the device.
|
||||
xla::MutableBorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(
|
||||
HostTensorToMutableBorrowingLiteral(xla_shape_, output, &literal));
|
||||
|
||||
VLOG(1) << "TransferLiteralFromOutfeed "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(xla_shape_);
|
||||
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
|
||||
stream_executor, xla_shape_, literal));
|
||||
|
||||
VLOG(1) << "TransferLiteralFromOutfeed complete.";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
|
||||
// device outfeed queue.
|
||||
TpuOutfeedDequeueTupleOp::TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx)
|
||||
: TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||
OP_REQUIRES(
|
||||
ctx, shapes_.size() == dtypes_.size(),
|
||||
errors::InvalidArgument("shapes and dtypes must be the same length."));
|
||||
// The `dtypes` list is inferred from the supplied inputs, so it
|
||||
// is always the correct length.
|
||||
for (int i = 0; i < shapes_.size(); i++) {
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
|
||||
xla_shapes_.push_back(xla_shape);
|
||||
}
|
||||
tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
|
||||
}
|
||||
|
||||
Status TpuOutfeedDequeueTupleOp::DoWork(
|
||||
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) {
|
||||
VLOG(1) << "TransferLiteralFromOutfeed "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(tuple_shape_);
|
||||
|
||||
for (int i = 0; i < shapes_.size(); ++i) {
|
||||
Tensor* output;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shapes_[i], &output));
|
||||
|
||||
xla::MutableBorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(
|
||||
HostTensorToMutableBorrowingLiteral(xla_shapes_[i], output, &literal));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
|
||||
stream_executor, xla_shapes_[i], literal));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// These ops execute on either the TPU device or the CPU device. When
|
||||
// running on CPU they must specify a non-negative value for
|
||||
// device_ordinal to indicate which TPU to receive outfeed from.
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("OutfeedDequeue").Device(DEVICE_TPU_NODE).HostMemory("output"),
|
||||
TpuOutfeedDequeueOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("OutfeedDequeue").Device(DEVICE_CPU),
|
||||
TpuOutfeedDequeueOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("OutfeedDequeueTuple").Device(DEVICE_TPU_NODE).HostMemory("outputs"),
|
||||
TpuOutfeedDequeueTupleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("OutfeedDequeueTuple").Device(DEVICE_CPU),
|
||||
TpuOutfeedDequeueTupleOp);
|
||||
|
||||
} // namespace tensorflow
|
69
tensorflow/core/tpu/kernels/outfeed_ops.h
Normal file
69
tensorflow/core/tpu/kernels/outfeed_ops.h
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The OutfeedDequeue op is used to retrieve a single tensor from the device
|
||||
// outfeed queue.
|
||||
class TpuOutfeedDequeueOp : public TpuTransferAsyncOpKernel {
|
||||
public:
|
||||
explicit TpuOutfeedDequeueOp(OpKernelConstruction* ctx);
|
||||
|
||||
Status DoWork(OpKernelContext* ctx,
|
||||
xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) override;
|
||||
|
||||
private:
|
||||
TensorShape shape_;
|
||||
DataType dtype_;
|
||||
xla::Shape xla_shape_;
|
||||
|
||||
// OutfeedDequeueOp is neither copyable nor movable.
|
||||
TpuOutfeedDequeueOp(const TpuOutfeedDequeueOp&) = delete;
|
||||
TpuOutfeedDequeueOp& operator=(const TpuOutfeedDequeueOp&) = delete;
|
||||
};
|
||||
|
||||
// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
|
||||
// device outfeed queue.
|
||||
class TpuOutfeedDequeueTupleOp : public TpuTransferAsyncOpKernel {
|
||||
public:
|
||||
explicit TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx);
|
||||
|
||||
Status DoWork(OpKernelContext* ctx,
|
||||
xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) override;
|
||||
|
||||
private:
|
||||
std::vector<TensorShape> shapes_;
|
||||
DataTypeVector dtypes_;
|
||||
std::vector<xla::Shape> xla_shapes_;
|
||||
xla::Shape tuple_shape_;
|
||||
|
||||
// OutfeedDequeueTupleOp is neither copyable nor movable.
|
||||
TpuOutfeedDequeueTupleOp(const TpuOutfeedDequeueTupleOp&) = delete;
|
||||
TpuOutfeedDequeueTupleOp& operator=(const TpuOutfeedDequeueTupleOp&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
|
27
tensorflow/core/tpu/kernels/replication_ops.cc
Normal file
27
tensorflow/core/tpu/kernels/replication_ops.cc
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_TPUReplicate").Device(DEVICE_TPU_SYSTEM),
|
||||
XlaDeviceDummyOp);
|
||||
|
||||
} // namespace tensorflow
|
62
tensorflow/core/tpu/kernels/tpu_handle_to_key_op.cc
Normal file
62
tensorflow/core/tpu/kernels/tpu_handle_to_key_op.cc
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TpuHandleToProtoKeyOp : public OpKernel {
|
||||
public:
|
||||
explicit TpuHandleToProtoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
~TpuHandleToProtoKeyOp() override = default;
|
||||
TpuHandleToProtoKeyOp(const TpuHandleToProtoKeyOp&) = delete;
|
||||
TpuHandleToProtoKeyOp& operator=(const TpuHandleToProtoKeyOp&) = delete;
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
VLOG(1) << "TpuHandleToProtoKeyOp::Compute " << ctx->op_kernel().name()
|
||||
<< " on device " << ctx->op_kernel().requested_device();
|
||||
const Tensor& uid = ctx->input(0);
|
||||
|
||||
ResourceMgr* rm = GetTPUConfigResourceMgr();
|
||||
tpu::TpuCompilationCacheInterface* cache;
|
||||
OP_REQUIRES_OK(ctx, rm->Lookup<tpu::TpuCompilationCacheInterface>(
|
||||
rm->default_container(),
|
||||
tpu::kCompilationCacheResourceName, &cache));
|
||||
core::ScopedUnref cache_unref(cache);
|
||||
|
||||
std::vector<std::string> keys;
|
||||
OP_REQUIRES_OK(ctx, cache->GetKeysFromUid(uid.scalar<int64>()(), &keys));
|
||||
|
||||
TensorShape output_shape;
|
||||
output_shape.AddDim(keys.size());
|
||||
Tensor* result = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
|
||||
for (int i = 0; i < keys.size(); ++i) {
|
||||
result->vec<tstring>()(i) = keys[i];
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TpuHandleToProtoKey").Device(DEVICE_CPU),
|
||||
TpuHandleToProtoKeyOp);
|
||||
|
||||
} // namespace tensorflow
|
98
tensorflow/core/tpu/kernels/transfer_ops.cc
Normal file
98
tensorflow/core/tpu/kernels/transfer_ops.cc
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
|
||||
const string& transfer_type,
|
||||
int number_of_threads)
|
||||
: AsyncOpKernel(ctx),
|
||||
thread_pool_(new thread::ThreadPool(
|
||||
ctx->env(),
|
||||
strings::StrCat(transfer_type, "_thread_",
|
||||
SanitizeThreadSuffix(def().name())),
|
||||
/*num_threads=*/8)) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
OP_REQUIRES(
|
||||
ctx, device_ordinal_ >= 0,
|
||||
errors::InvalidArgument(transfer_type,
|
||||
" ops must specify a device_ordinal when "
|
||||
"placed on CPU."));
|
||||
}
|
||||
}
|
||||
|
||||
void TpuTransferAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
|
||||
DoneCallback done) {
|
||||
CancellationToken token =
|
||||
ctx->cancellation_manager()->get_cancellation_token();
|
||||
bool already_cancelled;
|
||||
{
|
||||
// Only protect registering the cancellation callback as mu_ cannot be held
|
||||
// at a point where `done` could be called.
|
||||
mutex_lock lock(mu_);
|
||||
already_cancelled = !ctx->cancellation_manager()->RegisterCallback(
|
||||
token, [this]() { Cancel(); });
|
||||
}
|
||||
OP_REQUIRES_ASYNC(ctx, !already_cancelled,
|
||||
errors::Cancelled("Infeed was cancelled."), done);
|
||||
thread_pool_->Schedule([this, ctx, done, token]() {
|
||||
Status s = RunTransfer(ctx);
|
||||
ctx->cancellation_manager()->DeregisterCallback(token);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, s, done);
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
|
||||
auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
|
||||
|
||||
int real_device_ordinal = device_ordinal_;
|
||||
if (real_device_ordinal < 0) {
|
||||
const XlaDevice::Metadata* metadata;
|
||||
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
|
||||
real_device_ordinal = metadata->device_ordinal();
|
||||
}
|
||||
stream_executor::StreamExecutor* stream_executor =
|
||||
tpu_platform->ExecutorForDevice(real_device_ordinal).ValueOrDie();
|
||||
|
||||
// When Xprof profiling is off (which is the default), constructing the
|
||||
// activity is simple enough that its overhead is negligible.
|
||||
profiler::TraceMe activity(
|
||||
[this] { return profiler::TraceMeOp(name(), type_string()); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
return DoWork(
|
||||
ctx, xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager(),
|
||||
stream_executor);
|
||||
}
|
||||
|
||||
void TpuTransferAsyncOpKernel::Cancel() {
|
||||
mutex_lock lock(mu_);
|
||||
TF_CHECK_OK(tpu::TpuNodeContext::CloseTpuHost());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
56
tensorflow/core/tpu/kernels/transfer_ops.h
Normal file
56
tensorflow/core/tpu/kernels/transfer_ops.h
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Base class providing common functionality for async ops that transfer from
|
||||
// host to TPU.
|
||||
class TpuTransferAsyncOpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
|
||||
const string& transfer_type,
|
||||
int number_of_threads);
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
|
||||
protected:
|
||||
virtual Status DoWork(OpKernelContext* context,
|
||||
xla::TpuTransferManagerInterface* transfer_manager,
|
||||
stream_executor::StreamExecutor* stream_executor) = 0;
|
||||
|
||||
private:
|
||||
Status RunTransfer(OpKernelContext* ctx);
|
||||
void Cancel();
|
||||
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
int device_ordinal_;
|
||||
mutex mu_;
|
||||
|
||||
// TpuTransferAsyncOpKernel is neither copyable nor movable.
|
||||
TpuTransferAsyncOpKernel(const TpuTransferAsyncOpKernel&) = delete;
|
||||
TpuTransferAsyncOpKernel& operator=(const TpuTransferAsyncOpKernel&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
|
@ -15,6 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const DEVICE_TPU_NODE = "TPU";
|
||||
@ -27,4 +31,18 @@ const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR =
|
||||
const char* const kTPUReplicateAttr = "_tpu_replicate";
|
||||
const char* const kOutsideCompilationAttr = "_xla_outside_compilation";
|
||||
|
||||
xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
|
||||
XLA_Shape c_shape;
|
||||
XLA_Shape c_infeed_shape;
|
||||
|
||||
ApiConverter::ToC(shape, &c_shape);
|
||||
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
|
||||
&c_infeed_shape);
|
||||
xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
|
||||
ApiConverter::Free(&c_shape);
|
||||
ApiConverter::Free(&c_infeed_shape);
|
||||
return infeed_shape;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -56,6 +57,11 @@ static constexpr std::array<DataType, 16> kTpuAllTypes = {
|
||||
DT_COMPLEX64, DT_INT64, DT_UINT64, DT_QINT8, DT_QUINT8, DT_INT8, DT_UINT8,
|
||||
DT_INT16, DT_UINT16}};
|
||||
|
||||
// For the given shape, chooses a layout for infeed on TPU. The returned shape
|
||||
// has the same dimensions as the original shape, and only the layout is
|
||||
// changed.
|
||||
xla::Shape GetTPUInfeedLayout(const xla::Shape& shape);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_DEFS_H_
|
||||
|
@ -161,6 +161,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralFromDevice);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetInfeedLayout);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_LinearizeToBuffers);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_FreeBuffers);
|
||||
|
||||
|
@ -203,10 +203,12 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "tpu_transfer_manager_interface",
|
||||
srcs = ["tpu_transfer_manager_interface.cc"],
|
||||
hdrs = ["tpu_transfer_manager_interface.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":noncopyable_buffer",
|
||||
":tpu_platform_interface",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
],
|
||||
)
|
||||
|
@ -182,6 +182,8 @@ void TpuTransferManager_WriteSingleTupleIndexTable(
|
||||
XLA_TransferManager* manager, SE_Stream* stream,
|
||||
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
|
||||
SE_DeviceMemoryBase* region, SE_Status* status);
|
||||
void TpuTransferManager_GetInfeedLayout(XLA_Shape* shape,
|
||||
XLA_Shape* infeed_shape);
|
||||
void TpuTransferManager_LinearizeToBuffers(
|
||||
XLA_TransferManager* manager, XLA_Literal* c_literal, char*** buffers_array,
|
||||
int64_t** buffers_size, int64_t* buffers_array_size, SE_Status* status);
|
||||
@ -341,6 +343,7 @@ struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralFromDevice);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetInfeedLayout);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_LinearizeToBuffers);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_FreeBuffers);
|
||||
|
||||
|
@ -81,6 +81,12 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface {
|
||||
const xla::Shape& shape,
|
||||
stream_executor::DeviceMemoryBase* region) override;
|
||||
|
||||
Status LinearizeToBuffers(
|
||||
const xla::LiteralSlice& literal,
|
||||
std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) override {
|
||||
LOG(FATAL) << "Not yet implemented.";
|
||||
}
|
||||
|
||||
private:
|
||||
XLA_TransferManager* manager_;
|
||||
};
|
||||
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
/*static*/ TpuTransferManagerInterface*
|
||||
TpuTransferManagerInterface::GetRegisteredTpuTransferManager() {
|
||||
auto* platform = tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform(
|
||||
/*initialize_platform=*/false);
|
||||
if (platform == nullptr) {
|
||||
LOG(ERROR) << "Unable to retrieve registered TPU platform.";
|
||||
return nullptr;
|
||||
}
|
||||
auto tm = xla::TransferManager::GetForPlatform(platform);
|
||||
if (!tm.ok()) {
|
||||
LOG(ERROR) << "Unable to retrieve TpuTransferManager. No TPU platform is "
|
||||
"registered for platform "
|
||||
<< platform->Name() << " and ID " << platform->id();
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<TpuTransferManagerInterface*>(tm.ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -24,9 +24,16 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
class TpuTransferManagerInterface : public xla::TransferManager {
|
||||
public:
|
||||
virtual Status TransferBuffersToInfeed(
|
||||
se::StreamExecutor* executor,
|
||||
const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) = 0;
|
||||
|
||||
virtual Status LinearizeToBuffers(
|
||||
const LiteralSlice& literal,
|
||||
std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) = 0;
|
||||
|
||||
static TpuTransferManagerInterface* GetRegisteredTpuTransferManager();
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user