Add a test for Scatter without atomics.
This commit is contained in:
parent
27903a3978
commit
7d22d959e0
@ -225,6 +225,36 @@ ENTRY main {
|
|||||||
RunTest(hlo_text, &operand, &scatter_indices, &updates);
|
RunTest(hlo_text, &operand, &scatter_indices, &updates);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add_NoAtomic) {
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule TensorFlowScatter_Add
|
||||||
|
|
||||||
|
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||||
|
lhs = s32[] parameter(0)
|
||||||
|
rhs = s32[] parameter(1)
|
||||||
|
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
operand = s32[3,3] parameter(0)
|
||||||
|
indices = s32[2] parameter(1)
|
||||||
|
updates = s32[2,3] parameter(2)
|
||||||
|
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
|
||||||
|
to_apply=add_s32,
|
||||||
|
update_window_dims={1},
|
||||||
|
inserted_window_dims={0},
|
||||||
|
scatter_dims_to_operand_dims={0},
|
||||||
|
index_vector_dim=1,
|
||||||
|
use_atomic=false
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
Literal operand =
|
||||||
|
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
||||||
|
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
|
||||||
|
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
|
||||||
|
RunTest(hlo_text, &operand, &scatter_indices, &updates);
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
|
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
|
||||||
const string hlo_text = R"(
|
const string hlo_text = R"(
|
||||||
HloModule TensorFlowScatter_Mul
|
HloModule TensorFlowScatter_Mul
|
||||||
|
Loading…
Reference in New Issue
Block a user