Optimize transpose_conv

PiperOrigin-RevId: 243561442
This commit is contained in:
Renjie Liu 2019-04-15 00:01:48 -07:00 committed by TensorFlower Gardener
parent 0c464c70ce
commit 5e52b70188
6 changed files with 324 additions and 92 deletions

View File

@ -46,9 +46,12 @@ typedef enum {
kTfLiteMirrorPaddingSymmetric, kTfLiteMirrorPaddingSymmetric,
} TfLiteMirrorPaddingMode; } TfLiteMirrorPaddingMode;
// TODO(b/130259536): We should move this out of builtin_op_data.
typedef struct { typedef struct {
int width; int width;
int height; int height;
int width_offset;
int height_offset;
} TfLitePaddingValues; } TfLitePaddingValues;
typedef struct { typedef struct {

View File

@ -6130,33 +6130,112 @@ void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
} }
} }
// Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
// order (height, width, depth), constructed from patches in 'col_data', which
// is required to be in storage order (out_height * out_width, filter_height,
// filter_width, in_depth). Implementation by Yangqing Jia (jiayq).
// Copied from //tensorflow/core/kernels/conv_grad_input_ops.cc
template <typename T>
void Col2im(const T* col_data, const int depth, const int height,
const int width, const int filter_h, const int filter_w,
const int pad_t, const int pad_l, const int pad_b, const int pad_r,
const int stride_h, const int stride_w, T* im_data) {
int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
int h_pad = -pad_t;
for (int h = 0; h < height_col; ++h) {
int w_pad = -pad_l;
for (int w = 0; w < width_col; ++w) {
T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
// TODO(andydavis) Vectorize this loop (if compiler does not).
for (int i = 0; i < depth; ++i) {
im_patch_data[i] += col_data[i];
}
}
im_patch_data += depth;
col_data += depth;
}
// Jump over remaining number of depth.
im_patch_data += depth * (width - filter_w);
}
w_pad += stride_w;
}
h_pad += stride_h;
}
}
// TransposeConvV2 expect the weights in HWOI order.
inline void TransposeConvV2(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& col2im_shape, float* col2im_data) {
gemmlowp::ScopedProfilingLabel label("TransposeConvV2");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
const int batch_size = input_shape.Dims(0);
TFLITE_DCHECK(col2im_data);
TFLITE_DCHECK(hwoi_ordered_filter_data);
const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_image_size = output_height * output_width;
const int input_depth =
MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
const int output_depth =
MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
const int input_offset = input_image_size * input_depth;
const int output_offset = output_image_size * output_depth;
const int filter_height = hwoi_ordered_filter_shape.Dims(0);
const int filter_width = hwoi_ordered_filter_shape.Dims(1);
const int padding_top = params.padding_values.height;
const int padding_bottom =
params.padding_values.height + params.padding_values.height_offset;
const int padding_left = params.padding_values.width;
const int padding_right =
params.padding_values.width + params.padding_values.width_offset;
const int stride_height = params.stride_height;
const int stride_width = params.stride_width;
const int hwoi_ordered_filter_total_size =
filter_height * filter_width * output_depth;
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
Matrix;
typedef Eigen::Map<Matrix> MatrixRef;
typedef Eigen::Map<const Matrix> ConstMatrixRef;
ConstMatrixRef hwoi_ordered_filter_matrix_map(
hwoi_ordered_filter_data, hwoi_ordered_filter_total_size, input_depth);
float* output_data_p = output_data;
tensor_utils::ZeroVector(output_data, output_offset * batch_size);
for (int i = 0; i < batch_size; ++i) {
ConstMatrixRef input_matrix_map(input_data + input_offset * i,
input_image_size, input_depth);
MatrixRef output_matrix_map(col2im_data, input_image_size,
hwoi_ordered_filter_total_size);
Gemm(input_matrix_map, hwoi_ordered_filter_matrix_map.transpose(),
&output_matrix_map);
Col2im(col2im_data, output_depth, output_height, output_width,
filter_height, filter_width, padding_top, padding_left,
padding_bottom, padding_right, stride_height, stride_width,
output_data_p);
output_data_p += output_offset;
}
}
// TODO(renjieliu): Investigate whether we need to keep this.
inline void TransposeConv( inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape, const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape, const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& output_shape, const float* filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) { float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeConv"); gemmlowp::ScopedProfilingLabel label("TransposeConv");
// The complexity of the reference implementation is input.flat_size() *
// filter.flat_size() / in_channel.
//
// While the complexity of im2col->gemm
// implmentation is batch * output_height * output_width *
// (filter.flat_size() / out_channel)^2 * out_channel.
//
// so if input.flat_size() * out_channel^2 is much smaller than
// output.flat_size() * filter.size() * in_channel we should fall back to the
// reference implementation.
//
// TODO(b/122331966): optimize the intuitive implementation.
const int out_channel = output_shape.Dims(3);
const int in_channel = input_shape.Dims(3);
if ((input_shape.FlatSize() * out_channel * out_channel * 4) <
(filter_shape.FlatSize() * output_shape.FlatSize() * in_channel)) {
reference_ops::TransposeConv(params, input_shape, input_data, filter_shape,
filter_data, output_shape, output_data,
im2col_shape, im2col_data);
return;
}
// Note we could use transposed weights with forward conv for unstrided // Note we could use transposed weights with forward conv for unstrided
// cases. But we are already getting good performance with this code as-is. // cases. But we are already getting good performance with this code as-is.
TFLITE_DCHECK(im2col_data); TFLITE_DCHECK(im2col_data);

View File

@ -28,6 +28,12 @@ enum class PaddingType : uint8 { kNone, kSame, kValid };
struct PaddingValues { struct PaddingValues {
int16 width; int16 width;
int16 height; int16 height;
// offset is used for calculating "remaining" padding, for example, `width`
// is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
// 1 + 1 = 2.
int16 width_offset;
// Same as width_offset except it's over the height dimension.
int16 height_offset;
}; };
// This enumeration allows for non-default formats for the weights array // This enumeration allows for non-default formats for the weights array

View File

@ -19,6 +19,7 @@ limitations under the License.
namespace tflite { namespace tflite {
// TODO(renjieliu): Migrate others to use ComputePaddingWithLeftover.
inline int ComputePadding(int stride, int dilation_rate, int in_size, inline int ComputePadding(int stride, int dilation_rate, int in_size,
int filter_size, int out_size) { int filter_size, int out_size) {
int effective_filter_size = (filter_size - 1) * dilation_rate + 1; int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
@ -26,6 +27,19 @@ inline int ComputePadding(int stride, int dilation_rate, int in_size,
return padding > 0 ? padding : 0; return padding > 0 ? padding : 0;
} }
// It's not guaranteed that padding is symmetric. It's important to keep
// offset for algorithms need all paddings.
inline int ComputePaddingWithOffset(int stride, int dilation_rate, int in_size,
int filter_size, int out_size,
int* offset) {
int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
int total_padding =
((out_size - 1) * stride + effective_filter_size - in_size);
total_padding = total_padding > 0 ? total_padding : 0;
*offset = total_padding % 2;
return total_padding / 2;
}
// Matching GetWindowedOutputSize in TensorFlow. // Matching GetWindowedOutputSize in TensorFlow.
inline int ComputeOutSize(TfLitePadding padding, int image_size, inline int ComputeOutSize(TfLitePadding padding, int image_size,
int filter_size, int stride) { int filter_size, int stride) {
@ -47,10 +61,13 @@ inline TfLitePaddingValues ComputePaddingHeightWidth(
ComputeOutSize(padding, in_height, filter_height, stride_height); ComputeOutSize(padding, in_height, filter_height, stride_height);
TfLitePaddingValues padding_values; TfLitePaddingValues padding_values;
padding_values.height = int offset = 0;
ComputePadding(stride_height, 1, in_height, filter_height, out_height); padding_values.height = ComputePaddingWithOffset(
padding_values.width = stride_height, 1, in_height, filter_height, out_height, &offset);
ComputePadding(stride_width, 1, in_width, filter_width, out_width); padding_values.height_offset = offset;
padding_values.width = ComputePaddingWithOffset(
stride_width, 1, in_width, filter_width, out_width, &offset);
padding_values.width_offset = offset;
return padding_values; return padding_values;
} }
} // namespace tflite } // namespace tflite

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/eigen_support.h" #include "tensorflow/lite/kernels/eigen_support.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/kernels/padding.h" #include "tensorflow/lite/kernels/padding.h"
@ -49,11 +50,22 @@ const int kTensorNotAllocated = -1;
struct OpData { struct OpData {
// IDs are the arbitrary identifiers used by TF Lite to identify and access // IDs are the arbitrary identifiers used by TF Lite to identify and access
// memory buffers. // memory buffers.
int im2col_id = kTensorNotAllocated; int col2im_id = kTensorNotAllocated;
int transposed_weights_id = kTensorNotAllocated;
int scratch_tensor_id = kTensorNotAllocated;
// im2col is the only temporary currently tracked, therefore always index 0. // col2im is the temporary tensor allocated and used in optimized path for
// If more temporaries are added, they should be properly tracked. // storing col2im data:gemm result for input_matrix x filter_matrix.
int32_t im2col_index = 0; int32_t col2im_index;
// TfLiteConverter will transpose weights from HWOI to OHWI order.
// In optimized path, we will transpose them back to HWOI, this temporary
// tensor is allocated for storing transposed weights.
int32_t transposed_weights_index;
// Scratch tensor is used in the quantized path for storing accumulation
// results.
int32_t scratch_tensor_index;
TfLitePaddingValues padding; TfLitePaddingValues padding;
// The scaling factor from input to output (aka the 'real multiplier') can // The scaling factor from input to output (aka the 'real multiplier') can
@ -66,17 +78,12 @@ struct OpData {
int32_t output_activation_min; int32_t output_activation_min;
int32_t output_activation_max; int32_t output_activation_max;
int scratch_tensor_index; bool has_col2im = false;
bool weights_are_transposed = false;
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to use as scratch space for im2col, and
// to carry information from Prepare() to Eval().
auto* data = new OpData; auto* data = new OpData;
// Populate scratch_tensor_index.
context->AddTensors(context, /*tensors_to_add=*/1,
&data->scratch_tensor_index);
eigen_support::IncrementUsageCounter(context); eigen_support::IncrementUsageCounter(context);
return data; return data;
} }
@ -104,46 +111,106 @@ TfLiteStatus ResizeTensor(TfLiteContext* context,
return context->ResizeTensor(context, tensor_to_resize, shape); return context->ResizeTensor(context, tensor_to_resize, shape);
} }
static TfLiteStatus AllocateIm2colTensorIfRequired(TfLiteContext* context, // Allocate temporary tensors if necessary.
template <KernelType kernel_type>
static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
TfLiteType input_type,
TfLiteType weights_type,
TfLiteNode* node) { TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
if (data->im2col_id == kTensorNotAllocated) { int temporaries_count = 0;
context->AddTensors(context, 1, &data->im2col_id);
context->tensors[data->im2col_id].type = kTfLiteFloat32; // Allocate col2im tensor. Currently it's only used for optimized kernels.
if (kernel_type == kGenericOptimized) {
if (data->col2im_id == kTensorNotAllocated) {
context->AddTensors(context, 1, &data->col2im_id);
}
data->col2im_index = temporaries_count;
data->has_col2im = true;
++temporaries_count;
}
// Allocate transposed_weights tensor. Currently it's only used for optimized
// float kernels.
if (kernel_type == kGenericOptimized && input_type == kTfLiteFloat32) {
if (data->transposed_weights_id == kTensorNotAllocated) {
context->AddTensors(context, 1, &data->transposed_weights_id);
}
data->transposed_weights_index = temporaries_count;
data->weights_are_transposed = true;
++temporaries_count;
}
// Allocate scratch buffer tensor for UInt8 inputs.
if (input_type == kTfLiteUInt8) {
if (data->scratch_tensor_id == kTensorNotAllocated) {
context->AddTensors(context, 1, &data->scratch_tensor_id);
}
data->scratch_tensor_index = temporaries_count;
++temporaries_count;
} }
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(1); node->temporaries = TfLiteIntArrayCreate(temporaries_count);
node->temporaries->data[data->im2col_index] = data->im2col_id;
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context, TfLiteStatus ResizeCol2ImTensor(TfLiteContext* context,
const TfLiteTensor* output_shape, const TfLiteTensor* output_shape,
const TfLiteTensor* weights, const TfLiteTensor* weights,
const TfLiteTensor* input, const TfLiteTensor* input,
TfLiteTensor* im2col) { TfLiteTensor* col2im) {
if (output_shape->type != kTfLiteInt32) { if (output_shape->type != kTfLiteInt32) {
context->ReportError(context, "im2col shape is %d, not int32.", context->ReportError(context, "col2im shape is %d, not int32.",
output_shape->type); output_shape->type);
return kTfLiteError; return kTfLiteError;
} }
TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4); TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4);
TfLiteIntArray* im2col_shape_array = TfLiteIntArrayCreate(4); TfLiteIntArray* col2im_shape_array = TfLiteIntArrayCreate(2);
im2col_shape_array->data[0] = output_shape->data.i32[0]; const RuntimeShape& input_shape = GetTensorShape(input);
im2col_shape_array->data[1] = output_shape->data.i32[1]; const RuntimeShape& weights_shape = GetTensorShape(weights);
im2col_shape_array->data[2] = output_shape->data.i32[2]; col2im_shape_array->data[0] = input_shape.Dims(1) * input_shape.Dims(2);
const int input_depth = SizeOfDimension(input, 3); col2im_shape_array->data[1] =
const int filter_width = SizeOfDimension(weights, 2); weights_shape.Dims(0) * weights_shape.Dims(1) * weights_shape.Dims(2);
const int filter_height = SizeOfDimension(weights, 1);
im2col_shape_array->data[3] = input_depth * filter_height * filter_width;
im2col->type = input->type; col2im->type = input->type;
im2col->allocation_type = kTfLiteDynamic; col2im->allocation_type = kTfLiteDynamic;
return context->ResizeTensor(context, im2col, im2col_shape_array); return context->ResizeTensor(context, col2im, col2im_shape_array);
} }
TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context,
const TfLiteTensor* weights,
TfLiteTensor* transposed_weights) {
TfLiteIntArray* transposed_weights_shape_array = TfLiteIntArrayCreate(4);
const RuntimeShape& input_shape = GetTensorShape(weights);
transposed_weights_shape_array->data[0] = input_shape.Dims(1);
transposed_weights_shape_array->data[1] = input_shape.Dims(2);
transposed_weights_shape_array->data[2] = input_shape.Dims(0);
transposed_weights_shape_array->data[3] = input_shape.Dims(3);
transposed_weights->type = weights->type;
transposed_weights->allocation_type = kTfLiteDynamic;
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, transposed_weights,
transposed_weights_shape_array));
// Transpose the weights from from OHWI order to HWOI order.
TransposeParams transpose_params;
transpose_params.perm_count = 4;
transpose_params.perm[0] = 1;
transpose_params.perm[1] = 2;
transpose_params.perm[2] = 0;
transpose_params.perm[3] = 3;
optimized_ops::Transpose(transpose_params, input_shape,
GetTensorData<float>(weights),
GetTensorShape(transposed_weights),
GetTensorData<float>(transposed_weights));
return kTfLiteOk;
}
template <KernelType kernel_type>
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
@ -151,18 +218,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
// Allocate Im2col Tensor
TF_LITE_ENSURE_STATUS(AllocateIm2colTensorIfRequired(context, node));
// Retrieve tensors // Retrieve tensors
const TfLiteTensor* output_shape = const TfLiteTensor* output_shape =
GetInput(context, node, kOutputShapeTensor); GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* im2col =
&context->tensors[node->temporaries->data[user_data->im2col_index]];
// Tensor sanity checks // Tensor sanity checks
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
@ -177,24 +238,50 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3), TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
SizeOfDimension(weights, 3)); SizeOfDimension(weights, 3));
// Allocate col2Im, transposed_weights & scratch Tensor.
TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired<kernel_type>(
context, input->type, weights->type, node));
OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* col2im = nullptr;
if (data->has_col2im) {
node->temporaries->data[data->col2im_index] = data->col2im_id;
col2im = GetTemporary(context, node, user_data->col2im_index);
}
if (!IsConstantTensor(output_shape)) { if (!IsConstantTensor(output_shape)) {
// Defer resizing until Eval(). // Defer resizing until Eval().
SetTensorToDynamic(output); SetTensorToDynamic(output);
SetTensorToDynamic(im2col); if (data->has_col2im) {
SetTensorToDynamic(col2im);
}
} else { } else {
TF_LITE_ENSURE_STATUS(ResizeTensor(context, output_shape, output)); TF_LITE_ENSURE_STATUS(ResizeTensor(context, output_shape, output));
if (data->has_col2im) {
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
ResizeIm2ColTensor(context, output_shape, weights, input, im2col)); ResizeCol2ImTensor(context, output_shape, weights, input, col2im));
}
}
if (data->weights_are_transposed) {
node->temporaries->data[data->transposed_weights_index] =
data->transposed_weights_id;
TfLiteTensor* transposed_weights =
GetTemporary(context, node, user_data->transposed_weights_index);
if (!IsConstantTensor(weights)) {
SetTensorToDynamic(transposed_weights);
} else {
ResizeAndTransposeWeights(context, weights, transposed_weights);
}
} }
if (input->type == kTfLiteUInt8) { if (input->type == kTfLiteUInt8) {
// Set up a scratch buffer tensor. node->temporaries->data[data->scratch_tensor_index] =
TfLiteIntArrayFree(node->temporaries); data->scratch_tensor_id;
node->temporaries = TfLiteIntArrayCreate(1); TfLiteTensor* scratch_buffer =
node->temporaries->data[0] = data->scratch_tensor_index; GetTemporary(context, node, data->scratch_tensor_index);
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
scratch_buffer->type = kTfLiteInt32; scratch_buffer->type = kTfLiteInt32;
scratch_buffer->allocation_type = kTfLiteArenaRw; scratch_buffer->allocation_type = kTfLiteDynamic;
if (!IsConstantTensor(output_shape)) { if (!IsConstantTensor(output_shape)) {
SetTensorToDynamic(scratch_buffer); SetTensorToDynamic(scratch_buffer);
} else { } else {
@ -221,11 +308,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type> template <KernelType kernel_type>
void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data, void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
const TfLiteTensor* input, const TfLiteTensor* weights, const TfLiteTensor* input, const TfLiteTensor* weights,
TfLiteTensor* im2col, TfLiteTensor* output) { const TfLiteTensor* transposed_weights, TfLiteTensor* col2im,
TfLiteTensor* output) {
tflite::ConvParams op_params; tflite::ConvParams op_params;
op_params.padding_type = PaddingType::kSame; op_params.padding_type = PaddingType::kSame;
op_params.padding_values.width = data->padding.width; op_params.padding_values.width = data->padding.width;
op_params.padding_values.height = data->padding.height; op_params.padding_values.height = data->padding.height;
op_params.padding_values.width_offset = data->padding.width_offset;
op_params.padding_values.height_offset = data->padding.height_offset;
op_params.stride_width = params->stride_width; op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height; op_params.stride_height = params->stride_height;
switch (kernel_type) { switch (kernel_type) {
@ -234,15 +324,16 @@ void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
op_params, GetTensorShape(input), GetTensorData<float>(input), op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(weights), GetTensorData<float>(weights), GetTensorShape(weights), GetTensorData<float>(weights),
GetTensorShape(output), GetTensorData<float>(output), GetTensorShape(output), GetTensorData<float>(output),
GetTensorShape(im2col), GetTensorData<float>(im2col)); GetTensorShape(col2im), GetTensorData<float>(col2im));
break; break;
} }
case kGenericOptimized: { case kGenericOptimized: {
optimized_ops::TransposeConv( optimized_ops::TransposeConvV2(
op_params, GetTensorShape(input), GetTensorData<float>(input), op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(weights), GetTensorData<float>(weights), GetTensorShape(transposed_weights),
GetTensorShape(output), GetTensorData<float>(output), GetTensorData<float>(transposed_weights), GetTensorShape(output),
GetTensorShape(im2col), GetTensorData<float>(im2col)); GetTensorData<float>(output), GetTensorShape(col2im),
GetTensorData<float>(col2im));
break; break;
} }
} }
@ -250,7 +341,7 @@ void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
void EvalQuantized(const TfLiteTransposeConvParams* params, OpData* data, void EvalQuantized(const TfLiteTransposeConvParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* weights, const TfLiteTensor* input, const TfLiteTensor* weights,
TfLiteTensor* im2col, TfLiteTensor* output, TfLiteTensor* col2im, TfLiteTensor* output,
TfLiteTensor* scratch_buffer) { TfLiteTensor* scratch_buffer) {
int32_t input_offset = -input->params.zero_point; int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -weights->params.zero_point; int32_t filter_offset = -weights->params.zero_point;
@ -275,7 +366,7 @@ void EvalQuantized(const TfLiteTransposeConvParams* params, OpData* data,
op_params, GetTensorShape(input), GetTensorData<uint8>(input), op_params, GetTensorShape(input), GetTensorData<uint8>(input),
GetTensorShape(weights), GetTensorData<uint8>(weights), GetTensorShape(weights), GetTensorData<uint8>(weights),
GetTensorShape(output), GetTensorData<uint8>(output), GetTensorShape(output), GetTensorData<uint8>(output),
GetTensorShape(im2col), GetTensorData<uint8>(im2col), GetTensorShape(col2im), GetTensorData<uint8>(col2im),
GetTensorData<int32_t>(scratch_buffer)); GetTensorData<int32_t>(scratch_buffer));
} }
@ -288,8 +379,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* im2col = TfLiteTensor* col2im = data->has_col2im
&context->tensors[node->temporaries->data[data->im2col_index]]; ? GetTemporary(context, node, data->col2im_index)
: nullptr;
TfLiteTensor* transposed_weights =
data->weights_are_transposed
? GetTemporary(context, node, data->transposed_weights_index)
: nullptr;
const auto* params = const auto* params =
reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data); reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
@ -297,9 +393,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, ResizeTensor(context, output_shape, output)); TF_LITE_ENSURE_OK(context, ResizeTensor(context, output_shape, output));
} }
if (IsDynamicTensor(im2col)) { if (data->has_col2im && IsDynamicTensor(col2im)) {
TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape, TF_LITE_ENSURE_OK(context, ResizeCol2ImTensor(context, output_shape,
weights, input, im2col)); weights, input, col2im));
} }
// Get height and width of the output image. // Get height and width of the output image.
@ -315,18 +411,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently support float32 and uint8. // Currently support float32 and uint8.
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
EvalFloat<kernel_type>(params, data, input, weights, im2col, output); // Only for GenericOptimized path, we use transposed weights.
if (data->weights_are_transposed) {
if (!IsConstantTensor(weights)) {
ResizeAndTransposeWeights(context, weights, transposed_weights);
}
}
EvalFloat<kernel_type>(params, data, input, weights, transposed_weights,
col2im, output);
break; break;
} }
case kTfLiteUInt8: { case kTfLiteUInt8: {
// TODO(haoliang): support optimized implementation for quantized // TODO(haoliang): support optimized implementation for quantized
// TransposeConv. // TransposeConv.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index*/ 0); TfLiteTensor* scratch_buffer =
GetTemporary(context, node, data->scratch_tensor_index);
if (IsDynamicTensor(scratch_buffer)) { if (IsDynamicTensor(scratch_buffer)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
ResizeTensor(context, output_shape, scratch_buffer)); ResizeTensor(context, output_shape, scratch_buffer));
} }
EvalQuantized(params, data, input, weights, im2col, output, EvalQuantized(params, data, input, weights, col2im, output,
scratch_buffer); scratch_buffer);
break; break;
} }
@ -342,14 +446,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteRegistration* Register_TRANSPOSECONV_REF() { TfLiteRegistration* Register_TRANSPOSECONV_REF() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
transpose_conv::Init, transpose_conv::Free, transpose_conv::Prepare, transpose_conv::Init, transpose_conv::Free,
transpose_conv::Prepare<transpose_conv::kReference>,
transpose_conv::Eval<transpose_conv::kReference>}; transpose_conv::Eval<transpose_conv::kReference>};
return &r; return &r;
} }
TfLiteRegistration* Register_TRANSPOSECONV_GENERIC_OPT() { TfLiteRegistration* Register_TRANSPOSECONV_GENERIC_OPT() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
transpose_conv::Init, transpose_conv::Free, transpose_conv::Prepare, transpose_conv::Init, transpose_conv::Free,
transpose_conv::Prepare<transpose_conv::kGenericOptimized>,
transpose_conv::Eval<transpose_conv::kGenericOptimized>}; transpose_conv::Eval<transpose_conv::kGenericOptimized>};
return &r; return &r;
} }

