Introduce additional TPU infeed and outfeed ops

PiperOrigin-RevId: 325542225
Change-Id: Ie972e60d6c5639b71719837c500ecc716eda2ebd
This commit is contained in:
Frank Chen 2020-08-07 17:51:06 -07:00 committed by TensorFlower Gardener
parent 769155a21e
commit 3ba0deba91
19 changed files with 1378 additions and 1 deletions

View File

@ -88,7 +88,13 @@ cc_library(
name = "tpu_defs",
srcs = ["tpu_defs.cc"],
hdrs = ["tpu_defs.h"],
deps = ["//tensorflow/core:protos_all_cc"],
deps = [
":tpu_api",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/tpu:c_api_conversions",
"//tensorflow/stream_executor/tpu:c_api_decl",
],
)
cc_library(

View File

@ -28,10 +28,16 @@ tf_kernel_library(
deps = [
":cross_replica_ops",
":host_compute_ops",
":image_resize_ops",
":infeed_ops",
":outfeed_ops",
":replication_ops",
":topk_ops",
":tpu_compile_op",
":tpu_configuration_ops",
":tpu_execute_op",
":tpu_handle_to_key_op",
":transfer_ops",
],
)
@ -684,3 +690,104 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "infeed_ops",
srcs = ["infeed_ops.cc"],
hdrs = ["infeed_ops.h"],
visibility = ["//visibility:public"],
deps = [
":transfer_ops",
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:framework",
"//tensorflow/core/common_runtime:dma_helper",
"//tensorflow/core/framework:protos_all_cc",
"//tensorflow/core/kernels:transpose_functor",
"//tensorflow/core/platform:status",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor/tpu:tpu_transfer_manager_base",
"//tensorflow/stream_executor/tpu:tpu_transfer_manager_interface",
],
alwayslink = True,
)
cc_library(
name = "transfer_ops",
srcs = ["transfer_ops.cc"],
hdrs = ["transfer_ops.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor/tpu:tpu_node_context",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
"//tensorflow/stream_executor/tpu:tpu_transfer_manager_interface",
],
alwayslink = True,
)
cc_library(
name = "outfeed_ops",
srcs = ["outfeed_ops.cc"],
hdrs = ["outfeed_ops.h"],
visibility = ["//visibility:public"],
deps = [
":transfer_ops",
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/core:framework",
"//tensorflow/core/framework:protos_all_cc",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor:multi_platform_manager",
],
alwayslink = True,
)
cc_library(
name = "image_resize_ops",
srcs = ["image_resize_ops.cc"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/core:framework",
"//tensorflow/core/tpu:tpu_defs",
"@com_google_absl//absl/strings",
],
alwayslink = True,
)
cc_library(
name = "replication_ops",
srcs = ["replication_ops.cc"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
"//tensorflow/core:framework",
"//tensorflow/core/tpu:tpu_defs",
],
alwayslink = True,
)
cc_library(
name = "tpu_handle_to_key_op",
srcs = ["tpu_handle_to_key_op.cc"],
visibility = ["//visibility:public"],
deps = [
":tpu_compilation_cache_interface",
":tpu_op_consts",
"//tensorflow/core:framework",
"//tensorflow/core/tpu:tpu_configuration",
],
alwayslink = True,
)

View File

