Experimental Conv2D implementation specialized for deep convolutions (i.e. large in_depth * out_depth, see cost model in deep_conv2d.cc for details).
Can be enabled/disabled with environment variable (disabled by default). Currently only supports 3x3 filter transforms, but more transforms (and performance work) to come. // CPU 1-thread Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_ConvFloatFwdCPU1_conv04 130679850 90340598 +30.9% BM_ConvFloatFwdCPU1_conv13 133216209 97856032 +26.5% BM_ConvFloatFwdCPU1_conv23 153311103 112391095 +26.7% BM_ConvFloatFwdCPU1_conv28 111367333 84726090 +23.9% BM_ConvFloatFwdCPU1_conv38 61715027 50013699 +19.0% BM_ConvFloatFwdCPU1_conv43 100892227 92104830 +8.7% BM_ConvFloatFwdCPU1_conv48 156439547 130304041 +16.7% BM_ConvFloatFwdCPU1_conv52 769133078 647485239 +15.8% // CPU 16-threads Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_ConvFloatFwdCPU16_conv04 219668015 175358604 +20.2% BM_ConvFloatFwdCPU16_conv13 199494286 129051236 +35.3% BM_ConvFloatFwdCPU16_conv23 227566760 145085849 +36.2% BM_ConvFloatFwdCPU16_conv28 162253350 107485116 +33.8% BM_ConvFloatFwdCPU16_conv38 85966024 61742460 +28.2% BM_ConvFloatFwdCPU16_conv43 152369560 108266340 +28.9% BM_ConvFloatFwdCPU16_conv48 235960229 158102948 +33.0% BM_ConvFloatFwdCPU16_conv52 1082339398 810291519 +25.1% Change: 131850132
This commit is contained in:
parent
5e10944c63
commit
529631603a
tensorflow
core/kernels
python/kernel_tests
@ -481,6 +481,16 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "deep_conv2d_test",
|
||||
size = "small",
|
||||
deps = [
|
||||
":conv_ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "example_parsing_ops_test",
|
||||
size = "large",
|
||||
@ -1310,9 +1320,12 @@ tf_kernel_library(
|
||||
srcs = [
|
||||
"conv_grad_ops.cc",
|
||||
"conv_grad_ops_3d.cc",
|
||||
"deep_conv2d.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"conv_grad_ops.h",
|
||||
"deep_conv2d.h",
|
||||
"winograd_transform.h",
|
||||
],
|
||||
prefix = "conv_ops",
|
||||
deps = [
|
||||
@ -2002,7 +2015,10 @@ filegroup(
|
||||
"cwise_op_squared_difference.cc",
|
||||
"cwise_op_sub.cc",
|
||||
"cwise_op_tanh.cc",
|
||||
"deep_conv2d.cc",
|
||||
"deep_conv2d.h",
|
||||
"dynamic_partition_op.cc",
|
||||
"winograd_transform.h",
|
||||
":android_extended_ops_headers",
|
||||
],
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/conv_2d.h"
|
||||
#include "tensorflow/core/kernels/deep_conv2d.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
@ -104,6 +105,58 @@ class LaunchConv2DOp<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class LaunchDeepConvOp {
|
||||
public:
|
||||
static bool Run(OpKernelContext* ctx, const Tensor& input,
|
||||
const Tensor& filter, int batch, int input_rows,
|
||||
int input_cols, int in_depth, int filter_rows,
|
||||
int filter_cols, int pad_rows, int pad_cols, int out_rows,
|
||||
int out_cols, int out_depth, int stride_rows, int stride_cols,
|
||||
Tensor* output, TensorFormat data_format) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Conditionally launches DeepConv operation based on convolution parameters.
|
||||
template <>
|
||||
class LaunchDeepConvOp<CPUDevice, float> {
|
||||
public:
|
||||
static bool Run(OpKernelContext* ctx, const Tensor& input,
|
||||
const Tensor& filter, int batch, int input_rows,
|
||||
int input_cols, int in_depth, int filter_rows,
|
||||
int filter_cols, int pad_rows, int pad_cols, int out_rows,
|
||||
int out_cols, int out_depth, int stride_rows, int stride_cols,
|
||||
Tensor* output, TensorFormat data_format) {
|
||||
if (data_format != FORMAT_NHWC ||
|
||||
!CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
|
||||
in_depth, out_depth, out_rows, out_cols)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Conv2DArgs args;
|
||||
args.batch = batch;
|
||||
args.in_rows = input_rows;
|
||||
args.in_cols = input_cols;
|
||||
args.in_depth = in_depth;
|
||||
args.filter_rows = filter_rows;
|
||||
args.filter_cols = filter_cols;
|
||||
args.pad_rows = pad_rows;
|
||||
args.pad_cols = pad_cols;
|
||||
args.out_rows = out_rows;
|
||||
args.out_cols = out_cols;
|
||||
args.out_depth = out_depth;
|
||||
|
||||
auto input_ptr = input.template flat<float>().data();
|
||||
auto filter_ptr = filter.template flat<float>().data();
|
||||
auto output_ptr = output->template flat<float>().data();
|
||||
|
||||
functor::DeepConv2D<CPUDevice, float>()(ctx, args, input_ptr, filter_ptr,
|
||||
output_ptr);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class Conv2DOp : public BinaryOp<T> {
|
||||
public:
|
||||
@ -221,6 +274,14 @@ class Conv2DOp : public BinaryOp<T> {
|
||||
if (out_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (LaunchDeepConvOp<Device, T>::Run(
|
||||
context, input, filter, batch, input_rows, input_cols, in_depth,
|
||||
filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
|
||||
out_depth, stride_rows, stride_cols, output, data_format_)) {
|
||||
return;
|
||||
}
|
||||
|
||||
launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter,
|
||||
stride_rows, stride_cols,
|
||||
BrainPadding2EigenPadding(padding_), output, data_format_);
|
||||
|
1157
tensorflow/core/kernels/deep_conv2d.cc
Normal file
1157
tensorflow/core/kernels/deep_conv2d.cc
Normal file
File diff suppressed because it is too large
Load Diff
117
tensorflow/core/kernels/deep_conv2d.h
Normal file
117
tensorflow/core/kernels/deep_conv2d.h
Normal file
@ -0,0 +1,117 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpKernelContext;
|
||||
|
||||
// DeepConv2D is a Conv2D implementation specialzied for deep (i.e. large
|
||||
// in_depth * out_depth product) convolutions (see deep_conv2d.cc for details).
|
||||
|
||||
// DeepConv2DTransform is an interface for implementing transforms for
|
||||
// DeepConv2D. Implementations must specify transform matrices and
|
||||
// input/output/filter shapes. DeepConv2d computes:
|
||||
//
|
||||
// y = C[Ad * Bg]
|
||||
//
|
||||
// C: output transform matrix
|
||||
// A: input data transform matrix
|
||||
// B: filter transform matrix
|
||||
// d: vectorized 2D data tile
|
||||
// g: vectorized 2D filter tile
|
||||
// y: vectorized 2D output tile
|
||||
|
||||
template <typename T>
|
||||
class DeepConv2DTransform {
|
||||
public:
|
||||
virtual ~DeepConv2DTransform() {}
|
||||
|
||||
virtual void GetFilterTransformMatrix(const int64 rows, const int64 cols,
|
||||
T* transform_matrix) const = 0;
|
||||
|
||||
virtual void GetInputTransformMatrix(const int64 rows, const int64 cols,
|
||||
T* transform_matrix) const = 0;
|
||||
|
||||
virtual void GetOutputTransformMatrix(const int64 rows, const int64 cols,
|
||||
T* transform_matrix) const = 0;
|
||||
|
||||
struct Shape {
|
||||
Shape(int64 r, int64 c) : rows(r), cols(c) {}
|
||||
int64 rows;
|
||||
int64 cols;
|
||||
};
|
||||
|
||||
virtual const Shape& filter_shape() const = 0;
|
||||
virtual const Shape& input_shape() const = 0;
|
||||
virtual const Shape& output_shape() const = 0;
|
||||
};
|
||||
|
||||
// Conv2D arguments used by DeepConv2D implementation.
|
||||
struct Conv2DArgs {
|
||||
// Input layer dimensions
|
||||
int batch;
|
||||
int in_rows;
|
||||
int in_cols;
|
||||
int in_depth;
|
||||
int filter_rows;
|
||||
int filter_cols;
|
||||
int pad_rows;
|
||||
int pad_cols;
|
||||
|
||||
// Output layer dimensions
|
||||
int out_rows;
|
||||
int out_cols;
|
||||
int out_depth;
|
||||
|
||||
Conv2DArgs()
|
||||
: batch(0),
|
||||
in_rows(0),
|
||||
in_cols(0),
|
||||
in_depth(0),
|
||||
filter_rows(0),
|
||||
filter_cols(0),
|
||||
pad_rows(0),
|
||||
pad_cols(0),
|
||||
out_rows(0),
|
||||
out_cols(0),
|
||||
out_depth(0) {}
|
||||
};
|
||||
|
||||
// Returns true if convolution operation specified by function arguments
|
||||
// can use DeepConv2D implementation, and false otherwise.
|
||||
// May return false based on parameters, cost, or whether feature is disabled.
|
||||
bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows,
|
||||
int filter_cols, int in_depth, int out_depth,
|
||||
int out_rows, int out_cols);
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Calls DeepConv2D implementation (see deep_conv2d.cc for details).
|
||||
template <typename Device, typename T>
|
||||
struct DeepConv2D {
|
||||
void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input,
|
||||
const T* filter, T* output);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
|
159
tensorflow/core/kernels/deep_conv2d_test.cc
Normal file
159
tensorflow/core/kernels/deep_conv2d_test.cc
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/winograd_transform.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static void ComputeKroneckerProduct(const int rows, const int cols,
|
||||
const float* matrix, float* matrix_out) {
|
||||
for (int i = 0; i < rows; ++i) {
|
||||
for (int j = 0; j < cols; ++j) {
|
||||
const float v = matrix[i * cols + j];
|
||||
const int output_index_base = cols * (i * rows * cols + j);
|
||||
for (int k = 0; k < rows; ++k) {
|
||||
for (int l = 0; l < cols; ++l) {
|
||||
const int input_index = k * cols + l;
|
||||
const int output_index = k * cols * cols + l;
|
||||
matrix_out[output_index_base + output_index] =
|
||||
matrix[input_index] * v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeepConv2DTransformTest, Basic) {
|
||||
// Tests kronecker product of the following matrix with itself:
|
||||
//
|
||||
// [1.0 2.0]
|
||||
// [3.0 4.0]
|
||||
//
|
||||
const int rows = 2;
|
||||
const int cols = 2;
|
||||
|
||||
float transform_matrix[] = {1, 2, 3, 4};
|
||||
|
||||
const int kron_rows = rows * rows;
|
||||
const int kron_cols = cols * cols;
|
||||
float transform_matrix_kron[kron_rows * kron_cols];
|
||||
|
||||
ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
|
||||
&transform_matrix_kron[0]);
|
||||
|
||||
float transform_matrix_test[] = {1, 2, 2, 4, 3, 4, 6, 8,
|
||||
3, 6, 4, 8, 9, 12, 12, 16};
|
||||
|
||||
for (int i = 0; i < kron_rows * kron_cols; ++i) {
|
||||
EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeepConv2DTransformTest, WingradFilterTransformMatrix) {
|
||||
// Test that the filter transform matrix returned is the kronecker product of
|
||||
// the following matrix with itself:
|
||||
//
|
||||
// [ 1 0 0 ]
|
||||
// [ 1/2 1/2 1/2 ]
|
||||
// [ 1/2 -1/2 1/2 ]
|
||||
// [ 0 0 1 ]
|
||||
//
|
||||
const int rows = 4;
|
||||
const int cols = 3;
|
||||
|
||||
float transform_matrix[] = {1, 0, 0, 0.5, 0.5, 0.5, 0.5, -0.5, 0.5, 0, 0, 1};
|
||||
|
||||
const int kron_rows = rows * rows;
|
||||
const int kron_cols = cols * cols;
|
||||
|
||||
float transform_matrix_kron[kron_rows * kron_cols];
|
||||
|
||||
ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
|
||||
&transform_matrix_kron[0]);
|
||||
|
||||
float transform_matrix_test[kron_rows * kron_cols];
|
||||
WinogradTransform<float> t;
|
||||
t.GetFilterTransformMatrix(kron_rows, kron_cols, &transform_matrix_test[0]);
|
||||
|
||||
for (int i = 0; i < kron_rows * kron_cols; ++i) {
|
||||
EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeepConv2DTransformTest, WingradInputTransformMatrix) {
|
||||
// Test that the filter transform matrix returned is the kronecker product of
|
||||
// the following matrix:
|
||||
//
|
||||
// [1 0 -1 0]
|
||||
// [0 1 1 0]
|
||||
// [0 -1 1 0]
|
||||
// [0 1 0 -1]
|
||||
//
|
||||
const int rows = 4;
|
||||
const int cols = 4;
|
||||
|
||||
float transform_matrix[] = {1, 0, -1, 0, 0, 1, 1, 0,
|
||||
0, -1, 1, 0, 0, 1, 0, -1};
|
||||
|
||||
const int kron_rows = rows * rows;
|
||||
const int kron_cols = cols * cols;
|
||||
|
||||
float transform_matrix_kron[kron_rows * kron_cols];
|
||||
|
||||
ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
|
||||
&transform_matrix_kron[0]);
|
||||
|
||||
float transform_matrix_test[kron_rows * kron_cols];
|
||||
WinogradTransform<float> t;
|
||||
t.GetInputTransformMatrix(kron_rows, kron_cols, &transform_matrix_test[0]);
|
||||
|
||||
for (int i = 0; i < kron_rows * kron_cols; ++i) {
|
||||
EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeepConv2DTransformTest, WingradOutputTransformMatrix) {
|
||||
// Test that the filter transform matrix returned is the kronecker product of
|
||||
// the following matrix:
|
||||
//
|
||||
// [1 1 1 0]
|
||||
// [0 1 -1 -1]
|
||||
//
|
||||
const int rows = 2;
|
||||
const int cols = 4;
|
||||
|
||||
float transform_matrix[] = {1, 1, 1, 0, 0, 1, -1, -1};
|
||||
|
||||
const int kron_rows = rows * rows;
|
||||
const int kron_cols = cols * cols;
|
||||
|
||||
float transform_matrix_kron[kron_rows * kron_cols];
|
||||
|
||||
ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
|
||||
&transform_matrix_kron[0]);
|
||||
|
||||
float transform_matrix_test[kron_rows * kron_cols];
|
||||
WinogradTransform<float> t;
|
||||
t.GetOutputTransformMatrix(kron_rows, kron_cols, &transform_matrix_test[0]);
|
||||
|
||||
for (int i = 0; i < kron_rows * kron_cols; ++i) {
|
||||
EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
377
tensorflow/core/kernels/winograd_transform.h
Normal file
377
tensorflow/core/kernels/winograd_transform.h
Normal file
@ -0,0 +1,377 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
|
||||
|
||||
#include "tensorflow/core/kernels/deep_conv2d.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Winograd DeepConv2DTransform implementation for 3x3 filters.
|
||||
// Details:
|
||||
// *) Arithmetic complexity of computations: Shmuel Winograd
|
||||
// *) Fast Algorithms for Convolutional Neural Networks: Lavin, Gray
|
||||
|
||||
template <typename T>
|
||||
class WinogradTransform : public DeepConv2DTransform<T> {
|
||||
public:
|
||||
typedef typename DeepConv2DTransform<T>::Shape Shape;
|
||||
|
||||
WinogradTransform()
|
||||
: filter_shape_(3, 3), input_shape_(4, 4), output_shape_(2, 2) {}
|
||||
|
||||
virtual void GetFilterTransformMatrix(const int64 rows, const int64 cols,
|
||||
T* transform_matrix) const;
|
||||
|
||||
virtual void GetInputTransformMatrix(const int64 rows, const int64 cols,
|
||||
T* transform_matrix) const;
|
||||
|
||||
virtual void GetOutputTransformMatrix(const int64 rows, const int64 cols,
|
||||
T* transform_matrix) const;
|
||||
|
||||
virtual const Shape& filter_shape() const { return filter_shape_; }
|
||||
virtual const Shape& input_shape() const { return input_shape_; }
|
||||
virtual const Shape& output_shape() const { return output_shape_; }
|
||||
|
||||
private:
|
||||
const Shape filter_shape_;
|
||||
const Shape input_shape_;
|
||||
const Shape output_shape_;
|
||||
};
|
||||
|
||||
// The filter transform matrix is the kronecker product 'M * M' of the
|
||||
// following matrix 'M':
|
||||
//
|
||||
// [ 1 0 0 ]
|
||||
// [ 1/2 1/2 1/2 ]
|
||||
// [ 1/2 -1/2 1/2 ]
|
||||
// [ 0 0 1 ]
|
||||
//
|
||||
// The data layout of 'transform_matrix':
|
||||
// [input_tile_spatial_size, filter_spatial_size]
|
||||
//
|
||||
template <typename T>
|
||||
void WinogradTransform<T>::GetFilterTransformMatrix(const int64 rows,
|
||||
const int64 cols,
|
||||
T* transform_matrix) const {
|
||||
CHECK_GT(rows, 0);
|
||||
CHECK_GT(cols, 0);
|
||||
memset(transform_matrix, 0, sizeof(T) * rows * cols);
|
||||
|
||||
// Sub matrix [0,0]
|
||||
transform_matrix[0 * cols + 0] = T(1.0);
|
||||
|
||||
transform_matrix[1 * cols + 0] = T(0.5);
|
||||
transform_matrix[1 * cols + 1] = T(0.5);
|
||||
transform_matrix[1 * cols + 2] = T(0.5);
|
||||
|
||||
transform_matrix[2 * cols + 0] = T(0.5);
|
||||
transform_matrix[2 * cols + 1] = T(-0.5);
|
||||
transform_matrix[2 * cols + 2] = T(0.5);
|
||||
|
||||
transform_matrix[3 * cols + 2] = T(1.0);
|
||||
|
||||
// Sub matrix [1,0]
|
||||
transform_matrix[4 * cols + 0] = T(0.5);
|
||||
|
||||
transform_matrix[5 * cols + 0] = T(0.25);
|
||||
transform_matrix[5 * cols + 1] = T(0.25);
|
||||
transform_matrix[5 * cols + 2] = T(0.25);
|
||||
|
||||
transform_matrix[6 * cols + 0] = T(0.25);
|
||||
transform_matrix[6 * cols + 1] = T(-0.25);
|
||||
transform_matrix[6 * cols + 2] = T(0.25);
|
||||
|
||||
transform_matrix[7 * cols + 2] = T(0.5);
|
||||
|
||||
// Sub matrix [1,1]
|
||||
transform_matrix[4 * cols + 3] = T(0.5);
|
||||
|
||||
transform_matrix[5 * cols + 3] = T(0.25);
|
||||
transform_matrix[5 * cols + 4] = T(0.25);
|
||||
transform_matrix[5 * cols + 5] = T(0.25);
|
||||
|
||||
transform_matrix[6 * cols + 3] = T(0.25);
|
||||
transform_matrix[6 * cols + 4] = T(-0.25);
|
||||
transform_matrix[6 * cols + 5] = T(0.25);
|
||||
|
||||
transform_matrix[7 * cols + 5] = T(0.5);
|
||||
|
||||
// Sub matrix [1,2]
|
||||
transform_matrix[4 * cols + 6] = T(0.5);
|
||||
|
||||
transform_matrix[5 * cols + 6] = T(0.25);
|
||||
transform_matrix[5 * cols + 7] = T(0.25);
|
||||
transform_matrix[5 * cols + 8] = T(0.25);
|
||||
|
||||
transform_matrix[6 * cols + 6] = T(0.25);
|
||||
transform_matrix[6 * cols + 7] = T(-0.25);
|
||||
transform_matrix[6 * cols + 8] = T(0.25);
|
||||
|
||||
transform_matrix[7 * cols + 8] = T(0.5);
|
||||
|
||||
// Sub matrix [2,0]
|
||||
transform_matrix[8 * cols + 0] = T(0.5);
|
||||
|
||||
transform_matrix[9 * cols + 0] = T(0.25);
|
||||
transform_matrix[9 * cols + 1] = T(0.25);
|
||||
transform_matrix[9 * cols + 2] = T(0.25);
|
||||
|
||||
transform_matrix[10 * cols + 0] = T(0.25);
|
||||
transform_matrix[10 * cols + 1] = T(-0.25);
|
||||
transform_matrix[10 * cols + 2] = T(0.25);
|
||||
|
||||
transform_matrix[11 * cols + 2] = T(0.5);
|
||||
|
||||
// Sub matrix [2,1]
|
||||
transform_matrix[8 * cols + 3] = T(-0.5);
|
||||
|
||||
transform_matrix[9 * cols + 3] = T(-0.25);
|
||||
transform_matrix[9 * cols + 4] = T(-0.25);
|
||||
transform_matrix[9 * cols + 5] = T(-0.25);
|
||||
|
||||
transform_matrix[10 * cols + 3] = T(-0.25);
|
||||
transform_matrix[10 * cols + 4] = T(0.25);
|
||||
transform_matrix[10 * cols + 5] = T(-0.25);
|
||||
|
||||
transform_matrix[11 * cols + 5] = T(-0.5);
|
||||
|
||||
// Sub matrix [2,2]
|
||||
transform_matrix[8 * cols + 6] = T(0.5);
|
||||
|
||||
transform_matrix[9 * cols + 6] = T(0.25);
|
||||
transform_matrix[9 * cols + 7] = T(0.25);
|
||||
transform_matrix[9 * cols + 8] = T(0.25);
|
||||
|
||||
transform_matrix[10 * cols + 6] = T(0.25);
|
||||
transform_matrix[10 * cols + 7] = T(-0.25);
|
||||
transform_matrix[10 * cols + 8] = T(0.25);
|
||||
|
||||
transform_matrix[11 * cols + 8] = T(0.5);
|
||||
|
||||
// Sub matrix [3,2]
|
||||
transform_matrix[12 * cols + 6] = T(1.0);
|
||||
|
||||
transform_matrix[13 * cols + 6] = T(0.5);
|
||||
transform_matrix[13 * cols + 7] = T(0.5);
|
||||
transform_matrix[13 * cols + 8] = T(0.5);
|
||||
|
||||
transform_matrix[14 * cols + 6] = T(0.5);
|
||||
transform_matrix[14 * cols + 7] = T(-0.5);
|
||||
transform_matrix[14 * cols + 8] = T(0.5);
|
||||
|
||||
transform_matrix[15 * cols + 8] = T(1.0);
|
||||
}
|
||||
|
||||
// The input transform matrix is the kronecker product 'M * M' of the
|
||||
// following matrix 'M':
|
||||
//
|
||||
// [1 0 -1 0]
|
||||
// [0 1 1 0]
|
||||
// [0 -1 1 0]
|
||||
// [0 1 0 -1]
|
||||
//
|
||||
// Data layout of 'transform_matrix':
|
||||
// [tile_spatial_size, tile_spatial_size]
|
||||
//
|
||||
template <typename T>
|
||||
void WinogradTransform<T>::GetInputTransformMatrix(const int64 rows,
|
||||
const int64 cols,
|
||||
T* transform_matrix) const {
|
||||
CHECK_GT(rows, 0);
|
||||
CHECK_GT(cols, 0);
|
||||
memset(transform_matrix, 0, sizeof(T) * rows * cols);
|
||||
|
||||
// Sub matrix [0,0]
|
||||
transform_matrix[0 * cols + 0] = T(1.0);
|
||||
transform_matrix[0 * cols + 2] = T(-1.0);
|
||||
|
||||
transform_matrix[1 * cols + 1] = T(1.0);
|
||||
transform_matrix[1 * cols + 2] = T(1.0);
|
||||
|
||||
transform_matrix[2 * cols + 1] = T(-1.0);
|
||||
transform_matrix[2 * cols + 2] = T(1.0);
|
||||
|
||||
transform_matrix[3 * cols + 1] = T(1.0);
|
||||
transform_matrix[3 * cols + 3] = T(-1.0);
|
||||
|
||||
// Sub matrix [0,2]
|
||||
transform_matrix[0 * cols + 8] = T(-1.0);
|
||||
transform_matrix[0 * cols + 10] = T(1.0);
|
||||
|
||||
transform_matrix[1 * cols + 9] = T(-1.0);
|
||||
transform_matrix[1 * cols + 10] = T(-1.0);
|
||||
|
||||
transform_matrix[2 * cols + 9] = T(1.0);
|
||||
transform_matrix[2 * cols + 10] = T(-1.0);
|
||||
|
||||
transform_matrix[3 * cols + 9] = T(-1.0);
|
||||
transform_matrix[3 * cols + 11] = T(1.0);
|
||||
|
||||
// Sub matrix [1,1]
|
||||
transform_matrix[4 * cols + 4] = T(1.0);
|
||||
transform_matrix[4 * cols + 6] = T(-1.0);
|
||||
|
||||
transform_matrix[5 * cols + 5] = T(1.0);
|
||||
transform_matrix[5 * cols + 6] = T(1.0);
|
||||
|
||||
transform_matrix[6 * cols + 5] = T(-1.0);
|
||||
transform_matrix[6 * cols + 6] = T(1.0);
|
||||
|
||||
transform_matrix[7 * cols + 5] = T(1.0);
|
||||
transform_matrix[7 * cols + 7] = T(-1.0);
|
||||
|
||||
// Sub matrix [1,2]
|
||||
transform_matrix[4 * cols + 8] = T(1.0);
|
||||
transform_matrix[4 * cols + 10] = T(-1.0);
|
||||
|
||||
transform_matrix[5 * cols + 9] = T(1.0);
|
||||
transform_matrix[5 * cols + 10] = T(1.0);
|
||||
|
||||
transform_matrix[6 * cols + 9] = T(-1.0);
|
||||
transform_matrix[6 * cols + 10] = T(1.0);
|
||||
|
||||
transform_matrix[7 * cols + 9] = T(1.0);
|
||||
transform_matrix[7 * cols + 11] = T(-1.0);
|
||||
|
||||
// Sub matrix [2,1]
|
||||
transform_matrix[8 * cols + 4] = T(-1.0);
|
||||
transform_matrix[8 * cols + 6] = T(1.0);
|
||||
|
||||
transform_matrix[9 * cols + 5] = T(-1.0);
|
||||
transform_matrix[9 * cols + 6] = T(-1.0);
|
||||
|
||||
transform_matrix[10 * cols + 5] = T(1.0);
|
||||
transform_matrix[10 * cols + 6] = T(-1.0);
|
||||
|
||||
transform_matrix[11 * cols + 5] = T(-1.0);
|
||||
transform_matrix[11 * cols + 7] = T(1.0);
|
||||
|
||||
// Sub matrix [2,2]
|
||||
transform_matrix[8 * cols + 8] = T(1.0);
|
||||
transform_matrix[8 * cols + 10] = T(-1.0);
|
||||
|
||||
transform_matrix[9 * cols + 9] = T(1.0);
|
||||
transform_matrix[9 * cols + 10] = T(1.0);
|
||||
|
||||
transform_matrix[10 * cols + 9] = T(-1.0);
|
||||
transform_matrix[10 * cols + 10] = T(1.0);
|
||||
|
||||
transform_matrix[11 * cols + 9] = T(1.0);
|
||||
transform_matrix[11 * cols + 11] = T(-1.0);
|
||||
|
||||
// Sub matrix [3,1]
|
||||
transform_matrix[12 * cols + 4] = T(1.0);
|
||||
transform_matrix[12 * cols + 6] = T(-1.0);
|
||||
|
||||
transform_matrix[13 * cols + 5] = T(1.0);
|
||||
transform_matrix[13 * cols + 6] = T(1.0);
|
||||
|
||||
transform_matrix[14 * cols + 5] = T(-1.0);
|
||||
transform_matrix[14 * cols + 6] = T(1.0);
|
||||
|
||||
transform_matrix[15 * cols + 5] = T(1.0);
|
||||
transform_matrix[15 * cols + 7] = T(-1.0);
|
||||
|
||||
// Sub matrix [3,3]
|
||||
transform_matrix[12 * cols + 12] = T(-1.0);
|
||||
transform_matrix[12 * cols + 14] = T(1.0);
|
||||
|
||||
transform_matrix[13 * cols + 13] = T(-1.0);
|
||||
transform_matrix[13 * cols + 14] = T(-1.0);
|
||||
|
||||
transform_matrix[14 * cols + 13] = T(1.0);
|
||||
transform_matrix[14 * cols + 14] = T(-1.0);
|
||||
|
||||
transform_matrix[15 * cols + 13] = T(-1.0);
|
||||
transform_matrix[15 * cols + 15] = T(1.0);
|
||||
};
|
||||
|
||||
// The output transform matrix is the kronecker product 'M * M' of the
|
||||
// following matrix 'M':
|
||||
//
|
||||
// [1 1 1 0]
|
||||
// [0 1 -1 -1]
|
||||
//
|
||||
// Data layout of 'transform_matrix':
|
||||
// [out_tile_spatial_size, tile_spatial_size]
|
||||
//
|
||||
template <typename T>
|
||||
void WinogradTransform<T>::GetOutputTransformMatrix(const int64 rows,
|
||||
const int64 cols,
|
||||
T* transform_matrix) const {
|
||||
CHECK_GT(rows, 0);
|
||||
CHECK_GT(cols, 0);
|
||||
memset(transform_matrix, 0, sizeof(T) * rows * cols);
|
||||
|
||||
// Sub matrix [0,0]
|
||||
transform_matrix[0 * cols + 0] = T(1.0);
|
||||
transform_matrix[0 * cols + 1] = T(1.0);
|
||||
transform_matrix[0 * cols + 2] = T(1.0);
|
||||
|
||||
transform_matrix[1 * cols + 1] = T(1.0);
|
||||
transform_matrix[1 * cols + 2] = T(-1.0);
|
||||
transform_matrix[1 * cols + 3] = T(-1.0);
|
||||
|
||||
// Sub matrix [0,1]
|
||||
transform_matrix[0 * cols + 4] = T(1.0);
|
||||
transform_matrix[0 * cols + 5] = T(1.0);
|
||||
transform_matrix[0 * cols + 6] = T(1.0);
|
||||
|
||||
transform_matrix[1 * cols + 5] = T(1.0);
|
||||
transform_matrix[1 * cols + 6] = T(-1.0);
|
||||
transform_matrix[1 * cols + 7] = T(-1.0);
|
||||
|
||||
// Sub matrix [0,2]
|
||||
transform_matrix[0 * cols + 8] = T(1.0);
|
||||
transform_matrix[0 * cols + 9] = T(1.0);
|
||||
transform_matrix[0 * cols + 10] = T(1.0);
|
||||
|
||||
transform_matrix[1 * cols + 9] = T(1.0);
|
||||
transform_matrix[1 * cols + 10] = T(-1.0);
|
||||
transform_matrix[1 * cols + 11] = T(-1.0);
|
||||
|
||||
// Sub matrix [1,1]
|
||||
transform_matrix[2 * cols + 4] = T(1.0);
|
||||
transform_matrix[2 * cols + 5] = T(1.0);
|
||||
transform_matrix[2 * cols + 6] = T(1.0);
|
||||
|
||||
transform_matrix[3 * cols + 5] = T(1.0);
|
||||
transform_matrix[3 * cols + 6] = T(-1.0);
|
||||
transform_matrix[3 * cols + 7] = T(-1.0);
|
||||
|
||||
// Sub matrix [1,2]
|
||||
transform_matrix[2 * cols + 8] = T(-1.0);
|
||||
transform_matrix[2 * cols + 9] = T(-1.0);
|
||||
transform_matrix[2 * cols + 10] = T(-1.0);
|
||||
|
||||
transform_matrix[3 * cols + 9] = T(-1.0);
|
||||
transform_matrix[3 * cols + 10] = T(1.0);
|
||||
transform_matrix[3 * cols + 11] = T(1.0);
|
||||
|
||||
// Sub matrix [1,3]
|
||||
transform_matrix[2 * cols + 12] = T(-1.0);
|
||||
transform_matrix[2 * cols + 13] = T(-1.0);
|
||||
transform_matrix[2 * cols + 14] = T(-1.0);
|
||||
|
||||
transform_matrix[3 * cols + 13] = T(-1.0);
|
||||
transform_matrix[3 * cols + 14] = T(1.0);
|
||||
transform_matrix[3 * cols + 15] = T(1.0);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@ -1148,6 +1150,53 @@ class SeparableConv2DTest(tf.test.TestCase):
|
||||
expected=None)
|
||||
|
||||
|
||||
class DeepConv2DTest(tf.test.TestCase):
|
||||
|
||||
def _CompareFwdConv2D(self, tensor_in_sizes, filter_in_sizes,
|
||||
conv_strides, padding):
|
||||
"""Verifies that DeepConv2D and Conv2D produce the same values.
|
||||
|
||||
Args:
|
||||
tensor_in_sizes: Input tensor dimensions in
|
||||
[batch, input_rows, input_cols, input_depth].
|
||||
filter_in_sizes: Filter tensor dimensions in
|
||||
[kernel_rows, kernel_cols, input_depth, output_depth].
|
||||
conv_strides: [row_stride, col_stride] for the convolution;
|
||||
padding: Padding type.
|
||||
"""
|
||||
x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
|
||||
x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
|
||||
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
t1 = tf.constant(x1, shape=tensor_in_sizes)
|
||||
t2 = tf.constant(x2, shape=filter_in_sizes)
|
||||
strides = [1] + conv_strides + [1]
|
||||
|
||||
conv = tf.nn.conv2d(t1, t2, strides=strides, padding=padding)
|
||||
|
||||
os.environ["TF_USE_DEEP_CONV2D"] = "0"
|
||||
values_expect = sess.run([conv])
|
||||
|
||||
os.environ["TF_USE_DEEP_CONV2D"] = "1"
|
||||
values_test = sess.run([conv])
|
||||
|
||||
self.assertAllClose(values_expect, values_test, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def _RunTestCases(self, conv_strides, padding):
|
||||
input_sizes = [[5, 5, 5, 1248], [3, 17, 17, 192], [2, 35, 35, 288],
|
||||
[2, 6, 8, 517], [2, 7, 4, 81], [3, 11, 3, 77]]
|
||||
filter_sizes = [[3, 3, 1248, 128], [3, 3, 192, 192], [3, 3, 288, 384],
|
||||
[3, 3, 517, 64], [3, 3, 81, 77], [3, 3, 77, 181]]
|
||||
for input_shape, filter_shape in zip(input_sizes, filter_sizes):
|
||||
self._CompareFwdConv2D(input_shape, filter_shape, conv_strides, padding)
|
||||
|
||||
def testConv2D3x3FilterStride1x1Valid(self):
|
||||
self._RunTestCases([1, 1], "VALID")
|
||||
|
||||
def testConv2D3x3FilterStride1x1Same(self):
|
||||
self._RunTestCases([1, 1], "SAME")
|
||||
|
||||
|
||||
def GetInceptionFwdTest(input_size, filter_size, stride, padding):
|
||||
def Test(self):
|
||||
tf.logging.info("Testing InceptionFwd %s", (input_size, filter_size,
|
||||
|
Loading…
Reference in New Issue
Block a user