Introduce a ShapeUtil::ForEachIndexWithStatus, change index type to ArraySlice

This is not used yet, but I need it in a later CL.  I don't specifically need
the argument to be an ArraySlice, but it seemed cleaner than taking a const ref
to a vector.

No functional change intended.

PiperOrigin-RevId: 187352376
This commit is contained in:
Sanjoy Das 2018-02-28 11:07:10 -08:00 committed by TensorFlower Gardener
parent 12d8142dc1
commit 39a43c4f1d
6 changed files with 77 additions and 33 deletions

View File

@ -223,7 +223,7 @@ Status Literal::CopySliceFromInternal(
Literal::StrideConfig stride_config(src_literal.shape(), shape(),
copy_size);
auto copy_proc = [&](const std::vector<int64>& indexes) {
auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
// Map from multi-dimensional index, to source index.
std::transform(indexes.begin(), indexes.end(), src_base.begin(),
src_indexes.begin(), std::plus<int64>());

View File

@ -1269,7 +1269,7 @@ Status Literal::Populate(const FnType& generator) {
int64 minor_dimension_size =
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
auto init_function = [&](const std::vector<int64>& indexes) {
auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
const int64 index =
IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());

View File

@ -30,6 +30,7 @@ limitations under the License.
namespace xla {
namespace {
using tensorflow::gtl::ArraySlice;
using ::testing::ElementsAre;
using ::testing::HasSubstr;
@ -214,11 +215,11 @@ TEST_F(LiteralUtilTest, CreateSparse) {
std::vector<int64> expected_values = {8, 9, 7, 10};
EXPECT_EQ(literal->sparse_indices()->data(),
tensorflow::gtl::ArraySlice<int64>(
expected_indices.data(), expected_indices.num_elements()));
EXPECT_EQ(tensorflow::gtl::ArraySlice<int64>(literal->data<int64>().data(),
expected_values.size()),
tensorflow::gtl::ArraySlice<int64>(expected_values));
ArraySlice<int64>(expected_indices.data(),
expected_indices.num_elements()));
EXPECT_EQ(
ArraySlice<int64>(literal->data<int64>().data(), expected_values.size()),
ArraySlice<int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
@ -290,7 +291,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
// clang-format on
std::vector<std::tuple<int64, int64, string>> seen;
literal->EachCellAsString(
[&seen](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
[&seen](ArraySlice<int64> indices, const string& value) {
seen.emplace_back(indices[0], indices[1], value);
});
@ -622,11 +623,10 @@ TEST_F(LiteralUtilTest, TransposeR4) {
// clang-format on
auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
reshape->EachCell<float>(
[&](tensorflow::gtl::ArraySlice<int64> indices, float value) {
EXPECT_EQ(value, original->Get<float>(
{indices[2], indices[3], indices[0], indices[1]}));
});
reshape->EachCell<float>([&](ArraySlice<int64> indices, float value) {
EXPECT_EQ(value, original->Get<float>(
{indices[2], indices[3], indices[0], indices[1]}));
});
}
TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
@ -863,7 +863,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
const int64 zero_base[] = {0, 0, 0, 0};
const int64 step[] = {1, 1, 1, 1};
uint32 seqnr = 0;
auto init_proc = [&](const std::vector<int64>& indexes) {
auto init_proc = [&](ArraySlice<int64> indexes) {
source->Set(indexes, ++seqnr);
return true;
};
@ -879,7 +879,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
bool matched = true;
auto check_proc = [&](const std::vector<int64>& indexes) {
auto check_proc = [&](ArraySlice<int64> indexes) {
std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
std::transform(source_indexes.begin(), source_indexes.end(), src_base,
source_indexes.begin(), std::plus<int64>());
@ -1067,7 +1067,7 @@ TEST_F(LiteralUtilTest, Populate) {
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
auto literal = Literal::CreateFromShape(shape);
auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> uint32 {
auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
@ -1079,7 +1079,7 @@ TEST_F(LiteralUtilTest, Populate) {
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
auto check_function = [&](const std::vector<int64>& indexes) {
auto check_function = [&](ArraySlice<int64> indexes) {
auto value = literal->Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;

View File

@ -1222,7 +1222,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
// corresponding index of the resulting padded literal.
const PaddingConfig& pad_config = pad->padding_config();
auto func = [&](const std::vector<int64>& input_index) {
auto func = [&](ArraySlice<int64> input_index) {
for (auto i = 0; i < input_index.size(); ++i) {
// Interior padding occurs logically before edge padding, so in the case
// of negative edge padding elements are removed from the
@ -1518,7 +1518,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
base[result_to_arg_index[i]] = multi_index[i];
}
auto func = [&](const std::vector<int64>& input_index) {
auto func = [&](ArraySlice<int64> input_index) {
auto curr_val = arg_literal.Get<ReturnT>(input_index);
// Evaluate computation with specified literal operands.
@ -1954,7 +1954,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
auto result = operand_literal.CloneToUnique();
std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
auto func = [&](const std::vector<int64>& update_index) {
auto func = [&](ArraySlice<int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -564,16 +565,16 @@ class ShapeUtil {
// The visitor_function visitor function should return true if it wants to
// continue, or false otherwise.
//
// visitor_function must be a callable of type bool(const std::vector<int64>&)
// or compatible.
// visitor_function must be a callable of type
// StatusOr<bool>(ArraySlice<int64>) or compatible.
template <typename FnType>
static void ForEachIndex(const Shape& shape,
tensorflow::gtl::ArraySlice<int64> base,
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const FnType& visitor_function) {
static Status ForEachIndexWithStatus(const Shape& shape,
tensorflow::gtl::ArraySlice<int64> base,
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const FnType& visitor_function) {
if (ShapeUtil::HasZeroElements(shape)) {
return;
return Status::OK();
}
CHECK_EQ(Rank(shape), base.size());
CHECK_EQ(incr.size(), base.size());
@ -583,7 +584,11 @@ class ShapeUtil {
// once with the proper empty indexes.
int64 n = -1;
std::vector<int64> indexes(base.begin(), base.end());
while (n < rank && visitor_function(indexes)) {
while (n < rank) {
TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
if (!should_continue) {
break;
}
// Increments dimensions in minor to major order.
for (n = 0; n < rank; ++n) {
int64 dim = LayoutUtil::Minor(shape.layout(), n);
@ -594,6 +599,21 @@ class ShapeUtil {
indexes[dim] = base[dim];
}
}
return Status::OK();
}
template <typename FnType>
static void ForEachIndex(const Shape& shape,
tensorflow::gtl::ArraySlice<int64> base,
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const FnType& visitor_function) {
ForEachIndexWithStatus(shape, base, count, incr,
[&](tensorflow::gtl::ArraySlice<int64> indices) {
return StatusOr<bool>(visitor_function(indices));
})
.IgnoreError();
}
private:

View File

@ -573,10 +573,11 @@ TEST(ShapeUtilTest, ForEachIndex) {
Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
// Increments at every invocation.
int invocations = 0;
auto increment_func = [&invocations](const std::vector<int64>& indexes) {
invocations++;
return true;
};
auto increment_func =
[&invocations](tensorflow::gtl::ArraySlice<int64> indexes) {
invocations++;
return true;
};
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
@ -588,6 +589,29 @@ TEST(ShapeUtilTest, ForEachIndex) {
}
}
TEST(ShapeUtilTest, ForEachIndexWithStatus) {
Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
// Increments at every invocation.
int invocations = 0;
auto increment_func =
[&invocations](
tensorflow::gtl::ArraySlice<int64> indexes) -> StatusOr<bool> {
if (++invocations == 5) {
return Unimplemented("Cannot increment beyond 5.");
}
return true;
};
Status error_status = ShapeUtil::ForEachIndexWithStatus(
shape, /*base=*/{0, 0}, /*count=*/{10, 10}, /*incr=*/{0, 1},
increment_func);
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(error_status.error_message(),
::testing::HasSubstr("Cannot increment beyond 5."));
EXPECT_EQ(invocations, 5);
}
TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
// All output dimensions should be unmodified. One of the input dimensions is
// modified because the input rank is larger by one.