Introduce additional XLA TPU Ops to open source
PiperOrigin-RevId: 326343558 Change-Id: I47da1dc0c96cdf8223ccebef012e2a5088a857a4
This commit is contained in:
parent
8e01ae829b
commit
3ea5fc7f3f
@ -38,6 +38,7 @@ tf_kernel_library(
|
||||
":tpu_execute_op",
|
||||
":tpu_handle_to_key_op",
|
||||
":transfer_ops",
|
||||
"//tensorflow/core/tpu/kernels/xla:xla_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
52
tensorflow/core/tpu/kernels/xla/BUILD
Normal file
52
tensorflow/core/tpu/kernels/xla/BUILD
Normal file
@ -0,0 +1,52 @@
|
||||
# XLA Ops for TPUs
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_ops",
|
||||
srcs = [
|
||||
"get_item_op.cc",
|
||||
"host_compute_ops.cc",
|
||||
"index_ops.cc",
|
||||
"infeed_op.cc",
|
||||
"inplace_ops.cc",
|
||||
"outfeed_ops.cc",
|
||||
"segment_reduction_ops.cc",
|
||||
"where_op.cc",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:sharding_util",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_context",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/kernels:if_op",
|
||||
"//tensorflow/compiler/tf2xla/kernels:while_op",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/lib:scatter",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:comparators",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/core/tpu/kernels:cross_replica_ops",
|
||||
"//tensorflow/stream_executor/tpu:c_api_conversions",
|
||||
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
75
tensorflow/core/tpu/kernels/xla/get_item_op.cc
Normal file
75
tensorflow/core/tpu/kernels/xla/get_item_op.cc
Normal file
@ -0,0 +1,75 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// The Xla kernel to build up the computation for get_item(data, index).
|
||||
class GetItemXlaOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit GetItemXlaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape& data_shape = ctx->InputShape(0);
|
||||
const TensorShape& index_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsVectorOrHigher(data_shape),
|
||||
errors::InvalidArgument("data must be at least 1 dimensional."));
|
||||
OP_REQUIRES(ctx, index_shape.dims() == 1 && index_shape.dim_size(0) == 1,
|
||||
errors::InvalidArgument("index must be a vector of size 1."));
|
||||
|
||||
// NOTE(pbar) Use Concat to extend the indices to match cl/142279605.
|
||||
// This isn't the simplest way to emit the indices, but the code for
|
||||
// dynamic slice needs to be able to see that minor dims are const zero.
|
||||
auto const_zero = xla::ConstantR0(ctx->builder(), 0);
|
||||
std::vector<xla::XlaOp> operands;
|
||||
operands.push_back(xla::Reshape(ctx->Input(1), {}));
|
||||
for (int i = 1; i < data_shape.dims(); i++) {
|
||||
operands.push_back(const_zero);
|
||||
}
|
||||
|
||||
std::vector<int64> dims = {0};
|
||||
std::vector<int64> slice_sizes = {1};
|
||||
std::vector<int64> out_sizes = {};
|
||||
for (int i = 1; i < data_shape.dims(); i++) {
|
||||
dims.push_back(i);
|
||||
auto size = data_shape.dim_size(i);
|
||||
slice_sizes.push_back(size);
|
||||
out_sizes.push_back(size);
|
||||
}
|
||||
// NOTE: DynamicSlice here doesn't raise an error or wraps the index
|
||||
// if its out-of-range.
|
||||
auto slice = xla::DynamicSlice(ctx->Input(0), operands, slice_sizes);
|
||||
// In-order collapse to remove the 1st dim.
|
||||
auto reshape = xla::Reshape(slice, dims, out_sizes);
|
||||
ctx->SetOutput(0, reshape);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("GetItem"), GetItemXlaOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
498
tensorflow/core/tpu/kernels/xla/host_compute_ops.cc
Normal file
498
tensorflow/core/tpu/kernels/xla/host_compute_ops.cc
Normal file
@ -0,0 +1,498 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(phawkins) add a canonical copy of these operator names and refactor
|
||||
// everything to use it.
|
||||
static const char* const kSendFromHostOp = "_XlaSendFromHost";
|
||||
static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
|
||||
|
||||
Status MakeXlaShapes(gtl::ArraySlice<TensorShape> shapes,
|
||||
gtl::ArraySlice<DataType> dtypes,
|
||||
std::vector<xla::Shape>* xla_shapes,
|
||||
xla::Shape* xla_shape) {
|
||||
for (int i = 0; i < shapes.size(); i++) {
|
||||
xla::Shape single_xla_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(dtypes[i], shapes[i], &single_xla_shape));
|
||||
VLOG(2) << "Shape " << single_xla_shape.DebugString();
|
||||
xla_shapes->push_back(single_xla_shape);
|
||||
}
|
||||
// Temporarily add a dummy output to the shape array before making the tuple:
|
||||
// this output is used for control dependencies between host compute ops.
|
||||
xla_shapes->push_back(xla::ShapeUtil::MakeShape(xla::PRED, {}));
|
||||
*xla_shape = xla::ShapeUtil::MakeTupleShape(*xla_shapes);
|
||||
// Remove the dummy output from the vector that will be used to copy real
|
||||
// outputs from host to device.
|
||||
xla_shapes->pop_back();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// This TensorFlow pseudo-op is used to record host-side computation.
|
||||
class HostComputeOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit HostComputeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("cost_estimate_ns", &cost_estimate_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("tpu_core", &tpu_core_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_dtypes_));
|
||||
OP_REQUIRES(ctx, ctx->num_inputs() == input_dtypes_.size(),
|
||||
errors::InvalidArgument("Tinputs size=", input_dtypes_.size(),
|
||||
" but expected ", ctx->num_inputs(),
|
||||
" inputs."));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_dtypes_));
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() == output_dtypes_.size(),
|
||||
errors::InvalidArgument("Toutputs size=", output_dtypes_.size(),
|
||||
" but expected ", ctx->num_outputs(),
|
||||
" outputs."));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ancestors", &ancestors_));
|
||||
NameAttrList shape_inference_graph;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr("shape_inference_graph", &shape_inference_graph));
|
||||
if (shape_inference_graph.name().empty()) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &static_output_shapes_));
|
||||
OP_REQUIRES(ctx, static_output_shapes_.size() == output_dtypes_.size(),
|
||||
errors::InvalidArgument(
|
||||
"shapes attr list size ", static_output_shapes_.size(),
|
||||
" differs from dtypes size ", output_dtypes_.size()));
|
||||
OP_REQUIRES_OK(ctx, MakeXlaShapes(static_output_shapes_, output_dtypes_,
|
||||
&static_xla_output_shapes_,
|
||||
&static_xla_output_shape_));
|
||||
VLOG(2) << "Output Shape: " << static_xla_output_shape_.DebugString();
|
||||
} else {
|
||||
FunctionLibraryRuntime* flib_runtime = ctx->function_library();
|
||||
OP_REQUIRES(ctx, flib_runtime != nullptr,
|
||||
errors::Internal(
|
||||
"No function library runtime at kernel construction"));
|
||||
const FunctionLibraryDefinition* library =
|
||||
flib_runtime->GetFunctionLibraryDefinition();
|
||||
const FunctionDef* fdef = library->Find(shape_inference_graph.name());
|
||||
OP_REQUIRES(ctx, fdef != nullptr,
|
||||
errors::Internal("Failed to find function ",
|
||||
shape_inference_graph.name(),
|
||||
" in function library."));
|
||||
OP_REQUIRES_OK(ctx, FunctionDefToBodyHelper(
|
||||
*fdef, AttrSlice(&shape_inference_graph.attr()),
|
||||
library, &shape_inference_graph_function_));
|
||||
VLOG(2) << "Output Shape to be inferred at compile time";
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
|
||||
OP_REQUIRES(ctx, !token_input_nodes_.empty(),
|
||||
errors::InvalidArgument("XlaHostCompute node does not have ",
|
||||
kXlaTokenInputNodesAttrName, " attr"));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
|
||||
&original_node_name_));
|
||||
}
|
||||
|
||||
~HostComputeOp() override {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
XlaCompiler* compiler = ctx->compiler();
|
||||
|
||||
std::vector<xla::XlaOp> input_handles;
|
||||
std::vector<TensorShape> input_shapes;
|
||||
auto inputs = ctx->InputList("inputs", &input_handles, &input_shapes);
|
||||
const auto device_sharding = xla::sharding_builder::AssignDevice(tpu_core_);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(b, device_sharding);
|
||||
|
||||
std::vector<xla::XlaOp> input_tokens;
|
||||
for (auto& token_input_node : token_input_nodes_) {
|
||||
auto token_or = compiler->GetNodeToken(token_input_node);
|
||||
OP_REQUIRES_OK(ctx, token_or.status());
|
||||
input_tokens.push_back(token_or.ValueOrDie());
|
||||
}
|
||||
xla::XlaOp token = xla::AfterAll(b, input_tokens);
|
||||
|
||||
// Send values to the host.
|
||||
std::vector<xla::XlaOp> send_to_host_tokens;
|
||||
for (int i = 0; i < input_handles.size(); ++i) {
|
||||
const string channel_name = absl::StrCat(key_, "_dtoh_", i);
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(input_dtypes_[i],
|
||||
input_shapes[i], &xla_shape));
|
||||
// Specify frontend attributes.
|
||||
xla::FrontendAttributes attrs;
|
||||
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
|
||||
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(
|
||||
xla_shape.element_type());
|
||||
b->SetFrontendAttributes(attrs);
|
||||
xla::ChannelHandle channel;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, compiler->GetDeviceToHostChannelHandle(channel_name, &channel));
|
||||
send_to_host_tokens.push_back(
|
||||
xla::SendToHost(input_handles[i], token, xla_shape, channel));
|
||||
b->ClearOpMetadata();
|
||||
}
|
||||
xla::XlaOp recv_from_host_token_input =
|
||||
send_to_host_tokens.empty() ? token
|
||||
: xla::AfterAll(b, send_to_host_tokens);
|
||||
if (!input_handles.empty()) {
|
||||
// Register the shapes used in this transfer.
|
||||
OP_REQUIRES_OK(ctx, ctx->compiler()->SetDeviceToHostMetadata(
|
||||
key_, input_dtypes_, input_shapes));
|
||||
}
|
||||
// Compute the shapes of the values to copy to the device, if necessary.
|
||||
std::vector<TensorShape>* output_shapes;
|
||||
std::vector<xla::Shape>* xla_output_shapes;
|
||||
xla::Shape* xla_output_shape;
|
||||
std::vector<TensorShape> inferred_output_shapes;
|
||||
std::vector<xla::Shape> inferred_xla_output_shapes;
|
||||
xla::Shape inferred_xla_output_shape;
|
||||
if (shape_inference_graph_function_) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, InferOutputShapes(
|
||||
ctx, ctx->function_library()->GetFunctionLibraryDefinition(),
|
||||
&inferred_output_shapes));
|
||||
OP_REQUIRES_OK(ctx, MakeXlaShapes(inferred_output_shapes, output_dtypes_,
|
||||
&inferred_xla_output_shapes,
|
||||
&inferred_xla_output_shape));
|
||||
output_shapes = &inferred_output_shapes;
|
||||
xla_output_shapes = &inferred_xla_output_shapes;
|
||||
xla_output_shape = &inferred_xla_output_shape;
|
||||
} else {
|
||||
output_shapes = &static_output_shapes_;
|
||||
xla_output_shapes = &static_xla_output_shapes_;
|
||||
xla_output_shape = &static_xla_output_shape_;
|
||||
}
|
||||
OP_REQUIRES(
|
||||
ctx, output_shapes->size() == ctx->num_outputs(),
|
||||
errors::InvalidArgument("Op has ", ctx->num_outputs(), " outputs ",
|
||||
" but output shape vector of size ",
|
||||
output_shapes->size()));
|
||||
if (ctx->num_outputs() > 0) {
|
||||
// Register the shapes used in this transfer.
|
||||
OP_REQUIRES_OK(ctx, ctx->compiler()->SetHostToDeviceMetadata(
|
||||
key_, output_dtypes_, *output_shapes));
|
||||
}
|
||||
// Copy results to the device.
|
||||
std::vector<xla::XlaOp> recv_from_host_tokens;
|
||||
for (int i = 0; i < output_shapes->size(); ++i) {
|
||||
const string channel_name = absl::StrCat(key_, "_htod_", i);
|
||||
// Specify frontend attributes.
|
||||
xla::FrontendAttributes attrs;
|
||||
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
|
||||
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(
|
||||
xla_output_shapes->at(i).element_type());
|
||||
b->SetFrontendAttributes(attrs);
|
||||
xla::ChannelHandle channel;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, compiler->GetHostToDeviceChannelHandle(channel_name, &channel));
|
||||
|
||||
const auto result_token_tuple = xla::RecvFromHost(
|
||||
recv_from_host_token_input, xla_output_shapes->at(i), channel);
|
||||
b->ClearOpMetadata();
|
||||
recv_from_host_tokens.push_back(
|
||||
xla::GetTupleElement(result_token_tuple, /*index=*/1));
|
||||
ctx->SetOutput(i, xla::GetTupleElement(result_token_tuple, 0));
|
||||
}
|
||||
|
||||
// Set token output.
|
||||
xla::XlaOp token_output = recv_from_host_tokens.empty()
|
||||
? recv_from_host_token_input
|
||||
: xla::AfterAll(b, recv_from_host_tokens);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->compiler()->SetNodeToken(original_node_name_, token_output));
|
||||
}
|
||||
|
||||
private:
|
||||
Status LowerFunctionalOps(Graph* g,
|
||||
const FunctionLibraryDefinition& flib_def) {
|
||||
bool modified;
|
||||
do {
|
||||
modified = false;
|
||||
|
||||
// Lower "If" nodes first. Their body functions will be expanded as
|
||||
// function call nodes, which we will lower later.
|
||||
// We do not need to lower "While" nodes because shape inference can
|
||||
// handle them correctly (output shapes are input shapes).
|
||||
std::vector<Node*> if_nodes;
|
||||
for (Node* n : g->op_nodes()) {
|
||||
if (n->type_string() == "If") {
|
||||
if_nodes.push_back(n);
|
||||
}
|
||||
}
|
||||
for (Node* if_node : if_nodes) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteIfNode(if_node, g, /*keep_node_fetchable=*/false));
|
||||
}
|
||||
if (!if_nodes.empty()) {
|
||||
modified = true;
|
||||
}
|
||||
|
||||
// Lower function call nodes.
|
||||
std::vector<Node*> call_nodes;
|
||||
for (Node* n : g->op_nodes()) {
|
||||
if (IsFunctionCall(flib_def, *n)) {
|
||||
call_nodes.push_back(n);
|
||||
}
|
||||
}
|
||||
for (Node* call_node : call_nodes) {
|
||||
TF_RETURN_IF_ERROR(RewriteFunctionCallNode(
|
||||
call_node, g, flib_def, /*keep_caller_fetchable=*/false));
|
||||
}
|
||||
if (!call_nodes.empty()) {
|
||||
modified = true;
|
||||
}
|
||||
} while (modified);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InferOutputShapes(XlaOpKernelContext* ctx,
|
||||
const FunctionLibraryDefinition* flib_def,
|
||||
std::vector<TensorShape>* output_shapes) {
|
||||
// First unpack the inference graphdef from the attr into graph. Don't do
|
||||
// any shape inference at this point.
|
||||
Graph* graph = shape_inference_graph_function_->graph;
|
||||
|
||||
// Lower functional ops, because they are not friendly to shape inference.
|
||||
TF_RETURN_IF_ERROR(LowerFunctionalOps(graph, *flib_def));
|
||||
|
||||
// Now run shape inference, filling in the shapes of recvathost nodes.
|
||||
bool got_output_shapes = false;
|
||||
ShapeRefiner shape_refiner{graph->versions().producer(),
|
||||
graph->op_registry()};
|
||||
std::vector<Node*> nodes;
|
||||
GetReversePostOrder(*graph, &nodes);
|
||||
for (auto node : nodes) {
|
||||
TF_RETURN_IF_ERROR(shape_refiner.AddNode(node));
|
||||
if (node->type_string() == kRecvAtHostOp) {
|
||||
const AttrValue* key_attr = node->attrs().Find("key");
|
||||
if (key_attr == nullptr) {
|
||||
return errors::InvalidArgument("Node ", node->name(),
|
||||
" has no key attribute");
|
||||
}
|
||||
std::vector<TensorShape> dtoh_shapes;
|
||||
if (!ctx->compiler()
|
||||
->GetDeviceToHostShapes(key_attr->s(), &dtoh_shapes)
|
||||
.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape inference for HostCompute ", ctx->op_kernel().name(),
|
||||
" failed: host recv node ", node->name(), " with key '",
|
||||
key_attr->s(), "' has unknown shapes.");
|
||||
}
|
||||
if (dtoh_shapes.size() != node->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape inference for HostCompute ", ctx->op_kernel().name(),
|
||||
" failed: host recv node ", node->name(), " with key '",
|
||||
key_attr->s(), "' has ", node->num_outputs(),
|
||||
" outputs but inferred shapes expect ", dtoh_shapes.size());
|
||||
}
|
||||
for (int i = 0; i < node->num_outputs(); ++i) {
|
||||
shape_inference::InferenceContext* shape_ctx =
|
||||
shape_refiner.GetContext(node);
|
||||
shape_inference::ShapeHandle handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
shape_ctx->MakeShapeFromTensorShape(dtoh_shapes.at(i), &handle));
|
||||
shape_ctx->set_output(i, handle);
|
||||
}
|
||||
} else if (node->type_string() == kSendFromHostOp) {
|
||||
if (got_output_shapes) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape inference for HostCompute ", ctx->op_kernel().name(),
|
||||
" failed: inference graph has multiple send from host nodes");
|
||||
} else {
|
||||
got_output_shapes = true;
|
||||
// The last input is the dynamic key so don't record its shape.
|
||||
output_shapes->resize(node->num_inputs() - 1);
|
||||
shape_inference::InferenceContext* shape_ctx =
|
||||
shape_refiner.GetContext(node);
|
||||
for (int i = 0; i < node->num_inputs() - 1; ++i) {
|
||||
shape_inference::ShapeHandle handle = shape_ctx->input(i);
|
||||
if (!shape_ctx->FullyDefined(handle)) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape inference for HostCompute ", ctx->op_kernel().name(),
|
||||
" failed: send from host node ", node->name(),
|
||||
" has non-fully defined shape of input index ", i);
|
||||
}
|
||||
TensorShapeProto shape_proto;
|
||||
shape_ctx->ShapeHandleToProto(handle, &shape_proto);
|
||||
(*output_shapes)[i] = TensorShape(shape_proto);
|
||||
VLOG(2) << "Inferred shape " << shape_proto.DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!got_output_shapes) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape inference for HostCompute ", ctx->op_kernel().name(),
|
||||
" failed: inference graph has no send from host node");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DataTypeVector input_dtypes_;
|
||||
DataTypeVector output_dtypes_;
|
||||
std::vector<string> ancestors_;
|
||||
std::vector<TensorShape> static_output_shapes_;
|
||||
std::vector<xla::Shape> static_xla_output_shapes_;
|
||||
string original_node_name_;
|
||||
// If static_xla_output_shapes_.size() == 1 then xla_output_shape_ is the
|
||||
// unique output shape, otherwise it is a tuple of all the xla_output_shapes_.
|
||||
xla::Shape static_xla_output_shape_;
|
||||
string key_;
|
||||
// If shape inference is performed at runtime, the graph needed to perform
|
||||
// shape inference is stored in this function.
|
||||
std::unique_ptr<FunctionBody> shape_inference_graph_function_;
|
||||
int64 cost_estimate_;
|
||||
int64 tpu_core_;
|
||||
std::vector<string> token_input_nodes_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HostComputeOp);
|
||||
};
|
||||
|
||||
class SendToHostOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit SendToHostOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinput", &input_dtype_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
|
||||
OP_REQUIRES(ctx, !token_input_nodes_.empty(),
|
||||
errors::InvalidArgument("XlaSendToHost node does not have ",
|
||||
kXlaTokenInputNodesAttrName, " attr"));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
|
||||
&original_node_name_));
|
||||
}
|
||||
|
||||
~SendToHostOp() override {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
|
||||
XlaCompiler* compiler = ctx->compiler();
|
||||
xla::XlaOp operand = ctx->Input(0);
|
||||
std::vector<xla::XlaOp> input_tokens;
|
||||
for (auto& token_input_node : token_input_nodes_) {
|
||||
auto token_or = compiler->GetNodeToken(token_input_node);
|
||||
OP_REQUIRES_OK(ctx, token_or.status());
|
||||
input_tokens.push_back(token_or.ValueOrDie());
|
||||
}
|
||||
xla::XlaOp token = xla::AfterAll(b, input_tokens);
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(input_dtype_, ctx->InputShape(0),
|
||||
&xla_shape));
|
||||
// Specify frontend attributes.
|
||||
xla::FrontendAttributes attrs;
|
||||
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
|
||||
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(
|
||||
xla_shape.element_type());
|
||||
b->SetFrontendAttributes(attrs);
|
||||
xla::ChannelHandle channel;
|
||||
OP_REQUIRES_OK(ctx, compiler->GetDeviceToHostChannelHandle(key_, &channel));
|
||||
xla::XlaOp output_token =
|
||||
xla::SendToHost(operand, token, xla_shape, channel);
|
||||
OP_REQUIRES_OK(ctx,
|
||||
compiler->SetNodeToken(original_node_name_, output_token));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType input_dtype_;
|
||||
string key_;
|
||||
std::vector<string> token_input_nodes_;
|
||||
string original_node_name_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SendToHostOp);
|
||||
};
|
||||
|
||||
class RecvFromHostOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit RecvFromHostOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput", &output_dtype_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &output_shape_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
|
||||
OP_REQUIRES(ctx, !token_input_nodes_.empty(),
|
||||
errors::InvalidArgument("XlaRecvFromHost node does not have ",
|
||||
kXlaTokenInputNodesAttrName, " attr"));
|
||||
}
|
||||
|
||||
~RecvFromHostOp() override {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
|
||||
XlaCompiler* compiler = ctx->compiler();
|
||||
std::vector<xla::XlaOp> input_tokens;
|
||||
for (auto& token_input_node : token_input_nodes_) {
|
||||
auto token_or = compiler->GetNodeToken(token_input_node);
|
||||
OP_REQUIRES_OK(ctx, token_or.status());
|
||||
input_tokens.push_back(token_or.ValueOrDie());
|
||||
}
|
||||
xla::XlaOp token = xla::AfterAll(b, input_tokens);
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, TensorShapeToXLAShape(output_dtype_, output_shape_, &xla_shape));
|
||||
// Specify frontend attributes.
|
||||
xla::FrontendAttributes attrs;
|
||||
(*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
|
||||
(*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(
|
||||
xla_shape.element_type());
|
||||
b->SetFrontendAttributes(attrs);
|
||||
xla::ChannelHandle channel;
|
||||
OP_REQUIRES_OK(ctx, compiler->GetHostToDeviceChannelHandle(key_, &channel));
|
||||
xla::XlaOp result = xla::RecvFromHost(token, xla_shape, channel);
|
||||
// xla::RecvFromHost returns a tuple of (received data, token).
|
||||
ctx->SetOutput(0, xla::GetTupleElement(result, 0));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, compiler->SetNodeToken(name(), xla::GetTupleElement(result, 1)));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType output_dtype_;
|
||||
TensorShape output_shape_;
|
||||
string key_;
|
||||
std::vector<string> token_input_nodes_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RecvFromHostOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaHostCompute"), HostComputeOp);
|
||||
REGISTER_XLA_OP(Name("XlaSendToHost"), SendToHostOp);
|
||||
REGISTER_XLA_OP(Name("XlaRecvFromHost"), RecvFromHostOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
34
tensorflow/core/tpu/kernels/xla/index_ops.cc
Normal file
34
tensorflow/core/tpu/kernels/xla/index_ops.cc
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/index_ops.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// This registration is needed here because the ArgMax Op is defined in
|
||||
// third_party where DEVICE_TPU_XLA_JIT is not visible. Most Ops don't need a
|
||||
// specific TPU whitelist, but ArgMax does because it has a separate CustomCall
|
||||
// implementation on CPU.
|
||||
REGISTER_XLA_OP(Name("ArgMax")
|
||||
.Device(DEVICE_TPU_XLA_JIT)
|
||||
.CompileTimeConstantInput("dimension"),
|
||||
XlaArgMaxOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
162
tensorflow/core/tpu/kernels/xla/infeed_op.cc
Normal file
162
tensorflow/core/tpu/kernels/xla/infeed_op.cc
Normal file
@ -0,0 +1,162 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
|
||||
XLA_Shape c_shape;
|
||||
XLA_Shape c_infeed_shape;
|
||||
|
||||
ApiConverter::ToC(shape, &c_shape);
|
||||
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
|
||||
&c_infeed_shape);
|
||||
xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
|
||||
ApiConverter::Free(&c_shape);
|
||||
ApiConverter::Free(&c_infeed_shape);
|
||||
return infeed_shape;
|
||||
}
|
||||
|
||||
// Updates the layout of the given infeed shape, optionally considering the
|
||||
// sharding of the op. If the op has tile sharding, assign the layout based on
|
||||
// the shard shape.
|
||||
Status UpdateInfeedLayout(xla::Shape* shape,
|
||||
absl::optional<xla::OpSharding> sharding) {
|
||||
if (sharding && sharding->type() == xla::OpSharding::OTHER) {
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
||||
xla::HloSharding::FromProto(*sharding));
|
||||
for (int64 i = 0; i < sharding->tile_assignment_devices_size(); ++i) {
|
||||
auto device = sharding->tile_assignment_devices(i);
|
||||
auto shard_shape =
|
||||
GetTPUInfeedLayout(hlo_sharding.TileShape(*shape, device));
|
||||
if (i == 0) {
|
||||
*shape->mutable_layout() = shard_shape.layout();
|
||||
}
|
||||
if (xla::ShapeUtil::ElementsIn(shard_shape) == 0) {
|
||||
// Shapes with 0 dimensions may be assigned with a different layout, but
|
||||
// it doesn't matter since we're not sending any data.
|
||||
continue;
|
||||
}
|
||||
if (!xla::LayoutUtil::Equal(shard_shape.layout(), shape->layout())) {
|
||||
return xla::Unimplemented(
|
||||
"Sharded infeed with non-uniform layouts is not supported. Try "
|
||||
"turning off the infeed layout optimization "
|
||||
"(--transpose_tpu_infeed=false) and report to XLA team.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
*shape = GetTPUInfeedLayout(*shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(pbar) Work out if we need to Infeed Tuples - if so then
|
||||
// this op will need a way to provide a list of shapes
|
||||
// since they can't be provided by the runtime JIT mechanism.
|
||||
// (InfeedDequeue has no inputs!)
|
||||
// Compare this op to tf.Queue operations which operate on N tensors.
|
||||
|
||||
// This TensorFlow op supports the XLA Infeed primitve.
|
||||
class InfeedDequeueOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit InfeedDequeueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
OP_REQUIRES_OK(ctx, UpdateInfeedLayout(&xla_shape_, b->sharding()));
|
||||
ctx->SetOutput(0, xla::Infeed(b, xla_shape_));
|
||||
}
|
||||
|
||||
private:
|
||||
TensorShape shape_;
|
||||
DataType dtype_;
|
||||
xla::Shape xla_shape_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InfeedDequeueOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("InfeedDequeue"), InfeedDequeueOp);
|
||||
|
||||
// This TensorFlow op supports the XLA Infeed primitive for tuple types.
|
||||
class InfeedDequeueTupleOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit InfeedDequeueTupleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||
for (int i = 0; i < shapes_.size(); i++) {
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
|
||||
xla_shapes_.push_back(xla_shape);
|
||||
}
|
||||
}
|
||||
|
||||
~InfeedDequeueTupleOp() override {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
for (int64 i = 0; i < xla_shapes_.size(); ++i) {
|
||||
absl::optional<xla::OpSharding> sharding;
|
||||
if (b->sharding()) {
|
||||
sharding = b->sharding()->type() == xla::OpSharding::TUPLE
|
||||
? b->sharding()->tuple_shardings(i)
|
||||
: b->sharding();
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, UpdateInfeedLayout(&xla_shapes_[i], sharding));
|
||||
}
|
||||
tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
|
||||
auto tuple = xla::Infeed(b, tuple_shape_);
|
||||
|
||||
// Don't apply the infeed tuple sharding to the get-tuple-elements. They
|
||||
// need non-tuple shardings.
|
||||
xla::XlaScopedShardingAssignment clear_sharding(b, absl::nullopt);
|
||||
for (int i = 0; i < shapes_.size(); ++i) {
|
||||
ctx->SetOutput(i, xla::GetTupleElement(tuple, i));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<TensorShape> shapes_;
|
||||
DataTypeVector dtypes_;
|
||||
std::vector<xla::Shape> xla_shapes_;
|
||||
xla::Shape tuple_shape_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InfeedDequeueTupleOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("InfeedDequeueTuple"), InfeedDequeueTupleOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
142
tensorflow/core/tpu/kernels/xla/inplace_ops.cc
Normal file
142
tensorflow/core/tpu/kernels/xla/inplace_ops.cc
Normal file
@ -0,0 +1,142 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class InplaceUpdateOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit InplaceUpdateOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
VLOG(3) << "InplaceUpdateOp::Compile";
|
||||
|
||||
DataType index_type = input_type(1);
|
||||
OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
|
||||
errors::InvalidArgument("index must be int32 or int64"));
|
||||
|
||||
// TF Args are X, I, V
|
||||
const TensorShape x_shape = ctx->InputShape(0);
|
||||
const TensorShape i_shape = ctx->InputShape(1);
|
||||
const TensorShape v_shape = ctx->InputShape(2);
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
TensorShapeUtils::IsScalar(i_shape) ||
|
||||
TensorShapeUtils::IsVector(i_shape),
|
||||
errors::InvalidArgument("index must be Rank 0 or 1"));
|
||||
OP_REQUIRES(ctx, (x_shape.dims() == v_shape.dims()),
|
||||
errors::InvalidArgument("X and V must have the same Rank,"
|
||||
" X.shape=",
|
||||
x_shape.DebugString(),
|
||||
" V.shape=", v_shape.DebugString()));
|
||||
|
||||
auto* builder = ctx->builder();
|
||||
auto const_zero = xla::ConstantR0(builder, 0);
|
||||
auto current = ctx->Input(0);
|
||||
|
||||
for (int64 i = 0; i < i_shape.num_elements(); i++) {
|
||||
std::vector<xla::XlaOp> update_indices;
|
||||
update_indices.push_back(
|
||||
xla::Reshape(xla::SliceInDim(ctx->Input(1), i, i + 1, 1, 0), {}));
|
||||
for (int xi = 1; xi < x_shape.dims(); xi++) {
|
||||
update_indices.push_back(const_zero);
|
||||
}
|
||||
current = xla::DynamicUpdateSlice(
|
||||
current, xla::SliceInDim(ctx->Input(2), i, i + 1, 1, 0),
|
||||
update_indices);
|
||||
}
|
||||
ctx->SetOutput(0, current);
|
||||
|
||||
// TODO(b/118122460): Uncomment+format this code to use XLA Scatter.
|
||||
// auto* builder = ctx->builder();
|
||||
// const auto initial = ctx->Input(0);
|
||||
// const auto indices = ctx->Input(1);
|
||||
// const auto updates = ctx->Input(2);
|
||||
//
|
||||
// auto result = XlaScatter(
|
||||
// initial, updates, indices, /*indices_are_vectors=*/false,
|
||||
// [](xla::XlaOp, xla::XlaOp second, xla::XlaBuilder*) { return
|
||||
// second; }, builder);
|
||||
// OP_REQUIRES_OK(ctx, result.status());
|
||||
// ctx->SetOutput(0, result.ValueOrDie());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("InplaceUpdate"), InplaceUpdateOp);
|
||||
|
||||
class InplaceAddOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit InplaceAddOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
VLOG(3) << "InplaceAddOp::Compile";
|
||||
|
||||
DataType index_type = input_type(1);
|
||||
OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
|
||||
errors::InvalidArgument("index must be int32 or int64"));
|
||||
|
||||
// TF Args are X, I, V
|
||||
const TensorShape x_shape = ctx->InputShape(0);
|
||||
const TensorShape i_shape = ctx->InputShape(1);
|
||||
const TensorShape v_shape = ctx->InputShape(2);
|
||||
OP_REQUIRES(ctx,
|
||||
(TensorShapeUtils::IsScalar(i_shape) ||
|
||||
((i_shape.dims() == 1) && (i_shape.num_elements() == 1))),
|
||||
errors::InvalidArgument("index must be Rank 1 and size 1"));
|
||||
OP_REQUIRES(ctx, (x_shape.dims() == v_shape.dims()),
|
||||
errors::InvalidArgument("X and V must have the same Rank,"
|
||||
" X.shape=",
|
||||
x_shape.DebugString(),
|
||||
" V.shape=", v_shape.DebugString()));
|
||||
// Pad the indices out to the match the rank of params.
|
||||
auto* builder = ctx->builder();
|
||||
std::vector<xla::XlaOp> padded_indices;
|
||||
padded_indices.push_back(xla::Reshape(ctx->Input(1), {}));
|
||||
for (int i = 0; i < x_shape.dims() - 1; ++i) {
|
||||
padded_indices.push_back(XlaHelpers::Zero(builder, index_type));
|
||||
}
|
||||
|
||||
std::vector<int64> sizes;
|
||||
sizes.push_back(1);
|
||||
for (int i = 1; i < x_shape.dims(); i++) {
|
||||
sizes.push_back(x_shape.dim_size(i));
|
||||
}
|
||||
|
||||
auto prev = xla::DynamicSlice(ctx->Input(0), padded_indices, sizes);
|
||||
auto updated = xla::Add(prev, ctx->Input(2));
|
||||
auto result =
|
||||
xla::DynamicUpdateSlice(ctx->Input(0), updated, padded_indices);
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("InplaceAdd"), InplaceAddOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
91
tensorflow/core/tpu/kernels/xla/outfeed_ops.cc
Normal file
91
tensorflow/core/tpu/kernels/xla/outfeed_ops.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// This TensorFlow op implements the XLA Outfeed primitive.
|
||||
class OutfeedEnqueueOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit OutfeedEnqueueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, TensorShapeToXLAShape(dtype_, ctx->InputShape(0), &xla_shape));
|
||||
// Outfeed configuration is only needed for embedding outfeed.
|
||||
const string outfeed_config;
|
||||
xla::Outfeed(ctx->Input(0), xla_shape, outfeed_config);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(OutfeedEnqueueOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("OutfeedEnqueue"), OutfeedEnqueueOp);
|
||||
|
||||
// This TensorFlow op implements the XLA Outfeed primitive for tuple types.
|
||||
class OutfeedEnqueueTupleOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit OutfeedEnqueueTupleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
std::vector<xla::XlaOp> handles;
|
||||
std::vector<TensorShape> shapes;
|
||||
auto inputs = ctx->InputList("inputs", &handles, &shapes);
|
||||
|
||||
std::vector<xla::Shape> xla_shapes;
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
TensorShapeToXLAShape(dtypes_[i], shapes[i], &xla_shape));
|
||||
xla_shapes.push_back(xla_shape);
|
||||
}
|
||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(xla_shapes);
|
||||
VLOG(1) << "OutfeedEnqueueTuple: "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(tuple_shape);
|
||||
auto b = ctx->builder();
|
||||
auto tuple = xla::Tuple(b, handles);
|
||||
// Outfeed configuration is only needed for embedding outfeed.
|
||||
const string outfeed_config;
|
||||
xla::Outfeed(tuple, tuple_shape, outfeed_config);
|
||||
}
|
||||
|
||||
private:
|
||||
DataTypeVector dtypes_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(OutfeedEnqueueTupleOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("OutfeedEnqueueTuple"), OutfeedEnqueueTupleOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
145
tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
Normal file
145
tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
Normal file
@ -0,0 +1,145 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// TODO(b/32945756): Add a scatter op in XLA and move this to a HLO optimization
|
||||
// pass. Optimization for UnsortedSegmentSum on TPU: use k-hot matmul. This
|
||||
// optimization requires:
|
||||
// 1. data has dtype supported by TPU matmul and has rank of 1 or 2.
|
||||
// 2. indices has rank of 1.
|
||||
// 3. matmul op count is less than 800 billion.
|
||||
//
|
||||
// Example of calculating UnsortedSegmentSum by k-hot matmul:
|
||||
// data shape [A, B]
|
||||
// indices shape [A]
|
||||
// num_segment N
|
||||
// output shape [N, B]
|
||||
// matmul op count N * A * B
|
||||
// Step 1: create k-hot matrix
|
||||
// k-hot matrix has shape of [A, N], where row i is responsible for
|
||||
// collecting the sum of the i-th segment, concretely
|
||||
// k-hot[i][j] = 1 if indices[i] = j
|
||||
// Step 2: perform matmul
|
||||
// the final result is obtained by multiplying k-hot matrix with data
|
||||
// matrix, namely
|
||||
// k-hot * data => result
|
||||
// shape: [N, A] * [A, B] => [N, B]
|
||||
xla::XlaOp KHotMatmul(XlaOpKernelContext* ctx, xla::XlaBuilder* builder,
|
||||
const xla::XlaOp data, const xla::XlaOp indices,
|
||||
int64 num_segments) {
|
||||
DataType data_dtype = ctx->input_type(0);
|
||||
xla::PrimitiveType indices_type = ctx->input_xla_type(1);
|
||||
TensorShape data_shape = ctx->InputShape(0);
|
||||
TensorShape indices_shape = ctx->InputShape(1);
|
||||
xla::XlaOp linspace = xla::Iota(builder, indices_type, num_segments);
|
||||
xla::XlaOp linspace_col = xla::Reshape(linspace, {num_segments, 1});
|
||||
TensorShape indices_row_shape = indices_shape;
|
||||
indices_row_shape.InsertDim(0, 1);
|
||||
xla::XlaOp indices_row = xla::Reshape(indices, indices_row_shape.dim_sizes());
|
||||
xla::XlaOp k_hot = xla::Eq(indices_row, linspace_col);
|
||||
xla::XlaOp k_hot_with_data_dtype =
|
||||
XlaHelpers::ConvertElementType(k_hot, data_dtype);
|
||||
// F32 version of the KHotMatmul. It splits the F32 data into three
|
||||
// BF16 partial data and run KHotMatmul for each of them. The final result
|
||||
// is the summation of three BF16 results.
|
||||
// Note that this still doesn't fully retain f32 precision.
|
||||
// In particular, values smaller than 2^-111 may see loss of precision.
|
||||
xla::PrecisionConfig precision_config;
|
||||
if (data_dtype == DT_FLOAT) {
|
||||
precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST);
|
||||
} else {
|
||||
CHECK_EQ(data_dtype, DT_BFLOAT16);
|
||||
precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
|
||||
}
|
||||
precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
|
||||
return xla::Dot(k_hot_with_data_dtype, data, &precision_config);
|
||||
}
|
||||
|
||||
class UnsortedSegmentSum : public XlaOpKernel {
|
||||
public:
|
||||
explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
// output = unsorted_segment_sum(data, indices, num_segments)
|
||||
// Compute a tensor such that:
|
||||
// output[i] = sum over {j where indices[j] == i} of data[j]
|
||||
// output[i] == 0 if i does not appear in indices
|
||||
//
|
||||
// Contrast with segment_sum(), which assumes indices are sorted and that
|
||||
// max(indices)+1 is the desired size of the output.
|
||||
//
|
||||
// The returned output tensor has the same type as data, and the same shape
|
||||
// as data with the first indices.rank dimensions are replaced
|
||||
// by a single dimension with size num_segments.
|
||||
xla::XlaOp data = ctx->Input(0);
|
||||
TensorShape data_shape = ctx->InputShape(0);
|
||||
|
||||
xla::XlaOp indices = ctx->Input(1);
|
||||
TensorShape indices_shape = ctx->InputShape(1);
|
||||
|
||||
int64 num_segments;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
|
||||
|
||||
OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
|
||||
errors::InvalidArgument(
|
||||
"UnsortedSegmentSum requires that indices' rank be"
|
||||
" less than or equal to data's rank."));
|
||||
// Validate that indices.shape is a prefix of data.shape.
|
||||
for (int d = 0; d < indices_shape.dims(); ++d) {
|
||||
OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
|
||||
errors::InvalidArgument(
|
||||
"UnsortedSegmentSum requires indices shape to be prefix"
|
||||
" of data_shape, but dimension ",
|
||||
d, " differs ", data_shape.dim_size(d), " vs. ",
|
||||
indices_shape.dim_size(d)));
|
||||
}
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
TensorShape buffer_shape = data_shape;
|
||||
buffer_shape.RemoveDimRange(0, indices_shape.dims());
|
||||
buffer_shape.InsertDim(0, num_segments);
|
||||
auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype_),
|
||||
buffer_shape.dim_sizes());
|
||||
|
||||
auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
|
||||
return a + b;
|
||||
};
|
||||
|
||||
auto result = XlaScatter(buffer, /*updates=*/data, indices,
|
||||
/*indices_are_vectors=*/false, combiner, builder);
|
||||
OP_REQUIRES_OK(ctx, result.status());
|
||||
ctx->SetOutput(0, result.ValueOrDie());
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("UnsortedSegmentSum")
|
||||
.Device(DEVICE_TPU_XLA_JIT)
|
||||
.CompileTimeConstantInput("num_segments"),
|
||||
UnsortedSegmentSum);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
91
tensorflow/core/tpu/kernels/xla/where_op.cc
Normal file
91
tensorflow/core/tpu/kernels/xla/where_op.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/ops_util.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class WhereOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit WhereOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp condition = ctx->Input(0);
|
||||
xla::StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(condition);
|
||||
OP_REQUIRES_OK(ctx, input_shape.status());
|
||||
// Use S32 as indices first, then convert to S64 in the end if needed.
|
||||
auto iota_shape = input_shape.ValueOrDie();
|
||||
iota_shape.set_element_type(xla::S32);
|
||||
|
||||
int64 flattened_size = xla::Product(iota_shape.dimensions());
|
||||
xla::XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size});
|
||||
xla::XlaOp zeros = xla::ZerosLike(reshaped_condition);
|
||||
xla::XlaOp zeros_int = xla::ConvertElementType(zeros, xla::S32);
|
||||
xla::XlaOp reshaped_condition_int =
|
||||
xla::ConvertElementType(reshaped_condition, xla::S32);
|
||||
xla::XlaOp compared = xla::ConvertElementType(
|
||||
xla::Gt(reshaped_condition_int, zeros_int), xla::S32);
|
||||
xla::XlaOp length = xla::ReduceAll(
|
||||
compared, xla::Zero(ctx->builder(), xla::S32),
|
||||
xla::CreateScalarAddComputation(xla::S32, ctx->builder()));
|
||||
|
||||
std::vector<xla::XlaOp> to_sort = {reshaped_condition_int};
|
||||
std::vector<xla::PrimitiveType> types_to_sort = {xla::S32};
|
||||
// Generate iota for each dimension, which after combining becomes
|
||||
// indices of each element.
|
||||
for (int64 axis = 0; axis < iota_shape.rank(); ++axis) {
|
||||
xla::XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis);
|
||||
xla::XlaOp reshaped = xla::Reshape(iota, {flattened_size});
|
||||
to_sort.push_back(reshaped);
|
||||
types_to_sort.push_back(xla::S32);
|
||||
}
|
||||
|
||||
xla::XlaOp sorted = xla::Sort(
|
||||
to_sort, xla::CreateScalarGtComputation(types_to_sort, ctx->builder()),
|
||||
/*dimension=*/0,
|
||||
/*is_stable=*/true);
|
||||
std::vector<xla::XlaOp> to_concat;
|
||||
for (int64 i = 0; i < iota_shape.rank(); ++i) {
|
||||
xla::XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1);
|
||||
to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1}));
|
||||
}
|
||||
|
||||
xla::XlaOp result = xla::ConcatInDim(ctx->builder(), to_concat, 1);
|
||||
result = xla::ConvertElementType(result, ctx->output_xla_type(0));
|
||||
// Dynamic padder will handle the dynamic dimension.
|
||||
xla::XlaOp result_padded = xla::SetDimensionSize(result, length, 0);
|
||||
ctx->SetOutput(0, result_padded);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Where").Device(DEVICE_TPU_XLA_JIT), WhereOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user