STT-tensorflow/tensorflow/compiler/xla/tests/scatter_test.cc
Gunhan Gulsoy a7a3231e44 Static cast the large literal in scatter_test to avoid narrowing conversion errors when using msvc.
PiperOrigin-RevId: 294835705
Change-Id: I59c8bd4b422529c1aba1f9eb8f8fdcbcd415c35d
2020-02-12 22:10:21 -08:00

784 lines
24 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
using absl::nullopt;
class ScatterTest : public HloTestBase {
protected:
void RunTest(const string& hlo_text, Literal* operand,
Literal* scatter_indices, Literal* updates) {
RunTest(hlo_text, {operand, scatter_indices, updates});
}
void RunTest(const string& hlo_text, absl::Span<Literal* const> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text, config));
EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
}
};
XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV1_WithFusedAdds) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
p0 = s32[3,3] parameter(0)
operand = s32[3,3] add(p0, p0)
p1 = s32[2] parameter(1)
indices = s32[2] add(p1, p1)
p2 = s32[2,3] parameter(2)
updates = s32[2,3] add(p2, p2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 1});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
const char* hlo_text = R"(
HloModule TensorFlowScatterV2
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[3,2] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={0},
inserted_window_dims={1},
scatter_dims_to_operand_dims={1},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV2_InversePermutation) {
const char* hlo_text = R"(
HloModule TensorFlowScatterV2
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
permutation = s32[3,4] parameter(0)
reshape = s32[3,4,1] reshape(permutation)
operand = s32[3,4] iota(), iota_dimension=1
updates = s32[3,4,1,1] iota(), iota_dimension=1
iota = s32[3,4,1] iota(), iota_dimension=0
indices = s32[3,4,2] concatenate(iota, reshape), dimensions={2}
ROOT scatter = s32[3,4] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={2,3},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=2
}
)";
Literal permutation =
LiteralUtil::CreateR2<int32>({{1, 3, 2, 0}, {3, 0, 2, 1}, {2, 3, 1, 0}});
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text, config));
auto actual = ExecuteAndTransfer(std::move(module), {&permutation});
Literal expected =
LiteralUtil::CreateR2<int32>({{3, 0, 2, 1}, {1, 3, 2, 0}, {3, 2, 0, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
XLA_TEST_F(ScatterTest, SimpleR4) {
const char* hlo_text = R"(
HloModule SimpleR4
add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(f32[] lhs, f32[] rhs)
}
ENTRY main {
operand = f32[1,2,2,1] parameter(0)
indices = s32[1,3] parameter(1)
updates = f32[1,2,2,1] parameter(2)
ROOT scatter = f32[1,2,2,1] scatter(operand, indices, updates),
to_apply=add_f32,
update_window_dims={1,2,3},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0, 2, 1},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR4<float>({{{{0.f}, {0.f}}, {{0.f}, {0.f}}}});
Literal updates =
LiteralUtil::CreateR4<float>({{{{0.12}, {0.28}}, {{0.018}, {0.42}}}});
Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0, 0}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
const string hlo_text = R"(
HloModule TensorFlowScatter_Add
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add_UniqueIndices) {
const string hlo_text = R"(
HloModule TensorFlowScatter_Add
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1,
unique_indices=true
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
const string hlo_text = R"(
HloModule TensorFlowScatter_Mul
mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=mul_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
const string hlo_text = R"(
HloModule TensorFlowScatter_F32
add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(f32[] lhs, f32[] rhs)
}
ENTRY main {
operand = f32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = f32[2,3] parameter(2)
ROOT scatter = f32[3,3] scatter(operand, indices, updates),
to_apply=add_f32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
const char* hlo_text = R"(
HloModule TensorFlowScatter
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
const char* hlo_text = R"(
HloModule TensorFlowScatterMultipleBatchDims
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
updates = s32[2,3,2] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={1},
inserted_window_dims={1},
scatter_dims_to_operand_dims={1},
index_vector_dim=2
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
const char* hlo_text = R"(
HloModule TensorFlowScatterNd
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
updates = s32[2,2] parameter(2)
ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0,1},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
const char* hlo_text = R"(
HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
updates = s32[2,2] parameter(2)
ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0,1},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=0
}
)";
Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
const char* hlo_text = R"(
HloModule DynamicUpdateSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[1,1] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={0,1},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=0
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
Literal updates = LiteralUtil::CreateR2<int32>({{10}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
const char* hlo_text = R"(
HloModule BatchDynamicUpdateSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
updates = s32[2,1,1] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=0
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ZeroDimBounds) {
const char* hlo_text = R"(
HloModule TensorFlowScatter_ZeroDimBounds
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,0] parameter(2)
ROOT scatter = s32[3,0] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
const string hlo_text = R"(
HloModule Scatter_NoUpdateWindowDims
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3] parameter(0)
indices = s32[2,2,1] parameter(1)
updates = s32[2,2] parameter(2)
ROOT scatter = s32[3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=2
}
)";
Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
const string hlo_text = R"(
HloModule BatchDynamicSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
updates = s32[6,1,1]{2,1,0} parameter(2)
ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
const string hlo_text = R"(
HloModule BatchDynamicSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = u32[6,2]{1,0} parameter(1)
updates = s32[6,1,1]{2,1,0} parameter(2)
ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NegativeIndex) {
const string hlo_text = R"(
HloModule BatchDynamicSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
updates = s32[6,1,1]{2,1,0} parameter(2)
ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices =
LiteralUtil::CreateR2<int32>({{2, 7},
{2, 1},
{1, 1},
{-500, 1},
{static_cast<int32>(-2147483648), 1},
{1, 2}});
Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsUpdateWindow) {
const char* hlo_text = R"(
HloModule TensorFlowScatterNd_OobUpdateWindow
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[1,2] parameter(1)
updates = s32[1,2,2] parameter(2)
ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OneScalarIndex) {
const char* hlo_text = R"(
HloModule OneScalarIndex
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[2,3,2]{2,1,0} parameter(0)
index = s32[] parameter(1)
updates = s32[1,3,2]{2,1,0} parameter(2)
ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates),
to_apply=update_s32,
update_window_dims={0,1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0},
index_vector_dim=0
}
)";
Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
Literal updates =
LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ScalarUpdate) {
const char* hlo_text = R"(
HloModule ScalarUpdate
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[4]{0} parameter(0)
index = s32[] parameter(1)
updates = s32[] parameter(2)
ROOT scatter = s32[4]{0} scatter(operand, index, updates),
to_apply=update_s32,
update_window_dims={},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=0
}
)";
Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
Literal updates = LiteralUtil::CreateR0<int32>(25);
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, EmptyIndices) {
const string hlo_text = R"(
HloModule EmptyIndices
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3] parameter(0)
indices = s32[0] parameter(1)
updates = s32[0] parameter(2)
ROOT scatter = s32[3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({});
Literal updates = LiteralUtil::CreateR1<int32>({});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ScatterIntoScalar) {
const char* hlo_text = R"(
HloModule ScatterIntoScalar
update_s32 {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
parameter.1 = s32[] parameter(0)
parameter.2 = s32[0]{0} parameter(1)
parameter.3 = s32[] parameter(2)
ROOT scatter = s32[] scatter(parameter.1, parameter.2, parameter.3),
update_window_dims={},
inserted_window_dims={},
scatter_dims_to_operand_dims={},
index_vector_dim=0,
to_apply=update_s32
}
)";
Literal operand = LiteralUtil::CreateR0<int32>(1);
Literal scatter_indices = LiteralUtil::CreateR1<int32>({});
Literal updates = LiteralUtil::CreateR0<int32>(2);
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
} // namespace
} // namespace xla