[XLA] Use HloEvaluator for convolution in reference_util.

Also Speed up HloEvaluator's HandleConvolution in non-opt build, by moving calls
to HloInstruction::shape() out of the inner loop.

PiperOrigin-RevId: 163416183
This commit is contained in:
Kay Zhu 2017-07-27 17:59:23 -07:00 committed by TensorFlower Gardener
parent 569a00e681
commit 253bcbb71b
6 changed files with 193 additions and 281 deletions

View File

@ -563,6 +563,9 @@ cc_library(
":xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_evaluator",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/core:lib",
],

View File

@ -470,10 +470,11 @@ class Literal {
// Populates literal values by calling the generator function for every cell
// in this literal object.
template <typename NativeT>
Status Populate(
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
generator);
//
// generator must be a callable of the type
// NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
template <typename NativeT, typename FnType>
Status Populate(const FnType& generator);
// Creates a Literal of the given dimensions with all elements set to the
// given value.
@ -1107,12 +1108,10 @@ void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
}
template <typename NativeT>
Status Literal::Populate(
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
generator) {
template <typename NativeT, typename FnType>
Status Literal::Populate(const FnType& generator) {
const Shape& this_shape = shape();
int64 rank = ShapeUtil::Rank(this_shape);
const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(this_shape.element_type() ==
primitive_util::NativeToPrimitiveType<NativeT>());
tensorflow::gtl::MutableArraySlice<NativeT> data =
@ -1125,7 +1124,7 @@ Status Literal::Populate(
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
auto init_function = [&](const std::vector<int64>& indexes) {
int64 index = LinearIndex(indexes);
const int64 index = LinearIndex(indexes);
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
for (int64 i = 0; i < minor_dimension_size; ++i) {
minor_scan_indexes[stride_config.minor_dimension] = i;

View File

@ -20,6 +20,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/math/math_util.h"
@ -446,179 +449,85 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
std::pair<int64, int64> kernel_stride, Padding padding,
std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
ConvolutionDimensionNumbers dnums) {
std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}};
std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}};
HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
const int64 ksy = kernel_stride.first;
const int64 ksx = kernel_stride.second;
const int64 dy = lhs_dilation.first;
const int64 dx = lhs_dilation.second;
const int64 dky = rhs_dilation.first;
const int64 dkx = rhs_dilation.second;
CHECK_GE(dky, 1);
CHECK_GE(dkx, 1);
CHECK_GE(dy, 1);
CHECK_GE(dx, 1);
// Get all dimension sizes in lhs and rhs based on the given convolution
// dimension configuration.
const int64 ix = window_util::DilatedBound(
lhs_dimensions[dnums.spatial_dimensions(1)], dx);
const int64 iy = window_util::DilatedBound(
lhs_dimensions[dnums.spatial_dimensions(0)], dy);
const int64 iz = lhs_dimensions[dnums.feature_dimension()];
const int64 samples = lhs_dimensions[dnums.batch_dimension()];
const int64 kx = window_util::DilatedBound(
rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx);
const int64 ky = window_util::DilatedBound(
rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky);
const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()];
{
const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()];
CHECK_EQ(kiz, iz);
std::array<int64, 2> ordered_kernel_strides;
std::array<int64, 2> ordered_input_dimensions;
std::array<int64, 2> ordered_kernel_dimensions;
if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
ordered_kernel_strides[0] = kernel_stride.second;
ordered_kernel_strides[1] = kernel_stride.first;
} else {
ordered_kernel_strides[0] = kernel_stride.first;
ordered_kernel_strides[1] = kernel_stride.second;
}
if (padding == Padding::kSame) {
// We reject same padding with kernel striding, since it's somewhat
// nonsensical. We can always follow up to implement this with the desired
// semantics if anybody actually uses it.
CHECK_EQ(1, ksy);
CHECK_EQ(1, ksx);
}
ordered_input_dimensions[0] =
lhs_literal->shape().dimensions(dnums.spatial_dimensions(0));
ordered_input_dimensions[1] =
lhs_literal->shape().dimensions(dnums.spatial_dimensions(1));
ordered_kernel_dimensions[0] =
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
ordered_kernel_dimensions[1] =
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
const int64 ox =
padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx);
const int64 oy =
padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy);
const int64 istartx =
padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2;
const int64 istarty =
padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2;
// Create the output result array and reset the values to 0.
std::array<int64, 4> result_dimensions;
result_dimensions[dnums.batch_dimension()] = samples;
result_dimensions[dnums.feature_dimension()] = oz;
result_dimensions[dnums.spatial_dimensions(0)] = oy;
result_dimensions[dnums.spatial_dimensions(1)] = ox;
std::vector<std::pair<int64, int64>> paddings =
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
ordered_kernel_strides, padding);
CHECK_EQ(paddings.size(), 2);
Window window;
WindowDimension dim;
dim.set_size(
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
dim.set_stride(kernel_stride.first);
dim.set_padding_low(paddings[0].first);
dim.set_padding_high(paddings[0].second);
dim.set_window_dilation(rhs_dilation.first);
dim.set_base_dilation(lhs_dilation.first);
*window.add_dimensions() = dim;
WindowDimension dim2;
dim2.set_size(
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
dim2.set_stride(kernel_stride.second);
dim2.set_padding_low(paddings[1].first);
dim2.set_padding_high(paddings[1].second);
dim2.set_window_dilation(rhs_dilation.second);
dim2.set_base_dilation(lhs_dilation.second);
*window.add_dimensions() = dim2;
const Shape& shape =
ShapeInference::InferConvolveShape(lhs_literal->shape(),
rhs_literal->shape(), window, dnums)
.ConsumeValueOrDie();
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
b.AddInstruction(HloInstruction::CreateConvolve(
shape, lhs_instruction, rhs_instruction, window, dnums));
HloEvaluator evaluator;
std::unique_ptr<Literal> result_literal =
evaluator.Evaluate(*b.Build(), {}).ConsumeValueOrDie();
CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
auto result =
MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1],
result_dimensions[2], result_dimensions[3]);
result->Fill(0.0);
MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0),
result_literal->shape().dimensions(1),
result_literal->shape().dimensions(2),
result_literal->shape().dimensions(3));
const auto is_int32 = [](int64 x) {
return x >= std::numeric_limits<int32>::min() &&
x <= std::numeric_limits<int32>::max();
};
result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
*value = result_literal->Get<float>(indices);
});
// 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at
// least on x86-64), so we avoid them where possible.
const auto fast_idiv64 = [&](int64 a, int64 b) {
if (is_int32(a) && is_int32(b)) {
return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b));
}
return a / b;
};
const auto fast_imod64 = [&](int64 a, int64 b) {
if (is_int32(a) && is_int32(b)) {
return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b));
}
return a % b;
};
// Lambda to access the lhs operand at the given 4D index.
const auto lhs_element = [&](int64 batch, int64 feature, int64 height,
int64 width) {
if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) {
return 0.0f;
}
std::array<int64, 4> index;
index[dnums.batch_dimension()] = batch;
index[dnums.feature_dimension()] = feature;
index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy);
index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx);
return lhs(index[0], index[1], index[2], index[3]);
};
// Lambda to access the rhs operand at the given 4D index. height_over_dky
// should be equal to height / dky, and width_over_dkx should be equal to
// width / dkx. (This is an optimization to avoid doing divisions.)
const auto rhs_element =
[&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height,
int64 width, int64 height_over_dky, int64 width_over_dkx) {
DCHECK_EQ(height % dky, 0);
DCHECK_EQ(width % dkx, 0);
DCHECK_EQ(height / dky, height_over_dky);
DCHECK_EQ(width / dkx, width_over_dkx);
std::array<int64, 4> index;
index[dnums.kernel_output_feature_dimension()] = kernel_output_feature;
index[dnums.kernel_input_feature_dimension()] = kernel_input_feature;
index[dnums.kernel_spatial_dimensions(0)] = height_over_dky;
index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx;
return rhs(index[0], index[1], index[2], index[3]);
};
// Lambda to access the result data at the given 4D index.
const auto result_element = [&](int64 batch, int64 kernel_output_feature,
int64 height, int64 width) -> float& {
std::array<int64, 4> index;
index[dnums.batch_dimension()] = batch;
index[dnums.feature_dimension()] = kernel_output_feature;
index[dnums.spatial_dimensions(0)] = height;
index[dnums.spatial_dimensions(1)] = width;
return (*result)(index[0], index[1], index[2], index[3]);
};
for (int64 oyi = 0; oyi < oy; ++oyi) {
for (int64 oxi = 0; oxi < ox; ++oxi) {
for (int64 sample = 0; sample < samples; ++sample) {
for (int64 izi = 0; izi < iz; ++izi) {
for (int64 ozi = 0; ozi < oz; ++ozi) {
for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky;
kyi += dky, kyi_over_dky++) {
for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx;
kxi += dkx, kxi_over_dkx++) {
int64 iyi = istarty + ksy * oyi + kyi;
int64 ixi = istartx + ksx * oxi + kxi;
float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0)
? 0.0
: lhs_element(sample, izi, iyi, ixi);
float gain =
rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx);
float addend = input * gain;
result_element(sample, ozi, oyi, oxi) += addend;
}
}
}
}
}
}
}
if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 ||
iz == 0) {
LOG(INFO) << "Output will be trivially empty because one of these "
"dimensions is 0: samples: "
<< samples << " kx: " << kx << " ky: " << ky << " ox: " << ox
<< " oy: " << oy << " oz: " << oz << " iz: " << iz;
return result;
}
bool trivial = true;
auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice<int64> indices,
float value) {
if (value != 0.0) {
trivial = false;
}
};
lhs.Each(check_trivial);
if (trivial) {
LOG(FATAL) << "LHS is all 0.0.";
}
trivial = true;
rhs.Each(check_trivial);
if (trivial) {
LOG(FATAL) << "RHS is all 0.0.";
}
return result;
}

