Expose XLA Gather/Scatter to python via the tf2xla bridge

The semantics of these operations are a lot more flexible then tf.gather
and tf.scatter so this change enables the usage of these more powerful
APIs.

This API is *not* stable so any user have to be prepared for potential
breaking changes.

PiperOrigin-RevId: 286569124
Change-Id: If1690d749deb5d08bb89532842fecaad2c2921d4
This commit is contained in:
A. Unique TensorFlower 2019-12-20 06:36:47 -08:00 committed by TensorFlower Gardener
parent 067ffdd467
commit a5a0ad4300
5 changed files with 174 additions and 0 deletions

View File

@ -2036,6 +2036,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"XlaDynamicSlice",
"XlaDynamicUpdateSlice",
"XlaEinsum",
"XlaGather",
"XlaIf",
"XlaKeyValueSort",
"XlaPad",
@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"XlaReduce",
"XlaReduceWindow",
"XlaReplicaId",
"XlaScatter",
"XlaSelectAndScatter",
"XlaSelfAdjointEig",
"XlaSend",

View File

@ -48,6 +48,7 @@ tf_kernel_library(
"function_ops.cc",
"gather_op.cc",
"gather_op_helpers.h",
"gather_scatter_ops.cc",
"identity_op.cc",
"image_ops.cc",
"image_resize_ops.cc",

View File

@ -0,0 +1,102 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace {
class GatherOp : public XlaOpKernel {
public:
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
string dnums_attr;
OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
OP_REQUIRES(
context, dnums_.ParsePartialFromString(dnums_attr),
errors::InvalidArgument("Error parsing gather dimension numbers"));
OP_REQUIRES_OK(
context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
}
void Compile(XlaOpKernelContext* ctx) override {
std::vector<int64> slice_sizes;
OP_REQUIRES_OK(ctx,
ctx->ConstantInputAsIntVector("slice_sizes", &slice_sizes));
xla::XlaOp result =
xla::Gather(ctx->Input("operand"), ctx->Input("start_indices"), dnums_,
slice_sizes, indices_are_sorted_);
ctx->SetOutput(0, result);
}
private:
xla::GatherDimensionNumbers dnums_;
bool indices_are_sorted_;
};
REGISTER_XLA_OP(Name("XlaGather"), GatherOp);
class ScatterOp : public XlaOpKernel {
public:
explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) {
OP_REQUIRES_OK(
context, context->GetAttr("update_computation", &update_computation_));
string dnums_attr;
OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
OP_REQUIRES(
context, dnums_.ParsePartialFromString(dnums_attr),
errors::InvalidArgument("Error parsing scatter dimension numbers"));
OP_REQUIRES_OK(
context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
}
void Compile(XlaOpKernelContext* ctx) override {
const DataType dtype = ctx->input_type(0);
XlaCompiler::Argument update_computation_arg;
update_computation_arg.kind = XlaCompiler::Argument::kParameter;
update_computation_arg.type = dtype;
update_computation_arg.shape = TensorShape();
XlaCompiler::CompileOptions compile_options;
compile_options.use_tuple_arg = false;
compile_options.always_return_tuple = false;
compile_options.is_entry_computation = false;
XlaCompiler::CompilationResult update_computation;
OP_REQUIRES_OK(ctx, ctx->compiler()->CompileFunction(
compile_options, *update_computation_,
{update_computation_arg, update_computation_arg},
&update_computation));
xla::XlaOp result =
xla::Scatter(ctx->Input("operand"), ctx->Input("scatter_indices"),
ctx->Input("updates"), *update_computation.computation,
dnums_, indices_are_sorted_);
ctx->SetOutput(0, result);
}
private:
const NameAttrList* update_computation_;
xla::ScatterDimensionNumbers dnums_;
bool indices_are_sorted_;
};
REGISTER_XLA_OP(Name("XlaScatter"), ScatterOp);
} // namespace
} // namespace tensorflow

View File

@ -665,5 +665,50 @@ REGISTER_OP("XlaReplicaId")
})
.Doc("Replica ID.");
REGISTER_OP("XlaGather")
.Input("operand: T")
.Input("start_indices: Tindices")
.Input("slice_sizes: Tindices")
.Attr("dimension_numbers: string")
.Attr("indices_are_sorted: bool")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Output("output: T")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA Gather operator documented at
https://www.tensorflow.org/xla/operation_semantics#gather
operand: The array we're gathering from.
start_indices: Array containing the starting indices of the slices we gather.
dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
indices_are_sorted: Boolean indicating if the indices are sorted.
)doc");
REGISTER_OP("XlaScatter")
.Input("operand: T")
.Input("scatter_indices: Tindices")
.Input("updates: T")
.Attr("update_computation: func")
.Attr("dimension_numbers: string")
.Attr("indices_are_sorted: bool")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Output("output: T")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA Scatter operator documented at
https://www.tensorflow.org/xla/operation_semantics#scatter.
operand: Array to be scattered into.
scatter_indices: Array containing the starting indices of the slices that must
be scattered to.
updates: Array containing the values that must be used for scattering.
update_computation: Computation to be used for combining the existing values in
the input array and the updates during scatter.
dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
indices_are_sorted: Boolean indicating if the indices are sorted.
)doc");
} // namespace
} // namespace tensorflow

View File

@ -416,3 +416,27 @@ sort = gen_xla_ops.xla_sort
key_value_sort = gen_xla_ops.xla_key_value_sort
while_loop = gen_xla_ops.xla_while
dequantize = gen_xla_ops.xla_dequantize
def gather(operand, start_indices, dimension_numbers, slice_sizes,
indices_are_sorted=False, name=None):
return gen_xla_ops.xla_gather(
operand,
start_indices,
slice_sizes=slice_sizes,
dimension_numbers=dimension_numbers.SerializeToString(),
indices_are_sorted=indices_are_sorted,
name=name)
def scatter(operand, scatter_indices, updates, update_computation,
dimension_numbers, indices_are_sorted=False, name=None):
return gen_xla_ops.xla_scatter(
operand,
scatter_indices,
updates,
update_computation=update_computation,
dimension_numbers=dimension_numbers.SerializeToString(),
indices_are_sorted=indices_are_sorted,
name=name)