PR #34903: TransposeConv with Bias

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/34903

### Description of issue:
For generating .tflite file with TFLiteConverter, when model contains
Conv2DTranspose layers, bias cannot fold into Operator TRANSPOSECONV.
It will result with extra Op ADD following Op TRANSPOSECONV.
But with other CONV-like layers (Conv2D, DepthwiseConv2D),
bias will be fold into CONV layer.
(check detailed TF issue: https://github.com/tensorflow/tensorflow/issues/34622)

### PR try to resolve it:
So we try to resolve this issue by enable TransposeConv with bias for TFLite:

- Update TFLite graph_transform features with:
      fill TransposeConv bias with zero if there is no bias
      fuse bias add into preceding TransposeConv(TEST added)

- Update TransposeConv with bias:
      add bias input to TransposeConv
      add optional bias to TransposeConv kernels

### example of the results:
  TRANSPOSE_CONV inputs:

1. output_shape
2. weights
3. activation
4. bias

![fused_transposeconv](https://user-images.githubusercontent.com/55463253/70334128-bc088200-183c-11ea-9f94-a803cc80df99.png)

### Need to discuss:
~~currently this PR only update reference kernel for transpose_conv, optimised kernal is commented out.~~
~~several TEST need to be added as well, but~~ further suggestions are needed for adding additional test.
Copybara import of the project:

--
1c6eb9c98229a9e8248dc1fe913a20cc6dd89332 by Peng Sun <peng.sun@arm.com>:

Fuse TransposeConv with Bias

For generating .tflite file with TFLiteConverter, when model contains
Conv2DTranspose layers, bias cannot fold into Operator TRANSPOSECONV.
It will result with extra Op ADD following Op TRANSPOSECONV.
But with other CONV-like layers (Conv2D, DepthwiseConv2D),
bias will be fold into CONV layer.
(check TF issue: https://github.com/tensorflow/tensorflow/issues/34622)

So we try to resolve this issue by enable TransposeConv with bias for TFLite:
  Update TFLite graph_transform features with:
    fill TransposeConv bias with zero if there is no bias
    fuse bias add into preceding TransposeConv(TEST added)
  Update TransposeConv with bias:
    add bias input to TransposeConv
    add optional bias to TransposeConv kernels(TEST added)

--
22611b880c94eb753c88a0a3e2977200e55ebd2c by Peng Sun <peng.sun@arm.com>:

clang-format with google style.

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/34903 from psunn:TransposeConvWithBias 22611b880c94eb753c88a0a3e2977200e55ebd2c
PiperOrigin-RevId: 307872447
Change-Id: I367fcd65f2662f4c7846d37bc69dc43670c83961
This commit is contained in:
Peng Sun 2020-04-22 12:28:06 -07:00 committed by TensorFlower Gardener
parent ea0535f296
commit 43b8f6e710
23 changed files with 617 additions and 70 deletions

View File

@ -478,6 +478,7 @@ def TFL_TransposeConvOp:
TFL_1DTensorOf<[I32]>:$output_shape,
TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$weights,
TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$input,
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
TFL_PaddingAttr:$padding,
I32Attr:$stride_h,
I32Attr:$stride_w

View File

@ -1296,10 +1296,12 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
// CHECK-LABEL: conv2d_backprop_input
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
// CHECK: %[[CST_0:.*]] = constant unit
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
// CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
// CHECK: %[[CST_2:.*]] = constant unit
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2, %[[CST_2]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
// CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32>
// CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32>
}

View File

@ -2032,7 +2032,8 @@ func @testFullyConnectedWithBadOutputShape(%arg0: tensor<1x37xf32>, %arg1: tenso
// -----
func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> {
%0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32>
%cst = constant unit
%0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}
@ -2046,8 +2047,9 @@ func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tenso
// -----
func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> {
%cst = constant unit
// expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}
%0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32>
%0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<64x84x32xf32>
return %0 : tensor<64x84x32xf32>
}
@ -2055,8 +2057,9 @@ func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x
func @testTransposeConvBadOutputShape(%arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32> {
%cst = constant dense<[1, 64, 84, 32]> : tensor<4xi32>
%cst_1 = constant unit
// expected-error @+1 {{expect output type tensor<1x64x84x32xf32>, got tensor<1x64x84x31xf32>}}
%0 = "tfl.transpose_conv"(%cst, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32>
%0 = "tfl.transpose_conv"(%cst, %arg1, %arg2, %cst_1) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x31xf32>
return %0 : tensor<1x64x84x31xf32>
}

View File

@ -58,6 +58,9 @@ def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
def CreateNoneValue : NativeCodeCall<
"$_builder.create<mlir::ConstantOp>($0.getLoc(), $_builder.getNoneType(), $_builder.getUnitAttr())">;
// Checks if the value has only one user.
// TODO(karimnosseir): Move to a common place?
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
@ -343,6 +346,7 @@ def : Pat<
(TF_TransposeOp $filter,
(ConstantOp ConstantAttr<I32VectorElementsAttr<4>, "{2, 0, 1, 3}">)),
$out_backprop,
/*bias=*/ (CreateNoneValue $input_sizes),
/*padding=*/ $padding,
/*stride_h=*/ ExtractI32At<1>:$strides,
/*stride_w=*/ ExtractI32At<2>:$strides)>;

View File

@ -25,16 +25,17 @@ inline void TransposeConvV2(
const ConvParams& params, const int32* output_multiplier,
const int32* output_shift, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
const int8_t* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
const int8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
int8_t* output_data, const RuntimeShape& col2im_shape, int32_t* col2im_data,
int32_t* scratch_data, CpuBackendContext* cpu_backend_context) {
ruy::profiler::ScopeLabel label("TransposeConvV2/int8");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
const int batch_size = input_shape.Dims(0);
TFLITE_DCHECK(col2im_data);
TFLITE_DCHECK(hwoi_ordered_filter_data);
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
@ -93,6 +94,9 @@ inline void TransposeConvV2(
scratch_data_p += output_offset;
}
scratch_data_p = scratch_data;
optimized_ops::BiasAdd(scratch_data_p, bias_data, batch_size, output_height,
output_width, output_depth);
const int32_t output_min = std::numeric_limits<int8_t>::min();
const int32_t output_max = std::numeric_limits<int8_t>::max();

View File

@ -2946,6 +2946,18 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
output_data, DimsToShape(im2col_dims), im2col_data);
}
inline void TransposeConvV2(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& col2im_shape, float* col2im_data,
CpuBackendContext* cpu_backend_context) {
TransposeConvV2(params, input_shape, input_data, hwoi_ordered_filter_shape,
hwoi_ordered_filter_data, /*bias_shape*/ RuntimeShape(),
/*bias_data*/ nullptr, output_shape, output_data,
col2im_shape, col2im_data, cpu_backend_context);
}
template <typename T>
void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
const Dims<4>& filter_dims, int stride_width,

View File

@ -5588,20 +5588,38 @@ void Col2im(const T* col_data, const int depth, const int height,
}
}
template <typename T>
void BiasAdd(T* im_data, const T* bias_data, const int batch_size,
const int height, const int width, const int depth) {
if (bias_data) {
for (int n = 0; n < batch_size; ++n) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int d = 0; d < depth; ++d) {
im_data[d] += bias_data[d];
}
im_data += depth;
}
}
}
}
}
// TransposeConvV2 expect the weights in HWOI order.
inline void TransposeConvV2(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& col2im_shape, float* col2im_data,
CpuBackendContext* cpu_backend_context) {
const float* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
const float* bias_data, const RuntimeShape& output_shape,
float* const output_data, const RuntimeShape& col2im_shape,
float* col2im_data, CpuBackendContext* cpu_backend_context) {
ruy::profiler::ScopeLabel label("TransposeConvV2/float");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
const int batch_size = input_shape.Dims(0);
TFLITE_DCHECK(col2im_data);
TFLITE_DCHECK(hwoi_ordered_filter_data);
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
@ -5653,6 +5671,9 @@ inline void TransposeConvV2(
output_data_p);
output_data_p += output_offset;
}
output_data_p = output_data;
BiasAdd(output_data_p, bias_data, batch_size, output_height, output_width,
output_depth);
}
inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
@ -5813,17 +5834,18 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift,
inline void TransposeConvV2(
const ConvParams& params, const RuntimeShape& input_shape,
const uint8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
uint8_t* output_data, const RuntimeShape& col2im_shape,
int32_t* col2im_data, int32_t* scratch_data,
CpuBackendContext* cpu_backend_context) {
ruy::profiler::ScopeLabel label("TransposeConvV2/uint8");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
const int batch_size = input_shape.Dims(0);
TFLITE_DCHECK(col2im_data);
TFLITE_DCHECK(hwoi_ordered_filter_data);
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
@ -5881,6 +5903,9 @@ inline void TransposeConvV2(
scratch_data_p += output_offset;
}
scratch_data_p = scratch_data;
BiasAdd(scratch_data_p, bias_data, batch_size, output_height, output_width,
output_depth);
Quantize(params.output_multiplier, params.output_shift,
output_shape.FlatSize(), params.output_offset, scratch_data,

View File

@ -25,8 +25,9 @@ inline void TransposeConv(
const ConvParams& params, const int32* output_multiplier,
const int32* output_shift, const RuntimeShape& input_shape,
const int8* input_data, const RuntimeShape& filter_shape,
const int8* filter_data, const RuntimeShape& output_shape,
int8* output_data, const RuntimeShape& im2col_shape, int8* im2col_data,
const int8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
const RuntimeShape& im2col_shape, int8* im2col_data,
int32* scratch_buffer) {
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
@ -41,6 +42,9 @@ inline void TransposeConv(
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
}
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
@ -99,6 +103,9 @@ inline void TransposeConv(
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
int32 acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x,
out_channel)];
if (bias_data) {
acc += bias_data[out_channel];
}
acc = MultiplyByQuantizedMultiplier(
acc, output_multiplier[out_channel], output_shift[out_channel]);
acc += output_offset;

View File

@ -387,8 +387,20 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
op_params.stride_height = stride_height;
TransposeConv(op_params, DimsToShape(input_dims), input_data,
DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
output_data, DimsToShape(im2col_dims), im2col_data);
DimsToShape(filter_dims), filter_data,
/*bias_shape*/ RuntimeShape(), /*bias*/ nullptr,
DimsToShape(output_dims), output_data, DimsToShape(im2col_dims),
im2col_data);
}
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
TransposeConv(params, input_shape, input_data, filter_shape, filter_data,
/*bias_shape*/ RuntimeShape(), /*bias*/ nullptr, output_shape,
output_data, im2col_shape, im2col_data);
}
inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,