@ -0,0 +1,155 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
class TpuCustomResizeOp : public XlaOpKernel {
public:
explicit TpuCustomResizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
}
xla::Shape GetOutputShape(XlaOpKernelContext* ctx) const {
std::vector<int64> out_size;
auto status = ctx->ConstantInputAsIntVector(1, &out_size);
CHECK_EQ(out_size.size(), 2) << status.ToString();
xla::Shape output_shape =
TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
output_shape.mutable_dimensions()[1] = out_size[0];
output_shape.mutable_dimensions()[2] = out_size[1];
return output_shape;
}
string OpaqueField() const {
return absl::StrCat("\"", align_corners_, half_pixel_centers_, "\"");
}
void CompileGrad(XlaOpKernelContext* ctx, const char* target,
const xla::Shape& output_shape) {
auto input_shape =
TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
if (ctx->InputShape(1).dim_sizes() == ctx->InputShape(0).dim_sizes()) {
ctx->SetOutput(
0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
return;
}
// The gradient should be done in two phases for large resizes.
auto input = ctx->Input(0);
if (input_shape.dimensions(1) / output_shape.dimensions(1) > 3 &&
input_shape.dimensions(2) / output_shape.dimensions(2) > 3) {
auto intermediate_shape = output_shape;
intermediate_shape.mutable_dimensions()[1] = input_shape.dimensions(1);
input = xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
intermediate_shape, OpaqueField());
}
ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {input},
output_shape, OpaqueField()));
}
void CompileForward(XlaOpKernelContext* ctx, const char* target) {
auto output_shape = GetOutputShape(ctx);
if (ctx->InputShape(0).dim_size(1) == output_shape.dimensions(1) &&
ctx->InputShape(0).dim_size(2) == output_shape.dimensions(2)) {
ctx->SetOutput(
0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
return;
}
if (ctx->InputShape(0).dim_size(1) == 1 &&
ctx->InputShape(0).dim_size(2) == 1) {
ctx->SetOutput(0,
ctx->Input(0) + xla::Zeros(ctx->builder(), output_shape));
return;
}
ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
output_shape, OpaqueField()));
}
private:
bool align_corners_;
bool half_pixel_centers_;
};
class TpuResizeNearestNeighborOp : public TpuCustomResizeOp {
public:
explicit TpuResizeNearestNeighborOp(OpKernelConstruction* ctx)
: TpuCustomResizeOp(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
CompileForward(ctx, "ResizeNearest");
}
};
class TpuResizeBilinearOp : public TpuCustomResizeOp {
public:
explicit TpuResizeBilinearOp(OpKernelConstruction* ctx)
: TpuCustomResizeOp(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
CompileForward(ctx, "ResizeBilinear");
}
};
class TpuResizeNearestNeighborGradOp : public TpuCustomResizeOp {
public:
explicit TpuResizeNearestNeighborGradOp(OpKernelConstruction* ctx)
: TpuCustomResizeOp(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
CompileGrad(ctx, "ResizeNearestGrad", GetOutputShape(ctx));
}
};
class TpuResizeBilinearGradOp : public TpuCustomResizeOp {
public:
explicit TpuResizeBilinearGradOp(OpKernelConstruction* ctx)
: TpuCustomResizeOp(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
auto output_shape =
TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(1));
CompileGrad(ctx, "ResizeBilinearGrad", output_shape);
}
};
REGISTER_XLA_OP(Name("ResizeNearestNeighbor")
.CompileTimeConstantInput("size")
.Device(DEVICE_TPU_XLA_JIT),
TpuResizeNearestNeighborOp);
REGISTER_XLA_OP(Name("ResizeNearestNeighborGrad")
.CompileTimeConstantInput("size")
.Device(DEVICE_TPU_XLA_JIT),
TpuResizeNearestNeighborGradOp);
REGISTER_XLA_OP(Name("ResizeBilinear")
.CompileTimeConstantInput("size")
.Device(DEVICE_TPU_XLA_JIT),
TpuResizeBilinearOp);
REGISTER_XLA_OP(Name("ResizeBilinearGrad").Device(DEVICE_TPU_XLA_JIT),
TpuResizeBilinearGradOp);
} // namespace tensorflow

View File

