Experimental Conv2D implementation specialized for deep convolutions (i.e. large in_depth * out_depth, see cost model in deep_conv2d.cc for details).

Can be enabled/disabled with environment variable (disabled by default).
Currently only supports 3x3 filter transforms, but more transforms (and performance work) to come.

// CPU 1-thread
Benchmark                          Base (ns)  New (ns) Improvement
------------------------------------------------------------------
BM_ConvFloatFwdCPU1_conv04          130679850  90340598    +30.9%
BM_ConvFloatFwdCPU1_conv13          133216209  97856032    +26.5%
BM_ConvFloatFwdCPU1_conv23          153311103  112391095    +26.7%
BM_ConvFloatFwdCPU1_conv28          111367333  84726090    +23.9%
BM_ConvFloatFwdCPU1_conv38          61715027  50013699    +19.0%
BM_ConvFloatFwdCPU1_conv43          100892227  92104830     +8.7%
BM_ConvFloatFwdCPU1_conv48          156439547  130304041    +16.7%
BM_ConvFloatFwdCPU1_conv52          769133078  647485239    +15.8%

// CPU 16-threads
Benchmark                          Base (ns)  New (ns) Improvement
------------------------------------------------------------------
BM_ConvFloatFwdCPU16_conv04         219668015  175358604    +20.2%
BM_ConvFloatFwdCPU16_conv13         199494286  129051236    +35.3%
BM_ConvFloatFwdCPU16_conv23         227566760  145085849    +36.2%
BM_ConvFloatFwdCPU16_conv28         162253350  107485116    +33.8%
BM_ConvFloatFwdCPU16_conv38         85966024  61742460    +28.2%
BM_ConvFloatFwdCPU16_conv43         152369560  108266340    +28.9%
BM_ConvFloatFwdCPU16_conv48         235960229  158102948    +33.0%
BM_ConvFloatFwdCPU16_conv52         1082339398  810291519    +25.1%
Change: 131850132
This commit is contained in:
A. Unique TensorFlower 2016-08-31 10:09:16 -08:00 committed by TensorFlower Gardener
parent 5e10944c63
commit 529631603a
7 changed files with 1936 additions and 0 deletions

View File