View File

@ -2048,7 +2048,8 @@ void Transpose(const TransposeParams& params,
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& output_shape,
const float* filter_data, const RuntimeShape& bias_shape,
const float* bias_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
@ -2069,6 +2070,9 @@ inline void TransposeConv(
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
}
// Although transpose convolution simplifies to convolution with transposed
// weights for strides of 1, non-unitary striding complicates matters. To
@ -2116,16 +2120,27 @@ inline void TransposeConv(
}
}
}
if (bias_data) {
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
output_data[Offset(output_shape, batch, out_y, out_x,
out_channel)] += bias_data[out_channel];
}
}
}
}
}
}
inline void TransposeConv(const ConvParams& params,
const RuntimeShape& input_shape,
const uint8* input_data,
const RuntimeShape& filter_shape,
const uint8* filter_data,
const RuntimeShape& output_shape, uint8* output_data,
const RuntimeShape& im2col_shape, uint8* im2col_data,
int32* scratch_buffer) {
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, const RuntimeShape& im2col_shape, uint8* im2col_data,
int32* scratch_buffer) {
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
const int pad_width = params.padding_values.width;
@ -2153,6 +2168,9 @@ inline void TransposeConv(const ConvParams& params,
const int32 output_activation_min = params.quantized_activation_min;
const int32 output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
}
const int num_elements = output_shape.FlatSize();
// We need to initialize scratch_buffer to all 0s, as we apply the same
@ -2194,14 +2212,25 @@ inline void TransposeConv(const ConvParams& params,
}
}
}
for (int i = 0; i < num_elements; ++i) {
int32 acc = scratch_buffer[i];
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
acc += output_offset;
// Clamp the output before converting back to uint8.
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
output_data[i] = static_cast<uint8>(acc);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
int32 acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x,
out_channel)];
if (bias_data) {
acc += bias_data[out_channel];
}
int32 scaled_acc = MultiplyByQuantizedMultiplier(
acc, output_multiplier, output_shift);
scaled_acc += output_offset;
scaled_acc = std::max(scaled_acc, output_activation_min);
scaled_acc = std::min(scaled_acc, output_activation_max);
output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
static_cast<uint8>(scaled_acc);
}
}
}
}
}

