Introduce a ShapeUtil::ForEachIndexWithStatus, change index type to ArraySlice
This is not used yet, but I need it in a later CL. I don't specifically need the argument to be an ArraySlice, but it seemed cleaner than taking a const ref to a vector. No functional change intended. PiperOrigin-RevId: 187352376
This commit is contained in:
parent
12d8142dc1
commit
39a43c4f1d
@ -223,7 +223,7 @@ Status Literal::CopySliceFromInternal(
|
||||
Literal::StrideConfig stride_config(src_literal.shape(), shape(),
|
||||
copy_size);
|
||||
|
||||
auto copy_proc = [&](const std::vector<int64>& indexes) {
|
||||
auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
// Map from multi-dimensional index, to source index.
|
||||
std::transform(indexes.begin(), indexes.end(), src_base.begin(),
|
||||
src_indexes.begin(), std::plus<int64>());
|
||||
|
@ -1269,7 +1269,7 @@ Status Literal::Populate(const FnType& generator) {
|
||||
int64 minor_dimension_size =
|
||||
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
|
||||
|
||||
auto init_function = [&](const std::vector<int64>& indexes) {
|
||||
auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
const int64 index =
|
||||
IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
|
||||
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using tensorflow::gtl::ArraySlice;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
@ -214,11 +215,11 @@ TEST_F(LiteralUtilTest, CreateSparse) {
|
||||
std::vector<int64> expected_values = {8, 9, 7, 10};
|
||||
|
||||
EXPECT_EQ(literal->sparse_indices()->data(),
|
||||
tensorflow::gtl::ArraySlice<int64>(
|
||||
expected_indices.data(), expected_indices.num_elements()));
|
||||
EXPECT_EQ(tensorflow::gtl::ArraySlice<int64>(literal->data<int64>().data(),
|
||||
expected_values.size()),
|
||||
tensorflow::gtl::ArraySlice<int64>(expected_values));
|
||||
ArraySlice<int64>(expected_indices.data(),
|
||||
expected_indices.num_elements()));
|
||||
EXPECT_EQ(
|
||||
ArraySlice<int64>(literal->data<int64>().data(), expected_values.size()),
|
||||
ArraySlice<int64>(expected_values));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
|
||||
@ -290,7 +291,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
|
||||
// clang-format on
|
||||
std::vector<std::tuple<int64, int64, string>> seen;
|
||||
literal->EachCellAsString(
|
||||
[&seen](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
|
||||
[&seen](ArraySlice<int64> indices, const string& value) {
|
||||
seen.emplace_back(indices[0], indices[1], value);
|
||||
});
|
||||
|
||||
@ -622,11 +623,10 @@ TEST_F(LiteralUtilTest, TransposeR4) {
|
||||
// clang-format on
|
||||
auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
|
||||
|
||||
reshape->EachCell<float>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, float value) {
|
||||
EXPECT_EQ(value, original->Get<float>(
|
||||
{indices[2], indices[3], indices[0], indices[1]}));
|
||||
});
|
||||
reshape->EachCell<float>([&](ArraySlice<int64> indices, float value) {
|
||||
EXPECT_EQ(value, original->Get<float>(
|
||||
{indices[2], indices[3], indices[0], indices[1]}));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
|
||||
@ -863,7 +863,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
|
||||
const int64 zero_base[] = {0, 0, 0, 0};
|
||||
const int64 step[] = {1, 1, 1, 1};
|
||||
uint32 seqnr = 0;
|
||||
auto init_proc = [&](const std::vector<int64>& indexes) {
|
||||
auto init_proc = [&](ArraySlice<int64> indexes) {
|
||||
source->Set(indexes, ++seqnr);
|
||||
return true;
|
||||
};
|
||||
@ -879,7 +879,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
|
||||
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
|
||||
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
|
||||
bool matched = true;
|
||||
auto check_proc = [&](const std::vector<int64>& indexes) {
|
||||
auto check_proc = [&](ArraySlice<int64> indexes) {
|
||||
std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
|
||||
std::transform(source_indexes.begin(), source_indexes.end(), src_base,
|
||||
source_indexes.begin(), std::plus<int64>());
|
||||
@ -1067,7 +1067,7 @@ TEST_F(LiteralUtilTest, Populate) {
|
||||
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
|
||||
data.layout);
|
||||
auto literal = Literal::CreateFromShape(shape);
|
||||
auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> uint32 {
|
||||
auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
|
||||
// Offsets from linear index just to avoid R0 literals to be initialized
|
||||
// with zero.
|
||||
return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
|
||||
@ -1079,7 +1079,7 @@ TEST_F(LiteralUtilTest, Populate) {
|
||||
std::vector<int64> zero_base(data.dimensions.size(), 0);
|
||||
std::vector<int64> step(data.dimensions.size(), 1);
|
||||
bool matched = true;
|
||||
auto check_function = [&](const std::vector<int64>& indexes) {
|
||||
auto check_function = [&](ArraySlice<int64> indexes) {
|
||||
auto value = literal->Get<uint32>(indexes);
|
||||
matched = matched && (value == generator(indexes));
|
||||
return matched;
|
||||
|
@ -1222,7 +1222,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
// corresponding index of the resulting padded literal.
|
||||
const PaddingConfig& pad_config = pad->padding_config();
|
||||
|
||||
auto func = [&](const std::vector<int64>& input_index) {
|
||||
auto func = [&](ArraySlice<int64> input_index) {
|
||||
for (auto i = 0; i < input_index.size(); ++i) {
|
||||
// Interior padding occurs logically before edge padding, so in the case
|
||||
// of negative edge padding elements are removed from the
|
||||
@ -1518,7 +1518,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
base[result_to_arg_index[i]] = multi_index[i];
|
||||
}
|
||||
|
||||
auto func = [&](const std::vector<int64>& input_index) {
|
||||
auto func = [&](ArraySlice<int64> input_index) {
|
||||
auto curr_val = arg_literal.Get<ReturnT>(input_index);
|
||||
|
||||
// Evaluate computation with specified literal operands.
|
||||
@ -1954,7 +1954,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
auto result = operand_literal.CloneToUnique();
|
||||
std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
|
||||
|
||||
auto func = [&](const std::vector<int64>& update_index) {
|
||||
auto func = [&](ArraySlice<int64> update_index) {
|
||||
std::transform(update_index.begin(), update_index.end(), start.begin(),
|
||||
result_index.begin(), std::plus<int64>());
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -564,16 +565,16 @@ class ShapeUtil {
|
||||
// The visitor_function visitor function should return true if it wants to
|
||||
// continue, or false otherwise.
|
||||
//
|
||||
// visitor_function must be a callable of type bool(const std::vector<int64>&)
|
||||
// or compatible.
|
||||
// visitor_function must be a callable of type
|
||||
// StatusOr<bool>(ArraySlice<int64>) or compatible.
|
||||
template <typename FnType>
|
||||
static void ForEachIndex(const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<int64> base,
|
||||
tensorflow::gtl::ArraySlice<int64> count,
|
||||
tensorflow::gtl::ArraySlice<int64> incr,
|
||||
const FnType& visitor_function) {
|
||||
static Status ForEachIndexWithStatus(const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<int64> base,
|
||||
tensorflow::gtl::ArraySlice<int64> count,
|
||||
tensorflow::gtl::ArraySlice<int64> incr,
|
||||
const FnType& visitor_function) {
|
||||
if (ShapeUtil::HasZeroElements(shape)) {
|
||||
return;
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_EQ(Rank(shape), base.size());
|
||||
CHECK_EQ(incr.size(), base.size());
|
||||
@ -583,7 +584,11 @@ class ShapeUtil {
|
||||
// once with the proper empty indexes.
|
||||
int64 n = -1;
|
||||
std::vector<int64> indexes(base.begin(), base.end());
|
||||
while (n < rank && visitor_function(indexes)) {
|
||||
while (n < rank) {
|
||||
TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
|
||||
if (!should_continue) {
|
||||
break;
|
||||
}
|
||||
// Increments dimensions in minor to major order.
|
||||
for (n = 0; n < rank; ++n) {
|
||||
int64 dim = LayoutUtil::Minor(shape.layout(), n);
|
||||
@ -594,6 +599,21 @@ class ShapeUtil {
|
||||
indexes[dim] = base[dim];
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename FnType>
|
||||
static void ForEachIndex(const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<int64> base,
|
||||
tensorflow::gtl::ArraySlice<int64> count,
|
||||
tensorflow::gtl::ArraySlice<int64> incr,
|
||||
const FnType& visitor_function) {
|
||||
ForEachIndexWithStatus(shape, base, count, incr,
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices) {
|
||||
return StatusOr<bool>(visitor_function(indices));
|
||||
})
|
||||
.IgnoreError();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -573,10 +573,11 @@ TEST(ShapeUtilTest, ForEachIndex) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
|
||||
// Increments at every invocation.
|
||||
int invocations = 0;
|
||||
auto increment_func = [&invocations](const std::vector<int64>& indexes) {
|
||||
invocations++;
|
||||
return true;
|
||||
};
|
||||
auto increment_func =
|
||||
[&invocations](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
invocations++;
|
||||
return true;
|
||||
};
|
||||
|
||||
std::vector<int64> zero_base(data.dimensions.size(), 0);
|
||||
std::vector<int64> step(data.dimensions.size(), 1);
|
||||
@ -588,6 +589,29 @@ TEST(ShapeUtilTest, ForEachIndex) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ForEachIndexWithStatus) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
|
||||
// Increments at every invocation.
|
||||
int invocations = 0;
|
||||
auto increment_func =
|
||||
[&invocations](
|
||||
tensorflow::gtl::ArraySlice<int64> indexes) -> StatusOr<bool> {
|
||||
if (++invocations == 5) {
|
||||
return Unimplemented("Cannot increment beyond 5.");
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
Status error_status = ShapeUtil::ForEachIndexWithStatus(
|
||||
shape, /*base=*/{0, 0}, /*count=*/{10, 10}, /*incr=*/{0, 1},
|
||||
increment_func);
|
||||
|
||||
EXPECT_FALSE(error_status.ok());
|
||||
EXPECT_THAT(error_status.error_message(),
|
||||
::testing::HasSubstr("Cannot increment beyond 5."));
|
||||
EXPECT_EQ(invocations, 5);
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
|
||||
// All output dimensions should be unmodified. One of the input dimensions is
|
||||
// modified because the input rank is larger by one.
|
||||
|
Loading…
Reference in New Issue
Block a user