[XLA] Adds additional reference utility for R2 windowed reduction.
Change: 146889081
This commit is contained in:
parent
81430dc9bd
commit
6bbbd7e9d2
@ -134,6 +134,44 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
|
||||
return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
|
||||
const Array2D<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
|
||||
std::vector<int64> dim_lengths{operand.height(), operand.width()};
|
||||
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
|
||||
|
||||
std::vector<int64> window_counts(window.size(), 0);
|
||||
std::vector<int64> pad_low(window.size(), 0);
|
||||
for (int64 i = 0; i < window.size(); ++i) {
|
||||
window_counts[i] =
|
||||
WindowCount(dim_lengths[i], window[i], stride[i], padding);
|
||||
pad_low[i] = padding_both[i].first;
|
||||
}
|
||||
auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]);
|
||||
|
||||
// Do a full 2D reduce window.
|
||||
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
|
||||
for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
|
||||
int64 i0_base = i0 * stride[0] - pad_low[0];
|
||||
int64 i1_base = i1 * stride[1] - pad_low[1];
|
||||
|
||||
float val = init;
|
||||
for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
|
||||
for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
|
||||
if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
|
||||
i0_base + i0_win < operand.n1() &&
|
||||
i1_base + i1_win < operand.n2()) {
|
||||
val += operand(i0_base + i0_win, i1_base + i1_win);
|
||||
}
|
||||
}
|
||||
}
|
||||
(*result)(i0, i1) = val;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
|
||||
const Array4D<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
|
@ -144,6 +144,12 @@ class ReferenceUtil {
|
||||
static int64 WindowCount(int64 unpadded_width, int64 window_len, int64 stride,
|
||||
Padding padding);
|
||||
|
||||
// Performs a 2D window reduction with Add as the function to apply.
|
||||
static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
|
||||
const Array2D<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
|
||||
|
||||
// Performs a 4D window reduction with Add as the function to apply.
|
||||
static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
|
||||
const Array4D<float>& operand, float init,
|
||||
|
Loading…
Reference in New Issue
Block a user