[XLA] Clamp indices in DynamicSlice and DynamicUpdateSlice instead of wrapping.

This implements the following index clamping in all backends (CPU, GPU, Interpreter):

for(int i = 0; i < rank; ++i)
  start_index[i] = clamp(start_index[i], 0, output_dim_size[i] - update_dim_size[i])

Which ensures the slice (or update region) is always inbounds w.r.t the input.

PiperOrigin-RevId: 197082276
This commit is contained in:
Michael Kuperstein 2018-05-17 18:09:45 -07:00 committed by TensorFlower Gardener
parent a7fcec1b6e
commit aca0458707
5 changed files with 130 additions and 152 deletions

View File

@ -330,13 +330,14 @@ class ReferenceUtil {
return result;
}
// Slices with modulo-wrapping.
// Slices with index clamping
template <typename T>
static std::vector<T> ModSlice1D(const tensorflow::gtl::ArraySlice<T>& input,
int64 start, int64 size) {
static std::vector<T> ClampSlice1D(
const tensorflow::gtl::ArraySlice<T>& input, int64 start, int64 size) {
start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
std::vector<T> result;
for (int64 i = 0; i < size; ++i) {
result.push_back(input[(start + i) % input.size()]);
result.push_back(input[(start + i)]);
}
return result;
}

View File

@ -1547,6 +1547,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
operand_to_generator.at(hlo->operand(1))(dim_index));
// Clamp the start index so that the sliced portion fits in the operand:
// start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
// to oficially document different behavior.
start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
index[i]->getType());
llvm::Value* operand_dim_size = llvm::ConstantInt::get(
start_index_value->getType(), input_hlo->shape().dimensions(i));
llvm::Value* output_dim_size = llvm::ConstantInt::get(
start_index_value->getType(), hlo->shape().dimensions(i));
start_index_value = EmitIntegralMin(
ir_builder_->CreateSub(operand_dim_size, output_dim_size),
EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
start_index_value, /*is_signed=*/true),
/*is_signed=*/true);
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
slice_start_index[i] = start_index_value;
@ -1555,14 +1575,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
llvm_ir::IrArray::Index input_index(rank);
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// input_index = (start_index + offset_index) % dim_size
// Security note: this is the code that keeps the indices in-bounds.
llvm::Value* dim_size = llvm::ConstantInt::get(
index[i]->getType(), input_hlo->shape().dimensions(i));
llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
slice_start_index[i], index[i]->getType());
input_index[i] = ir_builder_->CreateURem(
ir_builder_->CreateAdd(start_index, index[i]), dim_size);
// input_index = start_index + offset_index
input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]);
}
return operand_to_generator.at(input_hlo)(input_index);
}
@ -1661,104 +1675,48 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
llvm_ir::IrArray::Index slice_start_index(rank);
llvm_ir::IrArray::Index slice_limit_index(rank);
// Slice starts at update[index - slice_start_index_adjusted],
// where adjusted value = slice_start_index when in bounds, and
// adjusted value = slice_start_index - input_dim, when wrapping.
llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
// Slice intersection gathers (ANDs) conditions on all ranks for which
// 'input' is set to 'update'
llvm::Value* slice_intersection = ir_builder_->getTrue();
for (int64 i = 0; i < rank; ++i) {
// Emit IR to read dynamic start indices from 'start_hlo'.
llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
operand_to_generator.at(start_hlo)(dim_index));
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
start_index_value, index[i]->getType());
// Clamp the start index so that the update region fits in the operand.
// start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
// to oficially document different behavior.
start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
index[i]->getType());
llvm::Value* input_dim_size = llvm::ConstantInt::get(
index[i]->getType(), input_hlo->shape().dimensions(i));
llvm::Value* update_dim_size = llvm::ConstantInt::get(
index[i]->getType(), update_hlo->shape().dimensions(i));
// Generate code to handle wrapping semantics:
// slice_start_index[i] = slice_start_index[i] % input_dim_size;
// slice_limit_index[i] = slice_start_index[i] + update_dim_size.
// slice_start_index[i] is updated in place and it will now be in
// range. slice_limit_index[i] may be out of range, and it's being
// URem-ed below if so.
slice_start_index[i] =
ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
start_index_value = EmitIntegralMin(
ir_builder_->CreateSub(input_dim_size, update_dim_size),
EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
start_index_value, /*is_signed=*/true),
/*is_signed=*/true);
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
slice_start_index[i] = start_index_value;
slice_limit_index[i] =
ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
// Test if slice_limit_index[i] is in bounds
llvm::Value* in_bounds =
ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
llvm_ir::LlvmIfData if_in_bounds =
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
// Handle true BB (slice_limit_index[i] <= input_dim_size).
SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
// Check that index[i] >= slice_start_index[i] &&
// index[i] < slice_limit_index[i]
llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
slice_intersection = ir_builder_->CreateAnd(
slice_intersection,
ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
"slice_intersection_in");
slice_intersection_in_bounds = ir_builder_->CreateAnd(
slice_intersection_in_bounds,
"slice_intersection");
slice_intersection = ir_builder_->CreateAnd(
slice_intersection,
ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
"slice_intersection_in");
// Handle false BB (slice_limit_index[i] > input_dim_size).
SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
// Check that index[i] >= slice_start_index[i] ||
// index[i] < slice_limit_index[i]%input_dim_size.
llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
index[i],
ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), index_wraps,
"slice_intersection_out");
llvm::Value* slice_intersection_out_of_bounds = ir_builder_->CreateAnd(
slice_intersection, slice_intersection_or, "slice_intersection_out");
// Create value for slice_start_index_adjusted[i] when out of bounds.
// If within out-of-bounds if.
llvm_ir::LlvmIfData if_start_needs_adjustment =
llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
SetToFirstInsertPoint(if_start_needs_adjustment.true_block, ir_builder_);
llvm::Value* slice_start_index_adjusted_oob =
ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
SetToFirstInsertPoint(if_start_needs_adjustment.after_block, ir_builder_);
llvm::PHINode* slice_start_index_adjusted_phi =
ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 2);
slice_start_index_adjusted_phi->addIncoming(
slice_start_index_adjusted_oob, if_start_needs_adjustment.true_block);
slice_start_index_adjusted_phi->addIncoming(
slice_start_index[i], if_start_needs_adjustment.false_block);
// End of if within if.
// After checking in/out of bounds.
SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
llvm::PHINode* phi_slice_intersection =
ir_builder_->CreatePHI(slice_intersection->getType(), 2);
phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
if_in_bounds.true_block);
phi_slice_intersection->addIncoming(slice_intersection_out_of_bounds,
if_start_needs_adjustment.after_block);
slice_intersection = phi_slice_intersection;
llvm::PHINode* phi_index =
ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
phi_index->addIncoming(slice_start_index_adjusted_phi,
if_start_needs_adjustment.after_block);
slice_start_index_adjusted[i] = phi_index;
"slice_intersection");
}
// Emit:
@ -1775,12 +1733,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Compute update index for intersection case.
llvm_ir::IrArray::Index update_index(rank);
for (int64 i = 0; i < rank; ++i) {
llvm::Value* update_dim_size = llvm::ConstantInt::get(
index[i]->getType(), update_hlo->shape().dimensions(i));
// NOTE: Subtraction will be positive due to bounds checking above.
update_index[i] = ir_builder_->CreateURem(
ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
update_dim_size);
update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]);
}
TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
operand_to_generator.at(update_hlo)(update_index));

