[XLA:CLIENT] Add a dense gradient of TorchGather.
PiperOrigin-RevId: 267660412
This commit is contained in:
parent
113c7f0ace
commit
354b298bd8
@ -27,9 +27,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
|
||||||
|
|
||||||
using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);
|
|
||||||
|
|
||||||
XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
|
XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
|
||||||
XlaBuilder* builder,
|
XlaBuilder* builder,
|
||||||
@ -45,69 +42,50 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
|
|||||||
const Shape scalar = ShapeUtil::MakeShape(type, {});
|
const Shape scalar = ShapeUtil::MakeShape(type, {});
|
||||||
auto lhs = Parameter(b.get(), 0, scalar, "lhs");
|
auto lhs = Parameter(b.get(), 0, scalar, "lhs");
|
||||||
auto rhs = Parameter(b.get(), 1, scalar, "rhs");
|
auto rhs = Parameter(b.get(), 1, scalar, "rhs");
|
||||||
generator(b.get(), lhs, rhs);
|
generator(lhs, rhs);
|
||||||
return b->BuildAndNoteError();
|
return b->BuildAndNoteError();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
XlaComputation CreateScalarAddComputation(PrimitiveType type,
|
XlaComputation CreateScalarAddComputation(PrimitiveType type,
|
||||||
XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
return CreateScalarComputation(
|
return CreateScalarComputation(
|
||||||
"add", type, builder,
|
"add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); });
|
||||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
|
||||||
return Add(lhs, rhs);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
|
XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
|
||||||
XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
return CreateScalarComputation(
|
return CreateScalarComputation(
|
||||||
"mul", type, builder,
|
"mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); });
|
||||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
|
||||||
return Mul(lhs, rhs);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaComputation CreateScalarGeComputation(PrimitiveType type,
|
XlaComputation CreateScalarGeComputation(PrimitiveType type,
|
||||||
XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
return CreateScalarComputation("ge", type, builder,
|
return CreateScalarComputation(
|
||||||
[](XlaBuilder* b, const XlaOp& lhs,
|
"ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); });
|
||||||
const XlaOp& rhs) { return Ge(lhs, rhs); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaComputation CreateScalarMaxComputation(PrimitiveType type,
|
XlaComputation CreateScalarMaxComputation(PrimitiveType type,
|
||||||
XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
return CreateScalarComputation(
|
return CreateScalarComputation(
|
||||||
"max", type, builder,
|
"max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); });
|
||||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
|
||||||
return Max(lhs, rhs);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaComputation CreateScalarMinComputation(PrimitiveType type,
|
XlaComputation CreateScalarMinComputation(PrimitiveType type,
|
||||||
XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
return CreateScalarComputation(
|
return CreateScalarComputation(
|
||||||
"min", type, builder,
|
"min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); });
|
||||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
|
||||||
return Min(lhs, rhs);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaComputation CreateScalarAndComputation(PrimitiveType type,
|
XlaComputation CreateScalarAndComputation(PrimitiveType type,
|
||||||
XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
return CreateScalarComputation(
|
return CreateScalarComputation(
|
||||||
"and", type, builder,
|
"and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); });
|
||||||
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
|
|
||||||
return And(lhs, rhs);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaComputation CreateScalarOrComputation(PrimitiveType type,
|
XlaComputation CreateScalarOrComputation(PrimitiveType type,
|
||||||
XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
return CreateScalarComputation("or", type, builder,
|
return CreateScalarComputation(
|
||||||
[](XlaBuilder* b, const XlaOp& lhs,
|
"or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); });
|
||||||
const XlaOp& rhs) { return Or(lhs, rhs); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
|
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
|
||||||
|
@ -24,6 +24,13 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
using XlaOpGenerator = std::function<XlaOp(XlaOp, XlaOp)>;
|
||||||
|
|
||||||
|
// Creates a scalar computation based on a lambda and returns it.
|
||||||
|
XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
|
||||||
|
XlaBuilder* builder,
|
||||||
|
XlaOpGenerator generator);
|
||||||
|
|
||||||
// Creates a scalar add computation and returns it.
|
// Creates a scalar add computation and returns it.
|
||||||
XlaComputation CreateScalarAddComputation(PrimitiveType type,
|
XlaComputation CreateScalarAddComputation(PrimitiveType type,
|
||||||
XlaBuilder* builder);
|
XlaBuilder* builder);
|
||||||
|
@ -208,6 +208,43 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim,
|
||||||
|
const std::function<XlaOp(XlaOp, XlaOp)>& combiner) {
|
||||||
|
XlaBuilder* builder = input.builder();
|
||||||
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
|
||||||
|
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||||
|
std::vector<int64> index_broacast_dims;
|
||||||
|
std::vector<int64> sizes;
|
||||||
|
for (int64 i = 0; i < index_shape.rank(); ++i) {
|
||||||
|
if (i < dim) {
|
||||||
|
index_broacast_dims.push_back(i);
|
||||||
|
} else {
|
||||||
|
if (i == dim) {
|
||||||
|
sizes.push_back(input_shape.dimensions(i));
|
||||||
|
}
|
||||||
|
index_broacast_dims.push_back(i + 1);
|
||||||
|
}
|
||||||
|
sizes.push_back(index_shape.dimensions(i));
|
||||||
|
}
|
||||||
|
auto mask =
|
||||||
|
Eq(BroadcastInDim(index, sizes, index_broacast_dims),
|
||||||
|
Iota(builder,
|
||||||
|
ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim));
|
||||||
|
auto masked_src =
|
||||||
|
Select(mask, BroadcastInDim(src, sizes, index_broacast_dims),
|
||||||
|
Zeros(builder,
|
||||||
|
ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
|
||||||
|
|
||||||
|
return combiner(
|
||||||
|
input,
|
||||||
|
Reduce(masked_src, Zero(builder, input_shape.element_type()),
|
||||||
|
CreateScalarComputation("reducer", input_shape.element_type(),
|
||||||
|
builder, combiner),
|
||||||
|
{dim + 1}));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) {
|
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) {
|
||||||
XlaBuilder* builder = input.builder();
|
XlaBuilder* builder = input.builder();
|
||||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
@ -57,6 +57,13 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
|
|||||||
// `index`.
|
// `index`.
|
||||||
XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true);
|
XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true);
|
||||||
|
|
||||||
|
// idx = index[i][j][k]
|
||||||
|
// output[idx][j][k] = combiner(input[idx][j][k], src[i][j][k]) # if dim == 0
|
||||||
|
// output[i][idx][k] = combiner(input[i][idx][k], src[i][j][k]) # if dim == 1
|
||||||
|
// output[i][j][idx] = combiner(input[i][j][idx], src[i][j][k]) # if dim == 2
|
||||||
|
XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim,
|
||||||
|
const std::function<XlaOp(XlaOp, XlaOp)>& combiner);
|
||||||
|
|
||||||
// Returns a new tensor which indexes the input tensor along dimension dim using
|
// Returns a new tensor which indexes the input tensor along dimension dim using
|
||||||
// the entries in index.
|
// the entries in index.
|
||||||
//
|
//
|
||||||
|
@ -132,6 +132,24 @@ XLA_TEST_F(SlicingTest, TorchGatherDense) {
|
|||||||
{input_data.get(), index_data.get()});
|
{input_data.get(), index_data.get()});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(SlicingTest, TorchScatterDense) {
|
||||||
|
xla::XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
xla::XlaOp src, index, input;
|
||||||
|
auto input_data = CreateR2Parameter<int>({{0, 0, 0}, {0, 0, 0}}, 0, "input",
|
||||||
|
&builder, &input);
|
||||||
|
auto index_data =
|
||||||
|
CreateR2Parameter<int>({{1, 0}, {1, 2}}, 1, "index", &builder, &index);
|
||||||
|
auto src_data =
|
||||||
|
CreateR2Parameter<int>({{1, 2}, {3, 4}}, 2, "src", &builder, &src);
|
||||||
|
TorchScatterDense(input, index, src, 1,
|
||||||
|
[](XlaOp l, XlaOp r) { return l + r; });
|
||||||
|
|
||||||
|
ComputeAndCompareR2<int>(
|
||||||
|
&builder, {{2, 1, 0}, {0, 3, 4}},
|
||||||
|
{input_data.get(), index_data.get(), src_data.get()});
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) {
|
XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) {
|
||||||
xla::XlaBuilder builder(TestName());
|
xla::XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user