View File

@ -203,7 +203,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(),
/* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_TILE, Register_TILE(),
/* min_version = */ 1,
/* max_version = */ 2);

View File

@ -50,6 +50,7 @@ enum KernelType {
constexpr int kOutputShapeTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kDataInputTensor = 2;
constexpr int kBiasTensor = 3;
constexpr int kOutputTensor = 0;
const int kTensorNotAllocated = -1;
@ -232,7 +233,7 @@ TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context,
GetTensorData<int8>(transposed_weights));
} else {
context->ReportError(
context, "Transpose conv only support float & uint8 right now.");
context, "Transpose conv only support float & uint8 & int8 right now.");
return kTfLiteError;
}
@ -243,8 +244,10 @@ template <KernelType kernel_type>
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
bool has_bias = NumInputs(node) == 4;
// Sanity checks on op
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE(context, has_bias || NumInputs(node) == 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
// Retrieve tensors
@ -252,6 +255,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
const TfLiteTensor* bias = nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Tensor sanity checks
@ -261,7 +266,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
input->type == kTfLiteUInt8 ||
input->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, weights->type, input->type);
if (has_bias) {
bias = GetOptionalInputTensor(context, node, kBiasTensor);
if (bias) {
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
}
} else {
TF_LITE_ENSURE_EQ(context, bias->type, input->type);
}
TF_LITE_ENSURE_EQ(context, NumElements(bias),
SizeOfDimension(weights, 0));
}
}
TF_LITE_ENSURE_EQ(context, output->type, input->type);
// Ensure that weights and inputs have the same channel dimension.
// Note: TOCO will reorder weights in the following format: OHWI.
@ -330,7 +351,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data->per_channel_output_multiplier.resize(number_channel);
data->per_channel_output_shift.resize(number_channel);
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
context, input, weights, nullptr, output, kTfLiteActNone,
context, input, weights, bias, output, kTfLiteActNone,
&data->output_multiplier, &data->output_shift,
&data->output_activation_min, &data->output_activation_max,
data->per_channel_output_multiplier.data(),
@ -343,7 +364,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params,
const OpData* data, const TfLiteTensor* input,
const TfLiteTensor* weights,
const TfLiteTensor* weights, const TfLiteTensor* bias,
const TfLiteTensor* transposed_weights, TfLiteTensor* col2im,
TfLiteTensor* output) {
tflite::ConvParams op_params;
@ -354,11 +375,13 @@ void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params,
op_params.padding_values.height_offset = data->padding.height_offset;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
switch (kernel_type) {
case kReference: {
reference_ops::TransposeConv(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(weights), GetTensorData<float>(weights),
GetTensorShape(bias), GetTensorData<float>(bias),
GetTensorShape(output), GetTensorData<float>(output),
GetTensorShape(col2im), GetTensorData<float>(col2im));
break;
@ -367,7 +390,8 @@ void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params,
optimized_ops::TransposeConvV2(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(transposed_weights),
GetTensorData<float>(transposed_weights), GetTensorShape(output),
GetTensorData<float>(transposed_weights), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output), GetTensorShape(col2im),
GetTensorData<float>(col2im),
CpuBackendContext::GetFromContext(context));
@ -380,7 +404,8 @@ template <KernelType kernel_type>
void EvalQuantized(TfLiteContext* context,
const TfLiteTransposeConvParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* weights,
const TfLiteTensor* transposed_weights, TfLiteTensor* col2im,
const TfLiteTensor* transposed_weights,
const TfLiteTensor* bias, TfLiteTensor* col2im,
TfLiteTensor* output, TfLiteTensor* scratch_buffer) {
int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -weights->params.zero_point;
@ -407,6 +432,7 @@ void EvalQuantized(TfLiteContext* context,
reference_ops::TransposeConv(
op_params, GetTensorShape(input), GetTensorData<uint8>(input),
GetTensorShape(weights), GetTensorData<uint8>(weights),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8>(output),
GetTensorShape(col2im), GetTensorData<uint8>(col2im),
GetTensorData<int32_t>(scratch_buffer));
@ -416,7 +442,8 @@ void EvalQuantized(TfLiteContext* context,
optimized_ops::TransposeConvV2(
op_params, GetTensorShape(input), GetTensorData<uint8>(input),
GetTensorShape(transposed_weights),
GetTensorData<uint8>(transposed_weights), GetTensorShape(output),
GetTensorData<uint8>(transposed_weights), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<uint8>(output), GetTensorShape(col2im),
GetTensorData<int32>(col2im), GetTensorData<int32>(scratch_buffer),
CpuBackendContext::GetFromContext(context));
@ -426,13 +453,11 @@ void EvalQuantized(TfLiteContext* context,
}
template <KernelType kernel_type>
void EvalQuantizedPerChannel(TfLiteContext* context,
const TfLiteTransposeConvParams* params,
OpData* data, const TfLiteTensor* input,
const TfLiteTensor* weights,
const TfLiteTensor* transposed_weights,
TfLiteTensor* col2im, TfLiteTensor* output,
TfLiteTensor* scratch_buffer) {
void EvalQuantizedPerChannel(
TfLiteContext* context, const TfLiteTransposeConvParams* params,
OpData* data, const TfLiteTensor* input, const TfLiteTensor* weights,
const TfLiteTensor* transposed_weights, const TfLiteTensor* bias,
TfLiteTensor* col2im, TfLiteTensor* output, TfLiteTensor* scratch_buffer) {
tflite::ConvParams op_params;
op_params.padding_type = PaddingType::kSame;
op_params.padding_values.width = data->padding.width;
@ -454,7 +479,8 @@ void EvalQuantizedPerChannel(TfLiteContext* context,
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(weights),
GetTensorData<int8>(weights), GetTensorShape(output),
GetTensorData<int8>(weights), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output), GetTensorShape(col2im),
GetTensorData<int8>(col2im), GetTensorData<int32_t>(scratch_buffer));
break;
@ -464,7 +490,8 @@ void EvalQuantizedPerChannel(TfLiteContext* context,
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(transposed_weights),
GetTensorData<int8>(transposed_weights), GetTensorShape(output),
GetTensorData<int8>(transposed_weights), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output), GetTensorShape(col2im),
GetTensorData<int32>(col2im), GetTensorData<int32>(scratch_buffer),
CpuBackendContext::GetFromContext(context));
@ -480,6 +507,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
const TfLiteTensor* bias =
(NumInputs(node) == 4)
? GetOptionalInputTensor(context, node, kBiasTensor)
: nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* col2im = data->has_col2im
@ -522,7 +553,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
ResizeAndTransposeWeights(context, weights, transposed_weights);
}
}
EvalFloat<kernel_type>(context, params, data, input, weights,
EvalFloat<kernel_type>(context, params, data, input, weights, bias,
transposed_weights, col2im, output);
break;
}
@ -539,7 +570,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
}
EvalQuantized<kernel_type>(context, params, data, input, weights,
transposed_weights, col2im, output,
transposed_weights, bias, col2im, output,
scratch_buffer);
break;
}
@ -554,8 +585,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
ResizeAndTransposeWeights(context, weights, transposed_weights);
}
EvalQuantizedPerChannel<kernel_type>(context, params, data, input,
weights, transposed_weights, col2im,
output, scratch_buffer);
weights, transposed_weights, bias,
col2im, output, scratch_buffer);
break;
}
default:

