clang-format

This commit is contained in:
Hoeseong Kim 2018-09-15 12:46:58 +09:00
parent b3ec2caeee
commit 33f57bd131
4 changed files with 44 additions and 36 deletions

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
/* /*
See extract_image_patches_op* files and docs for extract_image_patches in See extract_image_patches_op* files and docs for extract_image_patches in
../ops/image_ops.cc. ../ops/image_ops.cc.
Rates are not supported as of now, but the comments hint how to edit the code Rates are not supported as of now, but the comments hint how to edit the code
@ -60,7 +60,7 @@ class ExtractVolumePatchesOp : public UnaryOp<T> {
: UnaryOp<T>(context) { : UnaryOp<T>(context) {
ParseAttributeVec5(context, "ksizes", &ksizes_); ParseAttributeVec5(context, "ksizes", &ksizes_);
ParseAttributeVec5(context, "strides", &strides_); ParseAttributeVec5(context, "strides", &strides_);
//ParseAttributeVec5(context, "rates", &rates_); // ParseAttributeVec5(context, "rates", &rates_);
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
} }
@ -88,18 +88,20 @@ class ExtractVolumePatchesOp : public UnaryOp<T> {
/* /*
// TODO(hsgkim): enable rates // TODO(hsgkim): enable rates
// Rates are disabled as of now due to Eigen's definitions of extract_volume_patch // Rates are disabled as of now due to Eigen's definitions of
// functions; none of them accept rates as its argument and rates are fixed to // `extract_volume_patch` functions; none of them accept rates
// (1, 1, 1, 1, 1). A workaround has to be found for this. // as its argument and rates are fixed to (1, 1, 1, 1, 1). A
// workaround has to be found for this.
// In order to enable rates, uncomment the following lines and use // In order to enable rates, uncomment the following lines and use
// ksize_*_eff instead of ksize_* for the second argument of GetWindowedOutputSize // ksize_*_eff instead of ksize_* for the second argument of
// calls. // GetWindowedOutputSize calls.
const int rate_planes = rates_[1]; const int rate_planes = rates_[1];
const int rate_rows = rates_[2]; const int rate_rows = rates_[2];
const int rate_cols = rates_[3]; const int rate_cols = rates_[3];
const int ksize_planes_eff = ksize_planes + (ksize_planes - 1) * (rate_planes - 1); const int ksize_planes_eff = ksize_planes +
(ksize_planes - 1) * (rate_planes - 1);
const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
*/ */
@ -116,8 +118,9 @@ class ExtractVolumePatchesOp : public UnaryOp<T> {
GetWindowedOutputSize(in_cols, ksize_cols, stride_cols, GetWindowedOutputSize(in_cols, ksize_cols, stride_cols,
padding_, &out_cols, &pad_cols)); padding_, &out_cols, &pad_cols));
const std::vector<int64> out_sizes = {batch, out_planes, out_rows, out_cols, const std::vector<int64> out_sizes = {
ksize_planes * ksize_rows * ksize_cols * depth}; batch, out_planes, out_rows, out_cols,
ksize_planes * ksize_rows * ksize_cols * depth};
TensorShape out_shape(out_sizes); TensorShape out_shape(out_sizes);
Tensor* output = nullptr; Tensor* output = nullptr;
@ -129,9 +132,8 @@ class ExtractVolumePatchesOp : public UnaryOp<T> {
} }
functor::ExtractVolumePatchesForward<Device, T>()( functor::ExtractVolumePatchesForward<Device, T>()(
context->eigen_device<Device>(), input.tensor<T, 5>(), context->eigen_device<Device>(), input.tensor<T, 5>(), ksize_planes,
ksize_planes, ksize_rows, ksize_cols, ksize_rows, ksize_cols, stride_planes, stride_rows, stride_cols,
stride_planes, stride_rows, stride_cols,
/* rate_planes, rate_rows, rate_cols, */ /* rate_planes, rate_rows, rate_cols, */
BrainPadding2EigenPadding(padding_), output->tensor<T, 5>()); BrainPadding2EigenPadding(padding_), output->tensor<T, 5>());
} }
@ -161,16 +163,18 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER);
// Forward declarations of the functor specializations for GPU. // Forward declarations of the functor specializations for GPU.
namespace functor { namespace functor {
#define DECLARE_GPU_SPEC(T) \ // clang-format off
template <> \ #define DECLARE_GPU_SPEC(T) \
void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \ template <> \
const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \ void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \
int patch_planes, int patch_rows, int patch_cols, \ const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \
int stride_planes, int stride_rows, int stride_cols, \ int patch_planes, int patch_rows, int patch_cols, \
/* int rate_planes, int rate_rows, int rate_cols, */ \ int stride_planes, int stride_rows, int stride_cols, \
const Eigen::PaddingType& padding, \ /* int rate_planes, int rate_rows, int rate_cols, */ \
typename TTypes<T, 5>::Tensor output); \ const Eigen::PaddingType& padding, \
typename TTypes<T, 5>::Tensor output); \
extern template struct ExtractVolumePatchesForward<GPUDevice, T>; extern template struct ExtractVolumePatchesForward<GPUDevice, T>;
// clang-format on
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);

