diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index ed2e6cd129f..3508db46a7c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -478,6 +478,7 @@ def TFL_TransposeConvOp: TFL_1DTensorOf<[I32]>:$output_shape, TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$weights, TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$input, + TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_PaddingAttr:$padding, I32Attr:$stride_h, I32Attr:$stride_w diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index d6f2a83984f..91b38ab8d51 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1296,10 +1296,12 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, % // CHECK-LABEL: conv2d_backprop_input // CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> - // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32> + // CHECK: %[[CST_0:.*]] = constant unit + // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> // CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> - // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32> + // CHECK: %[[CST_2:.*]] = constant unit + // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2, %[[CST_2]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32> // CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index efa8234ab03..697e93e582c 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -2032,7 +2032,8 @@ func @testFullyConnectedWithBadOutputShape(%arg0: tensor<1x37xf32>, %arg1: tenso // ----- func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { - %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> + %cst = constant unit + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> return %0 : tensor<1x64x84x32xf32> } @@ -2046,8 +2047,9 @@ func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tenso // ----- func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> { + %cst = constant unit // expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}} - %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<64x84x32xf32> return %0 : tensor<64x84x32xf32> } @@ -2055,8 +2057,9 @@ func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x func @testTransposeConvBadOutputShape(%arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32> { %cst = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_1 = constant unit // expected-error @+1 {{expect output type tensor<1x64x84x32xf32>, got tensor<1x64x84x31xf32>}} - %0 = "tfl.transpose_conv"(%cst, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32> + %0 = "tfl.transpose_conv"(%cst, %arg1, %arg2, %cst_1) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x31xf32> return %0 : tensor<1x64x84x31xf32> } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 586ddf6211f..12796b86b1a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -58,6 +58,9 @@ def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">; def HasSameStaticShapes : Constraint; def HasNotSameStaticShapes : Constraint, "op must have not static same input shapes">; +def CreateNoneValue : NativeCodeCall< + "$_builder.create($0.getLoc(), $_builder.getNoneType(), $_builder.getUnitAttr())">; + // Checks if the value has only one user. // TODO(karimnosseir): Move to a common place? def HasOneUse : Constraint>; @@ -343,6 +346,7 @@ def : Pat< (TF_TransposeOp $filter, (ConstantOp ConstantAttr, "{2, 0, 1, 3}">)), $out_backprop, + /*bias=*/ (CreateNoneValue $input_sizes), /*padding=*/ $padding, /*stride_h=*/ ExtractI32At<1>:$strides, /*stride_w=*/ ExtractI32At<2>:$strides)>; diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h index 123e0a0082c..36519dd606f 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h @@ -25,16 +25,17 @@ inline void TransposeConvV2( const ConvParams& params, const int32* output_multiplier, const int32* output_shift, const RuntimeShape& input_shape, const int8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape, - const int8_t* hwoi_ordered_filter_data, const RuntimeShape& output_shape, + const int8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, int8_t* output_data, const RuntimeShape& col2im_shape, int32_t* col2im_data, int32_t* scratch_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("TransposeConvV2/int8"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4); - const int batch_size = input_shape.Dims(0); TFLITE_DCHECK(col2im_data); TFLITE_DCHECK(hwoi_ordered_filter_data); + const int batch_size = MatchingDim(input_shape, 0, output_shape, 0); const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); @@ -93,6 +94,9 @@ inline void TransposeConvV2( scratch_data_p += output_offset; } + scratch_data_p = scratch_data; + optimized_ops::BiasAdd(scratch_data_p, bias_data, batch_size, output_height, + output_width, output_depth); const int32_t output_min = std::numeric_limits::min(); const int32_t output_max = std::numeric_limits::max(); diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index bc8b9b2d3ac..f206dfa9235 100644 --- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -2946,6 +2946,18 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, output_data, DimsToShape(im2col_dims), im2col_data); } +inline void TransposeConvV2( + const ConvParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape, + const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape, + float* output_data, const RuntimeShape& col2im_shape, float* col2im_data, + CpuBackendContext* cpu_backend_context) { + TransposeConvV2(params, input_shape, input_data, hwoi_ordered_filter_shape, + hwoi_ordered_filter_data, /*bias_shape*/ RuntimeShape(), + /*bias_data*/ nullptr, output_shape, output_data, + col2im_shape, col2im_data, cpu_backend_context); +} + template void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, const Dims<4>& filter_dims, int stride_width, diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index ce9073773a5..38bd7bd4057 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -5588,20 +5588,38 @@ void Col2im(const T* col_data, const int depth, const int height, } } +template +void BiasAdd(T* im_data, const T* bias_data, const int batch_size, + const int height, const int width, const int depth) { + if (bias_data) { + for (int n = 0; n < batch_size; ++n) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + for (int d = 0; d < depth; ++d) { + im_data[d] += bias_data[d]; + } + im_data += depth; + } + } + } + } +} + // TransposeConvV2 expect the weights in HWOI order. inline void TransposeConvV2( const ConvParams& params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape, - const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape, - float* output_data, const RuntimeShape& col2im_shape, float* col2im_data, - CpuBackendContext* cpu_backend_context) { + const float* hwoi_ordered_filter_data, const RuntimeShape& bias_shape, + const float* bias_data, const RuntimeShape& output_shape, + float* const output_data, const RuntimeShape& col2im_shape, + float* col2im_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("TransposeConvV2/float"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4); - const int batch_size = input_shape.Dims(0); TFLITE_DCHECK(col2im_data); TFLITE_DCHECK(hwoi_ordered_filter_data); + const int batch_size = MatchingDim(input_shape, 0, output_shape, 0); const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); @@ -5653,6 +5671,9 @@ inline void TransposeConvV2( output_data_p); output_data_p += output_offset; } + output_data_p = output_data; + BiasAdd(output_data_p, bias_data, batch_size, output_height, output_width, + output_depth); } inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size, @@ -5813,17 +5834,18 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, inline void TransposeConvV2( const ConvParams& params, const RuntimeShape& input_shape, const uint8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape, - const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& output_shape, + const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, uint8_t* output_data, const RuntimeShape& col2im_shape, int32_t* col2im_data, int32_t* scratch_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("TransposeConvV2/uint8"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4); - const int batch_size = input_shape.Dims(0); TFLITE_DCHECK(col2im_data); TFLITE_DCHECK(hwoi_ordered_filter_data); + const int batch_size = MatchingDim(input_shape, 0, output_shape, 0); const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); @@ -5881,6 +5903,9 @@ inline void TransposeConvV2( scratch_data_p += output_offset; } + scratch_data_p = scratch_data; + BiasAdd(scratch_data_p, bias_data, batch_size, output_height, output_width, + output_depth); Quantize(params.output_multiplier, params.output_shift, output_shape.FlatSize(), params.output_offset, scratch_data, diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h b/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h index 1ad6e20f2dc..422adc2a333 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h @@ -25,8 +25,9 @@ inline void TransposeConv( const ConvParams& params, const int32* output_multiplier, const int32* output_shift, const RuntimeShape& input_shape, const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& output_shape, - int8* output_data, const RuntimeShape& im2col_shape, int8* im2col_data, + const int8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, + const RuntimeShape& im2col_shape, int8* im2col_data, int32* scratch_buffer) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; @@ -41,6 +42,9 @@ inline void TransposeConv( const int batches = MatchingDim(input_shape, 0, output_shape, 0); const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (bias_data) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } const int input_height = input_shape.Dims(1); const int input_width = input_shape.Dims(2); const int filter_height = filter_shape.Dims(1); @@ -99,6 +103,9 @@ inline void TransposeConv( for (int out_channel = 0; out_channel < output_depth; ++out_channel) { int32 acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x, out_channel)]; + if (bias_data) { + acc += bias_data[out_channel]; + } acc = MultiplyByQuantizedMultiplier( acc, output_multiplier[out_channel], output_shift[out_channel]); acc += output_offset; diff --git a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h index 2148be45590..f62c9bd197c 100644 --- a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h @@ -387,8 +387,20 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, op_params.stride_height = stride_height; TransposeConv(op_params, DimsToShape(input_dims), input_data, - DimsToShape(filter_dims), filter_data, DimsToShape(output_dims), - output_data, DimsToShape(im2col_dims), im2col_data); + DimsToShape(filter_dims), filter_data, + /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr, + DimsToShape(output_dims), output_data, DimsToShape(im2col_dims), + im2col_data); +} + +inline void TransposeConv( + const ConvParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& filter_shape, + const float* filter_data, const RuntimeShape& output_shape, + float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) { + TransposeConv(params, input_shape, input_data, filter_shape, filter_data, + /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr, output_shape, + output_data, im2col_shape, im2col_data); } inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index e991d4e758c..f40b268b443 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -2048,7 +2048,8 @@ void Transpose(const TransposeParams& params, inline void TransposeConv( const ConvParams& params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& filter_shape, - const float* filter_data, const RuntimeShape& output_shape, + const float* filter_data, const RuntimeShape& bias_shape, + const float* bias_data, const RuntimeShape& output_shape, float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; @@ -2069,6 +2070,9 @@ inline void TransposeConv( const int filter_width = filter_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + if (bias_data) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } // Although transpose convolution simplifies to convolution with transposed // weights for strides of 1, non-unitary striding complicates matters. To @@ -2116,16 +2120,27 @@ inline void TransposeConv( } } } + if (bias_data) { + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + output_data[Offset(output_shape, batch, out_y, out_x, + out_channel)] += bias_data[out_channel]; + } + } + } + } + } } -inline void TransposeConv(const ConvParams& params, - const RuntimeShape& input_shape, - const uint8* input_data, - const RuntimeShape& filter_shape, - const uint8* filter_data, - const RuntimeShape& output_shape, uint8* output_data, - const RuntimeShape& im2col_shape, uint8* im2col_data, - int32* scratch_buffer) { +inline void TransposeConv( + const ConvParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data, const RuntimeShape& im2col_shape, uint8* im2col_data, + int32* scratch_buffer) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int pad_width = params.padding_values.width; @@ -2153,6 +2168,9 @@ inline void TransposeConv(const ConvParams& params, const int32 output_activation_min = params.quantized_activation_min; const int32 output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (bias_data) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } const int num_elements = output_shape.FlatSize(); // We need to initialize scratch_buffer to all 0s, as we apply the same @@ -2194,14 +2212,25 @@ inline void TransposeConv(const ConvParams& params, } } } - for (int i = 0; i < num_elements; ++i) { - int32 acc = scratch_buffer[i]; - acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift); - acc += output_offset; - // Clamp the output before converting back to uint8. - acc = std::max(acc, output_activation_min); - acc = std::min(acc, output_activation_max); - output_data[i] = static_cast(acc); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + int32 acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x, + out_channel)]; + if (bias_data) { + acc += bias_data[out_channel]; + } + int32 scaled_acc = MultiplyByQuantizedMultiplier( + acc, output_multiplier, output_shift); + scaled_acc += output_offset; + scaled_acc = std::max(scaled_acc, output_activation_min); + scaled_acc = std::min(scaled_acc, output_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] = + static_cast(scaled_acc); + } + } + } } } diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index f5608b1a820..28515ae9f77 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -203,7 +203,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_COS, Register_COS()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_TILE, Register_TILE(), /* min_version = */ 1, /* max_version = */ 2); diff --git a/tensorflow/lite/kernels/transpose_conv.cc b/tensorflow/lite/kernels/transpose_conv.cc index 114b9ae48f4..9b2767f15a9 100644 --- a/tensorflow/lite/kernels/transpose_conv.cc +++ b/tensorflow/lite/kernels/transpose_conv.cc @@ -50,6 +50,7 @@ enum KernelType { constexpr int kOutputShapeTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kDataInputTensor = 2; +constexpr int kBiasTensor = 3; constexpr int kOutputTensor = 0; const int kTensorNotAllocated = -1; @@ -232,7 +233,7 @@ TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context, GetTensorData(transposed_weights)); } else { context->ReportError( - context, "Transpose conv only support float & uint8 right now."); + context, "Transpose conv only support float & uint8 & int8 right now."); return kTfLiteError; } @@ -243,8 +244,10 @@ template TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); + bool has_bias = NumInputs(node) == 4; + // Sanity checks on op - TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE(context, has_bias || NumInputs(node) == 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); // Retrieve tensors @@ -252,6 +255,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { GetInput(context, node, kOutputShapeTensor); const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); + const TfLiteTensor* bias = nullptr; + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Tensor sanity checks @@ -261,7 +266,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 || input->type == kTfLiteUInt8 || input->type == kTfLiteInt8); - TF_LITE_ENSURE_EQ(context, weights->type, input->type); + + if (has_bias) { + bias = GetOptionalInputTensor(context, node, kBiasTensor); + if (bias) { + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); + } + } else { + TF_LITE_ENSURE_EQ(context, bias->type, input->type); + } + TF_LITE_ENSURE_EQ(context, NumElements(bias), + SizeOfDimension(weights, 0)); + } + } + TF_LITE_ENSURE_EQ(context, output->type, input->type); // Ensure that weights and inputs have the same channel dimension. // Note: TOCO will reorder weights in the following format: OHWI. @@ -330,7 +351,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { data->per_channel_output_multiplier.resize(number_channel); data->per_channel_output_shift.resize(number_channel); TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( - context, input, weights, nullptr, output, kTfLiteActNone, + context, input, weights, bias, output, kTfLiteActNone, &data->output_multiplier, &data->output_shift, &data->output_activation_min, &data->output_activation_max, data->per_channel_output_multiplier.data(), @@ -343,7 +364,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params, const OpData* data, const TfLiteTensor* input, - const TfLiteTensor* weights, + const TfLiteTensor* weights, const TfLiteTensor* bias, const TfLiteTensor* transposed_weights, TfLiteTensor* col2im, TfLiteTensor* output) { tflite::ConvParams op_params; @@ -354,11 +375,13 @@ void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params, op_params.padding_values.height_offset = data->padding.height_offset; op_params.stride_width = params->stride_width; op_params.stride_height = params->stride_height; + switch (kernel_type) { case kReference: { reference_ops::TransposeConv( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(weights), GetTensorData(weights), + GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(col2im), GetTensorData(col2im)); break; @@ -367,7 +390,8 @@ void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params, optimized_ops::TransposeConvV2( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(transposed_weights), - GetTensorData(transposed_weights), GetTensorShape(output), + GetTensorData(transposed_weights), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(col2im), GetTensorData(col2im), CpuBackendContext::GetFromContext(context)); @@ -380,7 +404,8 @@ template void EvalQuantized(TfLiteContext* context, const TfLiteTransposeConvParams* params, OpData* data, const TfLiteTensor* input, const TfLiteTensor* weights, - const TfLiteTensor* transposed_weights, TfLiteTensor* col2im, + const TfLiteTensor* transposed_weights, + const TfLiteTensor* bias, TfLiteTensor* col2im, TfLiteTensor* output, TfLiteTensor* scratch_buffer) { int32_t input_offset = -input->params.zero_point; int32_t filter_offset = -weights->params.zero_point; @@ -407,6 +432,7 @@ void EvalQuantized(TfLiteContext* context, reference_ops::TransposeConv( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(weights), GetTensorData(weights), + GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(col2im), GetTensorData(col2im), GetTensorData(scratch_buffer)); @@ -416,7 +442,8 @@ void EvalQuantized(TfLiteContext* context, optimized_ops::TransposeConvV2( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(transposed_weights), - GetTensorData(transposed_weights), GetTensorShape(output), + GetTensorData(transposed_weights), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(col2im), GetTensorData(col2im), GetTensorData(scratch_buffer), CpuBackendContext::GetFromContext(context)); @@ -426,13 +453,11 @@ void EvalQuantized(TfLiteContext* context, } template -void EvalQuantizedPerChannel(TfLiteContext* context, - const TfLiteTransposeConvParams* params, - OpData* data, const TfLiteTensor* input, - const TfLiteTensor* weights, - const TfLiteTensor* transposed_weights, - TfLiteTensor* col2im, TfLiteTensor* output, - TfLiteTensor* scratch_buffer) { +void EvalQuantizedPerChannel( + TfLiteContext* context, const TfLiteTransposeConvParams* params, + OpData* data, const TfLiteTensor* input, const TfLiteTensor* weights, + const TfLiteTensor* transposed_weights, const TfLiteTensor* bias, + TfLiteTensor* col2im, TfLiteTensor* output, TfLiteTensor* scratch_buffer) { tflite::ConvParams op_params; op_params.padding_type = PaddingType::kSame; op_params.padding_values.width = data->padding.width; @@ -454,7 +479,8 @@ void EvalQuantizedPerChannel(TfLiteContext* context, op_params, data->per_channel_output_multiplier.data(), data->per_channel_output_shift.data(), GetTensorShape(input), GetTensorData(input), GetTensorShape(weights), - GetTensorData(weights), GetTensorShape(output), + GetTensorData(weights), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(col2im), GetTensorData(col2im), GetTensorData(scratch_buffer)); break; @@ -464,7 +490,8 @@ void EvalQuantizedPerChannel(TfLiteContext* context, op_params, data->per_channel_output_multiplier.data(), data->per_channel_output_shift.data(), GetTensorShape(input), GetTensorData(input), GetTensorShape(transposed_weights), - GetTensorData(transposed_weights), GetTensorShape(output), + GetTensorData(transposed_weights), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(col2im), GetTensorData(col2im), GetTensorData(scratch_buffer), CpuBackendContext::GetFromContext(context)); @@ -480,6 +507,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetInput(context, node, kOutputShapeTensor); const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); + const TfLiteTensor* bias = + (NumInputs(node) == 4) + ? GetOptionalInputTensor(context, node, kBiasTensor) + : nullptr; TfLiteTensor* output = GetOutput(context, node, kOutputTensor); OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* col2im = data->has_col2im @@ -522,7 +553,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { ResizeAndTransposeWeights(context, weights, transposed_weights); } } - EvalFloat(context, params, data, input, weights, + EvalFloat(context, params, data, input, weights, bias, transposed_weights, col2im, output); break; } @@ -539,7 +570,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } } EvalQuantized(context, params, data, input, weights, - transposed_weights, col2im, output, + transposed_weights, bias, col2im, output, scratch_buffer); break; } @@ -554,8 +585,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { ResizeAndTransposeWeights(context, weights, transposed_weights); } EvalQuantizedPerChannel(context, params, data, input, - weights, transposed_weights, col2im, - output, scratch_buffer); + weights, transposed_weights, bias, + col2im, output, scratch_buffer); break; } default: diff --git a/tensorflow/lite/kernels/transpose_conv_test.cc b/tensorflow/lite/kernels/transpose_conv_test.cc index 8f74a943b53..77dc22b13e8 100644 --- a/tensorflow/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/kernels/transpose_conv_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -441,6 +441,236 @@ TEST_P(TransposeConvOpTest, PaddingValidTestQuantized) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1})); } +template +class BaseTransposeConvBiasOpModel : public SingleOpModel { + public: + BaseTransposeConvBiasOpModel(TfLiteRegistration* registration, + std::initializer_list output_shape_data, + const TensorData& filter, + std::initializer_list filter_data, + const TensorData& input, + const TensorData& output, Padding padding, + int stride_w, int stride_h, TestType test_type, + int version = 3) { + if (test_type == TestType::kDynamic) { + output_shape_ = AddInput({TensorType_INT32, {4}}); + filter_ = AddInput(filter); + } else { + output_shape_ = AddConstInput(TensorType_INT32, output_shape_data, {4}); + filter_ = AddConstInput(filter, filter_data); + } + input_ = AddInput(input); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else if (input.type == TensorType_INT8) { + // per channel quantization. + std::vector bias_scale( + filter.per_channel_quantization_scales.size()); + std::vector bias_zero_points( + filter.per_channel_quantization_scales.size()); + for (size_t i = 0; i < filter.per_channel_quantization_scales.size(); + ++i) { + bias_scale[i] = input.scale * filter.per_channel_quantization_scales[i]; + bias_zero_points[i] = 0; + } + TensorData bias{TensorType_INT32, + {bias_size}, + /*min=*/0, + /*max=*/0, + /*scale=*/0, + /*zero_point=*/0, + true, + /*per_channel_quantization_scales=*/bias_scale, + /*per_channel_quantization_offsets=*/bias_zero_points, + /*channel_index==*/0}; + bias_ = AddInput(bias); + } else { + // per tensor quantization. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions, + CreateTransposeConvOptions(builder_, padding, stride_w, stride_h) + .Union()); + resolver_ = absl::make_unique( + BuiltinOperator_TRANSPOSE_CONV, registration, version); + BuildInterpreter({GetShape(output_shape_), GetShape(filter_), + GetShape(input_), GetShape(bias_)}); + + if (test_type == TestType::kDynamic) { + PopulateTensor(output_shape_, output_shape_data); + PopulateTensor(filter_, filter_data); + } + } + + void SetInput(std::initializer_list data) { + if (std::is_same::value) { + QuantizeAndPopulate(input_, data); + } else if (std::is_same::value) { + QuantizeAndPopulate(input_, data); + } else { + PopulateTensor(input_, data); + } + } + + void SetBias(std::initializer_list bias) { + if (std::is_same::value) { + QuantizeAndPopulate(bias_, bias); + } else if (std::is_same::value) { + PerChannelQuantizeBias(bias_, bias); + } else { + PopulateTensor(bias_, bias); + } + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int output_shape_; + int filter_; + int input_; + int bias_; + int output_; +}; + +class TransposeConvOpBiasModel : public BaseTransposeConvBiasOpModel { + public: + using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +// Test case: +// input_data = np.arange(1, 5).reshape(1,2,2,1).astype(np.float32) +// filter_data = np.arange(1, 19).reshape(3,3,2,1).astype(np.float32) +// bias_data = np.array([3,4]) +// input = tf.keras.layers.Input(shape=(2, 2, 1)) +// output = tf.keras.layers.Convolution2DTranspose(filters=2, +// kernel_size=[3, 3], +// strides=[2, 2], +// padding="valid")(input) +// model = tf.keras.models.Model(input, output) +// model.layers[1].set_weights([filter_data, bias_data]) +// output = model.predict(input_data) +TEST_P(TransposeConvOpTest, MultiChannelBiasTest) { + // TODO(b/138722124): Enable these tests on NNAPI. + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + TransposeConvOpBiasModel model( + GetRegistration(), /*output_shape=*/{1, 5, 5, 2}, + /*filter=*/{TensorType_FLOAT32, {2, 3, 3, 1}}, + /*filter_data=*/ + {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + /*input=*/{TensorType_FLOAT32, {1, 2, 2, 1}}, + /*output=*/{TensorType_FLOAT32, {}}, Padding_VALID, + /*stride_w=*/2, /*stride_h=*/2, GetTestType(), /* version */ 3); + model.SetInput({1, 2, 3, 4}); + model.SetBias({3, 4}); + model.Invoke(); + + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({4, 6, 6, 8, 10, 14, 9, 12, 13, 16, 10, 12, 12, + 14, 28, 32, 21, 24, 25, 28, 19, 24, 27, 32, 65, 76, + 45, 52, 57, 64, 24, 28, 30, 34, 64, 72, 39, 44, 47, + 52, 42, 46, 48, 52, 106, 114, 63, 68, 71, 76})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2})); +} + +class QuantizedTransposeConvBiasOpModel + : public BaseTransposeConvBiasOpModel { + public: + using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST_P(TransposeConvOpTest, SimpleBiasTestQuantized) { + // TODO(b/138722124): Enable these tests on NNAPI. + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + // Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9} + std::initializer_list filter_data = {129, 131, 133, 135, 137, + 139, 141, 143, 145}; + QuantizedTransposeConvBiasOpModel model( + GetRegistration(), {1, 4, 4, 1}, + {TensorType_UINT8, {1, 3, 3, 1}, -63.5, 64}, filter_data, + {TensorType_UINT8, {1, 4, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -508, 512}, Padding_SAME, 1, 1, GetTestType(), + /* version */ 3); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + model.SetBias({1}); + model.Invoke(); + + EXPECT_THAT( + model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({32, 64, 84, 76, 100, 192, 240, 200, 208, + 372, 420, 332, 264, 448, 488, 368}, + 1e-5))); + + // GetOutputShape() should always be same as model.SetOutputShape(...); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +class PerChannelQuantizedTransposeConvBiasOpModel + : public BaseTransposeConvBiasOpModel { + public: + using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } + + void SetInput(const std::initializer_list& data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(const std::initializer_list& data) { + PerChannelSymmetricQuantizeAndPopulate(filter_, data); + } +}; + +TEST_P(TransposeConvOpTest, SimpleBiasTestQuantizedPerChannelSingleChannel) { + // TODO(b/138722124): Enable these tests on NNAPI. + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + + const std::initializer_list filter_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + PerChannelQuantizedTransposeConvBiasOpModel model( + GetRegistration(), {1, 4, 4, 1}, + {TensorType_INT8, {1, 3, 3, 1}, 0, 0, 0, 0, true, {9.0 / 127}, {0}, 0}, + {}, {TensorType_INT8, {1, 4, 4, 1}, 0, 0, 16.0 / 255, -128}, + {TensorType_INT8, {}, 0, 0, 2, -128}, Padding_SAME, 1, 1, GetTestType(), + /* version */ 3); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + model.SetFilter(filter_data); + model.SetBias({1}); + model.Invoke(); + + EXPECT_THAT( + model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({30, 62, 84, 76, 100, 194, 238, 200, 208, + 372, 418, 330, 264, 446, 486, 366}, + 1e-5))); + + // GetOutputShape() should always be same as model.SetOutputShape(...); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + INSTANTIATE_TEST_SUITE_P( TransposeConvOpTest, TransposeConvOpTest, ::testing::Combine( diff --git a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc index 62a4b52bbb8..fcad8bc0086 100644 --- a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc +++ b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { @@ -30,7 +30,8 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) { const string& weights_name = op.inputs[1]; const auto& weights_shape = model.GetArray(weights_name).shape(); if (op.type == OperatorType::kConv || - op.type == OperatorType::kFullyConnected) { + op.type == OperatorType::kFullyConnected || + op.type == OperatorType::kTransposeConv) { return weights_shape.dims(0); } if (op.type == OperatorType::kDepthwiseConv) { @@ -40,8 +41,19 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) { return 0; } +bool CheckOpInputSize(const Operator& op) { + if (op.type == OperatorType::kConv || + op.type == OperatorType::kFullyConnected || + op.type == OperatorType::kDepthwiseConv) { + return (op.inputs.size() >= 3); + } else if (op.type == OperatorType::kTransposeConv) { + return (op.inputs.size() >= 4); + } + return true; +} + bool ProcessLinearOperator(Model* model, Operator* op) { - if (op->inputs.size() >= 3) { + if (CheckOpInputSize(*op)) { return false; } const string& output_name = op->outputs[0]; @@ -52,7 +64,6 @@ bool ProcessLinearOperator(Model* model, Operator* op) { const int depth = GetOutputDepthFromWeights(*model, *op); const string& bias_name = AvailableArrayName(*model, output_name + "_bias"); op->inputs.push_back(bias_name); - DCHECK_EQ(op->inputs.size(), 3); auto& bias_array = model->GetOrCreateArray(bias_name); bias_array.data_type = ArrayDataType::kFloat; bias_array.mutable_shape()->mutable_dims()->push_back(depth); @@ -68,7 +79,8 @@ bool ProcessLinearOperator(Model* model, Operator* op) { auto* op = model->operators[op_index].get(); if (op->type == OperatorType::kConv || op->type == OperatorType::kDepthwiseConv || - op->type == OperatorType::kFullyConnected) { + op->type == OperatorType::kFullyConnected || + op->type == OperatorType::kTransposeConv) { if (ProcessLinearOperator(model, op)) { AddMessageF("Added bias vector to %s as %s", LogName(*op), op->inputs[2]); *modified = true; diff --git a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index 330ce1bdf49..05a2fecf31d 100644 --- a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -17,16 +17,28 @@ limitations under the License. #include #include +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { namespace { +int GetBiasIndex(const Operator& op) { + if (op.type == OperatorType::kConv || + op.type == OperatorType::kFullyConnected || + op.type == OperatorType::kDepthwiseConv) { + return 2; + } else if (op.type == OperatorType::kTransposeConv) { + return 3; + } + LOG(FATAL) << "Unhandled operator type"; + return 0; +} + void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, const Operator* add_or_sub_op, int index_of_constant_input) { @@ -36,7 +48,8 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, if (preceding_op->inputs.size() < 3) { LOG(FATAL) << "Missing bias parameter"; } - auto& bias = model->GetArray(preceding_op->inputs[2]); + const auto bias_ind = GetBiasIndex(*preceding_op); + auto& bias = model->GetArray(preceding_op->inputs[bias_ind]); bias.minmax = nullptr; const auto& operand = model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); @@ -101,7 +114,8 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, LOG(FATAL) << "Missing bias parameter"; } const auto& weights_name = preceding_op->inputs[1]; - const auto& bias_name = preceding_op->inputs[2]; + const auto bias_ind = GetBiasIndex(*preceding_op); + const auto& bias_name = preceding_op->inputs[bias_ind]; auto& weights = model->GetArray(weights_name); DropMinMax(model, weights_name); auto& bias = model->GetArray(bias_name); @@ -136,7 +150,8 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, int output_depth; if (preceding_op->type == OperatorType::kConv || - preceding_op->type == OperatorType::kFullyConnected) { + preceding_op->type == OperatorType::kFullyConnected || + preceding_op->type == OperatorType::kTransposeConv) { output_depth = weights_shape.dims(0); } else if (preceding_op->type == OperatorType::kDepthwiseConv) { output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1); @@ -253,7 +268,8 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, if (preceding_op->type != OperatorType::kConv && preceding_op->type != OperatorType::kFullyConnected && - preceding_op->type != OperatorType::kDepthwiseConv) { + preceding_op->type != OperatorType::kDepthwiseConv && + preceding_op->type != OperatorType::kTransposeConv) { AddMessageF( "Not fusing %s because the preceding %s is not of one of the supported " "types", @@ -261,6 +277,13 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, return ::tensorflow::Status::OK(); } + if (preceding_op->type == OperatorType::kTransposeConv && + binary_op->type != OperatorType::kAdd) { + AddMessageF("Not fusing %s to preceding %s", LogName(*binary_op), + LogName(*preceding_op)); + return ::tensorflow::Status::OK(); + } + if (preceding_op->fused_activation_function != FusedActivationFunctionType::kNone) { AddMessageF( @@ -278,7 +301,8 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, } const auto& weights_name = preceding_op->inputs[1]; - const auto& bias_name = preceding_op->inputs[2]; + const auto bias_ind = GetBiasIndex(*preceding_op); + const auto& bias_name = preceding_op->inputs[bias_ind]; const auto& weights = model->GetArray(weights_name); const auto& bias = model->GetArray(bias_name); diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc index 421bff60a43..e6fd88c9787 100644 --- a/tensorflow/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/lite/toco/graph_transformations/quantize.cc @@ -244,6 +244,13 @@ bool ChooseQuantizationForOperatorInput( weights_input_index = 1; } } + if (op.type == OperatorType::kTransposeConv) { + if (input_index == 3) { + is_bias_vector = true; + activations_input_index = 2; + weights_input_index = 1; + } + } if (op.type == OperatorType::kLstmCell) { if (input_index == LstmCellOperator::BIASES_INPUT) { is_bias_vector = true; diff --git a/tensorflow/lite/toco/graph_transformations/tests/BUILD b/tensorflow/lite/toco/graph_transformations/tests/BUILD index 0b7b9d6471a..d83e97e1571 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/lite/toco/graph_transformations/tests/BUILD @@ -99,3 +99,15 @@ tf_cc_test( "@com_google_googletest//:gtest_main", ], ) + +tf_cc_test( + name = "fuse_binary_into_preceding_affine_test", + srcs = ["fuse_binary_into_preceding_affine_test.cc"], + deps = [ + "//tensorflow/lite/toco:graph_transformations", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc new file mode 100644 index 00000000000..b5c321c1a26 --- /dev/null +++ b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include +#include +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" + +namespace toco { + +namespace { +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector> ArrayFloatNear( + const std::vector& values, float max_abs_error = 1e-5) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(testing::FloatNear(v, max_abs_error)); + } + return matchers; +} +} // namespace + +class FuseBinaryIntoPrecedingAffineTest : public ::testing::Test { + protected: + FuseBinaryIntoPrecedingAffineTest() {} + + void SetUp() override { model_.reset(new Model); } + + void CreateArray(const string& name, const std::vector& shape) { + Array& array = model_->GetOrCreateArray(name); + array.data_type = ArrayDataType::kFloat; + Shape* array_shape = array.mutable_shape(); + *(array_shape->mutable_dims()) = shape; + } + + void CreateConstantArray(const string& name, const std::vector& shape, + const std::vector& data) { + CreateArray(name, shape); + Array& array = model_->GetOrCreateArray(name); + auto& array_buffer = array.GetMutableBuffer(); + int bufsize = 1; + for (int dim : shape) { + bufsize *= dim; + } + array_buffer.data.resize(bufsize); + float* buf_ptr = array_buffer.data.data(); + for (int i = 0; i < bufsize; ++i) { + buf_ptr[i] = data[i]; + } + } + + std::unique_ptr model_; +}; + +TEST_F(FuseBinaryIntoPrecedingAffineTest, FuseAddIntoTransposeConv) { + // Creating a model. + { + CreateConstantArray(/*name=*/"OutputShape", + /*shape=*/{1, 2}, /*data=*/{2, 2}); + CreateConstantArray("TransConvWeight", {2, 2}, {1.0, 2.0, 3.0, 4.0}); + CreateConstantArray("TransConvBias", {1}, {1.0}); + CreateArray(/*name=*/"TransConvInput", + /*shape=*/{2, 2}); + CreateArray("TransConvOutput", {2, 2}); + CreateConstantArray("AddInput2", {1}, {2.0}); + CreateArray("AddOutput", {2, 2}); + + auto* tc_op = new TransposeConvOperator; + tc_op->inputs = {"OutputShape", "TransConvWeight", "TransConvInput", + "TransConvBias"}; + tc_op->outputs = {"TransConvOutput"}; + model_->operators.push_back(std::unique_ptr(tc_op)); + + auto* add_op = new AddOperator; + add_op->inputs = {"TransConvOutput", "AddInput2"}; + add_op->outputs = {"AddOutput"}; + model_->operators.push_back(std::unique_ptr(add_op)); + } + toco::FuseBinaryIntoPrecedingAffine transformation; + bool modified; + ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok()); + EXPECT_TRUE(modified); + + // `Add` should be fused into `TransposeConv`. Only 1 op is left. + ASSERT_EQ(model_->operators.size(), 1); + const auto& op = model_->operators[0]; + ASSERT_EQ(op->type, OperatorType::kTransposeConv); + ASSERT_EQ(op->inputs.size(), 4); + + auto& weights_array = model_->GetArray(op->inputs[1]); + EXPECT_THAT(weights_array.GetBuffer().data, + ElementsAreArray(ArrayFloatNear({1.0, 2.0, 3.0, 4.0}))); + + auto& bias_array = model_->GetArray(op->inputs[3]); + EXPECT_THAT(bias_array.GetBuffer().data, + ElementsAreArray(ArrayFloatNear({3.0}))); +} +} // namespace toco diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 11a400318d1..7207496e6fc 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -1200,6 +1200,8 @@ struct SqueezeOperator : Operator { // inputs[0]: required: the output shape // inputs[1]: required: the weights // inputs[2]: required: the input activations array +// inputs[3]: optional: the bias vector, specifying the biases for each output +// channel. // NOTE: The input activations is NOT the first input. // // @@ -1212,6 +1214,7 @@ struct TransposeConvOperator : Operator { OUTPUT_SHAPE = 0, WEIGHTS = 1, DATA_INPUT = 2, + BIAS = 3, }; TransposeConvOperator() : Operator(OperatorType::kTransposeConv) {} diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index 9a2842a6046..1b259b796b2 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -148,6 +148,8 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kArgMin, 1}, "1.9.0"}, {{OperatorType::kArgMin, 2}, "1.14.0"}, {{OperatorType::kTransposeConv, 1}, "1.9.0"}, + {{OperatorType::kTransposeConv, 2}, kPendingReleaseOpVersion}, + {{OperatorType::kTransposeConv, 3}, kPendingReleaseOpVersion}, {{OperatorType::kSparseToDense, 1}, "1.9.0"}, {{OperatorType::kSparseToDense, 2}, "1.14.0"}, {{OperatorType::kSparseToDense, 3}, "1.15.0"}, diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index 6a32858e357..71fdad87bd2 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -130,9 +130,10 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, tensor_property.per_axis = true; tensor_property.per_axis_index = 0; tensor_property.symmetric = true; - property.inputs = {{1, tensor_property}, {2, {}}}; + property.inputs = {{2, {}}, {1, tensor_property}}; property.outputs = {{0, {}}}; - property.version = 2; + property.biases = {3}; + property.version = 3; break; } case BuiltinOperator_DEPTHWISE_CONV_2D: { diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 622cb134198..1107f042507 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -198,6 +198,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_TRANSPOSE_CONV: + // If the op has 4 inputs, it is version 3. + if (op_sig.input_types.size() == 4) { + return 3; + } // If the op takes int8 input, it is version 2. if (op_sig.input_types.at(0) == TensorType_INT8) { return 2; diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index 2a48ddd6714..22417c79a63 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -432,6 +432,13 @@ TEST(OpVersionTest, VersioningTransposeConvOperatorTest) { .input_types = std::vector{TensorType_INT8}, }; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE_CONV, + .input_types = std::vector{TensorType_INT32, TensorType_INT8, + TensorType_INT8, TensorType_INT32}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); } TEST(OpVersionTest, VersioningSVDFOperatorTest) {