This is currently GPU-only, but will become useful for CPU soon. HloEvaluator for dot goes through Eigen, so add an overload for int32 to call into that too. PiperOrigin-RevId: 269338175
679 lines
27 KiB
C++
679 lines
27 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/reference_util.h"
|
|
|
|
#include <array>
|
|
#include <utility>
|
|
|
|
#include "absl/container/flat_hash_set.h"
|
|
#include "absl/memory/memory.h"
|
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
|
#include "tensorflow/compiler/xla/literal_util.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
|
#include "tensorflow/compiler/xla/window_util.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
#include "tensorflow/core/lib/math/math_util.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
|
|
namespace xla {
|
|
|
|
/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
|
|
const Array2D<float>& input) {
|
|
auto result =
|
|
absl::make_unique<Array2D<double>>(input.height(), input.width());
|
|
for (int64 rowno = 0; rowno < input.height(); ++rowno) {
|
|
for (int64 colno = 0; colno < input.height(); ++colno) {
|
|
(*result)(rowno, colno) = input(rowno, colno);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
|
|
const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
|
|
Padding padding) {
|
|
return ConvArray3DGeneralDimensionsDilated(
|
|
lhs, rhs, kernel_stride, padding, 1, 1,
|
|
XlaBuilder::CreateDefaultConvDimensionNumbers(1));
|
|
}
|
|
|
|
/*static*/ std::unique_ptr<Array3D<float>>
|
|
ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
|
|
const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
|
|
Padding padding, int64 lhs_dilation, int64 rhs_dilation,
|
|
const ConvolutionDimensionNumbers& dnums) {
|
|
CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
|
|
CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
|
|
CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
|
|
// Reuse the code for Array4D-convolution by extending the 3D input into a 4D
|
|
// array by adding a fourth dummy dimension of size 1 without stride, padding
|
|
// and dilation.
|
|
Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
|
|
a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
|
|
CHECK_EQ(indices[3], 0);
|
|
*value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
|
|
});
|
|
Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
|
|
a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
|
|
CHECK_EQ(indices[3], 0);
|
|
*value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
|
|
});
|
|
// Add a second dummy spatial dimensions.
|
|
ConvolutionDimensionNumbers dnums2d = dnums;
|
|
dnums2d.add_input_spatial_dimensions(3);
|
|
dnums2d.add_kernel_spatial_dimensions(3);
|
|
dnums2d.add_output_spatial_dimensions(3);
|
|
std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
|
|
a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
|
|
{rhs_dilation, 1}, dnums2d);
|
|
|
|
auto convr3 = absl::make_unique<Array3D<float>>(
|
|
convr4->planes(), convr4->depth(), convr4->height());
|
|
convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
|
|
CHECK_EQ(indices[3], 0);
|
|
convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
|
|
});
|
|
return convr3;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
|
|
const Array4D<float>& lhs, const Array4D<float>& rhs,
|
|
std::pair<int64, int64> kernel_stride, Padding padding) {
|
|
return ConvArray4DGeneralDimensions(
|
|
lhs, rhs, kernel_stride, padding,
|
|
XlaBuilder::CreateDefaultConvDimensionNumbers());
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>>
|
|
ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
|
|
const Array4D<float>& depthwise_weights,
|
|
const Array4D<float>& pointwise_weights,
|
|
std::pair<int64, int64> kernel_stride,
|
|
Padding padding) {
|
|
const int64 depth_multiplier = depthwise_weights.planes();
|
|
CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
|
|
|
|
// Combine the two weights by reducing the depth_multiplier, so that we can
|
|
// apply a single convolution on the combined weights.
|
|
Array4D<float> weights(pointwise_weights.planes(), input.depth(),
|
|
depthwise_weights.height(), depthwise_weights.width());
|
|
for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
|
|
for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
|
|
for (int64 kz = 0; kz < input.depth(); ++kz) {
|
|
for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
|
|
float weight = 0.0;
|
|
for (int64 depth = 0; depth < depth_multiplier; ++depth) {
|
|
weight +=
|
|
depthwise_weights(depth, kz, ky, kx) *
|
|
pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
|
|
}
|
|
weights(out, kz, ky, kx) = weight;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return ConvArray4D(input, weights, kernel_stride, padding);
|
|
}
|
|
|
|
/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
|
|
int64 window_len, int64 stride,
|
|
Padding padding) {
|
|
if (padding == Padding::kValid) {
|
|
return window_util::StridedBound(unpadded_width, window_len, stride);
|
|
}
|
|
return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
|
|
}
|
|
|
|
/* static */ std::unique_ptr<std::vector<float>>
|
|
ReferenceUtil::ReduceWindow1DGeneric(
|
|
absl::Span<const float> operand, float init,
|
|
const std::function<float(float, float)>& reduce_func,
|
|
absl::Span<const int64> window, absl::Span<const int64> stride,
|
|
absl::Span<const std::pair<int64, int64>> padding) {
|
|
CHECK_EQ(window.size(), 1);
|
|
CHECK_EQ(stride.size(), 1);
|
|
CHECK_EQ(padding.size(), 1);
|
|
|
|
int64 padded_width = padding[0].first + operand.size() + padding[0].second;
|
|
int64 stride_amount = stride[0];
|
|
int64 window_size = window[0];
|
|
int64 result_size =
|
|
window_util::StridedBound(padded_width, window_size, stride_amount);
|
|
int64 pad_low = padding[0].first;
|
|
auto result = absl::make_unique<std::vector<float>>(result_size);
|
|
|
|
// Do a full 1D reduce window.
|
|
for (int64 i0 = 0; i0 < result_size; ++i0) {
|
|
int64 i0_base = i0 * stride_amount - pad_low;
|
|
float val = init;
|
|
for (int64 i0_win = 0; i0_win < window_size; ++i0_win) {
|
|
if (i0_base + i0_win >= 0 && i0_base + i0_win < operand.size()) {
|
|
val = reduce_func(val, operand[i0_base + i0_win]);
|
|
}
|
|
}
|
|
(*result)[i0] = val;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<std::vector<float>>
|
|
ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
|
|
absl::Span<const int64> window,
|
|
absl::Span<const int64> stride,
|
|
Padding padding) {
|
|
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
|
|
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
|
|
return ReduceWindow1DGeneric(
|
|
operand, init, add_reduce, window, stride,
|
|
xla::MakePadding(dim_lengths, window, stride, padding));
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
|
|
const Array3D<float>& operand, float init, absl::Span<const int64> window,
|
|
absl::Span<const int64> stride, Padding padding) {
|
|
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
|
|
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
|
|
|
|
std::vector<int64> window_counts(window.size(), 0);
|
|
std::vector<int64> pad_low(window.size(), 0);
|
|
for (int64 i = 0; i < window.size(); ++i) {
|
|
window_counts[i] =
|
|
WindowCount(dim_lengths[i], window[i], stride[i], padding);
|
|
pad_low[i] = padding_both[i].first;
|
|
}
|
|
auto result = absl::make_unique<Array3D<float>>(
|
|
window_counts[0], window_counts[1], window_counts[2]);
|
|
|
|
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
|
|
for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
|
|
for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
|
|
int64 i0_base = i0 * stride[0] - pad_low[0];
|
|
int64 i1_base = i1 * stride[1] - pad_low[1];
|
|
int64 i2_base = i2 * stride[2] - pad_low[2];
|
|
|
|
float val = init;
|
|
for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
|
|
for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
|
|
for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
|
|
if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
|
|
i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
|
|
i1_base + i1_win < operand.n2() &&
|
|
i2_base + i2_win < operand.n3()) {
|
|
val += operand(i0_base + i0_win, i1_base + i1_win,
|
|
i2_base + i2_win);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
(*result)(i0, i1, i2) = val;
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>>
|
|
ReferenceUtil::ReduceWindow4DGeneric(
|
|
const Array4D<float>& operand, float init,
|
|
const std::function<float(float, float)>& reduce_func,
|
|
absl::Span<const int64> window, absl::Span<const int64> stride,
|
|
Padding padding) {
|
|
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
|
|
operand.n4()};
|
|
return ReduceWindow4DGeneric(
|
|
operand, init, reduce_func, window, stride,
|
|
xla::MakePadding(dim_lengths, window, stride, padding));
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>>
|
|
ReferenceUtil::ReduceWindow4DGeneric(
|
|
const Array4D<float>& operand, float init,
|
|
const std::function<float(float, float)>& reduce_func,
|
|
absl::Span<const int64> window, absl::Span<const int64> stride,
|
|
absl::Span<const std::pair<int64, int64>> padding) {
|
|
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
|
|
operand.n4()};
|
|
|
|
std::vector<int64> window_counts(window.size(), 0);
|
|
std::vector<int64> pad_low(window.size(), 0);
|
|
for (int64 i = 0; i < window.size(); ++i) {
|
|
int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
|
|
window_counts[i] =
|
|
window_util::StridedBound(padded_width, window[i], stride[i]);
|
|
pad_low[i] = padding[i].first;
|
|
}
|
|
auto result = absl::make_unique<Array4D<float>>(
|
|
window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
|
|
// Do a full 4D reduce window.
|
|
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
|
|
for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
|
|
for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
|
|
for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
|
|
int64 i0_base = i0 * stride[0] - pad_low[0];
|
|
int64 i1_base = i1 * stride[1] - pad_low[1];
|
|
int64 i2_base = i2 * stride[2] - pad_low[2];
|
|
int64 i3_base = i3 * stride[3] - pad_low[3];
|
|
|
|
float val = init;
|
|
for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
|
|
for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
|
|
for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
|
|
for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
|
|
if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
|
|
i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
|
|
i0_base + i0_win < operand.n1() &&
|
|
i1_base + i1_win < operand.n2() &&
|
|
i2_base + i2_win < operand.n3() &&
|
|
i3_base + i3_win < operand.n4()) {
|
|
val = reduce_func(
|
|
val, operand(i0_base + i0_win, i1_base + i1_win,
|
|
i2_base + i2_win, i3_base + i3_win));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
(*result)(i0, i1, i2, i3) = val;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
|
|
const Array4D<float>& operand, float init, absl::Span<const int64> window,
|
|
absl::Span<const int64> stride, Padding padding) {
|
|
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
|
|
return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
|
|
padding);
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
|
|
const Array4D<float>& input, const Array4D<float>& mean,
|
|
const Array4D<float>& var, const Array4D<float>& scale,
|
|
const Array4D<float>& offset, float epsilon) {
|
|
auto normalized =
|
|
*MapArray4D(input, mean, [](float a, float b) { return a - b; });
|
|
normalized = *MapArray4D(normalized, var, [&](float a, float b) {
|
|
return a / std::sqrt(b + epsilon);
|
|
});
|
|
normalized =
|
|
*MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
|
|
return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>>
|
|
ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
|
|
const Array4D<float>& source,
|
|
float init,
|
|
absl::Span<const int64> window,
|
|
absl::Span<const int64> stride,
|
|
bool same_padding) {
|
|
Padding padding = same_padding ? Padding::kSame : Padding::kValid;
|
|
auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
|
|
operand.n3(), operand.n4());
|
|
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
|
|
operand.n4()};
|
|
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
|
|
// Fill the output, with the initial value.
|
|
result->Fill(init);
|
|
|
|
std::vector<int64> window_counts(window.size(), 0);
|
|
std::vector<int64> pad_low(window.size(), 0);
|
|
for (int64 i = 0; i < window.size(); ++i) {
|
|
window_counts[i] =
|
|
WindowCount(dim_lengths[i], window[i], stride[i], padding);
|
|
pad_low[i] = padding_both[i].first;
|
|
}
|
|
CHECK_EQ(window_counts[0], source.n1());
|
|
CHECK_EQ(window_counts[1], source.n2());
|
|
CHECK_EQ(window_counts[2], source.n3());
|
|
CHECK_EQ(window_counts[3], source.n4());
|
|
|
|
// Do a full 4D select and Scatter.
|
|
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
|
|
for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
|
|
for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
|
|
for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
|
|
// Now we are inside a window and need to find the max and the argmax.
|
|
int64 i0_base = i0 * stride[0] - pad_low[0];
|
|
int64 i1_base = i1 * stride[1] - pad_low[1];
|
|
int64 i2_base = i2 * stride[2] - pad_low[2];
|
|
int64 i3_base = i3 * stride[3] - pad_low[3];
|
|
int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
|
|
int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
|
|
int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
|
|
int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
|
|
float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
|
|
for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
|
|
for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
|
|
for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
|
|
for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
|
|
if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
|
|
i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
|
|
i0_base + i0_win < operand.n1() &&
|
|
i1_base + i1_win < operand.n2() &&
|
|
i2_base + i2_win < operand.n3() &&
|
|
i3_base + i3_win < operand.n4()) {
|
|
float tmp = operand(i0_base + i0_win, i1_base + i1_win,
|
|
i2_base + i2_win, i3_base + i3_win);
|
|
if (tmp > val) {
|
|
val = tmp;
|
|
scatter_0 = i0_base + i0_win;
|
|
scatter_1 = i1_base + i1_win;
|
|
scatter_2 = i2_base + i2_win;
|
|
scatter_3 = i3_base + i3_win;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
(*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
|
|
source(i0, i1, i2, i3);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>>
|
|
ReferenceUtil::ConvArray4DGeneralDimensions(
|
|
const Array4D<float>& lhs, const Array4D<float>& rhs,
|
|
std::pair<int64, int64> kernel_stride, Padding padding,
|
|
ConvolutionDimensionNumbers dimension_numbers) {
|
|
return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
|
|
{1, 1}, {1, 1},
|
|
std::move(dimension_numbers));
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>>
|
|
ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
|
const Array4D<float>& lhs, const Array4D<float>& rhs,
|
|
std::pair<int64, int64> kernel_stride, Padding padding,
|
|
std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
|
|
ConvolutionDimensionNumbers dnums) {
|
|
HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
|
|
auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
|
|
auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
|
|
|
|
std::array<int64, 2> ordered_kernel_strides;
|
|
std::array<int64, 2> ordered_input_dimensions;
|
|
std::array<int64, 2> ordered_kernel_dimensions;
|
|
if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
|
|
ordered_kernel_strides[0] = kernel_stride.second;
|
|
ordered_kernel_strides[1] = kernel_stride.first;
|
|
} else {
|
|
ordered_kernel_strides[0] = kernel_stride.first;
|
|
ordered_kernel_strides[1] = kernel_stride.second;
|
|
}
|
|
|
|
ordered_input_dimensions[0] =
|
|
lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
|
|
ordered_input_dimensions[1] =
|
|
lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
|
|
ordered_kernel_dimensions[0] =
|
|
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
|
|
ordered_kernel_dimensions[1] =
|
|
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
|
|
|
|
std::vector<std::pair<int64, int64>> paddings =
|
|
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
|
|
ordered_kernel_strides, padding);
|
|
CHECK_EQ(paddings.size(), 2);
|
|
|
|
Window window;
|
|
|
|
WindowDimension dim;
|
|
dim.set_size(
|
|
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
|
|
dim.set_stride(kernel_stride.first);
|
|
dim.set_padding_low(paddings[0].first);
|
|
dim.set_padding_high(paddings[0].second);
|
|
dim.set_window_dilation(rhs_dilation.first);
|
|
dim.set_base_dilation(lhs_dilation.first);
|
|
*window.add_dimensions() = dim;
|
|
|
|
WindowDimension dim2;
|
|
dim2.set_size(
|
|
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
|
|
dim2.set_stride(kernel_stride.second);
|
|
dim2.set_padding_low(paddings[1].first);
|
|
dim2.set_padding_high(paddings[1].second);
|
|
dim2.set_window_dilation(rhs_dilation.second);
|
|
dim2.set_base_dilation(lhs_dilation.second);
|
|
*window.add_dimensions() = dim2;
|
|
|
|
const Shape& shape =
|
|
ShapeInference::InferConvolveShape(
|
|
lhs_literal.shape(), rhs_literal.shape(),
|
|
/*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums)
|
|
.ConsumeValueOrDie();
|
|
|
|
HloInstruction* lhs_instruction =
|
|
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
|
|
HloInstruction* rhs_instruction =
|
|
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
|
|
|
|
PrecisionConfig precision_config;
|
|
precision_config.mutable_operand_precision()->Resize(
|
|
/*new_size=*/2, PrecisionConfig::DEFAULT);
|
|
b.AddInstruction(HloInstruction::CreateConvolve(
|
|
shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, window, dnums, precision_config));
|
|
HloModuleConfig config;
|
|
HloModule module("ReferenceUtil", config);
|
|
auto computation = module.AddEntryComputation(b.Build());
|
|
|
|
HloEvaluator evaluator;
|
|
Literal result_literal =
|
|
evaluator.Evaluate(*computation, {}).ConsumeValueOrDie();
|
|
|
|
CHECK_EQ(result_literal.shape().rank(), 4);
|
|
auto result =
|
|
absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
|
|
result_literal.shape().dimensions(1),
|
|
result_literal.shape().dimensions(2),
|
|
result_literal.shape().dimensions(3));
|
|
|
|
result->Each([&](absl::Span<const int64> indices, float* value) {
|
|
*value = result_literal.Get<float>(indices);
|
|
});
|
|
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<std::vector<float>>
|
|
ReferenceUtil::ReduceToColArray2D(
|
|
const Array2D<float>& matrix, float init,
|
|
const std::function<float(float, float)>& reduce_function) {
|
|
int64 rows = matrix.height();
|
|
int64 cols = matrix.width();
|
|
auto result = absl::make_unique<std::vector<float>>();
|
|
for (int64 i = 0; i < rows; ++i) {
|
|
float acc = init;
|
|
for (int64 j = 0; j < cols; ++j) {
|
|
acc = reduce_function(acc, matrix(i, j));
|
|
}
|
|
result->push_back(acc);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<std::vector<float>>
|
|
ReferenceUtil::ReduceToRowArray2D(
|
|
const Array2D<float>& matrix, float init,
|
|
const std::function<float(float, float)>& reduce_function) {
|
|
int64 rows = matrix.height();
|
|
int64 cols = matrix.width();
|
|
auto result = absl::make_unique<std::vector<float>>();
|
|
for (int64 i = 0; i < cols; ++i) {
|
|
float acc = init;
|
|
for (int64 j = 0; j < rows; ++j) {
|
|
acc = reduce_function(acc, matrix(j, i));
|
|
}
|
|
result->push_back(acc);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
|
|
const Array4D<float>& array, float init, absl::Span<const int64> dims,
|
|
const std::function<float(float, float)>& reduce_function) {
|
|
std::vector<float> result;
|
|
CHECK_EQ(dims.size(), 3);
|
|
const absl::flat_hash_set<int64> dim_set(dims.begin(), dims.end());
|
|
CHECK_EQ(dim_set.size(), 3);
|
|
for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
|
|
++a0) {
|
|
for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
|
|
++a1) {
|
|
for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
|
|
++a2) {
|
|
for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4());
|
|
++a3) {
|
|
float accumulator = init;
|
|
for (int64 i0 = 0;
|
|
i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
|
|
for (int64 i1 = 0;
|
|
i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
|
|
for (int64 i2 = 0;
|
|
i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
|
|
for (int64 i3 = 0;
|
|
i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
|
|
++i3) {
|
|
// Handle zero-sized arrays.
|
|
if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
|
|
array.n4() > 0) {
|
|
accumulator = reduce_function(
|
|
accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
result.push_back(accumulator);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
|
|
const std::vector<float>& array, const std::vector<int64>& bounds,
|
|
int64 broadcast_from_dim) {
|
|
auto result = absl::make_unique<Array4D<float>>(bounds[0], bounds[1],
|
|
bounds[2], bounds[3]);
|
|
for (int64 i = 0; i < result->n1(); ++i) {
|
|
for (int64 j = 0; j < result->n2(); ++j) {
|
|
for (int64 k = 0; k < result->n3(); ++k) {
|
|
for (int64 l = 0; l < result->n4(); ++l) {
|
|
switch (broadcast_from_dim) {
|
|
case 0:
|
|
(*result)(i, j, k, l) = array[i];
|
|
break;
|
|
case 1:
|
|
(*result)(i, j, k, l) = array[j];
|
|
break;
|
|
case 2:
|
|
(*result)(i, j, k, l) = array[k];
|
|
break;
|
|
case 3:
|
|
(*result)(i, j, k, l) = array[l];
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
|
|
const Array3D<float>& array, float init, absl::Span<const int64> dims,
|
|
const std::function<float(float, float)>& reduce_function) {
|
|
CHECK_EQ(dims.size(), 1);
|
|
int64 rows = dims[0] == 0 ? array.n2() : array.n1();
|
|
int64 cols = dims[0] == 2 ? array.n2() : array.n3();
|
|
auto result = absl::make_unique<Array2D<float>>(rows, cols);
|
|
result->Fill(init);
|
|
for (int i0 = 0; i0 < array.n1(); ++i0) {
|
|
for (int i1 = 0; i1 < array.n2(); ++i1) {
|
|
for (int i2 = 0; i2 < array.n3(); ++i2) {
|
|
int64 row = dims[0] == 0 ? i1 : i0;
|
|
int64 col = dims[0] == 2 ? i1 : i2;
|
|
(*result)(row, col) =
|
|
reduce_function((*result)(row, col), array(i0, i1, i2));
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
|
|
const Array2D<float>& matrix,
|
|
const std::function<float(float)>& map_function) {
|
|
int64 rows = matrix.height();
|
|
int64 cols = matrix.width();
|
|
auto result = absl::make_unique<Array2D<float>>(rows, cols);
|
|
for (int64 i = 0; i < rows; ++i) {
|
|
for (int64 j = 0; j < cols; ++j) {
|
|
(*result)(i, j) = map_function(matrix(i, j));
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
|
|
const Array2D<float>& lhs, const Array2D<float>& rhs,
|
|
const std::function<float(float, float)>& map_function) {
|
|
CHECK_EQ(lhs.height(), rhs.height());
|
|
CHECK_EQ(lhs.width(), rhs.width());
|
|
int64 rows = lhs.height();
|
|
int64 cols = rhs.width();
|
|
auto result = absl::make_unique<Array2D<float>>(rows, cols);
|
|
for (int64 i = 0; i < rows; ++i) {
|
|
for (int64 j = 0; j < cols; ++j) {
|
|
(*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
|
|
const Array2D<float>& matrix,
|
|
const std::function<float(float, int64, int64)>& map_function) {
|
|
int64 rows = matrix.height();
|
|
int64 cols = matrix.width();
|
|
auto result = absl::make_unique<Array2D<float>>(rows, cols);
|
|
for (int64 i = 0; i < rows; ++i) {
|
|
for (int64 j = 0; j < cols; ++j) {
|
|
(*result)(i, j) = map_function(matrix(i, j), i, j);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
} // namespace xla
|