[XLA] Clamp indices in DynamicSlice and DynamicUpdateSlice instead of wrapping.
This implements the following index clamping in all backends (CPU, GPU, Interpreter): for(int i = 0; i < rank; ++i) start_index[i] = clamp(start_index[i], 0, output_dim_size[i] - update_dim_size[i]) Which ensures the slice (or update region) is always inbounds w.r.t the input. PiperOrigin-RevId: 197082276
This commit is contained in:
parent
a7fcec1b6e
commit
aca0458707
@ -330,13 +330,14 @@ class ReferenceUtil {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Slices with modulo-wrapping.
|
||||
// Slices with index clamping
|
||||
template <typename T>
|
||||
static std::vector<T> ModSlice1D(const tensorflow::gtl::ArraySlice<T>& input,
|
||||
int64 start, int64 size) {
|
||||
static std::vector<T> ClampSlice1D(
|
||||
const tensorflow::gtl::ArraySlice<T>& input, int64 start, int64 size) {
|
||||
start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
|
||||
std::vector<T> result;
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
result.push_back(input[(start + i) % input.size()]);
|
||||
result.push_back(input[(start + i)]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -1547,6 +1547,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
|
||||
llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
|
||||
operand_to_generator.at(hlo->operand(1))(dim_index));
|
||||
|
||||
// Clamp the start index so that the sliced portion fits in the operand:
|
||||
// start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
|
||||
|
||||
// TODO(b/74360564): This is implementation defined behavior, but is
|
||||
// currently respected by all implementations. Change this if we ever decide
|
||||
// to oficially document different behavior.
|
||||
start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
|
||||
index[i]->getType());
|
||||
llvm::Value* operand_dim_size = llvm::ConstantInt::get(
|
||||
start_index_value->getType(), input_hlo->shape().dimensions(i));
|
||||
llvm::Value* output_dim_size = llvm::ConstantInt::get(
|
||||
start_index_value->getType(), hlo->shape().dimensions(i));
|
||||
|
||||
start_index_value = EmitIntegralMin(
|
||||
ir_builder_->CreateSub(operand_dim_size, output_dim_size),
|
||||
EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
|
||||
start_index_value, /*is_signed=*/true),
|
||||
/*is_signed=*/true);
|
||||
|
||||
start_index_value->setName(
|
||||
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
|
||||
slice_start_index[i] = start_index_value;
|
||||
@ -1555,14 +1575,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
|
||||
llvm_ir::IrArray::Index input_index(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
// Emit IR which computes:
|
||||
// input_index = (start_index + offset_index) % dim_size
|
||||
// Security note: this is the code that keeps the indices in-bounds.
|
||||
llvm::Value* dim_size = llvm::ConstantInt::get(
|
||||
index[i]->getType(), input_hlo->shape().dimensions(i));
|
||||
llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
|
||||
slice_start_index[i], index[i]->getType());
|
||||
input_index[i] = ir_builder_->CreateURem(
|
||||
ir_builder_->CreateAdd(start_index, index[i]), dim_size);
|
||||
// input_index = start_index + offset_index
|
||||
input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]);
|
||||
}
|
||||
return operand_to_generator.at(input_hlo)(input_index);
|
||||
}
|
||||
@ -1661,104 +1675,48 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
|
||||
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
|
||||
llvm_ir::IrArray::Index slice_start_index(rank);
|
||||
llvm_ir::IrArray::Index slice_limit_index(rank);
|
||||
// Slice starts at update[index - slice_start_index_adjusted],
|
||||
// where adjusted value = slice_start_index when in bounds, and
|
||||
// adjusted value = slice_start_index - input_dim, when wrapping.
|
||||
llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
|
||||
|
||||
// Slice intersection gathers (ANDs) conditions on all ranks for which
|
||||
// 'input' is set to 'update'
|
||||
llvm::Value* slice_intersection = ir_builder_->getTrue();
|
||||
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
// Emit IR to read dynamic start indices from 'start_hlo'.
|
||||
llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
|
||||
operand_to_generator.at(start_hlo)(dim_index));
|
||||
start_index_value->setName(
|
||||
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
|
||||
slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
|
||||
start_index_value, index[i]->getType());
|
||||
|
||||
// Clamp the start index so that the update region fits in the operand.
|
||||
// start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
|
||||
|
||||
// TODO(b/74360564): This is implementation defined behavior, but is
|
||||
// currently respected by all implementations. Change this if we ever decide
|
||||
// to oficially document different behavior.
|
||||
start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
|
||||
index[i]->getType());
|
||||
llvm::Value* input_dim_size = llvm::ConstantInt::get(
|
||||
index[i]->getType(), input_hlo->shape().dimensions(i));
|
||||
llvm::Value* update_dim_size = llvm::ConstantInt::get(
|
||||
index[i]->getType(), update_hlo->shape().dimensions(i));
|
||||
|
||||
// Generate code to handle wrapping semantics:
|
||||
// slice_start_index[i] = slice_start_index[i] % input_dim_size;
|
||||
// slice_limit_index[i] = slice_start_index[i] + update_dim_size.
|
||||
// slice_start_index[i] is updated in place and it will now be in
|
||||
// range. slice_limit_index[i] may be out of range, and it's being
|
||||
// URem-ed below if so.
|
||||
slice_start_index[i] =
|
||||
ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
|
||||
start_index_value = EmitIntegralMin(
|
||||
ir_builder_->CreateSub(input_dim_size, update_dim_size),
|
||||
EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
|
||||
start_index_value, /*is_signed=*/true),
|
||||
/*is_signed=*/true);
|
||||
|
||||
start_index_value->setName(
|
||||
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
|
||||
slice_start_index[i] = start_index_value;
|
||||
slice_limit_index[i] =
|
||||
ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
|
||||
|
||||
// Test if slice_limit_index[i] is in bounds
|
||||
llvm::Value* in_bounds =
|
||||
ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
|
||||
llvm_ir::LlvmIfData if_in_bounds =
|
||||
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
|
||||
|
||||
// Handle true BB (slice_limit_index[i] <= input_dim_size).
|
||||
SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
|
||||
// Check that index[i] >= slice_start_index[i] &&
|
||||
// index[i] < slice_limit_index[i]
|
||||
llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
|
||||
slice_intersection = ir_builder_->CreateAnd(
|
||||
slice_intersection,
|
||||
ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
|
||||
"slice_intersection_in");
|
||||
slice_intersection_in_bounds = ir_builder_->CreateAnd(
|
||||
slice_intersection_in_bounds,
|
||||
"slice_intersection");
|
||||
slice_intersection = ir_builder_->CreateAnd(
|
||||
slice_intersection,
|
||||
ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
|
||||
"slice_intersection_in");
|
||||
|
||||
// Handle false BB (slice_limit_index[i] > input_dim_size).
|
||||
SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
|
||||
// Check that index[i] >= slice_start_index[i] ||
|
||||
// index[i] < slice_limit_index[i]%input_dim_size.
|
||||
llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
|
||||
index[i],
|
||||
ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
|
||||
llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
|
||||
ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), index_wraps,
|
||||
"slice_intersection_out");
|
||||
llvm::Value* slice_intersection_out_of_bounds = ir_builder_->CreateAnd(
|
||||
slice_intersection, slice_intersection_or, "slice_intersection_out");
|
||||
// Create value for slice_start_index_adjusted[i] when out of bounds.
|
||||
// If within out-of-bounds if.
|
||||
llvm_ir::LlvmIfData if_start_needs_adjustment =
|
||||
llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
|
||||
SetToFirstInsertPoint(if_start_needs_adjustment.true_block, ir_builder_);
|
||||
llvm::Value* slice_start_index_adjusted_oob =
|
||||
ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
|
||||
SetToFirstInsertPoint(if_start_needs_adjustment.after_block, ir_builder_);
|
||||
llvm::PHINode* slice_start_index_adjusted_phi =
|
||||
ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 2);
|
||||
slice_start_index_adjusted_phi->addIncoming(
|
||||
slice_start_index_adjusted_oob, if_start_needs_adjustment.true_block);
|
||||
slice_start_index_adjusted_phi->addIncoming(
|
||||
slice_start_index[i], if_start_needs_adjustment.false_block);
|
||||
// End of if within if.
|
||||
|
||||
// After checking in/out of bounds.
|
||||
SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
|
||||
llvm::PHINode* phi_slice_intersection =
|
||||
ir_builder_->CreatePHI(slice_intersection->getType(), 2);
|
||||
phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
|
||||
if_in_bounds.true_block);
|
||||
phi_slice_intersection->addIncoming(slice_intersection_out_of_bounds,
|
||||
if_start_needs_adjustment.after_block);
|
||||
slice_intersection = phi_slice_intersection;
|
||||
|
||||
llvm::PHINode* phi_index =
|
||||
ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
|
||||
phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
|
||||
phi_index->addIncoming(slice_start_index_adjusted_phi,
|
||||
if_start_needs_adjustment.after_block);
|
||||
slice_start_index_adjusted[i] = phi_index;
|
||||
"slice_intersection");
|
||||
}
|
||||
|
||||
// Emit:
|
||||
@ -1775,12 +1733,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
|
||||
// Compute update index for intersection case.
|
||||
llvm_ir::IrArray::Index update_index(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
llvm::Value* update_dim_size = llvm::ConstantInt::get(
|
||||
index[i]->getType(), update_hlo->shape().dimensions(i));
|
||||
// NOTE: Subtraction will be positive due to bounds checking above.
|
||||
update_index[i] = ir_builder_->CreateURem(
|
||||
ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
|
||||
update_dim_size);
|
||||
update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
|
||||
operand_to_generator.at(update_hlo)(update_index));
|
||||
|
@ -1986,17 +1986,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
std::vector<int64> start(start_indices_typed.begin(),
|
||||
start_indices_typed.end());
|
||||
|
||||
std::vector<int64> operand_indices(start.size());
|
||||
// Clamp the start indices so the slice is in-bounds w.r.t the operand.
|
||||
|
||||
// TODO(b/74360564): This is implementation defined behavior, but is
|
||||
// currently respected by all implementations. Change this if we ever decide
|
||||
// to oficially document different behavior.
|
||||
for (int64 i = 0; i < start.size(); ++i) {
|
||||
start[i] = std::min<int64>(
|
||||
std::max(0LL, start[i]),
|
||||
operand_literal.shape().dimensions(i) - result_shape.dimensions(i));
|
||||
}
|
||||
|
||||
std::vector<int64> operand_indices(start.size());
|
||||
auto result = MakeUnique<Literal>(result_shape);
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
for (int64 i = 0; i < operand_indices.size(); ++i) {
|
||||
CHECK_GE(multi_index[i] + start[i], 0);
|
||||
// Mod is only used here to be consistent with the existing
|
||||
// backends' behavior.
|
||||
operand_indices[i] = (multi_index[i] + start[i]) %
|
||||
operand_literal.shape().dimensions(i);
|
||||
operand_indices[i] = multi_index[i] + start[i];
|
||||
}
|
||||
|
||||
auto result = operand_literal.Get<ReturnT>(operand_indices);
|
||||
@ -2013,23 +2020,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
auto result = operand_literal.CloneToUnique();
|
||||
auto start_indices_typed = start_indices_literal.data<IndexT>();
|
||||
const auto rank = ShapeUtil::Rank(result->shape());
|
||||
std::vector<int64> start(rank, 0);
|
||||
std::vector<int64> start(start_indices_typed.begin(),
|
||||
start_indices_typed.end());
|
||||
// Clamp the update start indices so the slice is in-bounds w.r.t the
|
||||
// operand.
|
||||
|
||||
// TODO(b/74360564): This is implementation defined behavior, but is
|
||||
// currently respected by all implementations. Change this if we ever decide
|
||||
// to oficially document different behavior.
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
// All other implementations currently wrap-around the index, so this
|
||||
// should do so as well.
|
||||
start[i] = (start_indices_typed[i] % result->shape().dimensions(i));
|
||||
start[i] += (start[i] < 0) * result->shape().dimensions(i);
|
||||
start[i] = std::min<int64>(
|
||||
std::max<int64>(0, start[i]),
|
||||
result->shape().dimensions(i) - update_literal.shape().dimensions(i));
|
||||
}
|
||||
std::vector<int64> result_index(rank, 0);
|
||||
|
||||
auto func = [&](tensorflow::gtl::ArraySlice<int64> update_index) {
|
||||
std::transform(update_index.begin(), update_index.end(), start.begin(),
|
||||
result_index.begin(), std::plus<int64>());
|
||||
// Same as above, wrap-around only to match other implementations'
|
||||
// semantics.
|
||||
std::transform(result_index.begin(), result_index.end(),
|
||||
result->shape().dimensions().begin(), result_index.begin(),
|
||||
std::modulus<int64>());
|
||||
result->Set<ReturnT>(result_index,
|
||||
update_literal.Get<ReturnT>(update_index));
|
||||
return true;
|
||||
|
@ -49,22 +49,41 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
IrArray::Index dim_index({ir_builder->getInt64(i)});
|
||||
TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index));
|
||||
llvm::Value* output_dim_size = llvm::ConstantInt::get(
|
||||
start_index[i]->getType(), output_shape.dimensions(i));
|
||||
llvm::Value* update_dim_size = llvm::ConstantInt::get(
|
||||
start_index[i]->getType(), update_shape.dimensions(i));
|
||||
|
||||
// Clamp the start index so that the update region fits in the operand.
|
||||
// start_index = clamp(start_index, 0, output_dim_size - update_dim_size)
|
||||
|
||||
// TODO(b/74360564): This is implementation defined behavior, but is
|
||||
// currently respected by all implementations. Change this if we ever decide
|
||||
// to oficially document different behavior.
|
||||
llvm::Value* max_bound =
|
||||
ir_builder->CreateSub(output_dim_size, update_dim_size);
|
||||
llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0);
|
||||
start_index[i] = ir_builder->CreateSelect(
|
||||
ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]),
|
||||
zero, start_index[i]);
|
||||
|
||||
start_index[i] = ir_builder->CreateSelect(
|
||||
ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound,
|
||||
start_index[i]),
|
||||
max_bound, start_index[i]);
|
||||
}
|
||||
|
||||
auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
|
||||
// Calculate output_index, where we'll write the value from update. For
|
||||
// each dimension,
|
||||
//
|
||||
// output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size.
|
||||
// output_index[dim] = start_index[dim] + update_index[dim]
|
||||
//
|
||||
IrArray::Index output_index(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
llvm::Value* dim_size = llvm::ConstantInt::get(
|
||||
update_index[i]->getType(), output_shape.dimensions(i));
|
||||
llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast(
|
||||
llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast(
|
||||
start_index[i], update_index[i]->getType());
|
||||
output_index[i] = ir_builder->CreateURem(
|
||||
ir_builder->CreateAdd(start_index0, update_index[i]), dim_size);
|
||||
output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]);
|
||||
}
|
||||
|
||||
// Do output[output_index] = update[update_index].
|
||||
|
@ -53,9 +53,9 @@ class DynamicSliceTest : public ClientLibraryTestBase {
|
||||
}
|
||||
|
||||
template <typename IndexT, typename DataT>
|
||||
void TestR1Wrap() {
|
||||
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
|
||||
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1});
|
||||
void TestR1OOB() {
|
||||
// Slice at dimension boundaries, but with out of bounds indices.
|
||||
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7});
|
||||
}
|
||||
|
||||
template <typename IndexT, typename DataT>
|
||||
@ -78,10 +78,10 @@ class DynamicSliceTest : public ClientLibraryTestBase {
|
||||
}
|
||||
|
||||
template <typename IndexT, typename DataT>
|
||||
void TestR2Wrap() {
|
||||
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
|
||||
void TestR2OOB() {
|
||||
// Slice at dimension boundaries, but with out of bounds indices.
|
||||
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3},
|
||||
{{5, 6, 4}, {8, 9, 7}, {2, 3, 1}});
|
||||
{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
|
||||
}
|
||||
|
||||
template <typename IndexT, typename DataT>
|
||||
@ -106,11 +106,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
|
||||
}
|
||||
|
||||
template <typename IndexT, typename DataT>
|
||||
void TestR3Wrap() {
|
||||
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
|
||||
void TestR3OOB() {
|
||||
// Slice at dimension boundaries, but with out of bounds indices.
|
||||
RunR3<IndexT, DataT>(
|
||||
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1},
|
||||
{2, 1, 2}, {{{6, 5}}, {{12, 11}}});
|
||||
{2, 1, 2}, {{{5, 6}}, {{11, 12}}});
|
||||
}
|
||||
|
||||
template <typename IndexT, typename DataT>
|
||||
@ -199,19 +199,19 @@ class DynamicSliceTest : public ClientLibraryTestBase {
|
||||
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1<int32, bfloat16>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32, int32>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap<int32, int32>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB<int32, int32>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64, float>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64, float>(); }
|
||||
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2<int32, bfloat16>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32, int32>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap<int32, int32>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB<int32, int32>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64, float>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
|
||||
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3<int32, bfloat16>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32, float>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap<int32, float>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB<int32, float>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64, float>(); }
|
||||
XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64, float>(); }
|
||||
|
||||
@ -332,17 +332,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
|
||||
}
|
||||
|
||||
template <typename IndexT, typename DataT>
|
||||
void TestWrap() {
|
||||
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
|
||||
void TestOOB() {
|
||||
// // Slice at dimension boundaries, but with out of bounds indices.
|
||||
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6},
|
||||
{10, 1, 2, 3, 4, 5, 8, 9});
|
||||
{0, 1, 2, 3, 4, 8, 9, 10});
|
||||
// R2 Shape: [3, 3]
|
||||
RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2},
|
||||
{{1, 2, 3}, {4, 5, 6}, {11, 8, 10}});
|
||||
{{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
|
||||
// R3 Shape: [2, 3, 2]
|
||||
RunR3<IndexT, DataT>(
|
||||
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
|
||||
{1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}});
|
||||
{1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
|
||||
}
|
||||
|
||||
template <typename IndexT, typename DataT>
|
||||
@ -476,20 +476,19 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
|
||||
Array3D<T> input_values(kSeq, kBatch, kDim);
|
||||
Array3D<T> update_values(size, kBatch, kDim);
|
||||
Array3D<T> expected_values(kSeq, kBatch, kDim);
|
||||
index = std::min(std::max(0, index), kSeq - size);
|
||||
|
||||
input_values.FillIota(static_cast<T>(0));
|
||||
T value = static_cast<T>(10);
|
||||
update_values.FillIota(static_cast<T>(value));
|
||||
|
||||
// TODO(b/34128753) Expected values may vary depending on backend when
|
||||
// the update wraps. According to documentation, the results are technically
|
||||
// implementation specific where the update is out of bounds, and hence
|
||||
// we don't really know what to pass into ComputeAndCompareR3.
|
||||
// the indices are out of bounds.
|
||||
expected_values.FillIota(static_cast<T>(0));
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int j = 0; j < kBatch; j++) {
|
||||
for (int k = 0; k < kDim; k++) {
|
||||
expected_values((index + i) % kSeq, j, k) = value++;
|
||||
expected_values(index + i, j, k) = value++;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -547,12 +546,10 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32, float>(); }
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64, int64>(); }
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64, uint64>(); }
|
||||
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, Int32WrapBF16) {
|
||||
TestWrap<int32, bfloat16>();
|
||||
}
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap<int32, float>(); }
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap<int64, int64>(); }
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap<uint64, uint64>(); }
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB<int32, bfloat16>(); }
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB<int32, float>(); }
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB<int64, int64>(); }
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB<uint64, uint64>(); }
|
||||
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) {
|
||||
// Slice at dimension start.
|
||||
@ -615,37 +612,37 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) {
|
||||
// Tests for simple R3 case where the update is contiguous (i.e. the minor
|
||||
// two dimensions are not sliced).
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
|
||||
// Single element, no wrap.
|
||||
// Single element, index in-bounds
|
||||
std::vector<int32> operand_shape({4, 5, 2});
|
||||
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) {
|
||||
// Single element, no wrap.
|
||||
// Single element, index in-bounds
|
||||
std::vector<int32> operand_shape({4, 5, 2});
|
||||
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
|
||||
// Multiple element, no wrap.
|
||||
// Multiples element, index in-bounds.
|
||||
std::vector<int32> operand_shape({4, 5, 2});
|
||||
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) {
|
||||
// Multiple element, no wrap.
|
||||
// Multiples element, index in-bounds.
|
||||
std::vector<int32> operand_shape({4, 5, 2});
|
||||
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/2);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) {
|
||||
// Multiple element, wrapping.
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) {
|
||||
// Multiple element, index out of bounds.
|
||||
std::vector<int32> operand_shape({4, 5, 2});
|
||||
RunR3Contiguous<float>(operand_shape, /*index=*/3, /*size=*/2);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) {
|
||||
// Multiple element, wrapping.
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) {
|
||||
// Multiple element, index out of bounds.
|
||||
std::vector<int32> operand_shape({4, 5, 2});
|
||||
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/3, /*size=*/2);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user