View File

@ -1986,17 +1986,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
std::vector<int64> operand_indices(start.size());
// Clamp the start indices so the slice is in-bounds w.r.t the operand.
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
// to oficially document different behavior.
for (int64 i = 0; i < start.size(); ++i) {
start[i] = std::min<int64>(
std::max(0LL, start[i]),
operand_literal.shape().dimensions(i) - result_shape.dimensions(i));
}
std::vector<int64> operand_indices(start.size());
auto result = MakeUnique<Literal>(result_shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
// Mod is only used here to be consistent with the existing
// backends' behavior.
operand_indices[i] = (multi_index[i] + start[i]) %
operand_literal.shape().dimensions(i);
operand_indices[i] = multi_index[i] + start[i];
}
auto result = operand_literal.Get<ReturnT>(operand_indices);
@ -2013,23 +2020,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = operand_literal.CloneToUnique();
auto start_indices_typed = start_indices_literal.data<IndexT>();
const auto rank = ShapeUtil::Rank(result->shape());
std::vector<int64> start(rank, 0);
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the
// operand.
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
// to oficially document different behavior.
for (int64 i = 0; i < rank; ++i) {
// All other implementations currently wrap-around the index, so this
// should do so as well.
start[i] = (start_indices_typed[i] % result->shape().dimensions(i));
start[i] += (start[i] < 0) * result->shape().dimensions(i);
start[i] = std::min<int64>(
std::max<int64>(0, start[i]),
result->shape().dimensions(i) - update_literal.shape().dimensions(i));
}
std::vector<int64> result_index(rank, 0);
auto func = [&](tensorflow::gtl::ArraySlice<int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
// Same as above, wrap-around only to match other implementations'
// semantics.
std::transform(result_index.begin(), result_index.end(),
result->shape().dimensions().begin(), result_index.begin(),
std::modulus<int64>());
result->Set<ReturnT>(result_index,
update_literal.Get<ReturnT>(update_index));
return true;

View File

@ -49,22 +49,41 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
for (int64 i = 0; i < rank; ++i) {
IrArray::Index dim_index({ir_builder->getInt64(i)});
TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index));
llvm::Value* output_dim_size = llvm::ConstantInt::get(
start_index[i]->getType(), output_shape.dimensions(i));
llvm::Value* update_dim_size = llvm::ConstantInt::get(
start_index[i]->getType(), update_shape.dimensions(i));
// Clamp the start index so that the update region fits in the operand.
// start_index = clamp(start_index, 0, output_dim_size - update_dim_size)
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
// to oficially document different behavior.
llvm::Value* max_bound =
ir_builder->CreateSub(output_dim_size, update_dim_size);
llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0);
start_index[i] = ir_builder->CreateSelect(
ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]),
zero, start_index[i]);
start_index[i] = ir_builder->CreateSelect(
ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound,
start_index[i]),
max_bound, start_index[i]);
}
auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
// Calculate output_index, where we'll write the value from update. For
// each dimension,
//
// output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size.
// output_index[dim] = start_index[dim] + update_index[dim]
//
IrArray::Index output_index(rank);
for (int64 i = 0; i < rank; ++i) {
llvm::Value* dim_size = llvm::ConstantInt::get(
update_index[i]->getType(), output_shape.dimensions(i));
llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast(
llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast(
start_index[i], update_index[i]->getType());
output_index[i] = ir_builder->CreateURem(
ir_builder->CreateAdd(start_index0, update_index[i]), dim_size);
output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]);
}
// Do output[output_index] = update[update_index].

