[XLA:CLIENT] Add a dense gradient of TorchGather.

PiperOrigin-RevId: 267660412
This commit is contained in:
Blake Hechtman 2019-09-06 13:25:01 -07:00 committed by TensorFlower Gardener
parent 113c7f0ace
commit 354b298bd8
5 changed files with 79 additions and 32 deletions

View File

@ -27,9 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);
XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
XlaBuilder* builder,
@ -45,69 +42,50 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
const Shape scalar = ShapeUtil::MakeShape(type, {});
auto lhs = Parameter(b.get(), 0, scalar, "lhs");
auto rhs = Parameter(b.get(), 1, scalar, "rhs");
generator(b.get(), lhs, rhs);
generator(lhs, rhs);
return b->BuildAndNoteError();
}
} // namespace
XlaComputation CreateScalarAddComputation(PrimitiveType type,
XlaBuilder* builder) {
return CreateScalarComputation(
"add", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
return Add(lhs, rhs);
});
"add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); });
}
XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
XlaBuilder* builder) {
return CreateScalarComputation(
"mul", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
return Mul(lhs, rhs);
});
"mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); });
}
XlaComputation CreateScalarGeComputation(PrimitiveType type,
XlaBuilder* builder) {
return CreateScalarComputation("ge", type, builder,
[](XlaBuilder* b, const XlaOp& lhs,
const XlaOp& rhs) { return Ge(lhs, rhs); });
return CreateScalarComputation(
"ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); });
}
XlaComputation CreateScalarMaxComputation(PrimitiveType type,
XlaBuilder* builder) {
return CreateScalarComputation(
"max", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
return Max(lhs, rhs);
});
"max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); });
}
XlaComputation CreateScalarMinComputation(PrimitiveType type,
XlaBuilder* builder) {
return CreateScalarComputation(
"min", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
return Min(lhs, rhs);
});
"min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); });
}
XlaComputation CreateScalarAndComputation(PrimitiveType type,
XlaBuilder* builder) {
return CreateScalarComputation(
"and", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
return And(lhs, rhs);
});
"and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); });
}
XlaComputation CreateScalarOrComputation(PrimitiveType type,
XlaBuilder* builder) {
return CreateScalarComputation("or", type, builder,
[](XlaBuilder* b, const XlaOp& lhs,
const XlaOp& rhs) { return Or(lhs, rhs); });
return CreateScalarComputation(
"or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); });
}
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,

View File

@ -24,6 +24,13 @@ limitations under the License.
namespace xla {
using XlaOpGenerator = std::function<XlaOp(XlaOp, XlaOp)>;
// Creates a scalar computation based on a lambda and returns it.
XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
XlaBuilder* builder,
XlaOpGenerator generator);
// Creates a scalar add computation and returns it.
XlaComputation CreateScalarAddComputation(PrimitiveType type,
XlaBuilder* builder);

View File

@ -208,6 +208,43 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse) {
});
}
XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim,
const std::function<XlaOp(XlaOp, XlaOp)>& combiner) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
std::vector<int64> index_broacast_dims;
std::vector<int64> sizes;
for (int64 i = 0; i < index_shape.rank(); ++i) {
if (i < dim) {
index_broacast_dims.push_back(i);
} else {
if (i == dim) {
sizes.push_back(input_shape.dimensions(i));
}
index_broacast_dims.push_back(i + 1);
}
sizes.push_back(index_shape.dimensions(i));
}
auto mask =
Eq(BroadcastInDim(index, sizes, index_broacast_dims),
Iota(builder,
ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim));
auto masked_src =
Select(mask, BroadcastInDim(src, sizes, index_broacast_dims),
Zeros(builder,
ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
return combiner(
input,
Reduce(masked_src, Zero(builder, input_shape.element_type()),
CreateScalarComputation("reducer", input_shape.element_type(),
builder, combiner),
{dim + 1}));
});
}
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {

View File

@ -57,6 +57,13 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
// `index`.
XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true);
// idx = index[i][j][k]
// output[idx][j][k] = combiner(input[idx][j][k], src[i][j][k]) # if dim == 0
// output[i][idx][k] = combiner(input[i][idx][k], src[i][j][k]) # if dim == 1
// output[i][j][idx] = combiner(input[i][j][idx], src[i][j][k]) # if dim == 2
XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim,
const std::function<XlaOp(XlaOp, XlaOp)>& combiner);
// Returns a new tensor which indexes the input tensor along dimension dim using
// the entries in index.
//

View File

@ -132,6 +132,24 @@ XLA_TEST_F(SlicingTest, TorchGatherDense) {
{input_data.get(), index_data.get()});
}
XLA_TEST_F(SlicingTest, TorchScatterDense) {
xla::XlaBuilder builder(TestName());
xla::XlaOp src, index, input;
auto input_data = CreateR2Parameter<int>({{0, 0, 0}, {0, 0, 0}}, 0, "input",
&builder, &input);
auto index_data =
CreateR2Parameter<int>({{1, 0}, {1, 2}}, 1, "index", &builder, &index);
auto src_data =
CreateR2Parameter<int>({{1, 2}, {3, 4}}, 2, "src", &builder, &src);
TorchScatterDense(input, index, src, 1,
[](XlaOp l, XlaOp r) { return l + r; });
ComputeAndCompareR2<int>(
&builder, {{2, 1, 0}, {0, 3, 4}},
{input_data.get(), index_data.get(), src_data.get()});
}
XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) {
xla::XlaBuilder builder(TestName());