Introduce additional XLA TPU Ops to open source

PiperOrigin-RevId: 326343558
Change-Id: I47da1dc0c96cdf8223ccebef012e2a5088a857a4
This commit is contained in:
Frank Chen 2020-08-12 16:58:21 -07:00 committed by TensorFlower Gardener
parent 8e01ae829b
commit 3ea5fc7f3f
10 changed files with 1291 additions and 0 deletions

View File

@ -38,6 +38,7 @@ tf_kernel_library(
":tpu_execute_op",
":tpu_handle_to_key_op",
":transfer_ops",
"//tensorflow/core/tpu/kernels/xla:xla_ops",
],
)

View File

@ -0,0 +1,52 @@
# XLA Ops for TPUs
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "xla_ops",
srcs = [
"get_item_op.cc",
"host_compute_ops.cc",
"index_ops.cc",
"infeed_op.cc",
"inplace_ops.cc",
"outfeed_ops.cc",
"segment_reduction_ops.cc",
"where_op.cc",
],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:sharding_util",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_context",
"//tensorflow/compiler/tf2xla:xla_helpers",
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/compiler/tf2xla/kernels:if_op",
"//tensorflow/compiler/tf2xla/kernels:while_op",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:comparators",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/core/tpu/kernels:cross_replica_ops",
"//tensorflow/stream_executor/tpu:c_api_conversions",
"//tensorflow/stream_executor/tpu:c_api_decl",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)

View File