@ -0,0 +1,529 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/infeed_ops.h"
#include <algorithm>
#include <vector>
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
namespace tensorflow {
namespace {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef tensorflow::tpu::NoncopyableBuffer LinearizerBuffer;
typedef std::deque<LinearizerBuffer> LinearizerBufferList;
// Transposes the given tensor using the tensorflow C++ transpose implementation
// to obtain a XLA literal for the host tensor laid out as the given layout. The
// returned tensor is normalized to the dim0major layout -- F32[10,20,30]{2,0,1}
// is returned as F32[20,10,30]{2,1,0}.
xla::StatusOr<Tensor> TransposeTensor(OpKernelContext* ctx,
const Tensor& input_tensor,
const xla::Shape& xla_shape) {
profiler::TraceMe trace_me("TransposeTensor", /*level=*/2);
const int64 rank = xla_shape.rank();
std::vector<int32> permutation(rank);
std::vector<int64> transposed_shapes(rank);
for (int64 i = 0; i < rank; ++i) {
permutation[i] = xla_shape.layout().minor_to_major(rank - 1 - i);
transposed_shapes[i] = xla_shape.dimensions(permutation[i]);
}
Tensor transposed_tensor;
// If this is a trivial transpose (i.e., bitcast), just create an aliased
// tensor with the transposed shape.
if (xla::LayoutUtil::IsMonotonicWithDim0Major(
xla::ShapeUtil::DropDegenerateDimensions(xla_shape).layout())) {
TensorShape shape;
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(transposed_shapes, &shape));
TF_RETURN_IF_ERROR(transposed_tensor.BitcastFrom(
input_tensor, input_tensor.dtype(), shape));
return transposed_tensor;
}
AllocatorAttributes alloc_attr;
alloc_attr.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp(input_tensor.dtype(),
TensorShape(transposed_shapes),
&transposed_tensor, alloc_attr));
// Eigen Transpose fails with SIGFPE if there is a dimension of size 0.
if (input_tensor.NumElements() > 0) {
TF_RETURN_IF_ERROR(DoTranspose<CPUDevice>(ctx->eigen_device<CPUDevice>(),
input_tensor, permutation,
&transposed_tensor));
}
return transposed_tensor;
}
xla::StatusOr<bool> GetLayoutOverride(OpKernelConstruction* ctx,
const char* attrn_name,
std::vector<int64>* minor_to_major) {
if (!ctx->HasAttr(attrn_name)) {
return false;
}
TF_RETURN_IF_ERROR(ctx->GetAttr(attrn_name, minor_to_major));
return !minor_to_major->empty();
}
Status GetInfeedShapeWithLayout(OpKernelConstruction* ctx,
const char* attrn_name,
const xla::Shape& input_shape,
xla::Shape* output_shape) {
std::vector<int64> minor_to_major;
TF_ASSIGN_OR_RETURN(bool has_override,
GetLayoutOverride(ctx, attrn_name, &minor_to_major));
if (!has_override) {
*output_shape = input_shape;
if (output_shape->IsTuple()) {
int64 tuple_elements = xla::ShapeUtil::TupleElementCount(*output_shape);
for (int64 i = 0; i < tuple_elements; ++i) {
xla::Shape* sub_shape =
xla::ShapeUtil::GetMutableSubshape(output_shape, {i});
*sub_shape->mutable_layout() = GetTPUInfeedLayout(*sub_shape).layout();
}
} else {
*output_shape->mutable_layout() =
GetTPUInfeedLayout(*output_shape).layout();
}
return Status::OK();
}
auto layout_func = [](const xla::Shape& shape) -> xla::Layout {
return GetTPUInfeedLayout(shape).layout();
};
return GetShapeWithLayout(input_shape, minor_to_major, layout_func,
output_shape);
}
// LinearizedBuffersWrapper is an opaque C++ data structure for the outputs of
// PrelinearizeOp and PrelinearizeTupleOp. It holds the resultant linearized
// buffers and references to input tensors whose underlying storage are shared
// with linearized buffers.
// NOTE: This is not a feature-complete implementation of the DT_VARIANT
// specification. In particular, we cannot currently serialize an arbitrary
// `LinearizerBufferList` (aka `std::deque<LinearizerBuffer>`)
// object, so the `Encode()` and `Decode()` methods are not implemented.
struct LinearizedBuffersWrapper {
explicit LinearizedBuffersWrapper() {}
explicit LinearizedBuffersWrapper(LinearizerBufferList bufs,
std::vector<tensorflow::Tensor> ts)
: buffers(std::move(bufs)), tensors(std::move(ts)) {}
LinearizedBuffersWrapper(const LinearizedBuffersWrapper& wrapper) {
// tensorflow::Variant requires this copy constructor to compile.
LOG(FATAL) << "LinearizedBuffersWrapper should not copy.";
}
LinearizedBuffersWrapper& operator=(const LinearizedBuffersWrapper& wrapper) =
delete;
LinearizedBuffersWrapper(LinearizedBuffersWrapper&&) = default;
LinearizedBuffersWrapper& operator=(LinearizedBuffersWrapper&&) = default;
~LinearizedBuffersWrapper() = default;
// These functions are tensorflow::Variant requirements.
string TypeName() const { return "(anonymous)::LinearizedBuffersWrapper"; }
void Encode(tensorflow::VariantTensorData* data) const {
LOG(ERROR) << "Encode() is not implemented for LinearizedBuffersWrapper "
"objects.";
}
bool Decode(const tensorflow::VariantTensorData& data) {
LOG(ERROR) << "Decode() is not implemented for LinearizedBuffersWrapper "
"objects.";
return false;
}
LinearizerBufferList buffers;
// Save references on tensors whose underlying storage are shared with
// LiteralLinearizer::Buffer in `buffers`.
std::vector<tensorflow::Tensor> tensors;
};
Status AutoTransposeAndLinearize(OpKernelContext* ctx,
const Tensor& input_tensor,
const xla::Shape& shape,
LinearizerBufferList* linearized_buffers,
std::vector<Tensor>* saved_input_tensors) {
const Tensor* tensor = &input_tensor;
// If the given layout is not in dim0major layout, tranposes the tensor.
bool has_transposed = false;
Tensor transposed_tensor;
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
// If the given layout is not in dim0major layout, transpose the tensor.
TF_ASSIGN_OR_RETURN(transposed_tensor,
TransposeTensor(ctx, input_tensor, shape));
tensor = &transposed_tensor;
has_transposed = true;
}
xla::BorrowingLiteral literal;
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
TF_RETURN_IF_ERROR(
xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager()
->LinearizeToBuffers(literal, linearized_buffers));
// The input tensor is ref-counted. Save a handle on the input tensor if
// its underlying storage is shared with linearized buffers to prevent
// input tensor from getting freed.
for (const auto& buffer : *linearized_buffers) {
if (!buffer.owns_data() && !has_transposed) {
// `buffer` is created from zero-copy fast path from the un-transposed
// input tensor so its underlying data is shared with input tensor.
// Save a handle to input tensor to increment its ref-count and avoid
// it getting deallocated after PrelinearizeTupleOp completes.
saved_input_tensors->push_back(*tensor);
// A literal can be linearized to zero to two buffers. If any of the
// linearized buffer shares storage with input tensor. We save exactly
// one handle on the input tensor.
break;
}
}
return Status::OK();
}
// PrelinearizeOp is used to linearize one tensor to the device format.
class PrelinearizeOp : public OpKernel {
public:
explicit PrelinearizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
xla::Shape shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
OP_REQUIRES_OK(ctx,
GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
}
void Compute(OpKernelContext* ctx) override {
const Tensor& input_tensor = ctx->input(0);
// Validate input.
OP_REQUIRES(
ctx, input_tensor.dtype() == dtype_,
errors::InvalidArgument("Prelinearize dtype mismatch; expected ",
DataType_Name(dtype_), ", got ",
DataType_Name(input_tensor.dtype())));
OP_REQUIRES(
ctx, input_tensor.shape() == shape_,
errors::InvalidArgument("Prelinearize shape mismatch; expected ",
shape_.DebugString(), ", got ",
input_tensor.shape().DebugString()));
// Auto-transpose and prelinearize.
LinearizerBufferList linearized_buffers;
std::vector<Tensor> saved_input_tensors;
auto status =
AutoTransposeAndLinearize(ctx, input_tensor, xla_shape_,
&linearized_buffers, &saved_input_tensors);
OP_REQUIRES_OK(ctx, status);
// Write to output.
tensorflow::Tensor* output;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
std::move(linearized_buffers), std::move(saved_input_tensors)};
}
bool IsExpensive() override { return true; }
private:
TensorShape shape_;
DataType dtype_;
xla::Shape xla_shape_;
// PrelinearizeOp is neither copyable nor movable.
PrelinearizeOp(const PrelinearizeOp&) = delete;
PrelinearizeOp& operator=(const PrelinearizeOp&) = delete;
};
// PrelinearizeTupleOp is used to linearize multiple tensors to the device
// format.
class PrelinearizeTupleOp : public OpKernel {
public:
explicit PrelinearizeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
OP_REQUIRES(
ctx, shapes_.size() == dtypes_.size(),
errors::InvalidArgument(
"shapes and dtypes must be the same length. shapes length = ",
shapes_.size(), ", dtypes length = ", dtypes_.size()));
std::vector<xla::Shape> xla_shapes;
for (int i = 0; i < shapes_.size(); i++) {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx,
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
xla_shapes.push_back(xla_shape);
}
OP_REQUIRES_OK(
ctx, GetInfeedShapeWithLayout(
ctx, "layouts", xla::ShapeUtil::MakeTupleShape(xla_shapes),
&tuple_shape_));
}
void Compute(OpKernelContext* ctx) override {
OpInputList values;
OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values));
OP_REQUIRES(ctx, values.size() == shapes_.size(),
errors::InvalidArgument(
"Wrong number of inputs to PrelinearizeTuple."));
LinearizerBufferList all_linearized_buffers;
std::vector<Tensor> all_saved_input_tensors;
for (int i = 0; i < values.size(); i++) {
// Validate input.
const Tensor& input_tensor = values[i];
OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i],
errors::InvalidArgument(
"PrelinearizeTuple dtype mismatch at tuple element ", i,
"; expected ", DataType_Name(dtypes_[i]), ", got ",
DataType_Name(input_tensor.dtype())));
OP_REQUIRES(ctx, input_tensor.shape() == shapes_[i],
errors::InvalidArgument(
"PrelinearizeTuple shape mismatch at tuple element ", i,
"; expected ", shapes_[i].DebugString(), ", got ",
input_tensor.shape().DebugString()));
// Auto-transpose and prelinearize.
LinearizerBufferList linearized_buffers;
std::vector<Tensor> saved_input_tensors;
auto status = AutoTransposeAndLinearize(
ctx, input_tensor, tuple_shape_.tuple_shapes(i), &linearized_buffers,
&saved_input_tensors);
OP_REQUIRES_OK(ctx, status);
all_linearized_buffers.insert(
all_linearized_buffers.end(),
std::make_move_iterator(linearized_buffers.begin()),
std::make_move_iterator(linearized_buffers.end()));
all_saved_input_tensors.insert(
all_saved_input_tensors.end(),
std::make_move_iterator(saved_input_tensors.begin()),
std::make_move_iterator(saved_input_tensors.end()));
}
tensorflow::Tensor* output;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
std::move(all_linearized_buffers), std::move(all_saved_input_tensors)};
}
bool IsExpensive() override { return true; }
private:
std::vector<TensorShape> shapes_;
DataTypeVector dtypes_;
xla::Shape tuple_shape_;
// PrelinearizeTupleOp is neither copyable nor movable.
PrelinearizeTupleOp(const PrelinearizeTupleOp&) = delete;
PrelinearizeTupleOp& operator=(const PrelinearizeTupleOp&) = delete;
};
// The InfeedEnqueuePrelinearizedBufferOp op is used to transfer prelinearized
// buffers to the device infeed queue.
class InfeedEnqueuePrelinearizedBufferOp : public TpuTransferAsyncOpKernel {
public:
explicit InfeedEnqueuePrelinearizedBufferOp(OpKernelConstruction* ctx)
: TpuTransferAsyncOpKernel(ctx, "prelinearized_buffers_to_infeed", 8) {}
Status DoWork(OpKernelContext* ctx,
xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) override {
const Tensor& input_tensor = ctx->input(0);
const LinearizedBuffersWrapper* wrapper =
input_tensor.scalar<tensorflow::Variant>()()
.get<LinearizedBuffersWrapper>();
TF_RETURN_IF_ERROR(transfer_manager->TransferBuffersToInfeed(
stream_executor, wrapper->buffers));
return Status::OK();
}
private:
// InfeedEnqueuePrelinearizedBufferOp is neither copyable nor movable.
InfeedEnqueuePrelinearizedBufferOp(
const InfeedEnqueuePrelinearizedBufferOp&) = delete;
InfeedEnqueuePrelinearizedBufferOp& operator=(
const InfeedEnqueuePrelinearizedBufferOp&) = delete;
};
} // anonymous namespace
TpuInfeedEnqueueOp::TpuInfeedEnqueueOp(OpKernelConstruction* ctx)
: TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
xla::Shape shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
OP_REQUIRES_OK(ctx,
GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
}
Status TpuInfeedEnqueueOp::DoWork(
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) {
const Tensor& input_tensor = ctx->input(0);
// Validate runtime shape and fail if it doesn't match the contract.
if (input_tensor.dtype() != dtype_) {
return errors::InvalidArgument("Infeed dtype mismatch.");
}
if (input_tensor.shape() != shape_) {
return errors::InvalidArgument("Infeed shape mismatch; expected ",
shape_.DebugString(), ", got ",
input_tensor.shape().DebugString());
}
const Tensor* tensor = &input_tensor;
Tensor transposed_tensor;
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(xla_shape_.layout())) {
// If the given layout is not in dim0major layout, transpose the tensor.
TF_ASSIGN_OR_RETURN(transposed_tensor,
TransposeTensor(ctx, input_tensor, xla_shape_));
tensor = &transposed_tensor;
}
xla::BorrowingLiteral literal;
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
// Transfer the given literal to the Infeed interface of the device.
TF_RETURN_IF_ERROR(
transfer_manager->TransferLiteralToInfeed(stream_executor, literal));
return Status::OK();
}
TpuInfeedEnqueueTupleOp::TpuInfeedEnqueueTupleOp(OpKernelConstruction* ctx)
: TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
OP_REQUIRES(
ctx, shapes_.size() == dtypes_.size(),
errors::InvalidArgument("shapes and dtypes must be the same length."));
std::vector<xla::Shape> xla_shapes;
for (int i = 0; i < shapes_.size(); i++) {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx,
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
xla_shapes.push_back(xla_shape);
}
OP_REQUIRES_OK(
ctx, GetInfeedShapeWithLayout(ctx, "layouts",
xla::ShapeUtil::MakeTupleShape(xla_shapes),
&tuple_shape_));
}
Status TpuInfeedEnqueueTupleOp::DoWork(
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) {
OpInputList values;
TF_RETURN_IF_ERROR(ctx->input_list("inputs", &values));
if (values.size() != shapes_.size()) {
return errors::InvalidArgument(
"Wrong number of inputs to InfeedEnqueueTuple.");
}
for (const auto& shapes : shapes_) {
VLOG(1) << "TransferLiteralToInfeed " << shapes.DebugString();
}
std::vector<Tensor> maybe_transposed_tensors;
maybe_transposed_tensors.reserve(values.size());
for (int i = 0; i < values.size(); i++) {
// Validate runtime shapes and fail if it doesn't match the contract.
const Tensor* tensor = &values[i];
if (tensor->shape() != shapes_[i]) {
return errors::InvalidArgument("Infeed shape mismatch for tuple element ",
i, "; expected ", shapes_[i].DebugString(),
", got ", tensor->shape().DebugString());
}
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(
tuple_shape_.tuple_shapes(i).layout())) {
// If the given layout is not in dim0major layout, tranposes the given
// tensor.
TF_ASSIGN_OR_RETURN(
Tensor transposed_tensor,
TransposeTensor(ctx, *tensor, tuple_shape_.tuple_shapes(i)));
maybe_transposed_tensors.emplace_back(transposed_tensor);
} else {
maybe_transposed_tensors.emplace_back(*tensor);
}
}
xla::BorrowingLiteral tuple;
TF_RETURN_IF_ERROR(
HostTensorsToBorrowingLiteralTuple(maybe_transposed_tensors, &tuple));
// Transfer the given literal to the Infeed interface of the device.
TF_RETURN_IF_ERROR(
transfer_manager->TransferLiteralToInfeed(stream_executor, tuple));
VLOG(1) << "TransferLiteralToInfeed complete.";
return Status::OK();
}
// These ops execute on either the TPU device or the CPU device. When running on
// CPU they must specify a non-negative value for device_ordinal to indicate
// which TPU to send infeed to.
REGISTER_KERNEL_BUILDER(
Name("InfeedEnqueue").Device(DEVICE_TPU_NODE).HostMemory("input"),
TpuInfeedEnqueueOp);
REGISTER_KERNEL_BUILDER(Name("InfeedEnqueue").Device(DEVICE_CPU),
TpuInfeedEnqueueOp);
REGISTER_KERNEL_BUILDER(
Name("InfeedEnqueueTuple").Device(DEVICE_TPU_NODE).HostMemory("inputs"),
TpuInfeedEnqueueTupleOp);
REGISTER_KERNEL_BUILDER(Name("InfeedEnqueueTuple").Device(DEVICE_CPU),
TpuInfeedEnqueueTupleOp);
// Prelinearize ops run on CPU as part of tf.data input pipeline.
REGISTER_KERNEL_BUILDER(Name("Prelinearize").Device(DEVICE_CPU),
PrelinearizeOp);
REGISTER_KERNEL_BUILDER(Name("PrelinearizeTuple").Device(DEVICE_CPU),
PrelinearizeTupleOp);
// InfeedEnqueuePrelinearizedBuffer op run on CPU and takes a device_ordinal to
// select the right device to infeed.
REGISTER_KERNEL_BUILDER(
Name("InfeedEnqueuePrelinearizedBuffer").Device(DEVICE_CPU),
InfeedEnqueuePrelinearizedBufferOp);
} // namespace tensorflow

