clang-format
This commit is contained in:
parent
b3ec2caeee
commit
33f57bd131
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
/*
|
/*
|
||||||
See extract_image_patches_op* files and docs for extract_image_patches in
|
See extract_image_patches_op* files and docs for extract_image_patches in
|
||||||
../ops/image_ops.cc.
|
../ops/image_ops.cc.
|
||||||
|
|
||||||
Rates are not supported as of now, but the comments hint how to edit the code
|
Rates are not supported as of now, but the comments hint how to edit the code
|
||||||
@ -60,7 +60,7 @@ class ExtractVolumePatchesOp : public UnaryOp<T> {
|
|||||||
: UnaryOp<T>(context) {
|
: UnaryOp<T>(context) {
|
||||||
ParseAttributeVec5(context, "ksizes", &ksizes_);
|
ParseAttributeVec5(context, "ksizes", &ksizes_);
|
||||||
ParseAttributeVec5(context, "strides", &strides_);
|
ParseAttributeVec5(context, "strides", &strides_);
|
||||||
//ParseAttributeVec5(context, "rates", &rates_);
|
// ParseAttributeVec5(context, "rates", &rates_);
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,18 +88,20 @@ class ExtractVolumePatchesOp : public UnaryOp<T> {
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
// TODO(hsgkim): enable rates
|
// TODO(hsgkim): enable rates
|
||||||
// Rates are disabled as of now due to Eigen's definitions of extract_volume_patch
|
// Rates are disabled as of now due to Eigen's definitions of
|
||||||
// functions; none of them accept rates as its argument and rates are fixed to
|
// `extract_volume_patch` functions; none of them accept rates
|
||||||
// (1, 1, 1, 1, 1). A workaround has to be found for this.
|
// as its argument and rates are fixed to (1, 1, 1, 1, 1). A
|
||||||
|
// workaround has to be found for this.
|
||||||
// In order to enable rates, uncomment the following lines and use
|
// In order to enable rates, uncomment the following lines and use
|
||||||
// ksize_*_eff instead of ksize_* for the second argument of GetWindowedOutputSize
|
// ksize_*_eff instead of ksize_* for the second argument of
|
||||||
// calls.
|
// GetWindowedOutputSize calls.
|
||||||
|
|
||||||
const int rate_planes = rates_[1];
|
const int rate_planes = rates_[1];
|
||||||
const int rate_rows = rates_[2];
|
const int rate_rows = rates_[2];
|
||||||
const int rate_cols = rates_[3];
|
const int rate_cols = rates_[3];
|
||||||
|
|
||||||
const int ksize_planes_eff = ksize_planes + (ksize_planes - 1) * (rate_planes - 1);
|
const int ksize_planes_eff = ksize_planes +
|
||||||
|
(ksize_planes - 1) * (rate_planes - 1);
|
||||||
const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
|
const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
|
||||||
const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
|
const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
|
||||||
*/
|
*/
|
||||||
@ -116,8 +118,9 @@ class ExtractVolumePatchesOp : public UnaryOp<T> {
|
|||||||
GetWindowedOutputSize(in_cols, ksize_cols, stride_cols,
|
GetWindowedOutputSize(in_cols, ksize_cols, stride_cols,
|
||||||
padding_, &out_cols, &pad_cols));
|
padding_, &out_cols, &pad_cols));
|
||||||
|
|
||||||
const std::vector<int64> out_sizes = {batch, out_planes, out_rows, out_cols,
|
const std::vector<int64> out_sizes = {
|
||||||
ksize_planes * ksize_rows * ksize_cols * depth};
|
batch, out_planes, out_rows, out_cols,
|
||||||
|
ksize_planes * ksize_rows * ksize_cols * depth};
|
||||||
TensorShape out_shape(out_sizes);
|
TensorShape out_shape(out_sizes);
|
||||||
|
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
@ -129,9 +132,8 @@ class ExtractVolumePatchesOp : public UnaryOp<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
functor::ExtractVolumePatchesForward<Device, T>()(
|
functor::ExtractVolumePatchesForward<Device, T>()(
|
||||||
context->eigen_device<Device>(), input.tensor<T, 5>(),
|
context->eigen_device<Device>(), input.tensor<T, 5>(), ksize_planes,
|
||||||
ksize_planes, ksize_rows, ksize_cols,
|
ksize_rows, ksize_cols, stride_planes, stride_rows, stride_cols,
|
||||||
stride_planes, stride_rows, stride_cols,
|
|
||||||
/* rate_planes, rate_rows, rate_cols, */
|
/* rate_planes, rate_rows, rate_cols, */
|
||||||
BrainPadding2EigenPadding(padding_), output->tensor<T, 5>());
|
BrainPadding2EigenPadding(padding_), output->tensor<T, 5>());
|
||||||
}
|
}
|
||||||
@ -161,16 +163,18 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER);
|
|||||||
// Forward declarations of the functor specializations for GPU.
|
// Forward declarations of the functor specializations for GPU.
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
#define DECLARE_GPU_SPEC(T) \
|
// clang-format off
|
||||||
template <> \
|
#define DECLARE_GPU_SPEC(T) \
|
||||||
void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \
|
template <> \
|
||||||
const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \
|
void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \
|
||||||
int patch_planes, int patch_rows, int patch_cols, \
|
const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \
|
||||||
int stride_planes, int stride_rows, int stride_cols, \
|
int patch_planes, int patch_rows, int patch_cols, \
|
||||||
/* int rate_planes, int rate_rows, int rate_cols, */ \
|
int stride_planes, int stride_rows, int stride_cols, \
|
||||||
const Eigen::PaddingType& padding, \
|
/* int rate_planes, int rate_rows, int rate_cols, */ \
|
||||||
typename TTypes<T, 5>::Tensor output); \
|
const Eigen::PaddingType& padding, \
|
||||||
|
typename TTypes<T, 5>::Tensor output); \
|
||||||
extern template struct ExtractVolumePatchesForward<GPUDevice, T>;
|
extern template struct ExtractVolumePatchesForward<GPUDevice, T>;
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||||
|
|
||||||
|
@ -16,10 +16,10 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
|
#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
|
||||||
#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
|
#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/kernels/eigen_volume_patch.h"
|
#include "tensorflow/core/kernels/eigen_volume_patch.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
@ -27,7 +27,7 @@ namespace functor {
|
|||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
struct ExtractVolumePatchesForward {
|
struct ExtractVolumePatchesForward {
|
||||||
void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
|
void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
|
||||||
int patch_planes, int patch_rows, int patch_cols,
|
int patch_planes, int patch_rows, int patch_cols,
|
||||||
int stride_planes, int stride_rows, int stride_cols,
|
int stride_planes, int stride_rows, int stride_cols,
|
||||||
/* int rate_planes, int rate_rows, int rate_cols, */
|
/* int rate_planes, int rate_rows, int rate_cols, */
|
||||||
const Eigen::PaddingType& padding,
|
const Eigen::PaddingType& padding,
|
||||||
@ -38,15 +38,15 @@ struct ExtractVolumePatchesForward {
|
|||||||
output_32bit.device(d) =
|
output_32bit.device(d) =
|
||||||
To32Bit(input)
|
To32Bit(input)
|
||||||
.extract_volume_patches(patch_cols, patch_rows, patch_planes,
|
.extract_volume_patches(patch_cols, patch_rows, patch_planes,
|
||||||
stride_cols, stride_rows, stride_planes,
|
stride_cols, stride_rows, stride_planes,
|
||||||
padding)
|
padding)
|
||||||
.reshape(output_32bit.dimensions());
|
.reshape(output_32bit.dimensions());
|
||||||
} else {
|
} else {
|
||||||
output.device(d) =
|
output.device(d) =
|
||||||
input
|
input
|
||||||
.extract_volume_patches(patch_cols, patch_rows, patch_planes,
|
.extract_volume_patches(patch_cols, patch_rows, patch_planes,
|
||||||
stride_cols, stride_rows, stride_planes,
|
stride_cols, stride_rows, stride_planes,
|
||||||
padding)
|
padding)
|
||||||
.reshape(output.dimensions());
|
.reshape(output.dimensions());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
|
||||||
#include "tensorflow/core/kernels/extract_volume_patches_op.h"
|
#include "tensorflow/core/kernels/extract_volume_patches_op.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
@ -2609,7 +2609,8 @@ REGISTER_OP("ExtractVolumePatches")
|
|||||||
int32 rate_rows = rates[2];
|
int32 rate_rows = rates[2];
|
||||||
int32 rate_cols = rates[3];
|
int32 rate_cols = rates[3];
|
||||||
|
|
||||||
int32 ksize_planes_eff = ksize_planes + (ksize_planes - 1) * (rate_planes - 1);
|
int32 ksize_planes_eff = ksize_planes +
|
||||||
|
(ksize_planes - 1) * (rate_planes - 1);
|
||||||
int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
|
int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
|
||||||
int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
|
int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
|
||||||
*/
|
*/
|
||||||
@ -2619,10 +2620,12 @@ REGISTER_OP("ExtractVolumePatches")
|
|||||||
DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
|
DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
|
||||||
DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
|
DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
|
||||||
DimensionHandle output_depth_dim;
|
DimensionHandle output_depth_dim;
|
||||||
TF_RETURN_IF_ERROR(c->Multiply(
|
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
|
||||||
c->Dim(input_shape, 4), ksize_planes * ksize_rows * ksize_cols, &output_depth_dim));
|
ksize_planes * ksize_rows * ksize_cols,
|
||||||
|
&output_depth_dim));
|
||||||
|
|
||||||
if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) {
|
if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
|
||||||
|
!c->ValueKnown(in_cols_dim)) {
|
||||||
ShapeHandle output_shape =
|
ShapeHandle output_shape =
|
||||||
c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
|
c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
|
||||||
InferenceContext::kUnknownDim, output_depth_dim});
|
InferenceContext::kUnknownDim, output_depth_dim});
|
||||||
@ -2647,8 +2650,9 @@ REGISTER_OP("ExtractVolumePatches")
|
|||||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
|
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
|
||||||
in_cols, ksize_cols, stride_cols, padding, &output_cols,
|
in_cols, ksize_cols, stride_cols, padding, &output_cols,
|
||||||
&padding_before, &padding_after));
|
&padding_before, &padding_after));
|
||||||
ShapeHandle output_shape = c->MakeShape(
|
ShapeHandle output_shape =
|
||||||
{batch_size_dim, output_planes, output_rows, output_cols, output_depth_dim});
|
c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
|
||||||
|
output_depth_dim});
|
||||||
c->set_output(0, output_shape);
|
c->set_output(0, output_shape);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
Loading…
x
Reference in New Issue
Block a user