@ -0,0 +1,75 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_util.h"
namespace tensorflow {
namespace {
// The Xla kernel to build up the computation for get_item(data, index).
class GetItemXlaOp : public XlaOpKernel {
public:
explicit GetItemXlaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape& data_shape = ctx->InputShape(0);
const TensorShape& index_shape = ctx->InputShape(1);
OP_REQUIRES(
ctx, TensorShapeUtils::IsVectorOrHigher(data_shape),
errors::InvalidArgument("data must be at least 1 dimensional."));
OP_REQUIRES(ctx, index_shape.dims() == 1 && index_shape.dim_size(0) == 1,
errors::InvalidArgument("index must be a vector of size 1."));
// NOTE(pbar) Use Concat to extend the indices to match cl/142279605.
// This isn't the simplest way to emit the indices, but the code for
// dynamic slice needs to be able to see that minor dims are const zero.
auto const_zero = xla::ConstantR0(ctx->builder(), 0);
std::vector<xla::XlaOp> operands;
operands.push_back(xla::Reshape(ctx->Input(1), {}));
for (int i = 1; i < data_shape.dims(); i++) {
operands.push_back(const_zero);
}
std::vector<int64> dims = {0};
std::vector<int64> slice_sizes = {1};
std::vector<int64> out_sizes = {};
for (int i = 1; i < data_shape.dims(); i++) {
dims.push_back(i);
auto size = data_shape.dim_size(i);
slice_sizes.push_back(size);
out_sizes.push_back(size);
}
// NOTE: DynamicSlice here doesn't raise an error or wraps the index
// if its out-of-range.
auto slice = xla::DynamicSlice(ctx->Input(0), operands, slice_sizes);
// In-order collapse to remove the 1st dim.
auto reshape = xla::Reshape(slice, dims, out_sizes);
ctx->SetOutput(0, reshape);
}
};
REGISTER_XLA_OP(Name("GetItem"), GetItemXlaOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,498 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
#include "tensorflow/core/common_runtime/lower_if_op.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
namespace {
// TODO(phawkins) add a canonical copy of these operator names and refactor
// everything to use it.
static const char* const kSendFromHostOp = "_XlaSendFromHost";
static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
Status MakeXlaShapes(gtl::ArraySlice<TensorShape> shapes,
gtl::ArraySlice<DataType> dtypes,
std::vector<xla::Shape>* xla_shapes,
xla::Shape* xla_shape) {
for (int i = 0; i < shapes.size(); i++) {
xla::Shape single_xla_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(dtypes[i], shapes[i], &single_xla_shape));
VLOG(2) << "Shape " << single_xla_shape.DebugString();
xla_shapes->push_back(single_xla_shape);
}
// Temporarily add a dummy output to the shape array before making the tuple:
// this output is used for control dependencies between host compute ops.
xla_shapes->push_back(xla::ShapeUtil::MakeShape(xla::PRED, {}));
*xla_shape = xla::ShapeUtil::MakeTupleShape(*xla_shapes);
// Remove the dummy output from the vector that will be used to copy real
// outputs from host to device.
xla_shapes->pop_back();
return Status::OK();
}
// This TensorFlow pseudo-op is used to record host-side computation.
class HostComputeOp : public XlaOpKernel {
public:
explicit HostComputeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("cost_estimate_ns", &cost_estimate_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("tpu_core", &tpu_core_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_dtypes_));
OP_REQUIRES(ctx, ctx->num_inputs() == input_dtypes_.size(),
errors::InvalidArgument("Tinputs size=", input_dtypes_.size(),
" but expected ", ctx->num_inputs(),
" inputs."));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_dtypes_));
OP_REQUIRES(ctx, ctx->num_outputs() == output_dtypes_.size(),
errors::InvalidArgument("Toutputs size=", output_dtypes_.size(),
" but expected ", ctx->num_outputs(),
" outputs."));
OP_REQUIRES_OK(ctx, ctx->GetAttr("ancestors", &ancestors_));
NameAttrList shape_inference_graph;
OP_REQUIRES_OK(
ctx, ctx->GetAttr("shape_inference_graph", &shape_inference_graph));
if (shape_inference_graph.name().empty()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &static_output_shapes_));
OP_REQUIRES(ctx, static_output_shapes_.size() == output_dtypes_.size(),
errors::InvalidArgument(
"shapes attr list size ", static_output_shapes_.size(),
" differs from dtypes size ", output_dtypes_.size()));
OP_REQUIRES_OK(ctx, MakeXlaShapes(static_output_shapes_, output_dtypes_,
&static_xla_output_shapes_,
&static_xla_output_shape_));
VLOG(2) << "Output Shape: " << static_xla_output_shape_.DebugString();
} else {
FunctionLibraryRuntime* flib_runtime = ctx->function_library();
OP_REQUIRES(ctx, flib_runtime != nullptr,
errors::Internal(
"No function library runtime at kernel construction"));
const FunctionLibraryDefinition* library =
flib_runtime->GetFunctionLibraryDefinition();
const FunctionDef* fdef = library->Find(shape_inference_graph.name());
OP_REQUIRES(ctx, fdef != nullptr,
errors::Internal("Failed to find function ",
shape_inference_graph.name(),
" in function library."));
OP_REQUIRES_OK(ctx, FunctionDefToBodyHelper(
*fdef, AttrSlice(&shape_inference_graph.attr()),
library, &shape_inference_graph_function_));
VLOG(2) << "Output Shape to be inferred at compile time";
}
OP_REQUIRES_OK(
ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
OP_REQUIRES(ctx, !token_input_nodes_.empty(),
errors::InvalidArgument("XlaHostCompute node does not have ",
kXlaTokenInputNodesAttrName, " attr"));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
&original_node_name_));
}
~HostComputeOp() override {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
XlaCompiler* compiler = ctx->compiler();
std::vector<xla::XlaOp> input_handles;
std::vector<TensorShape> input_shapes;
auto inputs = ctx->InputList("inputs", &input_handles, &input_shapes);
const auto device_sharding = xla::sharding_builder::AssignDevice(tpu_core_);
xla::XlaScopedShardingAssignment assign_sharding(b, device_sharding);
std::vector<xla::XlaOp> input_tokens;
for (auto& token_input_node : token_input_nodes_) {
auto token_or = compiler->GetNodeToken(token_input_node);
OP_REQUIRES_OK(ctx, token_or.status());
input_tokens.push_back(token_or.ValueOrDie());
}
xla::XlaOp token = xla::AfterAll(b, input_tokens);
// Send values to the host.
std::vector<xla::XlaOp> send_to_host_tokens;
for (int i = 0; i < input_handles.size(); ++i) {
const string channel_name = absl::StrCat(key_, "_dtoh_", i);
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(input_dtypes_[i],
input_shapes[i], &xla_shape));
// Specify frontend attributes.
xla::FrontendAttributes attrs;
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
xla::primitive_util::LowercasePrimitiveTypeName(
xla_shape.element_type());
b->SetFrontendAttributes(attrs);
xla::ChannelHandle channel;
OP_REQUIRES_OK(
ctx, compiler->GetDeviceToHostChannelHandle(channel_name, &channel));
send_to_host_tokens.push_back(
xla::SendToHost(input_handles[i], token, xla_shape, channel));
b->ClearOpMetadata();
}
xla::XlaOp recv_from_host_token_input =
send_to_host_tokens.empty() ? token
: xla::AfterAll(b, send_to_host_tokens);
if (!input_handles.empty()) {
// Register the shapes used in this transfer.
OP_REQUIRES_OK(ctx, ctx->compiler()->SetDeviceToHostMetadata(
key_, input_dtypes_, input_shapes));
}
// Compute the shapes of the values to copy to the device, if necessary.
std::vector<TensorShape>* output_shapes;
std::vector<xla::Shape>* xla_output_shapes;
xla::Shape* xla_output_shape;
std::vector<TensorShape> inferred_output_shapes;
std::vector<xla::Shape> inferred_xla_output_shapes;
xla::Shape inferred_xla_output_shape;
if (shape_inference_graph_function_) {
OP_REQUIRES_OK(
ctx, InferOutputShapes(
ctx, ctx->function_library()->GetFunctionLibraryDefinition(),
&inferred_output_shapes));
OP_REQUIRES_OK(ctx, MakeXlaShapes(inferred_output_shapes, output_dtypes_,
&inferred_xla_output_shapes,
&inferred_xla_output_shape));
output_shapes = &inferred_output_shapes;
xla_output_shapes = &inferred_xla_output_shapes;
xla_output_shape = &inferred_xla_output_shape;
} else {
output_shapes = &static_output_shapes_;
xla_output_shapes = &static_xla_output_shapes_;
xla_output_shape = &static_xla_output_shape_;
}
OP_REQUIRES(
ctx, output_shapes->size() == ctx->num_outputs(),
errors::InvalidArgument("Op has ", ctx->num_outputs(), " outputs ",
" but output shape vector of size ",
output_shapes->size()));
if (ctx->num_outputs() > 0) {
// Register the shapes used in this transfer.
OP_REQUIRES_OK(ctx, ctx->compiler()->SetHostToDeviceMetadata(
key_, output_dtypes_, *output_shapes));
}
// Copy results to the device.
std::vector<xla::XlaOp> recv_from_host_tokens;
for (int i = 0; i < output_shapes->size(); ++i) {
const string channel_name = absl::StrCat(key_, "_htod_", i);
// Specify frontend attributes.
xla::FrontendAttributes attrs;
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
xla::primitive_util::LowercasePrimitiveTypeName(
xla_output_shapes->at(i).element_type());
b->SetFrontendAttributes(attrs);
xla::ChannelHandle channel;
OP_REQUIRES_OK(
ctx, compiler->GetHostToDeviceChannelHandle(channel_name, &channel));
const auto result_token_tuple = xla::RecvFromHost(
recv_from_host_token_input, xla_output_shapes->at(i), channel);
b->ClearOpMetadata();
recv_from_host_tokens.push_back(
xla::GetTupleElement(result_token_tuple, /*index=*/1));
ctx->SetOutput(i, xla::GetTupleElement(result_token_tuple, 0));
}
// Set token output.
xla::XlaOp token_output = recv_from_host_tokens.empty()
? recv_from_host_token_input
: xla::AfterAll(b, recv_from_host_tokens);
OP_REQUIRES_OK(
ctx, ctx->compiler()->SetNodeToken(original_node_name_, token_output));
}
private:
Status LowerFunctionalOps(Graph* g,
const FunctionLibraryDefinition& flib_def) {
bool modified;
do {
modified = false;
// Lower "If" nodes first. Their body functions will be expanded as
// function call nodes, which we will lower later.
// We do not need to lower "While" nodes because shape inference can
// handle them correctly (output shapes are input shapes).
std::vector<Node*> if_nodes;
for (Node* n : g->op_nodes()) {
if (n->type_string() == "If") {
if_nodes.push_back(n);
}
}
for (Node* if_node : if_nodes) {
TF_RETURN_IF_ERROR(
RewriteIfNode(if_node, g, /*keep_node_fetchable=*/false));
}
if (!if_nodes.empty()) {
modified = true;
}
// Lower function call nodes.
std::vector<Node*> call_nodes;
for (Node* n : g->op_nodes()) {
if (IsFunctionCall(flib_def, *n)) {
call_nodes.push_back(n);
}
}
for (Node* call_node : call_nodes) {
TF_RETURN_IF_ERROR(RewriteFunctionCallNode(
call_node, g, flib_def, /*keep_caller_fetchable=*/false));
}
if (!call_nodes.empty()) {
modified = true;
}
} while (modified);
return Status::OK();
}
Status InferOutputShapes(XlaOpKernelContext* ctx,
const FunctionLibraryDefinition* flib_def,
std::vector<TensorShape>* output_shapes) {
// First unpack the inference graphdef from the attr into graph. Don't do
// any shape inference at this point.
Graph* graph = shape_inference_graph_function_->graph;
// Lower functional ops, because they are not friendly to shape inference.
TF_RETURN_IF_ERROR(LowerFunctionalOps(graph, *flib_def));
// Now run shape inference, filling in the shapes of recvathost nodes.
bool got_output_shapes = false;
ShapeRefiner shape_refiner{graph->versions().producer(),
graph->op_registry()};
std::vector<Node*> nodes;
GetReversePostOrder(*graph, &nodes);
for (auto node : nodes) {
TF_RETURN_IF_ERROR(shape_refiner.AddNode(node));
if (node->type_string() == kRecvAtHostOp) {
const AttrValue* key_attr = node->attrs().Find("key");
if (key_attr == nullptr) {
return errors::InvalidArgument("Node ", node->name(),
" has no key attribute");
}
std::vector<TensorShape> dtoh_shapes;
if (!ctx->compiler()
->GetDeviceToHostShapes(key_attr->s(), &dtoh_shapes)
.ok()) {
return errors::InvalidArgument(
"Shape inference for HostCompute ", ctx->op_kernel().name(),
" failed: host recv node ", node->name(), " with key '",
key_attr->s(), "' has unknown shapes.");
}
if (dtoh_shapes.size() != node->num_outputs()) {
return errors::InvalidArgument(
"Shape inference for HostCompute ", ctx->op_kernel().name(),
" failed: host recv node ", node->name(), " with key '",
key_attr->s(), "' has ", node->num_outputs(),
" outputs but inferred shapes expect ", dtoh_shapes.size());
}
for (int i = 0; i < node->num_outputs(); ++i) {
shape_inference::InferenceContext* shape_ctx =
shape_refiner.GetContext(node);
shape_inference::ShapeHandle handle;
TF_RETURN_IF_ERROR(
shape_ctx->MakeShapeFromTensorShape(dtoh_shapes.at(i), &handle));
shape_ctx->set_output(i, handle);
}
} else if (node->type_string() == kSendFromHostOp) {
if (got_output_shapes) {
return errors::InvalidArgument(
"Shape inference for HostCompute ", ctx->op_kernel().name(),
" failed: inference graph has multiple send from host nodes");
} else {
got_output_shapes = true;
// The last input is the dynamic key so don't record its shape.
output_shapes->resize(node->num_inputs() - 1);
shape_inference::InferenceContext* shape_ctx =
shape_refiner.GetContext(node);
for (int i = 0; i < node->num_inputs() - 1; ++i) {
shape_inference::ShapeHandle handle = shape_ctx->input(i);
if (!shape_ctx->FullyDefined(handle)) {
return errors::InvalidArgument(
"Shape inference for HostCompute ", ctx->op_kernel().name(),
" failed: send from host node ", node->name(),
" has non-fully defined shape of input index ", i);
}
TensorShapeProto shape_proto;
shape_ctx->ShapeHandleToProto(handle, &shape_proto);
(*output_shapes)[i] = TensorShape(shape_proto);
VLOG(2) << "Inferred shape " << shape_proto.DebugString();
}
}
}
}
if (!got_output_shapes) {
return errors::InvalidArgument(
"Shape inference for HostCompute ", ctx->op_kernel().name(),
" failed: inference graph has no send from host node");
}
return Status::OK();
}
DataTypeVector input_dtypes_;
DataTypeVector output_dtypes_;
std::vector<string> ancestors_;
std::vector<TensorShape> static_output_shapes_;
std::vector<xla::Shape> static_xla_output_shapes_;
string original_node_name_;
// If static_xla_output_shapes_.size() == 1 then xla_output_shape_ is the
// unique output shape, otherwise it is a tuple of all the xla_output_shapes_.
xla::Shape static_xla_output_shape_;
string key_;
// If shape inference is performed at runtime, the graph needed to perform
// shape inference is stored in this function.
std::unique_ptr<FunctionBody> shape_inference_graph_function_;
int64 cost_estimate_;
int64 tpu_core_;
std::vector<string> token_input_nodes_;
TF_DISALLOW_COPY_AND_ASSIGN(HostComputeOp);
};
class SendToHostOp : public XlaOpKernel {
public:
explicit SendToHostOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinput", &input_dtype_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
OP_REQUIRES(ctx, !token_input_nodes_.empty(),
errors::InvalidArgument("XlaSendToHost node does not have ",
kXlaTokenInputNodesAttrName, " attr"));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
&original_node_name_));
}
~SendToHostOp() override {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
XlaCompiler* compiler = ctx->compiler();
xla::XlaOp operand = ctx->Input(0);
std::vector<xla::XlaOp> input_tokens;
for (auto& token_input_node : token_input_nodes_) {
auto token_or = compiler->GetNodeToken(token_input_node);
OP_REQUIRES_OK(ctx, token_or.status());
input_tokens.push_back(token_or.ValueOrDie());
}
xla::XlaOp token = xla::AfterAll(b, input_tokens);
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(input_dtype_, ctx->InputShape(0),
&xla_shape));
// Specify frontend attributes.
xla::FrontendAttributes attrs;
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
xla::primitive_util::LowercasePrimitiveTypeName(
xla_shape.element_type());
b->SetFrontendAttributes(attrs);
xla::ChannelHandle channel;
OP_REQUIRES_OK(ctx, compiler->GetDeviceToHostChannelHandle(key_, &channel));
xla::XlaOp output_token =
xla::SendToHost(operand, token, xla_shape, channel);
OP_REQUIRES_OK(ctx,
compiler->SetNodeToken(original_node_name_, output_token));
}
private:
DataType input_dtype_;
string key_;
std::vector<string> token_input_nodes_;
string original_node_name_;
TF_DISALLOW_COPY_AND_ASSIGN(SendToHostOp);
};
class RecvFromHostOp : public XlaOpKernel {
public:
explicit RecvFromHostOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput", &output_dtype_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &output_shape_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
OP_REQUIRES(ctx, !token_input_nodes_.empty(),
errors::InvalidArgument("XlaRecvFromHost node does not have ",
kXlaTokenInputNodesAttrName, " attr"));
}
~RecvFromHostOp() override {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
XlaCompiler* compiler = ctx->compiler();
std::vector<xla::XlaOp> input_tokens;
for (auto& token_input_node : token_input_nodes_) {
auto token_or = compiler->GetNodeToken(token_input_node);
OP_REQUIRES_OK(ctx, token_or.status());
input_tokens.push_back(token_or.ValueOrDie());
}
xla::XlaOp token = xla::AfterAll(b, input_tokens);
xla::Shape xla_shape;
OP_REQUIRES_OK(
ctx, TensorShapeToXLAShape(output_dtype_, output_shape_, &xla_shape));
// Specify frontend attributes.
xla::FrontendAttributes attrs;
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
xla::primitive_util::LowercasePrimitiveTypeName(
xla_shape.element_type());
b->SetFrontendAttributes(attrs);
xla::ChannelHandle channel;
OP_REQUIRES_OK(ctx, compiler->GetHostToDeviceChannelHandle(key_, &channel));
xla::XlaOp result = xla::RecvFromHost(token, xla_shape, channel);
// xla::RecvFromHost returns a tuple of (received data, token).
ctx->SetOutput(0, xla::GetTupleElement(result, 0));
OP_REQUIRES_OK(
ctx, compiler->SetNodeToken(name(), xla::GetTupleElement(result, 1)));
}
private:
DataType output_dtype_;
TensorShape output_shape_;
string key_;
std::vector<string> token_input_nodes_;
TF_DISALLOW_COPY_AND_ASSIGN(RecvFromHostOp);
};
REGISTER_XLA_OP(Name("XlaHostCompute"), HostComputeOp);
REGISTER_XLA_OP(Name("XlaSendToHost"), SendToHostOp);
REGISTER_XLA_OP(Name("XlaRecvFromHost"), RecvFromHostOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/index_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
namespace {
// This registration is needed here because the ArgMax Op is defined in
// third_party where DEVICE_TPU_XLA_JIT is not visible. Most Ops don't need a
// specific TPU whitelist, but ArgMax does because it has a separate CustomCall
// implementation on CPU.
REGISTER_XLA_OP(Name("ArgMax")
.Device(DEVICE_TPU_XLA_JIT)
.CompileTimeConstantInput("dimension"),
XlaArgMaxOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,162 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
namespace tensorflow {
namespace {
xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
XLA_Shape c_shape;
XLA_Shape c_infeed_shape;
ApiConverter::ToC(shape, &c_shape);
tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
&c_infeed_shape);
xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
ApiConverter::Free(&c_shape);
ApiConverter::Free(&c_infeed_shape);
return infeed_shape;
}
// Updates the layout of the given infeed shape, optionally considering the
// sharding of the op. If the op has tile sharding, assign the layout based on
// the shard shape.
Status UpdateInfeedLayout(xla::Shape* shape,
absl::optional<xla::OpSharding> sharding) {
if (sharding && sharding->type() == xla::OpSharding::OTHER) {
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
xla::HloSharding::FromProto(*sharding));
for (int64 i = 0; i < sharding->tile_assignment_devices_size(); ++i) {
auto device = sharding->tile_assignment_devices(i);
auto shard_shape =
GetTPUInfeedLayout(hlo_sharding.TileShape(*shape, device));
if (i == 0) {
*shape->mutable_layout() = shard_shape.layout();
}
if (xla::ShapeUtil::ElementsIn(shard_shape) == 0) {
// Shapes with 0 dimensions may be assigned with a different layout, but
// it doesn't matter since we're not sending any data.
continue;
}
if (!xla::LayoutUtil::Equal(shard_shape.layout(), shape->layout())) {
return xla::Unimplemented(
"Sharded infeed with non-uniform layouts is not supported. Try "
"turning off the infeed layout optimization "
"(--transpose_tpu_infeed=false) and report to XLA team.");
}
}
return Status::OK();
}
*shape = GetTPUInfeedLayout(*shape);
return Status::OK();
}
// TODO(pbar) Work out if we need to Infeed Tuples - if so then
// this op will need a way to provide a list of shapes
// since they can't be provided by the runtime JIT mechanism.
// (InfeedDequeue has no inputs!)
// Compare this op to tf.Queue operations which operate on N tensors.
// This TensorFlow op supports the XLA Infeed primitve.
class InfeedDequeueOp : public XlaOpKernel {
public:
explicit InfeedDequeueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
OP_REQUIRES_OK(ctx, UpdateInfeedLayout(&xla_shape_, b->sharding()));
ctx->SetOutput(0, xla::Infeed(b, xla_shape_));
}
private:
TensorShape shape_;
DataType dtype_;
xla::Shape xla_shape_;
TF_DISALLOW_COPY_AND_ASSIGN(InfeedDequeueOp);
};
REGISTER_XLA_OP(Name("InfeedDequeue"), InfeedDequeueOp);
// This TensorFlow op supports the XLA Infeed primitive for tuple types.
class InfeedDequeueTupleOp : public XlaOpKernel {
public:
explicit InfeedDequeueTupleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
for (int i = 0; i < shapes_.size(); i++) {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx,
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
xla_shapes_.push_back(xla_shape);
}
}
~InfeedDequeueTupleOp() override {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
for (int64 i = 0; i < xla_shapes_.size(); ++i) {
absl::optional<xla::OpSharding> sharding;
if (b->sharding()) {
sharding = b->sharding()->type() == xla::OpSharding::TUPLE
? b->sharding()->tuple_shardings(i)
: b->sharding();
}
OP_REQUIRES_OK(ctx, UpdateInfeedLayout(&xla_shapes_[i], sharding));
}
tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
auto tuple = xla::Infeed(b, tuple_shape_);
// Don't apply the infeed tuple sharding to the get-tuple-elements. They
// need non-tuple shardings.
xla::XlaScopedShardingAssignment clear_sharding(b, absl::nullopt);
for (int i = 0; i < shapes_.size(); ++i) {
ctx->SetOutput(i, xla::GetTupleElement(tuple, i));
}
}
private:
std::vector<TensorShape> shapes_;
DataTypeVector dtypes_;
std::vector<xla::Shape> xla_shapes_;
xla::Shape tuple_shape_;
TF_DISALLOW_COPY_AND_ASSIGN(InfeedDequeueTupleOp);
};
REGISTER_XLA_OP(Name("InfeedDequeueTuple"), InfeedDequeueTupleOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -0,0 +1,142 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
namespace {
class InplaceUpdateOp : public XlaOpKernel {
public:
explicit InplaceUpdateOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* ctx) override {
VLOG(3) << "InplaceUpdateOp::Compile";
DataType index_type = input_type(1);
OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
errors::InvalidArgument("index must be int32 or int64"));
// TF Args are X, I, V
const TensorShape x_shape = ctx->InputShape(0);
const TensorShape i_shape = ctx->InputShape(1);
const TensorShape v_shape = ctx->InputShape(2);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(i_shape) ||
TensorShapeUtils::IsVector(i_shape),
errors::InvalidArgument("index must be Rank 0 or 1"));
OP_REQUIRES(ctx, (x_shape.dims() == v_shape.dims()),
errors::InvalidArgument("X and V must have the same Rank,"
" X.shape=",
x_shape.DebugString(),
" V.shape=", v_shape.DebugString()));
auto* builder = ctx->builder();
auto const_zero = xla::ConstantR0(builder, 0);
auto current = ctx->Input(0);
for (int64 i = 0; i < i_shape.num_elements(); i++) {
std::vector<xla::XlaOp> update_indices;
update_indices.push_back(
xla::Reshape(xla::SliceInDim(ctx->Input(1), i, i + 1, 1, 0), {}));
for (int xi = 1; xi < x_shape.dims(); xi++) {
update_indices.push_back(const_zero);
}
current = xla::DynamicUpdateSlice(
current, xla::SliceInDim(ctx->Input(2), i, i + 1, 1, 0),
update_indices);
}
ctx->SetOutput(0, current);
// TODO(b/118122460): Uncomment+format this code to use XLA Scatter.
// auto* builder = ctx->builder();
// const auto initial = ctx->Input(0);
// const auto indices = ctx->Input(1);
// const auto updates = ctx->Input(2);
//
// auto result = XlaScatter(
// initial, updates, indices, /*indices_are_vectors=*/false,
// [](xla::XlaOp, xla::XlaOp second, xla::XlaBuilder*) { return
// second; }, builder);
// OP_REQUIRES_OK(ctx, result.status());
// ctx->SetOutput(0, result.ValueOrDie());
}
};
REGISTER_XLA_OP(Name("InplaceUpdate"), InplaceUpdateOp);
class InplaceAddOp : public XlaOpKernel {
public:
explicit InplaceAddOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* ctx) override {
VLOG(3) << "InplaceAddOp::Compile";
DataType index_type = input_type(1);
OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
errors::InvalidArgument("index must be int32 or int64"));
// TF Args are X, I, V
const TensorShape x_shape = ctx->InputShape(0);
const TensorShape i_shape = ctx->InputShape(1);
const TensorShape v_shape = ctx->InputShape(2);
OP_REQUIRES(ctx,
(TensorShapeUtils::IsScalar(i_shape) ||
((i_shape.dims() == 1) && (i_shape.num_elements() == 1))),
errors::InvalidArgument("index must be Rank 1 and size 1"));
OP_REQUIRES(ctx, (x_shape.dims() == v_shape.dims()),
errors::InvalidArgument("X and V must have the same Rank,"
" X.shape=",
x_shape.DebugString(),
" V.shape=", v_shape.DebugString()));
// Pad the indices out to the match the rank of params.
auto* builder = ctx->builder();
std::vector<xla::XlaOp> padded_indices;
padded_indices.push_back(xla::Reshape(ctx->Input(1), {}));
for (int i = 0; i < x_shape.dims() - 1; ++i) {
padded_indices.push_back(XlaHelpers::Zero(builder, index_type));
}
std::vector<int64> sizes;
sizes.push_back(1);
for (int i = 1; i < x_shape.dims(); i++) {
sizes.push_back(x_shape.dim_size(i));
}
auto prev = xla::DynamicSlice(ctx->Input(0), padded_indices, sizes);
auto updated = xla::Add(prev, ctx->Input(2));
auto result =
xla::DynamicUpdateSlice(ctx->Input(0), updated, padded_indices);
ctx->SetOutput(0, result);
}
};
REGISTER_XLA_OP(Name("InplaceAdd"), InplaceAddOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,91 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
namespace {
// This TensorFlow op implements the XLA Outfeed primitive.
class OutfeedEnqueueOp : public XlaOpKernel {
public:
explicit OutfeedEnqueueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::Shape xla_shape;
OP_REQUIRES_OK(
ctx, TensorShapeToXLAShape(dtype_, ctx->InputShape(0), &xla_shape));
// Outfeed configuration is only needed for embedding outfeed.
const string outfeed_config;
xla::Outfeed(ctx->Input(0), xla_shape, outfeed_config);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(OutfeedEnqueueOp);
};
REGISTER_XLA_OP(Name("OutfeedEnqueue"), OutfeedEnqueueOp);
// This TensorFlow op implements the XLA Outfeed primitive for tuple types.
class OutfeedEnqueueTupleOp : public XlaOpKernel {
public:
explicit OutfeedEnqueueTupleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
}
void Compile(XlaOpKernelContext* ctx) override {
std::vector<xla::XlaOp> handles;
std::vector<TensorShape> shapes;
auto inputs = ctx->InputList("inputs", &handles, &shapes);
std::vector<xla::Shape> xla_shapes;
for (int i = 0; i < shapes.size(); ++i) {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx,
TensorShapeToXLAShape(dtypes_[i], shapes[i], &xla_shape));
xla_shapes.push_back(xla_shape);
}
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(xla_shapes);
VLOG(1) << "OutfeedEnqueueTuple: "
<< xla::ShapeUtil::HumanStringWithLayout(tuple_shape);
auto b = ctx->builder();
auto tuple = xla::Tuple(b, handles);
// Outfeed configuration is only needed for embedding outfeed.
const string outfeed_config;
xla::Outfeed(tuple, tuple_shape, outfeed_config);
}
private:
DataTypeVector dtypes_;
TF_DISALLOW_COPY_AND_ASSIGN(OutfeedEnqueueTupleOp);
};
REGISTER_XLA_OP(Name("OutfeedEnqueueTuple"), OutfeedEnqueueTupleOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -0,0 +1,145 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
namespace {
// TODO(b/32945756): Add a scatter op in XLA and move this to a HLO optimization
// pass. Optimization for UnsortedSegmentSum on TPU: use k-hot matmul. This
// optimization requires:
// 1. data has dtype supported by TPU matmul and has rank of 1 or 2.
// 2. indices has rank of 1.
// 3. matmul op count is less than 800 billion.
//
// Example of calculating UnsortedSegmentSum by k-hot matmul:
// data shape [A, B]
// indices shape [A]
// num_segment N
// output shape [N, B]
// matmul op count N * A * B
// Step 1: create k-hot matrix
// k-hot matrix has shape of [A, N], where row i is responsible for
// collecting the sum of the i-th segment, concretely
// k-hot[i][j] = 1 if indices[i] = j
// Step 2: perform matmul
// the final result is obtained by multiplying k-hot matrix with data
// matrix, namely
// k-hot * data => result
// shape: [N, A] * [A, B] => [N, B]
xla::XlaOp KHotMatmul(XlaOpKernelContext* ctx, xla::XlaBuilder* builder,
const xla::XlaOp data, const xla::XlaOp indices,
int64 num_segments) {
DataType data_dtype = ctx->input_type(0);
xla::PrimitiveType indices_type = ctx->input_xla_type(1);
TensorShape data_shape = ctx->InputShape(0);
TensorShape indices_shape = ctx->InputShape(1);
xla::XlaOp linspace = xla::Iota(builder, indices_type, num_segments);
xla::XlaOp linspace_col = xla::Reshape(linspace, {num_segments, 1});
TensorShape indices_row_shape = indices_shape;
indices_row_shape.InsertDim(0, 1);
xla::XlaOp indices_row = xla::Reshape(indices, indices_row_shape.dim_sizes());
xla::XlaOp k_hot = xla::Eq(indices_row, linspace_col);
xla::XlaOp k_hot_with_data_dtype =
XlaHelpers::ConvertElementType(k_hot, data_dtype);
// F32 version of the KHotMatmul. It splits the F32 data into three
// BF16 partial data and run KHotMatmul for each of them. The final result
// is the summation of three BF16 results.
// Note that this still doesn't fully retain f32 precision.
// In particular, values smaller than 2^-111 may see loss of precision.
xla::PrecisionConfig precision_config;
if (data_dtype == DT_FLOAT) {
precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST);
} else {
CHECK_EQ(data_dtype, DT_BFLOAT16);
precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
}
precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
return xla::Dot(k_hot_with_data_dtype, data, &precision_config);
}
class UnsortedSegmentSum : public XlaOpKernel {
public:
explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
// output = unsorted_segment_sum(data, indices, num_segments)
// Compute a tensor such that:
// output[i] = sum over {j where indices[j] == i} of data[j]
// output[i] == 0 if i does not appear in indices
//
// Contrast with segment_sum(), which assumes indices are sorted and that
// max(indices)+1 is the desired size of the output.
//
// The returned output tensor has the same type as data, and the same shape
// as data with the first indices.rank dimensions are replaced
// by a single dimension with size num_segments.
xla::XlaOp data = ctx->Input(0);
TensorShape data_shape = ctx->InputShape(0);
xla::XlaOp indices = ctx->Input(1);
TensorShape indices_shape = ctx->InputShape(1);
int64 num_segments;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
errors::InvalidArgument(
"UnsortedSegmentSum requires that indices' rank be"
" less than or equal to data's rank."));
// Validate that indices.shape is a prefix of data.shape.
for (int d = 0; d < indices_shape.dims(); ++d) {
OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
errors::InvalidArgument(
"UnsortedSegmentSum requires indices shape to be prefix"
" of data_shape, but dimension ",
d, " differs ", data_shape.dim_size(d), " vs. ",
indices_shape.dim_size(d)));
}
xla::XlaBuilder* builder = ctx->builder();
TensorShape buffer_shape = data_shape;
buffer_shape.RemoveDimRange(0, indices_shape.dims());
buffer_shape.InsertDim(0, num_segments);
auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype_),
buffer_shape.dim_sizes());
auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
return a + b;
};
auto result = XlaScatter(buffer, /*updates=*/data, indices,
/*indices_are_vectors=*/false, combiner, builder);
OP_REQUIRES_OK(ctx, result.status());
ctx->SetOutput(0, result.ValueOrDie());
}
private:
DataType dtype_;
};
REGISTER_XLA_OP(Name("UnsortedSegmentSum")
.Device(DEVICE_TPU_XLA_JIT)
.CompileTimeConstantInput("num_segments"),
UnsortedSegmentSum);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,91 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
namespace {
class WhereOp : public XlaOpKernel {
public:
explicit WhereOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp condition = ctx->Input(0);
xla::StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(condition);
OP_REQUIRES_OK(ctx, input_shape.status());
// Use S32 as indices first, then convert to S64 in the end if needed.
auto iota_shape = input_shape.ValueOrDie();
iota_shape.set_element_type(xla::S32);
int64 flattened_size = xla::Product(iota_shape.dimensions());
xla::XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size});
xla::XlaOp zeros = xla::ZerosLike(reshaped_condition);
xla::XlaOp zeros_int = xla::ConvertElementType(zeros, xla::S32);
xla::XlaOp reshaped_condition_int =
xla::ConvertElementType(reshaped_condition, xla::S32);
xla::XlaOp compared = xla::ConvertElementType(
xla::Gt(reshaped_condition_int, zeros_int), xla::S32);
xla::XlaOp length = xla::ReduceAll(
compared, xla::Zero(ctx->builder(), xla::S32),
xla::CreateScalarAddComputation(xla::S32, ctx->builder()));
std::vector<xla::XlaOp> to_sort = {reshaped_condition_int};
std::vector<xla::PrimitiveType> types_to_sort = {xla::S32};
// Generate iota for each dimension, which after combining becomes
// indices of each element.
for (int64 axis = 0; axis < iota_shape.rank(); ++axis) {
xla::XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis);
xla::XlaOp reshaped = xla::Reshape(iota, {flattened_size});
to_sort.push_back(reshaped);
types_to_sort.push_back(xla::S32);
}
xla::XlaOp sorted = xla::Sort(
to_sort, xla::CreateScalarGtComputation(types_to_sort, ctx->builder()),
/*dimension=*/0,
/*is_stable=*/true);
std::vector<xla::XlaOp> to_concat;
for (int64 i = 0; i < iota_shape.rank(); ++i) {
xla::XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1);
to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1}));
}
xla::XlaOp result = xla::ConcatInDim(ctx->builder(), to_concat, 1);
result = xla::ConvertElementType(result, ctx->output_xla_type(0));
// Dynamic padder will handle the dynamic dimension.
xla::XlaOp result_padded = xla::SetDimensionSize(result, length, 0);
ctx->SetOutput(0, result_padded);
}
};
REGISTER_XLA_OP(Name("Where").Device(DEVICE_TPU_XLA_JIT), WhereOp);
} // namespace
} // namespace tensorflow