View File

@ -0,0 +1,69 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
#define TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
namespace tensorflow {
// TODO(b/65200690): Rework this when there is a callback based infeed API to
// StreamExecutor.
// The InfeedEnqueue op is used to deliver data to the device infeed queue.
class TpuInfeedEnqueueOp : public TpuTransferAsyncOpKernel {
public:
explicit TpuInfeedEnqueueOp(OpKernelConstruction* ctx);
Status DoWork(OpKernelContext* ctx,
xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) override;
private:
TensorShape shape_;
DataType dtype_;
xla::Shape xla_shape_;
// TpuInfeedEnqueueOp is neither copyable nor movable.
TpuInfeedEnqueueOp(const TpuInfeedEnqueueOp&) = delete;
TpuInfeedEnqueueOp& operator=(const TpuInfeedEnqueueOp&) = delete;
};
// The InfeedEnqueueTuple op is used on the host to deliver multiple tensors to
// the device infeed queue as an XLA tuple.
class TpuInfeedEnqueueTupleOp : public TpuTransferAsyncOpKernel {
public:
explicit TpuInfeedEnqueueTupleOp(OpKernelConstruction* ctx);
Status DoWork(OpKernelContext* ctx,
xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) override;
private:
std::vector<TensorShape> shapes_;
DataTypeVector dtypes_;
xla::Shape tuple_shape_;
// TpuInfeedEnqueueTupleOp is neither copyable nor movable.
TpuInfeedEnqueueTupleOp(const TpuInfeedEnqueueTupleOp&) = delete;
TpuInfeedEnqueueTupleOp& operator=(const TpuInfeedEnqueueTupleOp&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_

View File

@ -0,0 +1,116 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/outfeed_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
namespace tensorflow {
TpuOutfeedDequeueOp::TpuOutfeedDequeueOp(OpKernelConstruction* ctx)
: TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
}
Status TpuOutfeedDequeueOp::DoWork(
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) {
Tensor* output;
TF_RETURN_IF_ERROR(ctx->allocate_output(0, shape_, &output));
// Transfer from the outfeed interface of the device.
xla::MutableBorrowingLiteral literal;
TF_RETURN_IF_ERROR(
HostTensorToMutableBorrowingLiteral(xla_shape_, output, &literal));
VLOG(1) << "TransferLiteralFromOutfeed "
<< xla::ShapeUtil::HumanStringWithLayout(xla_shape_);
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
stream_executor, xla_shape_, literal));
VLOG(1) << "TransferLiteralFromOutfeed complete.";
return Status::OK();
}
// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
// device outfeed queue.
TpuOutfeedDequeueTupleOp::TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx)
: TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
OP_REQUIRES(
ctx, shapes_.size() == dtypes_.size(),
errors::InvalidArgument("shapes and dtypes must be the same length."));
// The `dtypes` list is inferred from the supplied inputs, so it
// is always the correct length.
for (int i = 0; i < shapes_.size(); i++) {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx,
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
xla_shapes_.push_back(xla_shape);
}
tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
}
Status TpuOutfeedDequeueTupleOp::DoWork(
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) {
VLOG(1) << "TransferLiteralFromOutfeed "
<< xla::ShapeUtil::HumanStringWithLayout(tuple_shape_);
for (int i = 0; i < shapes_.size(); ++i) {
Tensor* output;
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shapes_[i], &output));
xla::MutableBorrowingLiteral literal;
TF_RETURN_IF_ERROR(
HostTensorToMutableBorrowingLiteral(xla_shapes_[i], output, &literal));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
stream_executor, xla_shapes_[i], literal));
}
return Status::OK();
}
// These ops execute on either the TPU device or the CPU device. When
// running on CPU they must specify a non-negative value for
// device_ordinal to indicate which TPU to receive outfeed from.
REGISTER_KERNEL_BUILDER(
Name("OutfeedDequeue").Device(DEVICE_TPU_NODE).HostMemory("output"),
TpuOutfeedDequeueOp);
REGISTER_KERNEL_BUILDER(Name("OutfeedDequeue").Device(DEVICE_CPU),
TpuOutfeedDequeueOp);
REGISTER_KERNEL_BUILDER(
Name("OutfeedDequeueTuple").Device(DEVICE_TPU_NODE).HostMemory("outputs"),
TpuOutfeedDequeueTupleOp);
REGISTER_KERNEL_BUILDER(Name("OutfeedDequeueTuple").Device(DEVICE_CPU),
TpuOutfeedDequeueTupleOp);
} // namespace tensorflow

