Add a test for Scatter without atomics.
This commit is contained in:
parent
27903a3978
commit
7d22d959e0
@ -225,6 +225,36 @@ ENTRY main {
|
||||
RunTest(hlo_text, &operand, &scatter_indices, &updates);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add_NoAtomic) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatter_Add
|
||||
|
||||
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||
lhs = s32[] parameter(0)
|
||||
rhs = s32[] parameter(1)
|
||||
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,3] parameter(0)
|
||||
indices = s32[2] parameter(1)
|
||||
updates = s32[2,3] parameter(2)
|
||||
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
|
||||
to_apply=add_s32,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1,
|
||||
use_atomic=false
|
||||
}
|
||||
)";
|
||||
Literal operand =
|
||||
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
||||
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
|
||||
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
|
||||
RunTest(hlo_text, &operand, &scatter_indices, &updates);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatter_Mul
|
||||
|
Loading…
Reference in New Issue
Block a user