View File

@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_ #ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_ #define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/eigen_volume_patch.h" #include "tensorflow/core/kernels/eigen_volume_patch.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow { namespace tensorflow {
namespace functor { namespace functor {
@ -27,7 +27,7 @@ namespace functor {
template <typename Device, typename T> template <typename Device, typename T>
struct ExtractVolumePatchesForward { struct ExtractVolumePatchesForward {
void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input, void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
int patch_planes, int patch_rows, int patch_cols, int patch_planes, int patch_rows, int patch_cols,
int stride_planes, int stride_rows, int stride_cols, int stride_planes, int stride_rows, int stride_cols,
/* int rate_planes, int rate_rows, int rate_cols, */ /* int rate_planes, int rate_rows, int rate_cols, */
const Eigen::PaddingType& padding, const Eigen::PaddingType& padding,
@ -38,15 +38,15 @@ struct ExtractVolumePatchesForward {
output_32bit.device(d) = output_32bit.device(d) =
To32Bit(input) To32Bit(input)
.extract_volume_patches(patch_cols, patch_rows, patch_planes, .extract_volume_patches(patch_cols, patch_rows, patch_planes,
stride_cols, stride_rows, stride_planes, stride_cols, stride_rows, stride_planes,
padding) padding)
.reshape(output_32bit.dimensions()); .reshape(output_32bit.dimensions());
} else { } else {
output.device(d) = output.device(d) =
input input
.extract_volume_patches(patch_cols, patch_rows, patch_planes, .extract_volume_patches(patch_cols, patch_rows, patch_planes,
stride_cols, stride_rows, stride_planes, stride_cols, stride_rows, stride_planes,
padding) padding)
.reshape(output.dimensions()); .reshape(output.dimensions());
} }
} }

View File

@ -17,8 +17,8 @@ limitations under the License.
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/extract_volume_patches_op.h" #include "tensorflow/core/kernels/extract_volume_patches_op.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow { namespace tensorflow {

View File

@ -2609,7 +2609,8 @@ REGISTER_OP("ExtractVolumePatches")
int32 rate_rows = rates[2]; int32 rate_rows = rates[2];
int32 rate_cols = rates[3]; int32 rate_cols = rates[3];
int32 ksize_planes_eff = ksize_planes + (ksize_planes - 1) * (rate_planes - 1); int32 ksize_planes_eff = ksize_planes +
(ksize_planes - 1) * (rate_planes - 1);
int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
*/ */
@ -2619,10 +2620,12 @@ REGISTER_OP("ExtractVolumePatches")
DimensionHandle in_rows_dim = c->Dim(input_shape, 2); DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
DimensionHandle in_cols_dim = c->Dim(input_shape, 3); DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
DimensionHandle output_depth_dim; DimensionHandle output_depth_dim;
TF_RETURN_IF_ERROR(c->Multiply( TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
c->Dim(input_shape, 4), ksize_planes * ksize_rows * ksize_cols, &output_depth_dim)); ksize_planes * ksize_rows * ksize_cols,
&output_depth_dim));
if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) { if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
!c->ValueKnown(in_cols_dim)) {
ShapeHandle output_shape = ShapeHandle output_shape =
c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim, output_depth_dim}); InferenceContext::kUnknownDim, output_depth_dim});
@ -2647,8 +2650,9 @@ REGISTER_OP("ExtractVolumePatches")
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
in_cols, ksize_cols, stride_cols, padding, &output_cols, in_cols, ksize_cols, stride_cols, padding, &output_cols,
&padding_before, &padding_after)); &padding_before, &padding_after));
ShapeHandle output_shape = c->MakeShape( ShapeHandle output_shape =
{batch_size_dim, output_planes, output_rows, output_cols, output_depth_dim}); c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
output_depth_dim});
c->set_output(0, output_shape); c->set_output(0, output_shape);
return Status::OK(); return Status::OK();
}); });