View File

@ -0,0 +1,69 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
#define TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
namespace tensorflow {
// The OutfeedDequeue op is used to retrieve a single tensor from the device
// outfeed queue.
class TpuOutfeedDequeueOp : public TpuTransferAsyncOpKernel {
public:
explicit TpuOutfeedDequeueOp(OpKernelConstruction* ctx);
Status DoWork(OpKernelContext* ctx,
xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) override;
private:
TensorShape shape_;
DataType dtype_;
xla::Shape xla_shape_;
// OutfeedDequeueOp is neither copyable nor movable.
TpuOutfeedDequeueOp(const TpuOutfeedDequeueOp&) = delete;
TpuOutfeedDequeueOp& operator=(const TpuOutfeedDequeueOp&) = delete;
};
// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
// device outfeed queue.
class TpuOutfeedDequeueTupleOp : public TpuTransferAsyncOpKernel {
public:
explicit TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx);
Status DoWork(OpKernelContext* ctx,
xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) override;
private:
std::vector<TensorShape> shapes_;
DataTypeVector dtypes_;
std::vector<xla::Shape> xla_shapes_;
xla::Shape tuple_shape_;
// OutfeedDequeueTupleOp is neither copyable nor movable.
TpuOutfeedDequeueTupleOp(const TpuOutfeedDequeueTupleOp&) = delete;
TpuOutfeedDequeueTupleOp& operator=(const TpuOutfeedDequeueTupleOp&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("_TPUReplicate").Device(DEVICE_TPU_SYSTEM),
XlaDeviceDummyOp);
} // namespace tensorflow

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
namespace tensorflow {
class TpuHandleToProtoKeyOp : public OpKernel {
public:
explicit TpuHandleToProtoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TpuHandleToProtoKeyOp() override = default;
TpuHandleToProtoKeyOp(const TpuHandleToProtoKeyOp&) = delete;
TpuHandleToProtoKeyOp& operator=(const TpuHandleToProtoKeyOp&) = delete;
void Compute(OpKernelContext* ctx) override {
VLOG(1) << "TpuHandleToProtoKeyOp::Compute " << ctx->op_kernel().name()
<< " on device " << ctx->op_kernel().requested_device();
const Tensor& uid = ctx->input(0);
ResourceMgr* rm = GetTPUConfigResourceMgr();
tpu::TpuCompilationCacheInterface* cache;
OP_REQUIRES_OK(ctx, rm->Lookup<tpu::TpuCompilationCacheInterface>(
rm->default_container(),
tpu::kCompilationCacheResourceName, &cache));
core::ScopedUnref cache_unref(cache);
std::vector<std::string> keys;
OP_REQUIRES_OK(ctx, cache->GetKeysFromUid(uid.scalar<int64>()(), &keys));
TensorShape output_shape;
output_shape.AddDim(keys.size());
Tensor* result = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
for (int i = 0; i < keys.size(); ++i) {
result->vec<tstring>()(i) = keys[i];
}
};
};
REGISTER_KERNEL_BUILDER(Name("TpuHandleToProtoKey").Device(DEVICE_CPU),
TpuHandleToProtoKeyOp);
} // namespace tensorflow

