Add a test for Scatter without atomics.

This commit is contained in:
Frederic Bastien 2019-07-11 08:39:40 -07:00
parent 27903a3978
commit 7d22d959e0

View File

@ -225,6 +225,36 @@ ENTRY main {
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add_NoAtomic) {
const string hlo_text = R"(
HloModule TensorFlowScatter_Add
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1,
use_atomic=false
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
const string hlo_text = R"(
HloModule TensorFlowScatter_Mul