View File

@ -404,12 +404,16 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
HloInstruction* rhs, const Window& window) override {
CHECK(ShapeUtil::IsArray(lhs->shape()));
CHECK(ShapeUtil::IsArray(rhs->shape()));
CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
CHECK(ShapeUtil::SameElementType(lhs->shape(), conv->shape()));
TF_CHECK_OK(ShapeUtil::ValidateShape(lhs->shape()));
TF_CHECK_OK(ShapeUtil::ValidateShape(rhs->shape()));
const Shape& result_shape = conv->shape();
const Shape& lhs_shape = lhs->shape();
const Shape& rhs_shape = rhs->shape();
TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape));
TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape));
CHECK(ShapeUtil::IsArray(lhs_shape));
CHECK(ShapeUtil::IsArray(rhs_shape));
CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape));
CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape));
const auto& dnums = conv->convolution_dimension_numbers();
const int64 num_spatial_dims = dnums.spatial_dimensions_size();
@ -417,23 +421,23 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
CHECK_GE(num_spatial_dims, 1);
CHECK_EQ(window.dimensions_size(), num_spatial_dims);
CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(lhs->shape()));
CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(rhs->shape()));
const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
const auto rhs_rank = ShapeUtil::Rank(rhs_shape);
CHECK_EQ(num_spatial_dims + 2, lhs_rank);
CHECK_EQ(num_spatial_dims + 2, rhs_rank);
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferConvolveShape(
lhs->shape(), rhs->shape(), window, dnums));
CHECK(ShapeUtil::Compatible(conv->shape(), inferred_return_shape))
<< "return shape set to: " << ShapeUtil::HumanString(conv->shape())
ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
window, dnums));
CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
// Dimension number applicable for both input (lhs), and output.
const int64 batch_dim = dnums.batch_dimension();
const int64 z_dim = dnums.feature_dimension();
@ -441,78 +445,78 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
const int64 z_size = ShapeUtil::GetDimension(lhs->shape(), z_dim);
const int64 z_size = ShapeUtil::GetDimension(lhs_shape, z_dim);
std::vector<int64> window_dimension_sizes;
for (auto i : dnums.kernel_spatial_dimensions()) {
window_dimension_sizes.push_back(
ShapeUtil::GetDimension(rhs->shape(), i));
window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
}
const Shape& window_shape = ShapeUtil::MakeShape(
rhs->shape().element_type(), window_dimension_sizes);
const Shape& window_shape =
ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
auto result = Literal::CreateFromShape(conv->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> out_index) {
ReturnT result_val = static_cast<ReturnT>(0);
DimensionVector lhs_index(lhs_rank);
DimensionVector rhs_index(rhs_rank);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
std::vector<int64> lhs_index(lhs_rank, 0);
std::vector<int64> rhs_index(rhs_rank, 0);
auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
ReturnT result_val = static_cast<ReturnT>(0);
lhs_index[batch_dim] = out_index[batch_dim];
rhs_index[kernel_output_z_dim] = out_index[z_dim];
std::fill(lhs_index.begin(), lhs_index.end(), 0);
std::fill(rhs_index.begin(), rhs_index.end(), 0);
std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
std::vector<int64> rhs_spatial_index(
dnums.kernel_spatial_dimensions_size(), 0);
lhs_index[batch_dim] = out_index[batch_dim];
rhs_index[kernel_output_z_dim] = out_index[z_dim];
// Convolve input feature with kernel.
do {
for (int64 iz = 0; iz < z_size; ++iz) {
lhs_index[z_dim] = iz;
rhs_index[kernel_input_z_dim] = iz;
// Convolve input feature with kernel.
do {
for (int64 iz = 0; iz < z_size; ++iz) {
lhs_index[z_dim] = iz;
rhs_index[kernel_input_z_dim] = iz;
// Find corresponding spatial dimension index for input (lhs).
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
// Spatial dimension number for input (lhs) and output.
const int64 spatial_dim = dnums.spatial_dimensions(ki);
// Find corresponding spatial dimension index for input (lhs).
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
// Spatial dimension number for input (lhs) and output.
const int64 spatial_dim = dnums.spatial_dimensions(ki);
// Calculate lhs (input) index without taking base dilation into
// account.
const int64 undilated_index =
out_index[spatial_dim] * window.dimensions(ki).stride() -
window.dimensions(ki).padding_low() +
rhs_spatial_index[ki] *
window.dimensions(ki).window_dilation();
// Skip if the lhs (input) index is to be dilated.
if (undilated_index % window.dimensions(ki).base_dilation() !=
0) {
goto cnt;
}
// Calculate the actual lhs (input) index after dilation.
lhs_index[spatial_dim] =
undilated_index / window.dimensions(ki).base_dilation();
// Skip if input index is not in bound.
if (!(lhs_index[spatial_dim] >= 0 &&
lhs_index[spatial_dim] <
lhs->shape().dimensions(spatial_dim))) {
goto cnt;
}
rhs_index[dnums.kernel_spatial_dimensions(ki)] =
rhs_spatial_index[ki];
}
result_val += lhs_literal.Get<ReturnT>(lhs_index) *
rhs_literal.Get<ReturnT>(rhs_index);
// Calculate lhs (input) index without taking base dilation into
// account.
const auto& window_dim = window.dimensions(ki);
const int64 undilated_index =
out_index[spatial_dim] * window_dim.stride() -
window_dim.padding_low() +
rhs_spatial_index[ki] * window_dim.window_dilation();
// Skip if the lhs (input) index is to be dilated.
if (undilated_index % window_dim.base_dilation() != 0) {
goto cnt;
}
cnt:;
} while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
return result_val;
}));
// Calculate the actual lhs (input) index after dilation.
lhs_index[spatial_dim] =
undilated_index / window_dim.base_dilation();
// Skip if input index is not in bound.
if (!(lhs_index[spatial_dim] >= 0 &&
lhs_index[spatial_dim] < lhs_shape.dimensions(spatial_dim))) {
goto cnt;
}
rhs_index[dnums.kernel_spatial_dimensions(ki)] =
rhs_spatial_index[ki];
}
result_val += lhs_literal.Get<ReturnT>(lhs_index) *
rhs_literal.Get<ReturnT>(rhs_index);
}
cnt:;
} while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
return result_val;
};
auto result = Literal::CreateFromShape(result_shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();

View File

@ -1218,34 +1218,4 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
return shape;
}
/* static */ void ShapeUtil::ForEachIndex(
const Shape& shape, tensorflow::gtl::ArraySlice<int64> base,
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const IndexVisitorFunction& visitor_function) {
if (ShapeUtil::HasZeroElements(shape)) {
return;
}
DCHECK_EQ(Rank(shape), base.size());
DCHECK_EQ(incr.size(), base.size());
DCHECK_EQ(count.size(), base.size());
const Layout& layout = shape.layout();
int64 rank = layout.minor_to_major_size();
// Allows handling R0 arrays, such that the visitor function will be called
// once with the proper empty indexes.
int64 n = -1;
std::vector<int64> indexes(base.begin(), base.end());
while (n < rank && visitor_function(indexes)) {
// Increments dimensions in minor to major order.
for (n = 0; n < rank; ++n) {
int64 dim = layout.minor_to_major(n);
indexes[dim] += incr[dim];
if (indexes[dim] < base[dim] + count[dim]) {
break;
}
indexes[dim] = base[dim];
}
}
}
} // namespace xla