View File

@ -0,0 +1,98 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/transfer_ops.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
namespace tensorflow {
TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
const string& transfer_type,
int number_of_threads)
: AsyncOpKernel(ctx),
thread_pool_(new thread::ThreadPool(
ctx->env(),
strings::StrCat(transfer_type, "_thread_",
SanitizeThreadSuffix(def().name())),
/*num_threads=*/8)) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
OP_REQUIRES(
ctx, device_ordinal_ >= 0,
errors::InvalidArgument(transfer_type,
" ops must specify a device_ordinal when "
"placed on CPU."));
}
}
void TpuTransferAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
DoneCallback done) {
CancellationToken token =
ctx->cancellation_manager()->get_cancellation_token();
bool already_cancelled;
{
// Only protect registering the cancellation callback as mu_ cannot be held
// at a point where `done` could be called.
mutex_lock lock(mu_);
already_cancelled = !ctx->cancellation_manager()->RegisterCallback(
token, [this]() { Cancel(); });
}
OP_REQUIRES_ASYNC(ctx, !already_cancelled,
errors::Cancelled("Infeed was cancelled."), done);
thread_pool_->Schedule([this, ctx, done, token]() {
Status s = RunTransfer(ctx);
ctx->cancellation_manager()->DeregisterCallback(token);
OP_REQUIRES_OK_ASYNC(ctx, s, done);
done();
});
}
Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
int real_device_ordinal = device_ordinal_;
if (real_device_ordinal < 0) {
const XlaDevice::Metadata* metadata;
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
real_device_ordinal = metadata->device_ordinal();
}
stream_executor::StreamExecutor* stream_executor =
tpu_platform->ExecutorForDevice(real_device_ordinal).ValueOrDie();
// When Xprof profiling is off (which is the default), constructing the
// activity is simple enough that its overhead is negligible.
profiler::TraceMe activity(
[this] { return profiler::TraceMeOp(name(), type_string()); },
profiler::TraceMeLevel::kInfo);
return DoWork(
ctx, xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager(),
stream_executor);
}
void TpuTransferAsyncOpKernel::Cancel() {
mutex_lock lock(mu_);
TF_CHECK_OK(tpu::TpuNodeContext::CloseTpuHost());
}
} // namespace tensorflow

