diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 0940a873fa4..de573429fdc 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -27,9 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace { - -using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, XlaBuilder* builder, @@ -45,69 +42,50 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, const Shape scalar = ShapeUtil::MakeShape(type, {}); auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); - generator(b.get(), lhs, rhs); + generator(lhs, rhs); return b->BuildAndNoteError(); } -} // namespace - XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "add", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return Add(lhs, rhs); - }); + "add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); }); } XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "mul", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return Mul(lhs, rhs); - }); + "mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); }); } XlaComputation CreateScalarGeComputation(PrimitiveType type, XlaBuilder* builder) { - return CreateScalarComputation("ge", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, - const XlaOp& rhs) { return Ge(lhs, rhs); }); + return CreateScalarComputation( + "ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); }); } XlaComputation CreateScalarMaxComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "max", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return Max(lhs, rhs); - }); + "max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); }); } XlaComputation CreateScalarMinComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "min", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return Min(lhs, rhs); - }); + "min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); }); } XlaComputation CreateScalarAndComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( - "and", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return And(lhs, rhs); - }); + "and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); }); } XlaComputation CreateScalarOrComputation(PrimitiveType type, XlaBuilder* builder) { - return CreateScalarComputation("or", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, - const XlaOp& rhs) { return Or(lhs, rhs); }); + return CreateScalarComputation( + "or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); }); } XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 270076a1586..350dcc5531d 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -24,6 +24,13 @@ limitations under the License. namespace xla { +using XlaOpGenerator = std::function; + +// Creates a scalar computation based on a lambda and returns it. +XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, + XlaBuilder* builder, + XlaOpGenerator generator); + // Creates a scalar add computation and returns it. XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder); diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index 7c577265740..b47ddb7919f 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -208,6 +208,43 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse) { }); } +XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim, + const std::function& combiner) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + std::vector index_broacast_dims; + std::vector sizes; + for (int64 i = 0; i < index_shape.rank(); ++i) { + if (i < dim) { + index_broacast_dims.push_back(i); + } else { + if (i == dim) { + sizes.push_back(input_shape.dimensions(i)); + } + index_broacast_dims.push_back(i + 1); + } + sizes.push_back(index_shape.dimensions(i)); + } + auto mask = + Eq(BroadcastInDim(index, sizes, index_broacast_dims), + Iota(builder, + ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim)); + auto masked_src = + Select(mask, BroadcastInDim(src, sizes, index_broacast_dims), + Zeros(builder, + ShapeUtil::MakeShape(input_shape.element_type(), sizes))); + + return combiner( + input, + Reduce(masked_src, Zero(builder, input_shape.element_type()), + CreateScalarComputation("reducer", input_shape.element_type(), + builder, combiner), + {dim + 1})); + }); +} + XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) { XlaBuilder* builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index 9a59a048b9f..cf83d63cec2 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -57,6 +57,13 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, // `index`. XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true); +// idx = index[i][j][k] +// output[idx][j][k] = combiner(input[idx][j][k], src[i][j][k]) # if dim == 0 +// output[i][idx][k] = combiner(input[i][idx][k], src[i][j][k]) # if dim == 1 +// output[i][j][idx] = combiner(input[i][j][idx], src[i][j][k]) # if dim == 2 +XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim, + const std::function& combiner); + // Returns a new tensor which indexes the input tensor along dimension dim using // the entries in index. // diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 7ebd45681e6..8e2e713c45c 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -132,6 +132,24 @@ XLA_TEST_F(SlicingTest, TorchGatherDense) { {input_data.get(), index_data.get()}); } +XLA_TEST_F(SlicingTest, TorchScatterDense) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp src, index, input; + auto input_data = CreateR2Parameter({{0, 0, 0}, {0, 0, 0}}, 0, "input", + &builder, &input); + auto index_data = + CreateR2Parameter({{1, 0}, {1, 2}}, 1, "index", &builder, &index); + auto src_data = + CreateR2Parameter({{1, 2}, {3, 4}}, 2, "src", &builder, &src); + TorchScatterDense(input, index, src, 1, + [](XlaOp l, XlaOp r) { return l + r; }); + + ComputeAndCompareR2( + &builder, {{2, 1, 0}, {0, 3, 4}}, + {input_data.get(), index_data.get(), src_data.get()}); +} + XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) { xla::XlaBuilder builder(TestName());