Optimize transpose_conv
PiperOrigin-RevId: 243561442
This commit is contained in:
parent
0c464c70ce
commit
5e52b70188
@ -46,9 +46,12 @@ typedef enum {
|
||||
kTfLiteMirrorPaddingSymmetric,
|
||||
} TfLiteMirrorPaddingMode;
|
||||
|
||||
// TODO(b/130259536): We should move this out of builtin_op_data.
|
||||
typedef struct {
|
||||
int width;
|
||||
int height;
|
||||
int width_offset;
|
||||
int height_offset;
|
||||
} TfLitePaddingValues;
|
||||
|
||||
typedef struct {
|
||||
|
@ -6130,33 +6130,112 @@ void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
|
||||
}
|
||||
}
|
||||
|
||||
// Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
|
||||
// order (height, width, depth), constructed from patches in 'col_data', which
|
||||
// is required to be in storage order (out_height * out_width, filter_height,
|
||||
// filter_width, in_depth). Implementation by Yangqing Jia (jiayq).
|
||||
// Copied from //tensorflow/core/kernels/conv_grad_input_ops.cc
|
||||
template <typename T>
|
||||
void Col2im(const T* col_data, const int depth, const int height,
|
||||
const int width, const int filter_h, const int filter_w,
|
||||
const int pad_t, const int pad_l, const int pad_b, const int pad_r,
|
||||
const int stride_h, const int stride_w, T* im_data) {
|
||||
int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
|
||||
int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
|
||||
int h_pad = -pad_t;
|
||||
for (int h = 0; h < height_col; ++h) {
|
||||
int w_pad = -pad_l;
|
||||
for (int w = 0; w < width_col; ++w) {
|
||||
T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
|
||||
for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
|
||||
for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
|
||||
if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
|
||||
// TODO(andydavis) Vectorize this loop (if compiler does not).
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
im_patch_data[i] += col_data[i];
|
||||
}
|
||||
}
|
||||
im_patch_data += depth;
|
||||
col_data += depth;
|
||||
}
|
||||
// Jump over remaining number of depth.
|
||||
im_patch_data += depth * (width - filter_w);
|
||||
}
|
||||
w_pad += stride_w;
|
||||
}
|
||||
h_pad += stride_h;
|
||||
}
|
||||
}
|
||||
|
||||
// TransposeConvV2 expect the weights in HWOI order.
|
||||
inline void TransposeConvV2(
|
||||
const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
|
||||
const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
|
||||
float* output_data, const RuntimeShape& col2im_shape, float* col2im_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("TransposeConvV2");
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
|
||||
const int batch_size = input_shape.Dims(0);
|
||||
TFLITE_DCHECK(col2im_data);
|
||||
TFLITE_DCHECK(hwoi_ordered_filter_data);
|
||||
|
||||
const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int output_image_size = output_height * output_width;
|
||||
const int input_depth =
|
||||
MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
|
||||
const int output_depth =
|
||||
MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
|
||||
const int input_offset = input_image_size * input_depth;
|
||||
const int output_offset = output_image_size * output_depth;
|
||||
|
||||
const int filter_height = hwoi_ordered_filter_shape.Dims(0);
|
||||
const int filter_width = hwoi_ordered_filter_shape.Dims(1);
|
||||
const int padding_top = params.padding_values.height;
|
||||
const int padding_bottom =
|
||||
params.padding_values.height + params.padding_values.height_offset;
|
||||
const int padding_left = params.padding_values.width;
|
||||
const int padding_right =
|
||||
params.padding_values.width + params.padding_values.width_offset;
|
||||
const int stride_height = params.stride_height;
|
||||
const int stride_width = params.stride_width;
|
||||
|
||||
const int hwoi_ordered_filter_total_size =
|
||||
filter_height * filter_width * output_depth;
|
||||
|
||||
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
|
||||
Matrix;
|
||||
typedef Eigen::Map<Matrix> MatrixRef;
|
||||
typedef Eigen::Map<const Matrix> ConstMatrixRef;
|
||||
ConstMatrixRef hwoi_ordered_filter_matrix_map(
|
||||
hwoi_ordered_filter_data, hwoi_ordered_filter_total_size, input_depth);
|
||||
float* output_data_p = output_data;
|
||||
tensor_utils::ZeroVector(output_data, output_offset * batch_size);
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
ConstMatrixRef input_matrix_map(input_data + input_offset * i,
|
||||
input_image_size, input_depth);
|
||||
MatrixRef output_matrix_map(col2im_data, input_image_size,
|
||||
hwoi_ordered_filter_total_size);
|
||||
Gemm(input_matrix_map, hwoi_ordered_filter_matrix_map.transpose(),
|
||||
&output_matrix_map);
|
||||
|
||||
Col2im(col2im_data, output_depth, output_height, output_width,
|
||||
filter_height, filter_width, padding_top, padding_left,
|
||||
padding_bottom, padding_right, stride_height, stride_width,
|
||||
output_data_p);
|
||||
output_data_p += output_offset;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(renjieliu): Investigate whether we need to keep this.
|
||||
inline void TransposeConv(
|
||||
const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const float* input_data, const RuntimeShape& filter_shape,
|
||||
const float* filter_data, const RuntimeShape& output_shape,
|
||||
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("TransposeConv");
|
||||
// The complexity of the reference implementation is input.flat_size() *
|
||||
// filter.flat_size() / in_channel.
|
||||
//
|
||||
// While the complexity of im2col->gemm
|
||||
// implmentation is batch * output_height * output_width *
|
||||
// (filter.flat_size() / out_channel)^2 * out_channel.
|
||||
//
|
||||
// so if input.flat_size() * out_channel^2 is much smaller than
|
||||
// output.flat_size() * filter.size() * in_channel we should fall back to the
|
||||
// reference implementation.
|
||||
//
|
||||
// TODO(b/122331966): optimize the intuitive implementation.
|
||||
const int out_channel = output_shape.Dims(3);
|
||||
const int in_channel = input_shape.Dims(3);
|
||||
if ((input_shape.FlatSize() * out_channel * out_channel * 4) <
|
||||
(filter_shape.FlatSize() * output_shape.FlatSize() * in_channel)) {
|
||||
reference_ops::TransposeConv(params, input_shape, input_data, filter_shape,
|
||||
filter_data, output_shape, output_data,
|
||||
im2col_shape, im2col_data);
|
||||
return;
|
||||
}
|
||||
// Note we could use transposed weights with forward conv for unstrided
|
||||
// cases. But we are already getting good performance with this code as-is.
|
||||
TFLITE_DCHECK(im2col_data);
|
||||
|
@ -28,6 +28,12 @@ enum class PaddingType : uint8 { kNone, kSame, kValid };
|
||||
struct PaddingValues {
|
||||
int16 width;
|
||||
int16 height;
|
||||
// offset is used for calculating "remaining" padding, for example, `width`
|
||||
// is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
|
||||
// 1 + 1 = 2.
|
||||
int16 width_offset;
|
||||
// Same as width_offset except it's over the height dimension.
|
||||
int16 height_offset;
|
||||
};
|
||||
|
||||
// This enumeration allows for non-default formats for the weights array
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// TODO(renjieliu): Migrate others to use ComputePaddingWithLeftover.
|
||||
inline int ComputePadding(int stride, int dilation_rate, int in_size,
|
||||
int filter_size, int out_size) {
|
||||
int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
|
||||
@ -26,6 +27,19 @@ inline int ComputePadding(int stride, int dilation_rate, int in_size,
|
||||
return padding > 0 ? padding : 0;
|
||||
}
|
||||
|
||||
// It's not guaranteed that padding is symmetric. It's important to keep
|
||||
// offset for algorithms need all paddings.
|
||||
inline int ComputePaddingWithOffset(int stride, int dilation_rate, int in_size,
|
||||
int filter_size, int out_size,
|
||||
int* offset) {
|
||||
int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
|
||||
int total_padding =
|
||||
((out_size - 1) * stride + effective_filter_size - in_size);
|
||||
total_padding = total_padding > 0 ? total_padding : 0;
|
||||
*offset = total_padding % 2;
|
||||
return total_padding / 2;
|
||||
}
|
||||
|
||||
// Matching GetWindowedOutputSize in TensorFlow.
|
||||
inline int ComputeOutSize(TfLitePadding padding, int image_size,
|
||||
int filter_size, int stride) {
|
||||
@ -47,10 +61,13 @@ inline TfLitePaddingValues ComputePaddingHeightWidth(
|
||||
ComputeOutSize(padding, in_height, filter_height, stride_height);
|
||||
|
||||
TfLitePaddingValues padding_values;
|
||||
padding_values.height =
|
||||
ComputePadding(stride_height, 1, in_height, filter_height, out_height);
|
||||
padding_values.width =
|
||||
ComputePadding(stride_width, 1, in_width, filter_width, out_width);
|
||||
int offset = 0;
|
||||
padding_values.height = ComputePaddingWithOffset(
|
||||
stride_height, 1, in_height, filter_height, out_height, &offset);
|
||||
padding_values.height_offset = offset;
|
||||
padding_values.width = ComputePaddingWithOffset(
|
||||
stride_width, 1, in_width, filter_width, out_width, &offset);
|
||||
padding_values.width_offset = offset;
|
||||
return padding_values;
|
||||
}
|
||||
} // namespace tflite
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
@ -49,11 +50,22 @@ const int kTensorNotAllocated = -1;
|
||||
struct OpData {
|
||||
// IDs are the arbitrary identifiers used by TF Lite to identify and access
|
||||
// memory buffers.
|
||||
int im2col_id = kTensorNotAllocated;
|
||||
int col2im_id = kTensorNotAllocated;
|
||||
int transposed_weights_id = kTensorNotAllocated;
|
||||
int scratch_tensor_id = kTensorNotAllocated;
|
||||
|
||||
// im2col is the only temporary currently tracked, therefore always index 0.
|
||||
// If more temporaries are added, they should be properly tracked.
|
||||
int32_t im2col_index = 0;
|
||||
// col2im is the temporary tensor allocated and used in optimized path for
|
||||
// storing col2im data:gemm result for input_matrix x filter_matrix.
|
||||
int32_t col2im_index;
|
||||
|
||||
// TfLiteConverter will transpose weights from HWOI to OHWI order.
|
||||
// In optimized path, we will transpose them back to HWOI, this temporary
|
||||
// tensor is allocated for storing transposed weights.
|
||||
int32_t transposed_weights_index;
|
||||
|
||||
// Scratch tensor is used in the quantized path for storing accumulation
|
||||
// results.
|
||||
int32_t scratch_tensor_index;
|
||||
|
||||
TfLitePaddingValues padding;
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
@ -66,17 +78,12 @@ struct OpData {
|
||||
int32_t output_activation_min;
|
||||
int32_t output_activation_max;
|
||||
|
||||
int scratch_tensor_index;
|
||||
bool has_col2im = false;
|
||||
bool weights_are_transposed = false;
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
||||
// Instead, we allocate a new object to use as scratch space for im2col, and
|
||||
// to carry information from Prepare() to Eval().
|
||||
auto* data = new OpData;
|
||||
// Populate scratch_tensor_index.
|
||||
context->AddTensors(context, /*tensors_to_add=*/1,
|
||||
&data->scratch_tensor_index);
|
||||
eigen_support::IncrementUsageCounter(context);
|
||||
return data;
|
||||
}
|
||||
@ -104,46 +111,106 @@ TfLiteStatus ResizeTensor(TfLiteContext* context,
|
||||
return context->ResizeTensor(context, tensor_to_resize, shape);
|
||||
}
|
||||
|
||||
static TfLiteStatus AllocateIm2colTensorIfRequired(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
// Allocate temporary tensors if necessary.
|
||||
template <KernelType kernel_type>
|
||||
static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
||||
TfLiteType input_type,
|
||||
TfLiteType weights_type,
|
||||
TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
if (data->im2col_id == kTensorNotAllocated) {
|
||||
context->AddTensors(context, 1, &data->im2col_id);
|
||||
context->tensors[data->im2col_id].type = kTfLiteFloat32;
|
||||
int temporaries_count = 0;
|
||||
|
||||
// Allocate col2im tensor. Currently it's only used for optimized kernels.
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
if (data->col2im_id == kTensorNotAllocated) {
|
||||
context->AddTensors(context, 1, &data->col2im_id);
|
||||
}
|
||||
data->col2im_index = temporaries_count;
|
||||
data->has_col2im = true;
|
||||
++temporaries_count;
|
||||
}
|
||||
|
||||
// Allocate transposed_weights tensor. Currently it's only used for optimized
|
||||
// float kernels.
|
||||
if (kernel_type == kGenericOptimized && input_type == kTfLiteFloat32) {
|
||||
if (data->transposed_weights_id == kTensorNotAllocated) {
|
||||
context->AddTensors(context, 1, &data->transposed_weights_id);
|
||||
}
|
||||
data->transposed_weights_index = temporaries_count;
|
||||
data->weights_are_transposed = true;
|
||||
++temporaries_count;
|
||||
}
|
||||
|
||||
// Allocate scratch buffer tensor for UInt8 inputs.
|
||||
if (input_type == kTfLiteUInt8) {
|
||||
if (data->scratch_tensor_id == kTensorNotAllocated) {
|
||||
context->AddTensors(context, 1, &data->scratch_tensor_id);
|
||||
}
|
||||
data->scratch_tensor_index = temporaries_count;
|
||||
++temporaries_count;
|
||||
}
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(1);
|
||||
node->temporaries->data[data->im2col_index] = data->im2col_id;
|
||||
node->temporaries = TfLiteIntArrayCreate(temporaries_count);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context,
|
||||
TfLiteStatus ResizeCol2ImTensor(TfLiteContext* context,
|
||||
const TfLiteTensor* output_shape,
|
||||
const TfLiteTensor* weights,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* im2col) {
|
||||
TfLiteTensor* col2im) {
|
||||
if (output_shape->type != kTfLiteInt32) {
|
||||
context->ReportError(context, "im2col shape is %d, not int32.",
|
||||
context->ReportError(context, "col2im shape is %d, not int32.",
|
||||
output_shape->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4);
|
||||
TfLiteIntArray* im2col_shape_array = TfLiteIntArrayCreate(4);
|
||||
im2col_shape_array->data[0] = output_shape->data.i32[0];
|
||||
im2col_shape_array->data[1] = output_shape->data.i32[1];
|
||||
im2col_shape_array->data[2] = output_shape->data.i32[2];
|
||||
const int input_depth = SizeOfDimension(input, 3);
|
||||
const int filter_width = SizeOfDimension(weights, 2);
|
||||
const int filter_height = SizeOfDimension(weights, 1);
|
||||
im2col_shape_array->data[3] = input_depth * filter_height * filter_width;
|
||||
TfLiteIntArray* col2im_shape_array = TfLiteIntArrayCreate(2);
|
||||
const RuntimeShape& input_shape = GetTensorShape(input);
|
||||
const RuntimeShape& weights_shape = GetTensorShape(weights);
|
||||
col2im_shape_array->data[0] = input_shape.Dims(1) * input_shape.Dims(2);
|
||||
col2im_shape_array->data[1] =
|
||||
weights_shape.Dims(0) * weights_shape.Dims(1) * weights_shape.Dims(2);
|
||||
|
||||
im2col->type = input->type;
|
||||
im2col->allocation_type = kTfLiteDynamic;
|
||||
return context->ResizeTensor(context, im2col, im2col_shape_array);
|
||||
col2im->type = input->type;
|
||||
col2im->allocation_type = kTfLiteDynamic;
|
||||
return context->ResizeTensor(context, col2im, col2im_shape_array);
|
||||
}
|
||||
|
||||
TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context,
|
||||
const TfLiteTensor* weights,
|
||||
TfLiteTensor* transposed_weights) {
|
||||
TfLiteIntArray* transposed_weights_shape_array = TfLiteIntArrayCreate(4);
|
||||
const RuntimeShape& input_shape = GetTensorShape(weights);
|
||||
transposed_weights_shape_array->data[0] = input_shape.Dims(1);
|
||||
transposed_weights_shape_array->data[1] = input_shape.Dims(2);
|
||||
transposed_weights_shape_array->data[2] = input_shape.Dims(0);
|
||||
transposed_weights_shape_array->data[3] = input_shape.Dims(3);
|
||||
|
||||
transposed_weights->type = weights->type;
|
||||
transposed_weights->allocation_type = kTfLiteDynamic;
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, transposed_weights,
|
||||
transposed_weights_shape_array));
|
||||
|
||||
// Transpose the weights from from OHWI order to HWOI order.
|
||||
TransposeParams transpose_params;
|
||||
transpose_params.perm_count = 4;
|
||||
transpose_params.perm[0] = 1;
|
||||
transpose_params.perm[1] = 2;
|
||||
transpose_params.perm[2] = 0;
|
||||
transpose_params.perm[3] = 3;
|
||||
|
||||
optimized_ops::Transpose(transpose_params, input_shape,
|
||||
GetTensorData<float>(weights),
|
||||
GetTensorShape(transposed_weights),
|
||||
GetTensorData<float>(transposed_weights));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
@ -151,18 +218,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
// Allocate Im2col Tensor
|
||||
TF_LITE_ENSURE_STATUS(AllocateIm2colTensorIfRequired(context, node));
|
||||
|
||||
// Retrieve tensors
|
||||
const TfLiteTensor* output_shape =
|
||||
GetInput(context, node, kOutputShapeTensor);
|
||||
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* im2col =
|
||||
&context->tensors[node->temporaries->data[user_data->im2col_index]];
|
||||
|
||||
// Tensor sanity checks
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
|
||||
@ -177,24 +238,50 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
|
||||
SizeOfDimension(weights, 3));
|
||||
|
||||
// Allocate col2Im, transposed_weights & scratch Tensor.
|
||||
TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired<kernel_type>(
|
||||
context, input->type, weights->type, node));
|
||||
|
||||
OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* col2im = nullptr;
|
||||
if (data->has_col2im) {
|
||||
node->temporaries->data[data->col2im_index] = data->col2im_id;
|
||||
col2im = GetTemporary(context, node, user_data->col2im_index);
|
||||
}
|
||||
|
||||
if (!IsConstantTensor(output_shape)) {
|
||||
// Defer resizing until Eval().
|
||||
SetTensorToDynamic(output);
|
||||
SetTensorToDynamic(im2col);
|
||||
if (data->has_col2im) {
|
||||
SetTensorToDynamic(col2im);
|
||||
}
|
||||
} else {
|
||||
TF_LITE_ENSURE_STATUS(ResizeTensor(context, output_shape, output));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ResizeIm2ColTensor(context, output_shape, weights, input, im2col));
|
||||
if (data->has_col2im) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ResizeCol2ImTensor(context, output_shape, weights, input, col2im));
|
||||
}
|
||||
}
|
||||
|
||||
if (data->weights_are_transposed) {
|
||||
node->temporaries->data[data->transposed_weights_index] =
|
||||
data->transposed_weights_id;
|
||||
TfLiteTensor* transposed_weights =
|
||||
GetTemporary(context, node, user_data->transposed_weights_index);
|
||||
if (!IsConstantTensor(weights)) {
|
||||
SetTensorToDynamic(transposed_weights);
|
||||
} else {
|
||||
ResizeAndTransposeWeights(context, weights, transposed_weights);
|
||||
}
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
// Set up a scratch buffer tensor.
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(1);
|
||||
node->temporaries->data[0] = data->scratch_tensor_index;
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
|
||||
node->temporaries->data[data->scratch_tensor_index] =
|
||||
data->scratch_tensor_id;
|
||||
TfLiteTensor* scratch_buffer =
|
||||
GetTemporary(context, node, data->scratch_tensor_index);
|
||||
scratch_buffer->type = kTfLiteInt32;
|
||||
scratch_buffer->allocation_type = kTfLiteArenaRw;
|
||||
scratch_buffer->allocation_type = kTfLiteDynamic;
|
||||
if (!IsConstantTensor(output_shape)) {
|
||||
SetTensorToDynamic(scratch_buffer);
|
||||
} else {
|
||||
@ -221,11 +308,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
template <KernelType kernel_type>
|
||||
void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
|
||||
const TfLiteTensor* input, const TfLiteTensor* weights,
|
||||
TfLiteTensor* im2col, TfLiteTensor* output) {
|
||||
const TfLiteTensor* transposed_weights, TfLiteTensor* col2im,
|
||||
TfLiteTensor* output) {
|
||||
tflite::ConvParams op_params;
|
||||
op_params.padding_type = PaddingType::kSame;
|
||||
op_params.padding_values.width = data->padding.width;
|
||||
op_params.padding_values.height = data->padding.height;
|
||||
op_params.padding_values.width_offset = data->padding.width_offset;
|
||||
op_params.padding_values.height_offset = data->padding.height_offset;
|
||||
op_params.stride_width = params->stride_width;
|
||||
op_params.stride_height = params->stride_height;
|
||||
switch (kernel_type) {
|
||||
@ -234,15 +324,16 @@ void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(weights), GetTensorData<float>(weights),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
GetTensorShape(im2col), GetTensorData<float>(im2col));
|
||||
GetTensorShape(col2im), GetTensorData<float>(col2im));
|
||||
break;
|
||||
}
|
||||
case kGenericOptimized: {
|
||||
optimized_ops::TransposeConv(
|
||||
optimized_ops::TransposeConvV2(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(weights), GetTensorData<float>(weights),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
GetTensorShape(im2col), GetTensorData<float>(im2col));
|
||||
GetTensorShape(transposed_weights),
|
||||
GetTensorData<float>(transposed_weights), GetTensorShape(output),
|
||||
GetTensorData<float>(output), GetTensorShape(col2im),
|
||||
GetTensorData<float>(col2im));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -250,7 +341,7 @@ void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
|
||||
|
||||
void EvalQuantized(const TfLiteTransposeConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input, const TfLiteTensor* weights,
|
||||
TfLiteTensor* im2col, TfLiteTensor* output,
|
||||
TfLiteTensor* col2im, TfLiteTensor* output,
|
||||
TfLiteTensor* scratch_buffer) {
|
||||
int32_t input_offset = -input->params.zero_point;
|
||||
int32_t filter_offset = -weights->params.zero_point;
|
||||
@ -275,7 +366,7 @@ void EvalQuantized(const TfLiteTransposeConvParams* params, OpData* data,
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8>(input),
|
||||
GetTensorShape(weights), GetTensorData<uint8>(weights),
|
||||
GetTensorShape(output), GetTensorData<uint8>(output),
|
||||
GetTensorShape(im2col), GetTensorData<uint8>(im2col),
|
||||
GetTensorShape(col2im), GetTensorData<uint8>(col2im),
|
||||
GetTensorData<int32_t>(scratch_buffer));
|
||||
}
|
||||
|
||||
@ -288,8 +379,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* im2col =
|
||||
&context->tensors[node->temporaries->data[data->im2col_index]];
|
||||
TfLiteTensor* col2im = data->has_col2im
|
||||
? GetTemporary(context, node, data->col2im_index)
|
||||
: nullptr;
|
||||
TfLiteTensor* transposed_weights =
|
||||
data->weights_are_transposed
|
||||
? GetTemporary(context, node, data->transposed_weights_index)
|
||||
: nullptr;
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
|
||||
|
||||
@ -297,9 +393,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeTensor(context, output_shape, output));
|
||||
}
|
||||
if (IsDynamicTensor(im2col)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape,
|
||||
weights, input, im2col));
|
||||
if (data->has_col2im && IsDynamicTensor(col2im)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeCol2ImTensor(context, output_shape,
|
||||
weights, input, col2im));
|
||||
}
|
||||
|
||||
// Get height and width of the output image.
|
||||
@ -315,18 +411,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Currently support float32 and uint8.
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
EvalFloat<kernel_type>(params, data, input, weights, im2col, output);
|
||||
// Only for GenericOptimized path, we use transposed weights.
|
||||
if (data->weights_are_transposed) {
|
||||
if (!IsConstantTensor(weights)) {
|
||||
ResizeAndTransposeWeights(context, weights, transposed_weights);
|
||||
}
|
||||
}
|
||||
EvalFloat<kernel_type>(params, data, input, weights, transposed_weights,
|
||||
col2im, output);
|
||||
break;
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
// TODO(haoliang): support optimized implementation for quantized
|
||||
// TransposeConv.
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index*/ 0);
|
||||
TfLiteTensor* scratch_buffer =
|
||||
GetTemporary(context, node, data->scratch_tensor_index);
|
||||
if (IsDynamicTensor(scratch_buffer)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeTensor(context, output_shape, scratch_buffer));
|
||||
}
|
||||
EvalQuantized(params, data, input, weights, im2col, output,
|
||||
EvalQuantized(params, data, input, weights, col2im, output,
|
||||
scratch_buffer);
|
||||
break;
|
||||
}
|
||||
@ -342,14 +446,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSECONV_REF() {
|
||||
static TfLiteRegistration r = {
|
||||
transpose_conv::Init, transpose_conv::Free, transpose_conv::Prepare,
|
||||
transpose_conv::Init, transpose_conv::Free,
|
||||
transpose_conv::Prepare<transpose_conv::kReference>,
|
||||
transpose_conv::Eval<transpose_conv::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSECONV_GENERIC_OPT() {
|
||||
static TfLiteRegistration r = {
|
||||
transpose_conv::Init, transpose_conv::Free, transpose_conv::Prepare,
|
||||
transpose_conv::Init, transpose_conv::Free,
|
||||
transpose_conv::Prepare<transpose_conv::kGenericOptimized>,
|
||||
transpose_conv::Eval<transpose_conv::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
@ -3818,6 +3818,17 @@ def make_conv2d_transpose_tests(options):
|
||||
"input_shape": [[1, 50, 54, 3]],
|
||||
"filter_shape": [[1, 1, 8, 3], [1, 2, 8, 3], [1, 3, 8, 3], [1, 4, 8, 3]],
|
||||
"output_shape": [[1, 100, 108, 8]],
|
||||
"dynamic_output_shape": [True, False],
|
||||
}, {
|
||||
"input_shape": [[1, 16, 1, 512]],
|
||||
"filter_shape": [[4, 1, 512, 512]],
|
||||
"output_shape": [[1, 32, 1, 512]],
|
||||
"dynamic_output_shape": [True, False],
|
||||
}, {
|
||||
"input_shape": [[1, 128, 128, 1]],
|
||||
"filter_shape": [[4, 4, 1, 1]],
|
||||
"output_shape": [[1, 256, 256, 1]],
|
||||
"dynamic_output_shape": [True, False],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
@ -3828,14 +3839,21 @@ def make_conv2d_transpose_tests(options):
|
||||
filter_tensor = tf.placeholder(
|
||||
dtype=tf.float32, name="filter", shape=parameters["filter_shape"])
|
||||
|
||||
input_tensors = [input_tensor, filter_tensor]
|
||||
|
||||
if parameters["dynamic_output_shape"]:
|
||||
output_shape = tf.placeholder(dtype=tf.int32, shape=[4])
|
||||
input_tensors.append(output_shape)
|
||||
else:
|
||||
output_shape = parameters["output_shape"]
|
||||
|
||||
out = tf.nn.conv2d_transpose(
|
||||
input_tensor,
|
||||
filter_tensor,
|
||||
output_shape=parameters["output_shape"],
|
||||
output_shape=output_shape,
|
||||
padding="SAME",
|
||||
strides=(1, 2, 2, 1))
|
||||
|
||||
input_tensors = [input_tensor, filter_tensor]
|
||||
return input_tensors, [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
@ -3843,6 +3861,9 @@ def make_conv2d_transpose_tests(options):
|
||||
create_tensor_data(np.float32, parameters["input_shape"]),
|
||||
create_tensor_data(np.float32, parameters["filter_shape"])
|
||||
]
|
||||
if parameters["dynamic_output_shape"]:
|
||||
values.append(np.array(parameters["output_shape"]))
|
||||
|
||||
return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
|
||||
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user