Expose XLA Gather/Scatter to python via the tf2xla bridge
The semantics of these operations are a lot more flexible then tf.gather and tf.scatter so this change enables the usage of these more powerful APIs. This API is *not* stable so any user have to be prepared for potential breaking changes. PiperOrigin-RevId: 286569124 Change-Id: If1690d749deb5d08bb89532842fecaad2c2921d4
This commit is contained in:
parent
067ffdd467
commit
a5a0ad4300
@ -2036,6 +2036,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"XlaDynamicSlice",
|
||||
"XlaDynamicUpdateSlice",
|
||||
"XlaEinsum",
|
||||
"XlaGather",
|
||||
"XlaIf",
|
||||
"XlaKeyValueSort",
|
||||
"XlaPad",
|
||||
@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"XlaReduce",
|
||||
"XlaReduceWindow",
|
||||
"XlaReplicaId",
|
||||
"XlaScatter",
|
||||
"XlaSelectAndScatter",
|
||||
"XlaSelfAdjointEig",
|
||||
"XlaSend",
|
||||
|
@ -48,6 +48,7 @@ tf_kernel_library(
|
||||
"function_ops.cc",
|
||||
"gather_op.cc",
|
||||
"gather_op_helpers.h",
|
||||
"gather_scatter_ops.cc",
|
||||
"identity_op.cc",
|
||||
"image_ops.cc",
|
||||
"image_resize_ops.cc",
|
||||
|
102
tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc
Normal file
102
tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc
Normal file
@ -0,0 +1,102 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class GatherOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
||||
string dnums_attr;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
|
||||
OP_REQUIRES(
|
||||
context, dnums_.ParsePartialFromString(dnums_attr),
|
||||
errors::InvalidArgument("Error parsing gather dimension numbers"));
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
std::vector<int64> slice_sizes;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ConstantInputAsIntVector("slice_sizes", &slice_sizes));
|
||||
xla::XlaOp result =
|
||||
xla::Gather(ctx->Input("operand"), ctx->Input("start_indices"), dnums_,
|
||||
slice_sizes, indices_are_sorted_);
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
|
||||
private:
|
||||
xla::GatherDimensionNumbers dnums_;
|
||||
bool indices_are_sorted_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaGather"), GatherOp);
|
||||
|
||||
class ScatterOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("update_computation", &update_computation_));
|
||||
string dnums_attr;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
|
||||
OP_REQUIRES(
|
||||
context, dnums_.ParsePartialFromString(dnums_attr),
|
||||
errors::InvalidArgument("Error parsing scatter dimension numbers"));
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const DataType dtype = ctx->input_type(0);
|
||||
|
||||
XlaCompiler::Argument update_computation_arg;
|
||||
update_computation_arg.kind = XlaCompiler::Argument::kParameter;
|
||||
update_computation_arg.type = dtype;
|
||||
update_computation_arg.shape = TensorShape();
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.use_tuple_arg = false;
|
||||
compile_options.always_return_tuple = false;
|
||||
compile_options.is_entry_computation = false;
|
||||
XlaCompiler::CompilationResult update_computation;
|
||||
OP_REQUIRES_OK(ctx, ctx->compiler()->CompileFunction(
|
||||
compile_options, *update_computation_,
|
||||
{update_computation_arg, update_computation_arg},
|
||||
&update_computation));
|
||||
|
||||
xla::XlaOp result =
|
||||
xla::Scatter(ctx->Input("operand"), ctx->Input("scatter_indices"),
|
||||
ctx->Input("updates"), *update_computation.computation,
|
||||
dnums_, indices_are_sorted_);
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
|
||||
private:
|
||||
const NameAttrList* update_computation_;
|
||||
xla::ScatterDimensionNumbers dnums_;
|
||||
bool indices_are_sorted_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaScatter"), ScatterOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -665,5 +665,50 @@ REGISTER_OP("XlaReplicaId")
|
||||
})
|
||||
.Doc("Replica ID.");
|
||||
|
||||
REGISTER_OP("XlaGather")
|
||||
.Input("operand: T")
|
||||
.Input("start_indices: Tindices")
|
||||
.Input("slice_sizes: Tindices")
|
||||
.Attr("dimension_numbers: string")
|
||||
.Attr("indices_are_sorted: bool")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Output("output: T")
|
||||
.SetShapeFn(UnchangedRank)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA Gather operator documented at
|
||||
https://www.tensorflow.org/xla/operation_semantics#gather
|
||||
operand: The array we're gathering from.
|
||||
start_indices: Array containing the starting indices of the slices we gather.
|
||||
dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
|
||||
slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
|
||||
indices_are_sorted: Boolean indicating if the indices are sorted.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaScatter")
|
||||
.Input("operand: T")
|
||||
.Input("scatter_indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Attr("update_computation: func")
|
||||
.Attr("dimension_numbers: string")
|
||||
.Attr("indices_are_sorted: bool")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Output("output: T")
|
||||
.SetShapeFn(UnchangedRank)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA Scatter operator documented at
|
||||
https://www.tensorflow.org/xla/operation_semantics#scatter.
|
||||
|
||||
operand: Array to be scattered into.
|
||||
scatter_indices: Array containing the starting indices of the slices that must
|
||||
be scattered to.
|
||||
updates: Array containing the values that must be used for scattering.
|
||||
update_computation: Computation to be used for combining the existing values in
|
||||
the input array and the updates during scatter.
|
||||
dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
|
||||
indices_are_sorted: Boolean indicating if the indices are sorted.
|
||||
)doc");
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -416,3 +416,27 @@ sort = gen_xla_ops.xla_sort
|
||||
key_value_sort = gen_xla_ops.xla_key_value_sort
|
||||
while_loop = gen_xla_ops.xla_while
|
||||
dequantize = gen_xla_ops.xla_dequantize
|
||||
|
||||
|
||||
def gather(operand, start_indices, dimension_numbers, slice_sizes,
|
||||
indices_are_sorted=False, name=None):
|
||||
return gen_xla_ops.xla_gather(
|
||||
operand,
|
||||
start_indices,
|
||||
slice_sizes=slice_sizes,
|
||||
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
name=name)
|
||||
|
||||
|
||||
def scatter(operand, scatter_indices, updates, update_computation,
|
||||
dimension_numbers, indices_are_sorted=False, name=None):
|
||||
return gen_xla_ops.xla_scatter(
|
||||
operand,
|
||||
scatter_indices,
|
||||
updates,
|
||||
update_computation=update_computation,
|
||||
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
name=name)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user