View File

@ -421,12 +421,39 @@ class ShapeUtil {
// current index.
// The visitor_function visitor function should return true if it wants to
// continue, or false otherwise.
using IndexVisitorFunction = std::function<bool(const std::vector<int64>&)>;
//
// visitor_function must be a callable of type bool(const std::vector<int64>&)
// or compatible.
template <typename FnType>
static void ForEachIndex(const Shape& shape,
tensorflow::gtl::ArraySlice<int64> base,
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const IndexVisitorFunction& visitor_function);
const FnType& visitor_function) {
if (ShapeUtil::HasZeroElements(shape)) {
return;
}
CHECK_EQ(Rank(shape), base.size());
CHECK_EQ(incr.size(), base.size());
CHECK_EQ(count.size(), base.size());
const Layout& layout = shape.layout();
const int64 rank = layout.minor_to_major_size();
// Allows handling R0 arrays, such that the visitor function will be called
// once with the proper empty indexes.
int64 n = -1;
std::vector<int64> indexes(base.begin(), base.end());
while (n < rank && visitor_function(indexes)) {
// Increments dimensions in minor to major order.
for (n = 0; n < rank; ++n) {
int64 dim = layout.minor_to_major(n);
indexes[dim] += incr[dim];
if (indexes[dim] < base[dim] + count[dim]) {
break;
}
indexes[dim] = base[dim];
}
}
}
private:
// Validates all of the non-layout properties of the shape -- this is a helper