Expose XLA Gather/Scatter to python via the tf2xla bridge
The semantics of these operations are a lot more flexible then tf.gather and tf.scatter so this change enables the usage of these more powerful APIs. This API is *not* stable so any user have to be prepared for potential breaking changes. PiperOrigin-RevId: 286569124 Change-Id: If1690d749deb5d08bb89532842fecaad2c2921d4
This commit is contained in:
parent
067ffdd467
commit
a5a0ad4300
@ -2036,6 +2036,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
|||||||
"XlaDynamicSlice",
|
"XlaDynamicSlice",
|
||||||
"XlaDynamicUpdateSlice",
|
"XlaDynamicUpdateSlice",
|
||||||
"XlaEinsum",
|
"XlaEinsum",
|
||||||
|
"XlaGather",
|
||||||
"XlaIf",
|
"XlaIf",
|
||||||
"XlaKeyValueSort",
|
"XlaKeyValueSort",
|
||||||
"XlaPad",
|
"XlaPad",
|
||||||
@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
|||||||
"XlaReduce",
|
"XlaReduce",
|
||||||
"XlaReduceWindow",
|
"XlaReduceWindow",
|
||||||
"XlaReplicaId",
|
"XlaReplicaId",
|
||||||
|
"XlaScatter",
|
||||||
"XlaSelectAndScatter",
|
"XlaSelectAndScatter",
|
||||||
"XlaSelfAdjointEig",
|
"XlaSelfAdjointEig",
|
||||||
"XlaSend",
|
"XlaSend",
|
||||||
|
|||||||
@ -48,6 +48,7 @@ tf_kernel_library(
|
|||||||
"function_ops.cc",
|
"function_ops.cc",
|
||||||
"gather_op.cc",
|
"gather_op.cc",
|
||||||
"gather_op_helpers.h",
|
"gather_op_helpers.h",
|
||||||
|
"gather_scatter_ops.cc",
|
||||||
"identity_op.cc",
|
"identity_op.cc",
|
||||||
"image_ops.cc",
|
"image_ops.cc",
|
||||||
"image_resize_ops.cc",
|
"image_resize_ops.cc",
|
||||||
|
|||||||
102
tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc
Normal file
102
tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class GatherOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
||||||
|
string dnums_attr;
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, dnums_.ParsePartialFromString(dnums_attr),
|
||||||
|
errors::InvalidArgument("Error parsing gather dimension numbers"));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
std::vector<int64> slice_sizes;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->ConstantInputAsIntVector("slice_sizes", &slice_sizes));
|
||||||
|
xla::XlaOp result =
|
||||||
|
xla::Gather(ctx->Input("operand"), ctx->Input("start_indices"), dnums_,
|
||||||
|
slice_sizes, indices_are_sorted_);
|
||||||
|
ctx->SetOutput(0, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
xla::GatherDimensionNumbers dnums_;
|
||||||
|
bool indices_are_sorted_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("XlaGather"), GatherOp);
|
||||||
|
|
||||||
|
class ScatterOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->GetAttr("update_computation", &update_computation_));
|
||||||
|
string dnums_attr;
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, dnums_.ParsePartialFromString(dnums_attr),
|
||||||
|
errors::InvalidArgument("Error parsing scatter dimension numbers"));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const DataType dtype = ctx->input_type(0);
|
||||||
|
|
||||||
|
XlaCompiler::Argument update_computation_arg;
|
||||||
|
update_computation_arg.kind = XlaCompiler::Argument::kParameter;
|
||||||
|
update_computation_arg.type = dtype;
|
||||||
|
update_computation_arg.shape = TensorShape();
|
||||||
|
|
||||||
|
XlaCompiler::CompileOptions compile_options;
|
||||||
|
compile_options.use_tuple_arg = false;
|
||||||
|
compile_options.always_return_tuple = false;
|
||||||
|
compile_options.is_entry_computation = false;
|
||||||
|
XlaCompiler::CompilationResult update_computation;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->compiler()->CompileFunction(
|
||||||
|
compile_options, *update_computation_,
|
||||||
|
{update_computation_arg, update_computation_arg},
|
||||||
|
&update_computation));
|
||||||
|
|
||||||
|
xla::XlaOp result =
|
||||||
|
xla::Scatter(ctx->Input("operand"), ctx->Input("scatter_indices"),
|
||||||
|
ctx->Input("updates"), *update_computation.computation,
|
||||||
|
dnums_, indices_are_sorted_);
|
||||||
|
ctx->SetOutput(0, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const NameAttrList* update_computation_;
|
||||||
|
xla::ScatterDimensionNumbers dnums_;
|
||||||
|
bool indices_are_sorted_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("XlaScatter"), ScatterOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
@ -665,5 +665,50 @@ REGISTER_OP("XlaReplicaId")
|
|||||||
})
|
})
|
||||||
.Doc("Replica ID.");
|
.Doc("Replica ID.");
|
||||||
|
|
||||||
|
REGISTER_OP("XlaGather")
|
||||||
|
.Input("operand: T")
|
||||||
|
.Input("start_indices: Tindices")
|
||||||
|
.Input("slice_sizes: Tindices")
|
||||||
|
.Attr("dimension_numbers: string")
|
||||||
|
.Attr("indices_are_sorted: bool")
|
||||||
|
.Attr("T: numbertype")
|
||||||
|
.Attr("Tindices: {int32, int64}")
|
||||||
|
.Output("output: T")
|
||||||
|
.SetShapeFn(UnchangedRank)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Wraps the XLA Gather operator documented at
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#gather
|
||||||
|
operand: The array we're gathering from.
|
||||||
|
start_indices: Array containing the starting indices of the slices we gather.
|
||||||
|
dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
|
||||||
|
slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
|
||||||
|
indices_are_sorted: Boolean indicating if the indices are sorted.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("XlaScatter")
|
||||||
|
.Input("operand: T")
|
||||||
|
.Input("scatter_indices: Tindices")
|
||||||
|
.Input("updates: T")
|
||||||
|
.Attr("update_computation: func")
|
||||||
|
.Attr("dimension_numbers: string")
|
||||||
|
.Attr("indices_are_sorted: bool")
|
||||||
|
.Attr("T: numbertype")
|
||||||
|
.Attr("Tindices: {int32, int64}")
|
||||||
|
.Output("output: T")
|
||||||
|
.SetShapeFn(UnchangedRank)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Wraps the XLA Scatter operator documented at
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#scatter.
|
||||||
|
|
||||||
|
operand: Array to be scattered into.
|
||||||
|
scatter_indices: Array containing the starting indices of the slices that must
|
||||||
|
be scattered to.
|
||||||
|
updates: Array containing the values that must be used for scattering.
|
||||||
|
update_computation: Computation to be used for combining the existing values in
|
||||||
|
the input array and the updates during scatter.
|
||||||
|
dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
|
||||||
|
indices_are_sorted: Boolean indicating if the indices are sorted.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -416,3 +416,27 @@ sort = gen_xla_ops.xla_sort
|
|||||||
key_value_sort = gen_xla_ops.xla_key_value_sort
|
key_value_sort = gen_xla_ops.xla_key_value_sort
|
||||||
while_loop = gen_xla_ops.xla_while
|
while_loop = gen_xla_ops.xla_while
|
||||||
dequantize = gen_xla_ops.xla_dequantize
|
dequantize = gen_xla_ops.xla_dequantize
|
||||||
|
|
||||||
|
|
||||||
|
def gather(operand, start_indices, dimension_numbers, slice_sizes,
|
||||||
|
indices_are_sorted=False, name=None):
|
||||||
|
return gen_xla_ops.xla_gather(
|
||||||
|
operand,
|
||||||
|
start_indices,
|
||||||
|
slice_sizes=slice_sizes,
|
||||||
|
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||||
|
indices_are_sorted=indices_are_sorted,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def scatter(operand, scatter_indices, updates, update_computation,
|
||||||
|
dimension_numbers, indices_are_sorted=False, name=None):
|
||||||
|
return gen_xla_ops.xla_scatter(
|
||||||
|
operand,
|
||||||
|
scatter_indices,
|
||||||
|
updates,
|
||||||
|
update_computation=update_computation,
|
||||||
|
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||||
|
indices_are_sorted=indices_are_sorted,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user