View File

@ -1,4 +1,4 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -441,6 +441,236 @@ TEST_P(TransposeConvOpTest, PaddingValidTestQuantized) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1}));
}
template <typename InputType>
class BaseTransposeConvBiasOpModel : public SingleOpModel {
public:
BaseTransposeConvBiasOpModel(TfLiteRegistration* registration,
std::initializer_list<int> output_shape_data,
const TensorData& filter,
std::initializer_list<InputType> filter_data,
const TensorData& input,
const TensorData& output, Padding padding,
int stride_w, int stride_h, TestType test_type,
int version = 3) {
if (test_type == TestType::kDynamic) {
output_shape_ = AddInput({TensorType_INT32, {4}});
filter_ = AddInput(filter);
} else {
output_shape_ = AddConstInput(TensorType_INT32, output_shape_data, {4});
filter_ = AddConstInput(filter, filter_data);
}
input_ = AddInput(input);
int bias_size = GetShape(filter_)[0];
if (input.type == TensorType_FLOAT32) {
bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
} else if (input.type == TensorType_INT8) {
// per channel quantization.
std::vector<float> bias_scale(
filter.per_channel_quantization_scales.size());
std::vector<int64_t> bias_zero_points(
filter.per_channel_quantization_scales.size());
for (size_t i = 0; i < filter.per_channel_quantization_scales.size();
++i) {
bias_scale[i] = input.scale * filter.per_channel_quantization_scales[i];
bias_zero_points[i] = 0;
}
TensorData bias{TensorType_INT32,
{bias_size},
/*min=*/0,
/*max=*/0,
/*scale=*/0,
/*zero_point=*/0,
true,
/*per_channel_quantization_scales=*/bias_scale,
/*per_channel_quantization_offsets=*/bias_zero_points,
/*channel_index==*/0};
bias_ = AddInput(bias);
} else {
// per tensor quantization.
auto bias_scale = GetScale(input_) * GetScale(filter_);
TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
bias_ = AddInput(bias);
}
output_ = AddOutput(output);
SetBuiltinOp(
BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
.Union());
resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_TRANSPOSE_CONV, registration, version);
BuildInterpreter({GetShape(output_shape_), GetShape(filter_),
GetShape(input_), GetShape(bias_)});
if (test_type == TestType::kDynamic) {
PopulateTensor<int32_t>(output_shape_, output_shape_data);
PopulateTensor<InputType>(filter_, filter_data);
}
}
void SetInput(std::initializer_list<float> data) {
if (std::is_same<InputType, uint8_t>::value) {
QuantizeAndPopulate<uint8_t>(input_, data);
} else if (std::is_same<InputType, int8_t>::value) {
QuantizeAndPopulate<int8_t>(input_, data);
} else {
PopulateTensor(input_, data);
}
}
void SetBias(std::initializer_list<float> bias) {
if (std::is_same<InputType, uint8_t>::value) {
QuantizeAndPopulate<int32_t>(bias_, bias);
} else if (std::is_same<InputType, int8_t>::value) {
PerChannelQuantizeBias(bias_, bias);
} else {
PopulateTensor(bias_, bias);
}
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
int output_shape_;
int filter_;
int input_;
int bias_;
int output_;
};
class TransposeConvOpBiasModel : public BaseTransposeConvBiasOpModel<float> {
public:
using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel;
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
// Test case:
// input_data = np.arange(1, 5).reshape(1,2,2,1).astype(np.float32)
// filter_data = np.arange(1, 19).reshape(3,3,2,1).astype(np.float32)
// bias_data = np.array([3,4])
// input = tf.keras.layers.Input(shape=(2, 2, 1))
// output = tf.keras.layers.Convolution2DTranspose(filters=2,
// kernel_size=[3, 3],
// strides=[2, 2],
// padding="valid")(input)
// model = tf.keras.models.Model(input, output)
// model.layers[1].set_weights([filter_data, bias_data])
// output = model.predict(input_data)
TEST_P(TransposeConvOpTest, MultiChannelBiasTest) {
// TODO(b/138722124): Enable these tests on NNAPI.
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
TransposeConvOpBiasModel model(
GetRegistration(), /*output_shape=*/{1, 5, 5, 2},
/*filter=*/{TensorType_FLOAT32, {2, 3, 3, 1}},
/*filter_data=*/
{1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18},
/*input=*/{TensorType_FLOAT32, {1, 2, 2, 1}},
/*output=*/{TensorType_FLOAT32, {}}, Padding_VALID,
/*stride_w=*/2, /*stride_h=*/2, GetTestType(), /* version */ 3);
model.SetInput({1, 2, 3, 4});
model.SetBias({3, 4});
model.Invoke();
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({4, 6, 6, 8, 10, 14, 9, 12, 13, 16, 10, 12, 12,
14, 28, 32, 21, 24, 25, 28, 19, 24, 27, 32, 65, 76,
45, 52, 57, 64, 24, 28, 30, 34, 64, 72, 39, 44, 47,
52, 42, 46, 48, 52, 106, 114, 63, 68, 71, 76}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
}
class QuantizedTransposeConvBiasOpModel
: public BaseTransposeConvBiasOpModel<uint8_t> {
public:
using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel;
std::vector<float> GetDequantizedOutput() {
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
GetScale(output_), GetZeroPoint(output_));
}
};
TEST_P(TransposeConvOpTest, SimpleBiasTestQuantized) {
// TODO(b/138722124): Enable these tests on NNAPI.
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
// Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9}
std::initializer_list<uint8_t> filter_data = {129, 131, 133, 135, 137,
139, 141, 143, 145};
QuantizedTransposeConvBiasOpModel model(
GetRegistration(), {1, 4, 4, 1},
{TensorType_UINT8, {1, 3, 3, 1}, -63.5, 64}, filter_data,
{TensorType_UINT8, {1, 4, 4, 1}, -63.5, 64},
{TensorType_UINT8, {}, -508, 512}, Padding_SAME, 1, 1, GetTestType(),
/* version */ 3);
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
model.SetBias({1});
model.Invoke();
EXPECT_THAT(
model.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear({32, 64, 84, 76, 100, 192, 240, 200, 208,
372, 420, 332, 264, 448, 488, 368},
1e-5)));
// GetOutputShape() should always be same as model.SetOutputShape(...);
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
class PerChannelQuantizedTransposeConvBiasOpModel
: public BaseTransposeConvBiasOpModel<int8_t> {
public:
using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel;
std::vector<float> GetDequantizedOutput() {
return Dequantize<int8_t>(ExtractVector<int8_t>(output_), GetScale(output_),
GetZeroPoint(output_));
}
void SetInput(const std::initializer_list<float>& data) {
QuantizeAndPopulate<int8_t>(input_, data);
}
void SetFilter(const std::initializer_list<float>& data) {
PerChannelSymmetricQuantizeAndPopulate(filter_, data);
}
};
TEST_P(TransposeConvOpTest, SimpleBiasTestQuantizedPerChannelSingleChannel) {
// TODO(b/138722124): Enable these tests on NNAPI.
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
const std::initializer_list<float> filter_data = {1, 2, 3, 4, 5, 6, 7, 8, 9};
PerChannelQuantizedTransposeConvBiasOpModel model(
GetRegistration(), {1, 4, 4, 1},
{TensorType_INT8, {1, 3, 3, 1}, 0, 0, 0, 0, true, {9.0 / 127}, {0}, 0},
{}, {TensorType_INT8, {1, 4, 4, 1}, 0, 0, 16.0 / 255, -128},
{TensorType_INT8, {}, 0, 0, 2, -128}, Padding_SAME, 1, 1, GetTestType(),
/* version */ 3);
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
model.SetFilter(filter_data);
model.SetBias({1});
model.Invoke();
EXPECT_THAT(
model.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear({30, 62, 84, 76, 100, 194, 238, 200, 208,
372, 418, 330, 264, 446, 486, 366},
1e-5)));
// GetOutputShape() should always be same as model.SetOutputShape(...);
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
INSTANTIATE_TEST_SUITE_P(
TransposeConvOpTest, TransposeConvOpTest,
::testing::Combine(

View File

@ -17,10 +17,10 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"
namespace toco {
@ -30,7 +30,8 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
const string& weights_name = op.inputs[1];
const auto& weights_shape = model.GetArray(weights_name).shape();
if (op.type == OperatorType::kConv ||
op.type == OperatorType::kFullyConnected) {
op.type == OperatorType::kFullyConnected ||
op.type == OperatorType::kTransposeConv) {
return weights_shape.dims(0);
}
if (op.type == OperatorType::kDepthwiseConv) {
@ -40,8 +41,19 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
return 0;
}
bool CheckOpInputSize(const Operator& op) {
if (op.type == OperatorType::kConv ||
op.type == OperatorType::kFullyConnected ||
op.type == OperatorType::kDepthwiseConv) {
return (op.inputs.size() >= 3);
} else if (op.type == OperatorType::kTransposeConv) {
return (op.inputs.size() >= 4);
}
return true;
}
bool ProcessLinearOperator(Model* model, Operator* op) {
if (op->inputs.size() >= 3) {
if (CheckOpInputSize(*op)) {
return false;
}
const string& output_name = op->outputs[0];
@ -52,7 +64,6 @@ bool ProcessLinearOperator(Model* model, Operator* op) {
const int depth = GetOutputDepthFromWeights(*model, *op);
const string& bias_name = AvailableArrayName(*model, output_name + "_bias");
op->inputs.push_back(bias_name);
DCHECK_EQ(op->inputs.size(), 3);
auto& bias_array = model->GetOrCreateArray(bias_name);
bias_array.data_type = ArrayDataType::kFloat;
bias_array.mutable_shape()->mutable_dims()->push_back(depth);
@ -68,7 +79,8 @@ bool ProcessLinearOperator(Model* model, Operator* op) {
auto* op = model->operators[op_index].get();
if (op->type == OperatorType::kConv ||
op->type == OperatorType::kDepthwiseConv ||
op->type == OperatorType::kFullyConnected) {
op->type == OperatorType::kFullyConnected ||
op->type == OperatorType::kTransposeConv) {
if (ProcessLinearOperator(model, op)) {
AddMessageF("Added bias vector to %s as %s", LogName(*op), op->inputs[2]);
*modified = true;

View File

@ -17,16 +17,28 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"
namespace toco {
namespace {
int GetBiasIndex(const Operator& op) {
if (op.type == OperatorType::kConv ||
op.type == OperatorType::kFullyConnected ||
op.type == OperatorType::kDepthwiseConv) {
return 2;
} else if (op.type == OperatorType::kTransposeConv) {
return 3;
}
LOG(FATAL) << "Unhandled operator type";
return 0;
}
void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
const Operator* add_or_sub_op,
int index_of_constant_input) {
@ -36,7 +48,8 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
if (preceding_op->inputs.size() < 3) {
LOG(FATAL) << "Missing bias parameter";
}
auto& bias = model->GetArray(preceding_op->inputs[2]);
const auto bias_ind = GetBiasIndex(*preceding_op);
auto& bias = model->GetArray(preceding_op->inputs[bias_ind]);
bias.minmax = nullptr;
const auto& operand =
model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
@ -101,7 +114,8 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
LOG(FATAL) << "Missing bias parameter";
}
const auto& weights_name = preceding_op->inputs[1];
const auto& bias_name = preceding_op->inputs[2];
const auto bias_ind = GetBiasIndex(*preceding_op);
const auto& bias_name = preceding_op->inputs[bias_ind];
auto& weights = model->GetArray(weights_name);
DropMinMax(model, weights_name);
auto& bias = model->GetArray(bias_name);
@ -136,7 +150,8 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
int output_depth;
if (preceding_op->type == OperatorType::kConv ||
preceding_op->type == OperatorType::kFullyConnected) {
preceding_op->type == OperatorType::kFullyConnected ||
preceding_op->type == OperatorType::kTransposeConv) {
output_depth = weights_shape.dims(0);
} else if (preceding_op->type == OperatorType::kDepthwiseConv) {
output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1);
@ -253,7 +268,8 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
if (preceding_op->type != OperatorType::kConv &&
preceding_op->type != OperatorType::kFullyConnected &&
preceding_op->type != OperatorType::kDepthwiseConv) {
preceding_op->type != OperatorType::kDepthwiseConv &&
preceding_op->type != OperatorType::kTransposeConv) {
AddMessageF(
"Not fusing %s because the preceding %s is not of one of the supported "
"types",
@ -261,6 +277,13 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
return ::tensorflow::Status::OK();
}
if (preceding_op->type == OperatorType::kTransposeConv &&
binary_op->type != OperatorType::kAdd) {
AddMessageF("Not fusing %s to preceding %s", LogName(*binary_op),
LogName(*preceding_op));
return ::tensorflow::Status::OK();
}
if (preceding_op->fused_activation_function !=
FusedActivationFunctionType::kNone) {
AddMessageF(
@ -278,7 +301,8 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
}
const auto& weights_name = preceding_op->inputs[1];
const auto& bias_name = preceding_op->inputs[2];
const auto bias_ind = GetBiasIndex(*preceding_op);
const auto& bias_name = preceding_op->inputs[bias_ind];
const auto& weights = model->GetArray(weights_name);
const auto& bias = model->GetArray(bias_name);

View File

@ -244,6 +244,13 @@ bool ChooseQuantizationForOperatorInput(
weights_input_index = 1;
}
}
if (op.type == OperatorType::kTransposeConv) {
if (input_index == 3) {
is_bias_vector = true;
activations_input_index = 2;
weights_input_index = 1;
}
}
if (op.type == OperatorType::kLstmCell) {
if (input_index == LstmCellOperator::BIASES_INPUT) {
is_bias_vector = true;

View File

@ -99,3 +99,15 @@ tf_cc_test(
"@com_google_googletest//:gtest_main",
],
)
tf_cc_test(
name = "fuse_binary_into_preceding_affine_test",
srcs = ["fuse_binary_into_preceding_affine_test.cc"],
deps = [
"//tensorflow/lite/toco:graph_transformations",
"//tensorflow/lite/toco:model",
"//tensorflow/lite/toco:tooling_util",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -0,0 +1,115 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tooling_util.h"
namespace toco {
namespace {
// A gmock matcher that check that elements of a float vector match to a given
// tolerance.
std::vector<testing::Matcher<float>> ArrayFloatNear(
const std::vector<float>& values, float max_abs_error = 1e-5) {
std::vector<testing::Matcher<float>> matchers;
matchers.reserve(values.size());
for (const float& v : values) {
matchers.emplace_back(testing::FloatNear(v, max_abs_error));
}
return matchers;
}
} // namespace
class FuseBinaryIntoPrecedingAffineTest : public ::testing::Test {
protected:
FuseBinaryIntoPrecedingAffineTest() {}
void SetUp() override { model_.reset(new Model); }
void CreateArray(const string& name, const std::vector<int>& shape) {
Array& array = model_->GetOrCreateArray(name);
array.data_type = ArrayDataType::kFloat;
Shape* array_shape = array.mutable_shape();
*(array_shape->mutable_dims()) = shape;
}
void CreateConstantArray(const string& name, const std::vector<int>& shape,
const std::vector<float>& data) {
CreateArray(name, shape);
Array& array = model_->GetOrCreateArray(name);
auto& array_buffer = array.GetMutableBuffer<ArrayDataType::kFloat>();
int bufsize = 1;
for (int dim : shape) {
bufsize *= dim;
}
array_buffer.data.resize(bufsize);
float* buf_ptr = array_buffer.data.data();
for (int i = 0; i < bufsize; ++i) {
buf_ptr[i] = data[i];
}
}
std::unique_ptr<Model> model_;
};
TEST_F(FuseBinaryIntoPrecedingAffineTest, FuseAddIntoTransposeConv) {
// Creating a model.
{
CreateConstantArray(/*name=*/"OutputShape",
/*shape=*/{1, 2}, /*data=*/{2, 2});
CreateConstantArray("TransConvWeight", {2, 2}, {1.0, 2.0, 3.0, 4.0});
CreateConstantArray("TransConvBias", {1}, {1.0});
CreateArray(/*name=*/"TransConvInput",
/*shape=*/{2, 2});
CreateArray("TransConvOutput", {2, 2});
CreateConstantArray("AddInput2", {1}, {2.0});
CreateArray("AddOutput", {2, 2});
auto* tc_op = new TransposeConvOperator;
tc_op->inputs = {"OutputShape", "TransConvWeight", "TransConvInput",
"TransConvBias"};
tc_op->outputs = {"TransConvOutput"};
model_->operators.push_back(std::unique_ptr<Operator>(tc_op));
auto* add_op = new AddOperator;
add_op->inputs = {"TransConvOutput", "AddInput2"};
add_op->outputs = {"AddOutput"};
model_->operators.push_back(std::unique_ptr<Operator>(add_op));
}
toco::FuseBinaryIntoPrecedingAffine transformation;
bool modified;
ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok());
EXPECT_TRUE(modified);
// `Add` should be fused into `TransposeConv`. Only 1 op is left.
ASSERT_EQ(model_->operators.size(), 1);
const auto& op = model_->operators[0];
ASSERT_EQ(op->type, OperatorType::kTransposeConv);
ASSERT_EQ(op->inputs.size(), 4);
auto& weights_array = model_->GetArray(op->inputs[1]);
EXPECT_THAT(weights_array.GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear({1.0, 2.0, 3.0, 4.0})));
auto& bias_array = model_->GetArray(op->inputs[3]);
EXPECT_THAT(bias_array.GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear({3.0})));
}
} // namespace toco

View File

@ -1200,6 +1200,8 @@ struct SqueezeOperator : Operator {
// inputs[0]: required: the output shape
// inputs[1]: required: the weights
// inputs[2]: required: the input activations array
// inputs[3]: optional: the bias vector, specifying the biases for each output
// channel.
// NOTE: The input activations is NOT the first input.
//
//
@ -1212,6 +1214,7 @@ struct TransposeConvOperator : Operator {
OUTPUT_SHAPE = 0,
WEIGHTS = 1,
DATA_INPUT = 2,
BIAS = 3,
};
TransposeConvOperator() : Operator(OperatorType::kTransposeConv) {}

View File

@ -148,6 +148,8 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kArgMin, 1}, "1.9.0"},
{{OperatorType::kArgMin, 2}, "1.14.0"},
{{OperatorType::kTransposeConv, 1}, "1.9.0"},
{{OperatorType::kTransposeConv, 2}, kPendingReleaseOpVersion},
{{OperatorType::kTransposeConv, 3}, kPendingReleaseOpVersion},
{{OperatorType::kSparseToDense, 1}, "1.9.0"},
{{OperatorType::kSparseToDense, 2}, "1.14.0"},
{{OperatorType::kSparseToDense, 3}, "1.15.0"},

View File

@ -130,9 +130,10 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
tensor_property.per_axis = true;
tensor_property.per_axis_index = 0;
tensor_property.symmetric = true;
property.inputs = {{1, tensor_property}, {2, {}}};
property.inputs = {{2, {}}, {1, tensor_property}};
property.outputs = {{0, {}}};
property.version = 2;
property.biases = {3};
property.version = 3;
break;
}
case BuiltinOperator_DEPTHWISE_CONV_2D: {

View File

@ -198,6 +198,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
case BuiltinOperator_TRANSPOSE_CONV:
// If the op has 4 inputs, it is version 3.
if (op_sig.input_types.size() == 4) {
return 3;
}
// If the op takes int8 input, it is version 2.
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;

View File

@ -432,6 +432,13 @@ TEST(OpVersionTest, VersioningTransposeConvOperatorTest) {
.input_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_TRANSPOSE_CONV,
.input_types = std::vector<TensorType>{TensorType_INT32, TensorType_INT8,
TensorType_INT8, TensorType_INT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
}
TEST(OpVersionTest, VersioningSVDFOperatorTest) {