@ -481,6 +481,16 @@ tf_cc_test(
],
)
tf_cc_test(
name = "deep_conv2d_test",
size = "small",
deps = [
":conv_ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "example_parsing_ops_test",
size = "large",
@ -1310,9 +1320,12 @@ tf_kernel_library(
srcs = [
"conv_grad_ops.cc",
"conv_grad_ops_3d.cc",
"deep_conv2d.cc",
],
hdrs = [
"conv_grad_ops.h",
"deep_conv2d.h",
"winograd_transform.h",
],
prefix = "conv_ops",
deps = [
@ -2002,7 +2015,10 @@ filegroup(
"cwise_op_squared_difference.cc",
"cwise_op_sub.cc",
"cwise_op_tanh.cc",
"deep_conv2d.cc",
"deep_conv2d.h",
"dynamic_partition_op.cc",
"winograd_transform.h",
":android_extended_ops_headers",
],
)

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/deep_conv2d.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@ -104,6 +105,58 @@ class LaunchConv2DOp<CPUDevice, T> {
}
};
template <typename Device, typename T>
class LaunchDeepConvOp {
public:
static bool Run(OpKernelContext* ctx, const Tensor& input,
const Tensor& filter, int batch, int input_rows,
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
int out_cols, int out_depth, int stride_rows, int stride_cols,
Tensor* output, TensorFormat data_format) {
return false;
}
};
// Conditionally launches DeepConv operation based on convolution parameters.
template <>
class LaunchDeepConvOp<CPUDevice, float> {
public:
static bool Run(OpKernelContext* ctx, const Tensor& input,
const Tensor& filter, int batch, int input_rows,
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
int out_cols, int out_depth, int stride_rows, int stride_cols,
Tensor* output, TensorFormat data_format) {
if (data_format != FORMAT_NHWC ||
!CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
in_depth, out_depth, out_rows, out_cols)) {
return false;
}
Conv2DArgs args;
args.batch = batch;
args.in_rows = input_rows;
args.in_cols = input_cols;
args.in_depth = in_depth;
args.filter_rows = filter_rows;
args.filter_cols = filter_cols;
args.pad_rows = pad_rows;
args.pad_cols = pad_cols;
args.out_rows = out_rows;
args.out_cols = out_cols;
args.out_depth = out_depth;
auto input_ptr = input.template flat<float>().data();
auto filter_ptr = filter.template flat<float>().data();
auto output_ptr = output->template flat<float>().data();
functor::DeepConv2D<CPUDevice, float>()(ctx, args, input_ptr, filter_ptr,
output_ptr);
return true;
}
};
template <typename Device, typename T>
class Conv2DOp : public BinaryOp<T> {
public:
@ -221,6 +274,14 @@ class Conv2DOp : public BinaryOp<T> {
if (out_shape.num_elements() == 0) {
return;
}
if (LaunchDeepConvOp<Device, T>::Run(
context, input, filter, batch, input_rows, input_cols, in_depth,
filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
out_depth, stride_rows, stride_cols, output, data_format_)) {
return;
}
launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter,
stride_rows, stride_cols,
BrainPadding2EigenPadding(padding_), output, data_format_);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,117 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
class OpKernelContext;
// DeepConv2D is a Conv2D implementation specialzied for deep (i.e. large
// in_depth * out_depth product) convolutions (see deep_conv2d.cc for details).
// DeepConv2DTransform is an interface for implementing transforms for
// DeepConv2D. Implementations must specify transform matrices and
// input/output/filter shapes. DeepConv2d computes:
//
// y = C[Ad * Bg]
//
// C: output transform matrix
// A: input data transform matrix
// B: filter transform matrix
// d: vectorized 2D data tile
// g: vectorized 2D filter tile
// y: vectorized 2D output tile
template <typename T>
class DeepConv2DTransform {
public:
virtual ~DeepConv2DTransform() {}
virtual void GetFilterTransformMatrix(const int64 rows, const int64 cols,
T* transform_matrix) const = 0;
virtual void GetInputTransformMatrix(const int64 rows, const int64 cols,
T* transform_matrix) const = 0;
virtual void GetOutputTransformMatrix(const int64 rows, const int64 cols,
T* transform_matrix) const = 0;
struct Shape {
Shape(int64 r, int64 c) : rows(r), cols(c) {}
int64 rows;
int64 cols;
};
virtual const Shape& filter_shape() const = 0;
virtual const Shape& input_shape() const = 0;
virtual const Shape& output_shape() const = 0;
};
// Conv2D arguments used by DeepConv2D implementation.
struct Conv2DArgs {
// Input layer dimensions
int batch;
int in_rows;
int in_cols;
int in_depth;
int filter_rows;
int filter_cols;
int pad_rows;
int pad_cols;
// Output layer dimensions
int out_rows;
int out_cols;
int out_depth;
Conv2DArgs()
: batch(0),
in_rows(0),
in_cols(0),
in_depth(0),
filter_rows(0),
filter_cols(0),
pad_rows(0),
pad_cols(0),
out_rows(0),
out_cols(0),
out_depth(0) {}
};
// Returns true if convolution operation specified by function arguments
// can use DeepConv2D implementation, and false otherwise.
// May return false based on parameters, cost, or whether feature is disabled.
bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows,
int filter_cols, int in_depth, int out_depth,
int out_rows, int out_cols);
namespace functor {
// Calls DeepConv2D implementation (see deep_conv2d.cc for details).
template <typename Device, typename T>
struct DeepConv2D {
void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input,
const T* filter, T* output);
};
} // namespace functor
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_

View File

