New implementation of the transposed convolution kernel which is faster and easier on the eyes.

depth.tflite runs 7% faster (from 96ms to 89ms) on a Pixel 3

PiperOrigin-RevId: 260017515
This commit is contained in:
A. Unique TensorFlower 2019-07-25 14:09:14 -07:00 committed by TensorFlower Gardener
parent 25cd82af1e
commit ce08d5a7ad
3 changed files with 46 additions and 43 deletions

View File

@ -29,21 +29,9 @@ constexpr int kPhwc4ChannelsInPlane = 4;
constexpr int kPhwo4i4ChannelsInPlane = 4;
constexpr int kPiohw4ChannelsInPlane = 4;
} // namespace
uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape) {
return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
}
uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape) {
return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
}
// Layout is Po,H,W,OI4x4.
Status ConvertToPHWO4I4(absl::Span<const float> in, const OHWI& shape,
absl::Span<float> out) {
absl::Span<float> out, bool reverse_space) {
if (in.size() != shape.DimensionsProduct()) {
return InvalidArgumentError(absl::StrCat(
"ConvertToPHWO4I4: Input data size does not match expected size: ",
@ -70,7 +58,9 @@ Status ConvertToPHWO4I4(absl::Span<const float> in, const OHWI& shape,
// tensor is in OHWI
int tensor_o = p * kPhwo4i4ChannelsInPlane + co;
int tensor_i = c * kPhwo4i4ChannelsInPlane + ci;
value = in[shape.LinearIndex({tensor_o, h, w, tensor_i})];
const int in_h = reverse_space ? shape.h - 1 - h : h;
const int in_w = reverse_space ? shape.w - 1 - w : w;
value = in[shape.LinearIndex({tensor_o, in_h, in_w, tensor_i})];
}
(*output++) = value;
}
@ -82,11 +72,34 @@ Status ConvertToPHWO4I4(absl::Span<const float> in, const OHWI& shape,
return OkStatus();
}
} // namespace
uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape) {
return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
}
uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape) {
return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
}
std::vector<float> ConvertToPHWO4I4(
const Tensor<OHWI, DataType::FLOAT32>& tensor) {
std::vector<float> transposed(GetElementsSizeForPHWO4I4(tensor.shape));
ConvertToPHWO4I4(tensor.data, tensor.shape,
absl::MakeSpan(transposed.data(), transposed.size()))
absl::MakeSpan(transposed.data(), transposed.size()),
/*reverse_space=*/false)
.IgnoreError();
return transposed;
}
std::vector<float> ConvertToPHWO4I4Transposed(
const Tensor<OHWI, DataType::FLOAT32>& tensor) {
std::vector<float> transposed(GetElementsSizeForPHWO4I4(tensor.shape));
ConvertToPHWO4I4(tensor.data, tensor.shape,
absl::MakeSpan(transposed.data(), transposed.size()),
/*reverse_space=*/true)
.IgnoreError();
return transposed;
}

View File

@ -63,14 +63,14 @@ std::vector<float> ConvertToPIOHW4(
// @return number of elements when shape is converted into PHWO4I4.
uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape);
// Layout is Po,H,W,OI4x4.
Status ConvertToPHWO4I4(absl::Span<const float> in, const OHWI& shape,
absl::Span<float> out);
// Convenience wrapper around a method above.
std::vector<float> ConvertToPHWO4I4(
const Tensor<OHWI, DataType::FLOAT32>& tensor);
// Convenience wrapper around a method above, for Transposed Convolution.
std::vector<float> ConvertToPHWO4I4Transposed(
const Tensor<OHWI, DataType::FLOAT32>& tensor);
// @return (x,y,z) size for PHWO4I4 to access elements where each element
// consists of 4 values.
uint3 Get3DSizeForPHWO4I4(const OHWI& shape);

View File

@ -41,8 +41,6 @@ class ConvolutionTransposedBuffers : public NodeShader {
auto attr = absl::any_cast<const ConvolutionTransposedAttributes&>(
ctx.node->operation.attributes);
auto weights = attr.weights.shape;
const int32_t inner_size_w = (weights.w - 1) / attr.stride.w + 1;
const int32_t inner_size_h = (weights.h - 1) / attr.stride.h + 1;
std::vector<Variable> parameters = {
{"input_data_0_h", input->tensor.shape.h},
@ -50,33 +48,25 @@ class ConvolutionTransposedBuffers : public NodeShader {
{"src_depth", IntegralDivideRoundUp(weights.i, 4)},
{"kernel_size", int2(weights.w, weights.h)},
{"stride", int2(attr.stride.w, attr.stride.h)},
{"padding", int2(attr.padding.prepended.w, attr.padding.prepended.h)},
{"inner_size", int2(inner_size_w, inner_size_h)},
{"padding", int2(weights.w - 1 - attr.padding.prepended.w,
weights.h - 1 - attr.padding.prepended.h)},
};
std::vector<std::pair<std::string, Object>> objects = {
{"weights", MakeReadonlyObject(Get3DSizeForPHWO4I4(attr.weights.shape),
ConvertToPHWO4I4(attr.weights))}};
{"weights",
MakeReadonlyObject(Get3DSizeForPHWO4I4(attr.weights.shape),
ConvertToPHWO4I4Transposed(attr.weights))}};
std::string source = R"(
ivec2 kernel_offset = $kernel_size$ - ivec2(1,1);
ivec2 offset = gid.xy + $padding$ - kernel_offset;
offset %= $stride$;
offset += $stride$;
offset %= $stride$;
ivec2 f_offset;
f_offset.x = offset.x == 0 ? 0 : ($stride.x$ - offset.x);
f_offset.y = offset.y == 0 ? 0 : ($stride.y$ - offset.y);
for (int ky = 0; ky < $inner_size.y$; ++ky) {
for (int kx = 0; kx < $inner_size.x$; ++kx) {
ivec2 index = ivec2(kx, ky) * $stride$ + f_offset;
bool inside_kernel = index.x < $kernel_size.x$ && index.y < $kernel_size.y$;
ivec2 coord = (gid.xy + index + $padding$ - kernel_offset) / $stride$;
bool outside = coord.x < 0 || coord.y < 0 ||
coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$;
if (inside_kernel && !outside) {
index = kernel_offset - index;
int i = index.y * $kernel_size.x$ + index.x;
#define IN_BOUNDS(p, p0, p1) (all(greaterThanEqual(p, p0)) && all(lessThan(p, p1)))
ivec2 p0 = ($padding$ + $stride$ - gid.xy % $stride$) % $stride$;
for (int y = p0.y; y < $kernel_size.y$; y += $stride.y$) {
for (int x = p0.x; x < $kernel_size.x$; x += $stride.x$) {
int i = y * $kernel_size.x$ + x;
ivec2 idx = gid.xy + ivec2(x, y) - $padding$;
if (IN_BOUNDS(idx, ivec2(0), ivec2($input_data_0_w$, $input_data_0_h$) * $stride$)) {
ivec2 coord = idx / $stride$;
for (int l = 0; l < $src_depth$; ++l) {
vec4 src_color = $input_data_0[coord.x, coord.y, l]$;
value_0.x += dot(src_color, $weights[l * 4 + 0, i, gid.z]$);