View File

@ -0,0 +1,56 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/stream_executor_util.h"
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
namespace tensorflow {
// Base class providing common functionality for async ops that transfer from
// host to TPU.
class TpuTransferAsyncOpKernel : public AsyncOpKernel {
public:
explicit TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
const string& transfer_type,
int number_of_threads);
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
protected:
virtual Status DoWork(OpKernelContext* context,
xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) = 0;
private:
Status RunTransfer(OpKernelContext* ctx);
void Cancel();
std::unique_ptr<thread::ThreadPool> thread_pool_;
int device_ordinal_;
mutex mu_;
// TpuTransferAsyncOpKernel is neither copyable nor movable.
TpuTransferAsyncOpKernel(const TpuTransferAsyncOpKernel&) = delete;
TpuTransferAsyncOpKernel& operator=(const TpuTransferAsyncOpKernel&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_

View File

@ -15,6 +15,10 @@ limitations under the License.
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
namespace tensorflow {
const char* const DEVICE_TPU_NODE = "TPU";
@ -27,4 +31,18 @@ const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR =
const char* const kTPUReplicateAttr = "_tpu_replicate";
const char* const kOutsideCompilationAttr = "_xla_outside_compilation";
xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
XLA_Shape c_shape;
XLA_Shape c_infeed_shape;
ApiConverter::ToC(shape, &c_shape);
tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
&c_infeed_shape);
xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
ApiConverter::Free(&c_shape);
ApiConverter::Free(&c_infeed_shape);
return infeed_shape;
}
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <array>
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
@ -56,6 +57,11 @@ static constexpr std::array<DataType, 16> kTpuAllTypes = {
DT_COMPLEX64, DT_INT64, DT_UINT64, DT_QINT8, DT_QUINT8, DT_INT8, DT_UINT8,
DT_INT16, DT_UINT16}};
// For the given shape, chooses a layout for infeed on TPU. The returned shape
// has the same dimensions as the original shape, and only the layout is
// changed.
xla::Shape GetTPUInfeedLayout(const xla::Shape& shape);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_TPU_DEFS_H_