View File

@ -53,9 +53,9 @@ class DynamicSliceTest : public ClientLibraryTestBase {
}
template <typename IndexT, typename DataT>
void TestR1Wrap() {
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1});
void TestR1OOB() {
// Slice at dimension boundaries, but with out of bounds indices.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7});
}
template <typename IndexT, typename DataT>
@ -78,10 +78,10 @@ class DynamicSliceTest : public ClientLibraryTestBase {
}
template <typename IndexT, typename DataT>
void TestR2Wrap() {
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
void TestR2OOB() {
// Slice at dimension boundaries, but with out of bounds indices.
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3},
{{5, 6, 4}, {8, 9, 7}, {2, 3, 1}});
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
}
template <typename IndexT, typename DataT>
@ -106,11 +106,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
}
template <typename IndexT, typename DataT>
void TestR3Wrap() {
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
void TestR3OOB() {
// Slice at dimension boundaries, but with out of bounds indices.
RunR3<IndexT, DataT>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1},
{2, 1, 2}, {{{6, 5}}, {{12, 11}}});
{2, 1, 2}, {{{5, 6}}, {{11, 12}}});
}
template <typename IndexT, typename DataT>
@ -199,19 +199,19 @@ class DynamicSliceTest : public ClientLibraryTestBase {
XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64, float>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32, float>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap<int32, float>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB<int32, float>(); }
XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64, float>(); }
@ -332,17 +332,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
}
template <typename IndexT, typename DataT>
void TestWrap() {
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
void TestOOB() {
// // Slice at dimension boundaries, but with out of bounds indices.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6},
{10, 1, 2, 3, 4, 5, 8, 9});
{0, 1, 2, 3, 4, 8, 9, 10});
// R2 Shape: [3, 3]
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2},
{{1, 2, 3}, {4, 5, 6}, {11, 8, 10}});
{{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
// R3 Shape: [2, 3, 2]
RunR3<IndexT, DataT>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
{1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}});
{1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
}
template <typename IndexT, typename DataT>
@ -476,20 +476,19 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
Array3D<T> input_values(kSeq, kBatch, kDim);
Array3D<T> update_values(size, kBatch, kDim);
Array3D<T> expected_values(kSeq, kBatch, kDim);
index = std::min(std::max(0, index), kSeq - size);
input_values.FillIota(static_cast<T>(0));
T value = static_cast<T>(10);
update_values.FillIota(static_cast<T>(value));
// TODO(b/34128753) Expected values may vary depending on backend when
// the update wraps. According to documentation, the results are technically
// implementation specific where the update is out of bounds, and hence
// we don't really know what to pass into ComputeAndCompareR3.
// the indices are out of bounds.
expected_values.FillIota(static_cast<T>(0));
for (int i = 0; i < size; i++) {
for (int j = 0; j < kBatch; j++) {
for (int k = 0; k < kDim; k++) {
expected_values((index + i) % kSeq, j, k) = value++;
expected_values(index + i, j, k) = value++;
}
}
}
@ -547,12 +546,10 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64, uint64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32WrapBF16) {
TestWrap<int32, bfloat16>();
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap<uint64, uint64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB<uint64, uint64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) {
// Slice at dimension start.
@ -615,37 +612,37 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) {
// Tests for simple R3 case where the update is contiguous (i.e. the minor
// two dimensions are not sliced).
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
// Single element, no wrap.
// Single element, index in-bounds
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) {
// Single element, no wrap.
// Single element, index in-bounds
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
// Multiple element, no wrap.
// Multiples element, index in-bounds.
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) {
// Multiple element, no wrap.
// Multiples element, index in-bounds.
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) {
// Multiple element, wrapping.
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) {
// Multiple element, index out of bounds.
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/3, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) {
// Multiple element, wrapping.
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) {
// Multiple element, index out of bounds.
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/3, /*size=*/2);
}