From 39a43c4f1d73b0210795d2003b127d3ffa284e98 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 28 Feb 2018 11:07:10 -0800 Subject: [PATCH] Introduce a ShapeUtil::ForEachIndexWithStatus, change index type to ArraySlice This is not used yet, but I need it in a later CL. I don't specifically need the argument to be an ArraySlice, but it seemed cleaner than taking a const ref to a vector. No functional change intended. PiperOrigin-RevId: 187352376 --- tensorflow/compiler/xla/literal_util.cc | 2 +- tensorflow/compiler/xla/literal_util.h | 2 +- tensorflow/compiler/xla/literal_util_test.cc | 30 +++++++-------- .../compiler/xla/service/hlo_evaluator.cc | 6 +-- tensorflow/compiler/xla/shape_util.h | 38 ++++++++++++++----- tensorflow/compiler/xla/shape_util_test.cc | 32 ++++++++++++++-- 6 files changed, 77 insertions(+), 33 deletions(-) diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 823da43b5ab..3962a9b3161 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -223,7 +223,7 @@ Status Literal::CopySliceFromInternal( Literal::StrideConfig stride_config(src_literal.shape(), shape(), copy_size); - auto copy_proc = [&](const std::vector& indexes) { + auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { // Map from multi-dimensional index, to source index. std::transform(indexes.begin(), indexes.end(), src_base.begin(), src_indexes.begin(), std::plus()); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index d5ae3fd7232..1d58f0cbc72 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -1269,7 +1269,7 @@ Status Literal::Populate(const FnType& generator) { int64 minor_dimension_size = ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - auto init_function = [&](const std::vector& indexes) { + auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { const int64 index = IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index ee2f4fe8744..9ff07711100 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -30,6 +30,7 @@ limitations under the License. namespace xla { namespace { +using tensorflow::gtl::ArraySlice; using ::testing::ElementsAre; using ::testing::HasSubstr; @@ -214,11 +215,11 @@ TEST_F(LiteralUtilTest, CreateSparse) { std::vector expected_values = {8, 9, 7, 10}; EXPECT_EQ(literal->sparse_indices()->data(), - tensorflow::gtl::ArraySlice( - expected_indices.data(), expected_indices.num_elements())); - EXPECT_EQ(tensorflow::gtl::ArraySlice(literal->data().data(), - expected_values.size()), - tensorflow::gtl::ArraySlice(expected_values)); + ArraySlice(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ( + ArraySlice(literal->data().data(), expected_values.size()), + ArraySlice(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -290,7 +291,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format on std::vector> seen; literal->EachCellAsString( - [&seen](tensorflow::gtl::ArraySlice indices, const string& value) { + [&seen](ArraySlice indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -622,11 +623,10 @@ TEST_F(LiteralUtilTest, TransposeR4) { // clang-format on auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float value) { - EXPECT_EQ(value, original->Get( - {indices[2], indices[3], indices[0], indices[1]})); - }); + reshape->EachCell([&](ArraySlice indices, float value) { + EXPECT_EQ(value, original->Get( + {indices[2], indices[3], indices[0], indices[1]})); + }); } TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { @@ -863,7 +863,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; - auto init_proc = [&](const std::vector& indexes) { + auto init_proc = [&](ArraySlice indexes) { source->Set(indexes, ++seqnr); return true; }; @@ -879,7 +879,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; - auto check_proc = [&](const std::vector& indexes) { + auto check_proc = [&](ArraySlice indexes) { std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); std::transform(source_indexes.begin(), source_indexes.end(), src_base, source_indexes.begin(), std::plus()); @@ -1067,7 +1067,7 @@ TEST_F(LiteralUtilTest, Populate) { primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); auto literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { + auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), @@ -1079,7 +1079,7 @@ TEST_F(LiteralUtilTest, Populate) { std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](const std::vector& indexes) { + auto check_function = [&](ArraySlice indexes) { auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index c3a3251b7dd..edb1ad23609 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1222,7 +1222,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // corresponding index of the resulting padded literal. const PaddingConfig& pad_config = pad->padding_config(); - auto func = [&](const std::vector& input_index) { + auto func = [&](ArraySlice input_index) { for (auto i = 0; i < input_index.size(); ++i) { // Interior padding occurs logically before edge padding, so in the case // of negative edge padding elements are removed from the @@ -1518,7 +1518,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { base[result_to_arg_index[i]] = multi_index[i]; } - auto func = [&](const std::vector& input_index) { + auto func = [&](ArraySlice input_index) { auto curr_val = arg_literal.Get(input_index); // Evaluate computation with specified literal operands. @@ -1954,7 +1954,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { auto result = operand_literal.CloneToUnique(); std::vector result_index(ShapeUtil::Rank(result->shape()), 0); - auto func = [&](const std::vector& update_index) { + auto func = [&](ArraySlice update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 8ee263fe5e5..923315e001f 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -564,16 +565,16 @@ class ShapeUtil { // The visitor_function visitor function should return true if it wants to // continue, or false otherwise. // - // visitor_function must be a callable of type bool(const std::vector&) - // or compatible. + // visitor_function must be a callable of type + // StatusOr(ArraySlice) or compatible. template - static void ForEachIndex(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, - const FnType& visitor_function) { + static Status ForEachIndexWithStatus(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function) { if (ShapeUtil::HasZeroElements(shape)) { - return; + return Status::OK(); } CHECK_EQ(Rank(shape), base.size()); CHECK_EQ(incr.size(), base.size()); @@ -583,7 +584,11 @@ class ShapeUtil { // once with the proper empty indexes. int64 n = -1; std::vector indexes(base.begin(), base.end()); - while (n < rank && visitor_function(indexes)) { + while (n < rank) { + TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes)); + if (!should_continue) { + break; + } // Increments dimensions in minor to major order. for (n = 0; n < rank; ++n) { int64 dim = LayoutUtil::Minor(shape.layout(), n); @@ -594,6 +599,21 @@ class ShapeUtil { indexes[dim] = base[dim]; } } + + return Status::OK(); + } + + template + static void ForEachIndex(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function) { + ForEachIndexWithStatus(shape, base, count, incr, + [&](tensorflow::gtl::ArraySlice indices) { + return StatusOr(visitor_function(indices)); + }) + .IgnoreError(); } private: diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4db97d45b20..a3574156983 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -573,10 +573,11 @@ TEST(ShapeUtilTest, ForEachIndex) { Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); // Increments at every invocation. int invocations = 0; - auto increment_func = [&invocations](const std::vector& indexes) { - invocations++; - return true; - }; + auto increment_func = + [&invocations](tensorflow::gtl::ArraySlice indexes) { + invocations++; + return true; + }; std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); @@ -588,6 +589,29 @@ TEST(ShapeUtilTest, ForEachIndex) { } } +TEST(ShapeUtilTest, ForEachIndexWithStatus) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); + // Increments at every invocation. + int invocations = 0; + auto increment_func = + [&invocations]( + tensorflow::gtl::ArraySlice indexes) -> StatusOr { + if (++invocations == 5) { + return Unimplemented("Cannot increment beyond 5."); + } + return true; + }; + + Status error_status = ShapeUtil::ForEachIndexWithStatus( + shape, /*base=*/{0, 0}, /*count=*/{10, 10}, /*incr=*/{0, 1}, + increment_func); + + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT(error_status.error_message(), + ::testing::HasSubstr("Cannot increment beyond 5.")); + EXPECT_EQ(invocations, 5); +} + TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one.