@ -0,0 +1,159 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/winograd_transform.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
static void ComputeKroneckerProduct(const int rows, const int cols,
const float* matrix, float* matrix_out) {
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
const float v = matrix[i * cols + j];
const int output_index_base = cols * (i * rows * cols + j);
for (int k = 0; k < rows; ++k) {
for (int l = 0; l < cols; ++l) {
const int input_index = k * cols + l;
const int output_index = k * cols * cols + l;
matrix_out[output_index_base + output_index] =
matrix[input_index] * v;
}
}
}
}
}
TEST(DeepConv2DTransformTest, Basic) {
// Tests kronecker product of the following matrix with itself:
//
// [1.0 2.0]
// [3.0 4.0]
//
const int rows = 2;
const int cols = 2;
float transform_matrix[] = {1, 2, 3, 4};
const int kron_rows = rows * rows;
const int kron_cols = cols * cols;
float transform_matrix_kron[kron_rows * kron_cols];
ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
&transform_matrix_kron[0]);
float transform_matrix_test[] = {1, 2, 2, 4, 3, 4, 6, 8,
3, 6, 4, 8, 9, 12, 12, 16};
for (int i = 0; i < kron_rows * kron_cols; ++i) {
EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
}
}
TEST(DeepConv2DTransformTest, WingradFilterTransformMatrix) {
// Test that the filter transform matrix returned is the kronecker product of
// the following matrix with itself:
//
// [ 1 0 0 ]
// [ 1/2 1/2 1/2 ]
// [ 1/2 -1/2 1/2 ]
// [ 0 0 1 ]
//
const int rows = 4;
const int cols = 3;
float transform_matrix[] = {1, 0, 0, 0.5, 0.5, 0.5, 0.5, -0.5, 0.5, 0, 0, 1};
const int kron_rows = rows * rows;
const int kron_cols = cols * cols;
float transform_matrix_kron[kron_rows * kron_cols];
ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
&transform_matrix_kron[0]);
float transform_matrix_test[kron_rows * kron_cols];
WinogradTransform<float> t;
t.GetFilterTransformMatrix(kron_rows, kron_cols, &transform_matrix_test[0]);
for (int i = 0; i < kron_rows * kron_cols; ++i) {
EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
}
}
TEST(DeepConv2DTransformTest, WingradInputTransformMatrix) {
// Test that the filter transform matrix returned is the kronecker product of
// the following matrix:
//
// [1 0 -1 0]
// [0 1 1 0]
// [0 -1 1 0]
// [0 1 0 -1]
//
const int rows = 4;
const int cols = 4;
float transform_matrix[] = {1, 0, -1, 0, 0, 1, 1, 0,
0, -1, 1, 0, 0, 1, 0, -1};
const int kron_rows = rows * rows;
const int kron_cols = cols * cols;
float transform_matrix_kron[kron_rows * kron_cols];
ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
&transform_matrix_kron[0]);
float transform_matrix_test[kron_rows * kron_cols];
WinogradTransform<float> t;
t.GetInputTransformMatrix(kron_rows, kron_cols, &transform_matrix_test[0]);
for (int i = 0; i < kron_rows * kron_cols; ++i) {
EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
}
}
TEST(DeepConv2DTransformTest, WingradOutputTransformMatrix) {
// Test that the filter transform matrix returned is the kronecker product of
// the following matrix:
//
// [1 1 1 0]
// [0 1 -1 -1]
//
const int rows = 2;
const int cols = 4;
float transform_matrix[] = {1, 1, 1, 0, 0, 1, -1, -1};
const int kron_rows = rows * rows;
const int kron_cols = cols * cols;
float transform_matrix_kron[kron_rows * kron_cols];
ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
&transform_matrix_kron[0]);
float transform_matrix_test[kron_rows * kron_cols];
WinogradTransform<float> t;
t.GetOutputTransformMatrix(kron_rows, kron_cols, &transform_matrix_test[0]);
for (int i = 0; i < kron_rows * kron_cols; ++i) {
EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
}
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,377 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
#include "tensorflow/core/kernels/deep_conv2d.h"
namespace tensorflow {
// Winograd DeepConv2DTransform implementation for 3x3 filters.
// Details:
// *) Arithmetic complexity of computations: Shmuel Winograd
// *) Fast Algorithms for Convolutional Neural Networks: Lavin, Gray
template <typename T>
class WinogradTransform : public DeepConv2DTransform<T> {
public:
typedef typename DeepConv2DTransform<T>::Shape Shape;
WinogradTransform()
: filter_shape_(3, 3), input_shape_(4, 4), output_shape_(2, 2) {}
virtual void GetFilterTransformMatrix(const int64 rows, const int64 cols,
T* transform_matrix) const;
virtual void GetInputTransformMatrix(const int64 rows, const int64 cols,
T* transform_matrix) const;
virtual void GetOutputTransformMatrix(const int64 rows, const int64 cols,
T* transform_matrix) const;
virtual const Shape& filter_shape() const { return filter_shape_; }
virtual const Shape& input_shape() const { return input_shape_; }
virtual const Shape& output_shape() const { return output_shape_; }
private:
const Shape filter_shape_;
const Shape input_shape_;
const Shape output_shape_;
};
// The filter transform matrix is the kronecker product 'M * M' of the
// following matrix 'M':
//
// [ 1 0 0 ]
// [ 1/2 1/2 1/2 ]
// [ 1/2 -1/2 1/2 ]
// [ 0 0 1 ]
//
// The data layout of 'transform_matrix':
// [input_tile_spatial_size, filter_spatial_size]
//
template <typename T>
void WinogradTransform<T>::GetFilterTransformMatrix(const int64 rows,
const int64 cols,
T* transform_matrix) const {
CHECK_GT(rows, 0);
CHECK_GT(cols, 0);
memset(transform_matrix, 0, sizeof(T) * rows * cols);
// Sub matrix [0,0]
transform_matrix[0 * cols + 0] = T(1.0);
transform_matrix[1 * cols + 0] = T(0.5);
transform_matrix[1 * cols + 1] = T(0.5);
transform_matrix[1 * cols + 2] = T(0.5);
transform_matrix[2 * cols + 0] = T(0.5);
transform_matrix[2 * cols + 1] = T(-0.5);
transform_matrix[2 * cols + 2] = T(0.5);
transform_matrix[3 * cols + 2] = T(1.0);
// Sub matrix [1,0]
transform_matrix[4 * cols + 0] = T(0.5);
transform_matrix[5 * cols + 0] = T(0.25);
transform_matrix[5 * cols + 1] = T(0.25);
transform_matrix[5 * cols + 2] = T(0.25);
transform_matrix[6 * cols + 0] = T(0.25);
transform_matrix[6 * cols + 1] = T(-0.25);
transform_matrix[6 * cols + 2] = T(0.25);
transform_matrix[7 * cols + 2] = T(0.5);
// Sub matrix [1,1]
transform_matrix[4 * cols + 3] = T(0.5);
transform_matrix[5 * cols + 3] = T(0.25);
transform_matrix[5 * cols + 4] = T(0.25);
transform_matrix[5 * cols + 5] = T(0.25);
transform_matrix[6 * cols + 3] = T(0.25);
transform_matrix[6 * cols + 4] = T(-0.25);
transform_matrix[6 * cols + 5] = T(0.25);
transform_matrix[7 * cols + 5] = T(0.5);
// Sub matrix [1,2]
transform_matrix[4 * cols + 6] = T(0.5);
transform_matrix[5 * cols + 6] = T(0.25);
transform_matrix[5 * cols + 7] = T(0.25);
transform_matrix[5 * cols + 8] = T(0.25);
transform_matrix[6 * cols + 6] = T(0.25);
transform_matrix[6 * cols + 7] = T(-0.25);
transform_matrix[6 * cols + 8] = T(0.25);
transform_matrix[7 * cols + 8] = T(0.5);
// Sub matrix [2,0]
transform_matrix[8 * cols + 0] = T(0.5);
transform_matrix[9 * cols + 0] = T(0.25);
transform_matrix[9 * cols + 1] = T(0.25);
transform_matrix[9 * cols + 2] = T(0.25);
transform_matrix[10 * cols + 0] = T(0.25);
transform_matrix[10 * cols + 1] = T(-0.25);
transform_matrix[10 * cols + 2] = T(0.25);
transform_matrix[11 * cols + 2] = T(0.5);
// Sub matrix [2,1]
transform_matrix[8 * cols + 3] = T(-0.5);
transform_matrix[9 * cols + 3] = T(-0.25);
transform_matrix[9 * cols + 4] = T(-0.25);
transform_matrix[9 * cols + 5] = T(-0.25);
transform_matrix[10 * cols + 3] = T(-0.25);
transform_matrix[10 * cols + 4] = T(0.25);
transform_matrix[10 * cols + 5] = T(-0.25);
transform_matrix[11 * cols + 5] = T(-0.5);
// Sub matrix [2,2]
transform_matrix[8 * cols + 6] = T(0.5);
transform_matrix[9 * cols + 6] = T(0.25);
transform_matrix[9 * cols + 7] = T(0.25);
transform_matrix[9 * cols + 8] = T(0.25);
transform_matrix[10 * cols + 6] = T(0.25);
transform_matrix[10 * cols + 7] = T(-0.25);
transform_matrix[10 * cols + 8] = T(0.25);
transform_matrix[11 * cols + 8] = T(0.5);
// Sub matrix [3,2]
transform_matrix[12 * cols + 6] = T(1.0);
transform_matrix[13 * cols + 6] = T(0.5);
transform_matrix[13 * cols + 7] = T(0.5);
transform_matrix[13 * cols + 8] = T(0.5);
transform_matrix[14 * cols + 6] = T(0.5);
transform_matrix[14 * cols + 7] = T(-0.5);
transform_matrix[14 * cols + 8] = T(0.5);
transform_matrix[15 * cols + 8] = T(1.0);
}
// The input transform matrix is the kronecker product 'M * M' of the
// following matrix 'M':
//
// [1 0 -1 0]
// [0 1 1 0]
// [0 -1 1 0]
// [0 1 0 -1]
//
// Data layout of 'transform_matrix':
// [tile_spatial_size, tile_spatial_size]
//
template <typename T>
void WinogradTransform<T>::GetInputTransformMatrix(const int64 rows,
const int64 cols,
T* transform_matrix) const {
CHECK_GT(rows, 0);
CHECK_GT(cols, 0);
memset(transform_matrix, 0, sizeof(T) * rows * cols);
// Sub matrix [0,0]
transform_matrix[0 * cols + 0] = T(1.0);
transform_matrix[0 * cols + 2] = T(-1.0);
transform_matrix[1 * cols + 1] = T(1.0);
transform_matrix[1 * cols + 2] = T(1.0);
transform_matrix[2 * cols + 1] = T(-1.0);
transform_matrix[2 * cols + 2] = T(1.0);
transform_matrix[3 * cols + 1] = T(1.0);
transform_matrix[3 * cols + 3] = T(-1.0);
// Sub matrix [0,2]
transform_matrix[0 * cols + 8] = T(-1.0);
transform_matrix[0 * cols + 10] = T(1.0);
transform_matrix[1 * cols + 9] = T(-1.0);
transform_matrix[1 * cols + 10] = T(-1.0);
transform_matrix[2 * cols + 9] = T(1.0);
transform_matrix[2 * cols + 10] = T(-1.0);
transform_matrix[3 * cols + 9] = T(-1.0);
transform_matrix[3 * cols + 11] = T(1.0);
// Sub matrix [1,1]
transform_matrix[4 * cols + 4] = T(1.0);
transform_matrix[4 * cols + 6] = T(-1.0);
transform_matrix[5 * cols + 5] = T(1.0);
transform_matrix[5 * cols + 6] = T(1.0);
transform_matrix[6 * cols + 5] = T(-1.0);
transform_matrix[6 * cols + 6] = T(1.0);
transform_matrix[7 * cols + 5] = T(1.0);
transform_matrix[7 * cols + 7] = T(-1.0);
// Sub matrix [1,2]
transform_matrix[4 * cols + 8] = T(1.0);
transform_matrix[4 * cols + 10] = T(-1.0);
transform_matrix[5 * cols + 9] = T(1.0);
transform_matrix[5 * cols + 10] = T(1.0);
transform_matrix[6 * cols + 9] = T(-1.0);
transform_matrix[6 * cols + 10] = T(1.0);
transform_matrix[7 * cols + 9] = T(1.0);
transform_matrix[7 * cols + 11] = T(-1.0);
// Sub matrix [2,1]
transform_matrix[8 * cols + 4] = T(-1.0);
transform_matrix[8 * cols + 6] = T(1.0);
transform_matrix[9 * cols + 5] = T(-1.0);
transform_matrix[9 * cols + 6] = T(-1.0);
transform_matrix[10 * cols + 5] = T(1.0);
transform_matrix[10 * cols + 6] = T(-1.0);
transform_matrix[11 * cols + 5] = T(-1.0);
transform_matrix[11 * cols + 7] = T(1.0);
// Sub matrix [2,2]
transform_matrix[8 * cols + 8] = T(1.0);
transform_matrix[8 * cols + 10] = T(-1.0);
transform_matrix[9 * cols + 9] = T(1.0);
transform_matrix[9 * cols + 10] = T(1.0);
transform_matrix[10 * cols + 9] = T(-1.0);
transform_matrix[10 * cols + 10] = T(1.0);
transform_matrix[11 * cols + 9] = T(1.0);
transform_matrix[11 * cols + 11] = T(-1.0);
// Sub matrix [3,1]
transform_matrix[12 * cols + 4] = T(1.0);
transform_matrix[12 * cols + 6] = T(-1.0);
transform_matrix[13 * cols + 5] = T(1.0);
transform_matrix[13 * cols + 6] = T(1.0);
transform_matrix[14 * cols + 5] = T(-1.0);
transform_matrix[14 * cols + 6] = T(1.0);
transform_matrix[15 * cols + 5] = T(1.0);
transform_matrix[15 * cols + 7] = T(-1.0);
// Sub matrix [3,3]
transform_matrix[12 * cols + 12] = T(-1.0);
transform_matrix[12 * cols + 14] = T(1.0);
transform_matrix[13 * cols + 13] = T(-1.0);
transform_matrix[13 * cols + 14] = T(-1.0);
transform_matrix[14 * cols + 13] = T(1.0);
transform_matrix[14 * cols + 14] = T(-1.0);
transform_matrix[15 * cols + 13] = T(-1.0);
transform_matrix[15 * cols + 15] = T(1.0);
};
// The output transform matrix is the kronecker product 'M * M' of the
// following matrix 'M':
//
// [1 1 1 0]
// [0 1 -1 -1]
//
// Data layout of 'transform_matrix':
// [out_tile_spatial_size, tile_spatial_size]
//
template <typename T>
void WinogradTransform<T>::GetOutputTransformMatrix(const int64 rows,
const int64 cols,
T* transform_matrix) const {
CHECK_GT(rows, 0);
CHECK_GT(cols, 0);
memset(transform_matrix, 0, sizeof(T) * rows * cols);
// Sub matrix [0,0]
transform_matrix[0 * cols + 0] = T(1.0);
transform_matrix[0 * cols + 1] = T(1.0);
transform_matrix[0 * cols + 2] = T(1.0);
transform_matrix[1 * cols + 1] = T(1.0);
transform_matrix[1 * cols + 2] = T(-1.0);
transform_matrix[1 * cols + 3] = T(-1.0);
// Sub matrix [0,1]
transform_matrix[0 * cols + 4] = T(1.0);
transform_matrix[0 * cols + 5] = T(1.0);
transform_matrix[0 * cols + 6] = T(1.0);
transform_matrix[1 * cols + 5] = T(1.0);
transform_matrix[1 * cols + 6] = T(-1.0);
transform_matrix[1 * cols + 7] = T(-1.0);
// Sub matrix [0,2]
transform_matrix[0 * cols + 8] = T(1.0);
transform_matrix[0 * cols + 9] = T(1.0);
transform_matrix[0 * cols + 10] = T(1.0);
transform_matrix[1 * cols + 9] = T(1.0);
transform_matrix[1 * cols + 10] = T(-1.0);
transform_matrix[1 * cols + 11] = T(-1.0);
// Sub matrix [1,1]
transform_matrix[2 * cols + 4] = T(1.0);
transform_matrix[2 * cols + 5] = T(1.0);
transform_matrix[2 * cols + 6] = T(1.0);
transform_matrix[3 * cols + 5] = T(1.0);
transform_matrix[3 * cols + 6] = T(-1.0);
transform_matrix[3 * cols + 7] = T(-1.0);
// Sub matrix [1,2]
transform_matrix[2 * cols + 8] = T(-1.0);
transform_matrix[2 * cols + 9] = T(-1.0);
transform_matrix[2 * cols + 10] = T(-1.0);
transform_matrix[3 * cols + 9] = T(-1.0);
transform_matrix[3 * cols + 10] = T(1.0);
transform_matrix[3 * cols + 11] = T(1.0);
// Sub matrix [1,3]
transform_matrix[2 * cols + 12] = T(-1.0);
transform_matrix[2 * cols + 13] = T(-1.0);
transform_matrix[2 * cols + 14] = T(-1.0);
transform_matrix[3 * cols + 13] = T(-1.0);
transform_matrix[3 * cols + 14] = T(1.0);
transform_matrix[3 * cols + 15] = T(1.0);
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
@ -1148,6 +1150,53 @@ class SeparableConv2DTest(tf.test.TestCase):
expected=None)
class DeepConv2DTest(tf.test.TestCase):
def _CompareFwdConv2D(self, tensor_in_sizes, filter_in_sizes,
conv_strides, padding):
"""Verifies that DeepConv2D and Conv2D produce the same values.
Args:
tensor_in_sizes: Input tensor dimensions in
[batch, input_rows, input_cols, input_depth].
filter_in_sizes: Filter tensor dimensions in
[kernel_rows, kernel_cols, input_depth, output_depth].
conv_strides: [row_stride, col_stride] for the convolution;
padding: Padding type.
"""
x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
with self.test_session(use_gpu=False) as sess:
t1 = tf.constant(x1, shape=tensor_in_sizes)
t2 = tf.constant(x2, shape=filter_in_sizes)
strides = [1] + conv_strides + [1]
conv = tf.nn.conv2d(t1, t2, strides=strides, padding=padding)
os.environ["TF_USE_DEEP_CONV2D"] = "0"
values_expect = sess.run([conv])
os.environ["TF_USE_DEEP_CONV2D"] = "1"
values_test = sess.run([conv])
self.assertAllClose(values_expect, values_test, rtol=1e-5, atol=1e-5)
def _RunTestCases(self, conv_strides, padding):
input_sizes = [[5, 5, 5, 1248], [3, 17, 17, 192], [2, 35, 35, 288],
[2, 6, 8, 517], [2, 7, 4, 81], [3, 11, 3, 77]]
filter_sizes = [[3, 3, 1248, 128], [3, 3, 192, 192], [3, 3, 288, 384],
[3, 3, 517, 64], [3, 3, 81, 77], [3, 3, 77, 181]]
for input_shape, filter_shape in zip(input_sizes, filter_sizes):
self._CompareFwdConv2D(input_shape, filter_shape, conv_strides, padding)
def testConv2D3x3FilterStride1x1Valid(self):
self._RunTestCases([1, 1], "VALID")
def testConv2D3x3FilterStride1x1Same(self):
self._RunTestCases([1, 1], "SAME")
def GetInceptionFwdTest(input_size, filter_size, stride, padding):
def Test(self):
tf.logging.info("Testing InceptionFwd %s", (input_size, filter_size,