[XLA:CLIENT] Add a dense gradient of TorchGather.
PiperOrigin-RevId: 267660412
This commit is contained in:
parent
113c7f0ace
commit
354b298bd8
@ -27,9 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);
|
||||
|
||||
XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
|
||||
XlaBuilder* builder,
|
||||
@ -45,69 +42,50 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
|
||||
const Shape scalar = ShapeUtil::MakeShape(type, {});
|
||||
auto lhs = Parameter(b.get(), 0, scalar, "lhs");
|
||||
auto rhs = Parameter(b.get(), 1, scalar, "rhs");
|
||||
generator(b.get(), lhs, rhs);
|
||||
generator(lhs, rhs);
|
||||
return b->BuildAndNoteError();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaComputation CreateScalarAddComputation(PrimitiveType type,
|
||||
XlaBuilder* builder) {
|
||||
return CreateScalarComputation(
|
||||
"add", type, builder,
|
||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
||||
return Add(lhs, rhs);
|
||||
});
|
||||
"add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); });
|
||||
}
|
||||
|
||||
XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
|
||||
XlaBuilder* builder) {
|
||||
return CreateScalarComputation(
|
||||
"mul", type, builder,
|
||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
||||
return Mul(lhs, rhs);
|
||||
});
|
||||
"mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); });
|
||||
}
|
||||
|
||||
XlaComputation CreateScalarGeComputation(PrimitiveType type,
|
||||
XlaBuilder* builder) {
|
||||
return CreateScalarComputation("ge", type, builder,
|
||||
[](XlaBuilder* b, const XlaOp& lhs,
|
||||
const XlaOp& rhs) { return Ge(lhs, rhs); });
|
||||
return CreateScalarComputation(
|
||||
"ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); });
|
||||
}
|
||||
|
||||
XlaComputation CreateScalarMaxComputation(PrimitiveType type,
|
||||
XlaBuilder* builder) {
|
||||
return CreateScalarComputation(
|
||||
"max", type, builder,
|
||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
||||
return Max(lhs, rhs);
|
||||
});
|
||||
"max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); });
|
||||
}
|
||||
|
||||
XlaComputation CreateScalarMinComputation(PrimitiveType type,
|
||||
XlaBuilder* builder) {
|
||||
return CreateScalarComputation(
|
||||
"min", type, builder,
|
||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
||||
return Min(lhs, rhs);
|
||||
});
|
||||
"min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); });
|
||||
}
|
||||
|
||||
XlaComputation CreateScalarAndComputation(PrimitiveType type,
|
||||
XlaBuilder* builder) {
|
||||
return CreateScalarComputation(
|
||||
"and", type, builder,
|
||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
||||
return And(lhs, rhs);
|
||||
});
|
||||
"and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); });
|
||||
}
|
||||
|
||||
XlaComputation CreateScalarOrComputation(PrimitiveType type,
|
||||
XlaBuilder* builder) {
|
||||
return CreateScalarComputation("or", type, builder,
|
||||
[](XlaBuilder* b, const XlaOp& lhs,
|
||||
const XlaOp& rhs) { return Or(lhs, rhs); });
|
||||
return CreateScalarComputation(
|
||||
"or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); });
|
||||
}
|
||||
|
||||
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
|
||||
|
@ -24,6 +24,13 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
using XlaOpGenerator = std::function<XlaOp(XlaOp, XlaOp)>;
|
||||
|
||||
// Creates a scalar computation based on a lambda and returns it.
|
||||
XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
|
||||
XlaBuilder* builder,
|
||||
XlaOpGenerator generator);
|
||||
|
||||
// Creates a scalar add computation and returns it.
|
||||
XlaComputation CreateScalarAddComputation(PrimitiveType type,
|
||||
XlaBuilder* builder);
|
||||
|
@ -208,6 +208,43 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim,
|
||||
const std::function<XlaOp(XlaOp, XlaOp)>& combiner) {
|
||||
XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
std::vector<int64> index_broacast_dims;
|
||||
std::vector<int64> sizes;
|
||||
for (int64 i = 0; i < index_shape.rank(); ++i) {
|
||||
if (i < dim) {
|
||||
index_broacast_dims.push_back(i);
|
||||
} else {
|
||||
if (i == dim) {
|
||||
sizes.push_back(input_shape.dimensions(i));
|
||||
}
|
||||
index_broacast_dims.push_back(i + 1);
|
||||
}
|
||||
sizes.push_back(index_shape.dimensions(i));
|
||||
}
|
||||
auto mask =
|
||||
Eq(BroadcastInDim(index, sizes, index_broacast_dims),
|
||||
Iota(builder,
|
||||
ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim));
|
||||
auto masked_src =
|
||||
Select(mask, BroadcastInDim(src, sizes, index_broacast_dims),
|
||||
Zeros(builder,
|
||||
ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
|
||||
|
||||
return combiner(
|
||||
input,
|
||||
Reduce(masked_src, Zero(builder, input_shape.element_type()),
|
||||
CreateScalarComputation("reducer", input_shape.element_type(),
|
||||
builder, combiner),
|
||||
{dim + 1}));
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) {
|
||||
XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
|
@ -57,6 +57,13 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
|
||||
// `index`.
|
||||
XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true);
|
||||
|
||||
// idx = index[i][j][k]
|
||||
// output[idx][j][k] = combiner(input[idx][j][k], src[i][j][k]) # if dim == 0
|
||||
// output[i][idx][k] = combiner(input[i][idx][k], src[i][j][k]) # if dim == 1
|
||||
// output[i][j][idx] = combiner(input[i][j][idx], src[i][j][k]) # if dim == 2
|
||||
XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim,
|
||||
const std::function<XlaOp(XlaOp, XlaOp)>& combiner);
|
||||
|
||||
// Returns a new tensor which indexes the input tensor along dimension dim using
|
||||
// the entries in index.
|
||||
//
|
||||
|
@ -132,6 +132,24 @@ XLA_TEST_F(SlicingTest, TorchGatherDense) {
|
||||
{input_data.get(), index_data.get()});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SlicingTest, TorchScatterDense) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::XlaOp src, index, input;
|
||||
auto input_data = CreateR2Parameter<int>({{0, 0, 0}, {0, 0, 0}}, 0, "input",
|
||||
&builder, &input);
|
||||
auto index_data =
|
||||
CreateR2Parameter<int>({{1, 0}, {1, 2}}, 1, "index", &builder, &index);
|
||||
auto src_data =
|
||||
CreateR2Parameter<int>({{1, 2}, {3, 4}}, 2, "src", &builder, &src);
|
||||
TorchScatterDense(input, index, src, 1,
|
||||
[](XlaOp l, XlaOp r) { return l + r; });
|
||||
|
||||
ComputeAndCompareR2<int>(
|
||||
&builder, {{2, 1, 0}, {0, 3, 4}},
|
||||
{input_data.get(), index_data.get(), src_data.get()});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user