View File

@ -161,6 +161,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralFromDevice);
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetInfeedLayout);
TFTPU_SET_FN(executor_fn, TpuTransferManager_LinearizeToBuffers);
TFTPU_SET_FN(executor_fn, TpuTransferManager_FreeBuffers);

View File

@ -203,10 +203,12 @@ cc_library(
cc_library(
name = "tpu_transfer_manager_interface",
srcs = ["tpu_transfer_manager_interface.cc"],
hdrs = ["tpu_transfer_manager_interface.h"],
visibility = ["//visibility:public"],
deps = [
":noncopyable_buffer",
":tpu_platform_interface",
"//tensorflow/compiler/xla/service:transfer_manager",
],
)

View File

@ -182,6 +182,8 @@ void TpuTransferManager_WriteSingleTupleIndexTable(
XLA_TransferManager* manager, SE_Stream* stream,
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
SE_DeviceMemoryBase* region, SE_Status* status);
void TpuTransferManager_GetInfeedLayout(XLA_Shape* shape,
XLA_Shape* infeed_shape);
void TpuTransferManager_LinearizeToBuffers(
XLA_TransferManager* manager, XLA_Literal* c_literal, char*** buffers_array,
int64_t** buffers_size, int64_t* buffers_array_size, SE_Status* status);
@ -341,6 +343,7 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralFromDevice);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetInfeedLayout);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_LinearizeToBuffers);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_FreeBuffers);

View File

@ -81,6 +81,12 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface {
const xla::Shape& shape,
stream_executor::DeviceMemoryBase* region) override;
Status LinearizeToBuffers(
const xla::LiteralSlice& literal,
std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) override {
LOG(FATAL) << "Not yet implemented.";
}
private:
XLA_TransferManager* manager_;
};

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace xla {
/*static*/ TpuTransferManagerInterface*
TpuTransferManagerInterface::GetRegisteredTpuTransferManager() {
auto* platform = tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform(
/*initialize_platform=*/false);
if (platform == nullptr) {
LOG(ERROR) << "Unable to retrieve registered TPU platform.";
return nullptr;
}
auto tm = xla::TransferManager::GetForPlatform(platform);
if (!tm.ok()) {
LOG(ERROR) << "Unable to retrieve TpuTransferManager. No TPU platform is "
"registered for platform "
<< platform->Name() << " and ID " << platform->id();
return nullptr;
}
return static_cast<TpuTransferManagerInterface*>(tm.ValueOrDie());
}
} // namespace xla

View File

@ -24,9 +24,16 @@ limitations under the License.
namespace xla {
class TpuTransferManagerInterface : public xla::TransferManager {
public:
virtual Status TransferBuffersToInfeed(
se::StreamExecutor* executor,
const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) = 0;
virtual Status LinearizeToBuffers(
const LiteralSlice& literal,
std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) = 0;
static TpuTransferManagerInterface* GetRegisteredTpuTransferManager();
};
} // namespace xla