View File

@ -3818,6 +3818,17 @@ def make_conv2d_transpose_tests(options):
"input_shape": [[1, 50, 54, 3]], "input_shape": [[1, 50, 54, 3]],
"filter_shape": [[1, 1, 8, 3], [1, 2, 8, 3], [1, 3, 8, 3], [1, 4, 8, 3]], "filter_shape": [[1, 1, 8, 3], [1, 2, 8, 3], [1, 3, 8, 3], [1, 4, 8, 3]],
"output_shape": [[1, 100, 108, 8]], "output_shape": [[1, 100, 108, 8]],
"dynamic_output_shape": [True, False],
}, {
"input_shape": [[1, 16, 1, 512]],
"filter_shape": [[4, 1, 512, 512]],
"output_shape": [[1, 32, 1, 512]],
"dynamic_output_shape": [True, False],
}, {
"input_shape": [[1, 128, 128, 1]],
"filter_shape": [[4, 4, 1, 1]],
"output_shape": [[1, 256, 256, 1]],
"dynamic_output_shape": [True, False],
}] }]
def build_graph(parameters): def build_graph(parameters):
@ -3828,14 +3839,21 @@ def make_conv2d_transpose_tests(options):
filter_tensor = tf.placeholder( filter_tensor = tf.placeholder(
dtype=tf.float32, name="filter", shape=parameters["filter_shape"]) dtype=tf.float32, name="filter", shape=parameters["filter_shape"])
input_tensors = [input_tensor, filter_tensor]
if parameters["dynamic_output_shape"]:
output_shape = tf.placeholder(dtype=tf.int32, shape=[4])
input_tensors.append(output_shape)
else:
output_shape = parameters["output_shape"]
out = tf.nn.conv2d_transpose( out = tf.nn.conv2d_transpose(
input_tensor, input_tensor,
filter_tensor, filter_tensor,
output_shape=parameters["output_shape"], output_shape=output_shape,
padding="SAME", padding="SAME",
strides=(1, 2, 2, 1)) strides=(1, 2, 2, 1))
input_tensors = [input_tensor, filter_tensor]
return input_tensors, [out] return input_tensors, [out]
def build_inputs(parameters, sess, inputs, outputs): def build_inputs(parameters, sess, inputs, outputs):
@ -3843,6 +3861,9 @@ def make_conv2d_transpose_tests(options):
create_tensor_data(np.float32, parameters["input_shape"]), create_tensor_data(np.float32, parameters["input_shape"]),
create_tensor_data(np.float32, parameters["filter_shape"]) create_tensor_data(np.float32, parameters["filter_shape"])
] ]
if parameters["dynamic_output_shape"]:
values.append(np.array(parameters["output_shape"]))
return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
make_zip_of_tests(options, test_parameters, build_graph, build_inputs) make_zip_of_tests(options, test_parameters, build_graph, build_inputs)