Increase the supported number of dimensions for slice, strided_slice and pad to 8
PiperOrigin-RevId: 264930991
This commit is contained in:
parent
0f65838cb9
commit
5b28aa6891
@ -246,6 +246,7 @@ tensorflow/core/kernels/slice_op_cpu_impl_4.cc
|
||||
tensorflow/core/kernels/slice_op_cpu_impl_5.cc
|
||||
tensorflow/core/kernels/slice_op_cpu_impl_6.cc
|
||||
tensorflow/core/kernels/slice_op_cpu_impl_7.cc
|
||||
tensorflow/core/kernels/slice_op_cpu_impl_8.cc
|
||||
tensorflow/core/kernels/softmax_op.cc
|
||||
tensorflow/core/kernels/softplus_op.cc
|
||||
tensorflow/core/kernels/softsign_op.cc
|
||||
@ -274,6 +275,7 @@ tensorflow/core/kernels/strided_slice_op_inst_4.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_5.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_6.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_7.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_8.cc
|
||||
tensorflow/core/kernels/string_join_op.cc
|
||||
tensorflow/core/kernels/string_util.cc
|
||||
tensorflow/core/kernels/tensor_array.cc
|
||||
|
@ -138,6 +138,7 @@ tf_kernel_library(
|
||||
"strided_slice_op_inst_5.cc",
|
||||
"strided_slice_op_inst_6.cc",
|
||||
"strided_slice_op_inst_7.cc",
|
||||
"strided_slice_op_inst_8.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"slice_op.h",
|
||||
@ -6122,6 +6123,7 @@ filegroup(
|
||||
"slice_op_cpu_impl_5.cc",
|
||||
"slice_op_cpu_impl_6.cc",
|
||||
"slice_op_cpu_impl_7.cc",
|
||||
"slice_op_cpu_impl_8.cc",
|
||||
"softmax_op.cc",
|
||||
"softmax_op_functor.h",
|
||||
"split_lib.h",
|
||||
@ -6139,6 +6141,7 @@ filegroup(
|
||||
"strided_slice_op_inst_5.cc",
|
||||
"strided_slice_op_inst_6.cc",
|
||||
"strided_slice_op_inst_7.cc",
|
||||
"strided_slice_op_inst_8.cc",
|
||||
"unpack_op.cc",
|
||||
"variable_ops.cc",
|
||||
"variable_ops.h",
|
||||
|
@ -52,7 +52,7 @@ class PadOp : public OpKernel {
|
||||
const Tensor& in1 = context->input(1);
|
||||
const int dims = in0.dims();
|
||||
static const int kMinDims = 0;
|
||||
static const int kMaxDims = 6;
|
||||
static const int kMaxDims = 8;
|
||||
OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
|
||||
errors::Unimplemented("inputs rank not in [", kMinDims, ",",
|
||||
kMaxDims, "]: ", dims));
|
||||
|
@ -34,7 +34,9 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct functor::Pad<GPUDevice, T, Tpadding, 3>; \
|
||||
template struct functor::Pad<GPUDevice, T, Tpadding, 4>; \
|
||||
template struct functor::Pad<GPUDevice, T, Tpadding, 5>; \
|
||||
template struct functor::Pad<GPUDevice, T, Tpadding, 6>;
|
||||
template struct functor::Pad<GPUDevice, T, Tpadding, 6>; \
|
||||
template struct functor::Pad<GPUDevice, T, Tpadding, 7>; \
|
||||
template struct functor::Pad<GPUDevice, T, Tpadding, 8>;
|
||||
|
||||
#define DEFINE_GPU_SPECS(T) \
|
||||
DEFINE_GPU_PAD_SPECS(T, int32) \
|
||||
|
@ -202,6 +202,7 @@ class SliceOp : public OpKernel {
|
||||
HANDLE_DIM(5);
|
||||
HANDLE_DIM(6);
|
||||
HANDLE_DIM(7);
|
||||
HANDLE_DIM(8);
|
||||
|
||||
#undef HANDLE_DIM
|
||||
|
||||
@ -247,7 +248,8 @@ namespace functor {
|
||||
DECLARE_CPU_SPEC(T, 4); \
|
||||
DECLARE_CPU_SPEC(T, 5); \
|
||||
DECLARE_CPU_SPEC(T, 6); \
|
||||
DECLARE_CPU_SPEC(T, 7);
|
||||
DECLARE_CPU_SPEC(T, 7); \
|
||||
DECLARE_CPU_SPEC(T, 8);
|
||||
|
||||
TF_CALL_ALL_TYPES(DECLARE_FOR_N);
|
||||
|
||||
@ -286,7 +288,8 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(T, 4); \
|
||||
DECLARE_GPU_SPEC(T, 5); \
|
||||
DECLARE_GPU_SPEC(T, 6); \
|
||||
DECLARE_GPU_SPEC(T, 7);
|
||||
DECLARE_GPU_SPEC(T, 7); \
|
||||
DECLARE_GPU_SPEC(T, 8);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N);
|
||||
TF_CALL_complex64(DECLARE_FOR_N);
|
||||
@ -352,7 +355,8 @@ namespace functor {
|
||||
DECLARE_SYCL_SPEC(T, 4); \
|
||||
DECLARE_SYCL_SPEC(T, 5); \
|
||||
DECLARE_SYCL_SPEC(T, 6); \
|
||||
DECLARE_SYCL_SPEC(T, 7);
|
||||
DECLARE_SYCL_SPEC(T, 7); \
|
||||
DECLARE_SYCL_SPEC(T, 8);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N);
|
||||
DECLARE_FOR_N(int32);
|
||||
|
18
tensorflow/core/kernels/slice_op_cpu_impl_8.cc
Normal file
18
tensorflow/core/kernels/slice_op_cpu_impl_8.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 8
|
||||
#include "tensorflow/core/kernels/slice_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
@ -34,7 +34,8 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct functor::Slice<GPUDevice, T, 4>; \
|
||||
template struct functor::Slice<GPUDevice, T, 5>; \
|
||||
template struct functor::Slice<GPUDevice, T, 6>; \
|
||||
template struct functor::Slice<GPUDevice, T, 7>;
|
||||
template struct functor::Slice<GPUDevice, T, 7>; \
|
||||
template struct functor::Slice<GPUDevice, T, 8>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex64(DEFINE_GPU_KERNELS);
|
||||
|
@ -170,6 +170,7 @@ class StridedSliceOp : public OpKernel {
|
||||
HANDLE_DIM(5);
|
||||
HANDLE_DIM(6);
|
||||
HANDLE_DIM(7);
|
||||
HANDLE_DIM(8);
|
||||
|
||||
#undef HANDLE_DIM
|
||||
|
||||
@ -268,6 +269,7 @@ class StridedSliceGradOp : public OpKernel {
|
||||
HANDLE_DIM(5);
|
||||
HANDLE_DIM(6);
|
||||
HANDLE_DIM(7);
|
||||
HANDLE_DIM(8);
|
||||
|
||||
#undef HANDLE_DIM
|
||||
}
|
||||
@ -384,6 +386,7 @@ class StridedSliceAssignOp : public OpKernel {
|
||||
HANDLE_DIM(5);
|
||||
HANDLE_DIM(6);
|
||||
HANDLE_DIM(7);
|
||||
HANDLE_DIM(8);
|
||||
#undef HANDLE_DIM
|
||||
|
||||
OP_REQUIRES(context, false,
|
||||
|
@ -38,6 +38,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct functor::StridedSlice<GPUDevice, T, 5>; \
|
||||
template struct functor::StridedSlice<GPUDevice, T, 6>; \
|
||||
template struct functor::StridedSlice<GPUDevice, T, 7>; \
|
||||
template struct functor::StridedSlice<GPUDevice, T, 8>; \
|
||||
template struct functor::StridedSliceGrad<GPUDevice, T, 1>; \
|
||||
template struct functor::StridedSliceGrad<GPUDevice, T, 2>; \
|
||||
template struct functor::StridedSliceGrad<GPUDevice, T, 3>; \
|
||||
@ -45,6 +46,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct functor::StridedSliceGrad<GPUDevice, T, 5>; \
|
||||
template struct functor::StridedSliceGrad<GPUDevice, T, 6>; \
|
||||
template struct functor::StridedSliceGrad<GPUDevice, T, 7>; \
|
||||
template struct functor::StridedSliceGrad<GPUDevice, T, 8>; \
|
||||
template struct functor::StridedSliceAssign<GPUDevice, T, 1>; \
|
||||
template struct functor::StridedSliceAssign<GPUDevice, T, 2>; \
|
||||
template struct functor::StridedSliceAssign<GPUDevice, T, 3>; \
|
||||
@ -52,6 +54,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct functor::StridedSliceAssign<GPUDevice, T, 5>; \
|
||||
template struct functor::StridedSliceAssign<GPUDevice, T, 6>; \
|
||||
template struct functor::StridedSliceAssign<GPUDevice, T, 7>; \
|
||||
template struct functor::StridedSliceAssign<GPUDevice, T, 8>; \
|
||||
template struct functor::StridedSliceAssignScalar<GPUDevice, T>;
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
23
tensorflow/core/kernels/strided_slice_op_inst_8.cc
Normal file
23
tensorflow/core/kernels/strided_slice_op_inst_8.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
#define STRIDED_SLICE_INSTANTIATE_DIM 8
|
||||
#include "tensorflow/core/kernels/strided_slice_op_impl.h"
|
||||
#undef STRIDED_SLICE_INSTANTIATE_DIM
|
Loading…
Reference in New Issue
Block a user