Split up conv_ops_fused kernels.
This improves build times by allowing the double, float, and half implementations to build in parallel. PiperOrigin-RevId: 235576953
This commit is contained in:
parent
c715e350ca
commit
1c6f10152f
@ -43,7 +43,9 @@ tensorflow/core/kernels/conv_grad_input_ops.cc
|
||||
tensorflow/core/kernels/conv_grad_ops.cc
|
||||
tensorflow/core/kernels/conv_ops.cc
|
||||
tensorflow/core/kernels/conv_ops_3d.cc
|
||||
tensorflow/core/kernels/conv_ops_fused.cc
|
||||
tensorflow/core/kernels/conv_ops_fused_double.cc
|
||||
tensorflow/core/kernels/conv_ops_fused_float.cc
|
||||
tensorflow/core/kernels/conv_ops_fused_half.cc
|
||||
tensorflow/core/kernels/conv_ops_using_gemm.cc
|
||||
tensorflow/core/kernels/crop_and_resize_op.cc
|
||||
tensorflow/core/kernels/ctc_decoder_ops.cc
|
||||
|
@ -5624,7 +5624,10 @@ filegroup(
|
||||
"conv_grad_ops.h",
|
||||
"conv_ops.cc",
|
||||
"conv_ops_3d.cc",
|
||||
"conv_ops_fused.cc",
|
||||
"conv_ops_fused_double.cc",
|
||||
"conv_ops_fused_float.cc",
|
||||
"conv_ops_fused_half.cc",
|
||||
"conv_ops_fused_impl.h",
|
||||
"conv_ops_using_gemm.cc",
|
||||
"crop_and_resize_op.cc",
|
||||
"crop_and_resize_op.h",
|
||||
|
39
tensorflow/core/kernels/conv_ops_fused_double.cc
Normal file
39
tensorflow/core/kernels/conv_ops_fused_double.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// If we're using the alternative GEMM-based implementation of Conv2D for the
|
||||
// CPU implementation, don't register this EigenTensor-based version.
|
||||
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
|
||||
// contractions with non-default contraction output kernels.
|
||||
#if !defined(USE_GEMM_FOR_CONV) && !defined(EIGEN_USE_LIBXSMM)
|
||||
TF_CALL_double(REGISTER_FUSED_CPU_CONV2D);
|
||||
#endif // !USE_GEMM_FOR_CONV
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
namespace functor {
|
||||
DECLARE_FUNCTOR_GPU_SPEC(double);
|
||||
} // namespace functor
|
||||
|
||||
TF_CALL_double(REGISTER_FUSED_GPU_CONV2D);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/conv_ops_fused_float.cc
Normal file
39
tensorflow/core/kernels/conv_ops_fused_float.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// If we're using the alternative GEMM-based implementation of Conv2D for the
|
||||
// CPU implementation, don't register this EigenTensor-based version.
|
||||
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
|
||||
// contractions with non-default contraction output kernels.
|
||||
#if !defined(USE_GEMM_FOR_CONV) && !defined(EIGEN_USE_LIBXSMM)
|
||||
TF_CALL_float(REGISTER_FUSED_CPU_CONV2D);
|
||||
#endif // !USE_GEMM_FOR_CONV
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
namespace functor {
|
||||
DECLARE_FUNCTOR_GPU_SPEC(float);
|
||||
} // namespace functor
|
||||
|
||||
TF_CALL_float(REGISTER_FUSED_GPU_CONV2D);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
29
tensorflow/core/kernels/conv_ops_fused_half.cc
Normal file
29
tensorflow/core/kernels/conv_ops_fused_half.cc
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
namespace functor {
|
||||
DECLARE_FUNCTOR_GPU_SPEC(Eigen::half);
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
@ -28,6 +28,9 @@ limitations under the License.
|
||||
//
|
||||
// NOTE: GPU only supports fusion of Conv2D + BiasAdd + <optional Relu>.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
|
||||
|
||||
#define USE_EIGEN_TENSOR
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
@ -63,7 +66,6 @@ namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace {
|
||||
// Supported Conv2D fusions. Not all of them supported on all type of devices.
|
||||
enum class FusedComputationType {
|
||||
// NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports
|
||||
@ -463,12 +465,12 @@ class FusedConvParameters : public ConvParameters {
|
||||
se::dnn::ActivationMode activation_mode_;
|
||||
};
|
||||
|
||||
bool operator==(const FusedConvParameters& lhs,
|
||||
inline bool operator==(const FusedConvParameters& lhs,
|
||||
const FusedConvParameters& rhs) {
|
||||
return lhs.get_data_as_tuple() == rhs.get_data_as_tuple();
|
||||
}
|
||||
|
||||
bool operator!=(const FusedConvParameters& lhs,
|
||||
inline bool operator!=(const FusedConvParameters& lhs,
|
||||
const FusedConvParameters& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
@ -482,7 +484,7 @@ using AutoTuneFusedConv =
|
||||
AutoTuneSingleton<FusedConvAutoTuneGroup, FusedConvParameters,
|
||||
se::dnn::AlgorithmConfig>;
|
||||
|
||||
int64 ConvolveScratchSize() {
|
||||
inline int64 ConvolveScratchSize() {
|
||||
static int64 convolve_scratch_size = GetDnnWorkspaceLimit(
|
||||
// default value is in bytes despite the name of the environment variable
|
||||
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
|
||||
@ -822,8 +824,6 @@ struct LaunchFusedConv2DOp<GPUDevice, T> {
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Device, typename T>
|
||||
class FusedConv2DOp : public OpKernel {
|
||||
public:
|
||||
@ -962,22 +962,9 @@ class FusedConv2DOp : public OpKernel {
|
||||
Name("_FusedConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
FusedConv2DOp<CPUDevice, T>);
|
||||
|
||||
// If we're using the alternative GEMM-based implementation of Conv2D for the
|
||||
// CPU implementation, don't register this EigenTensor-based version.
|
||||
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
|
||||
// contractions with non-default contraction output kernels.
|
||||
#if !defined(USE_GEMM_FOR_CONV) && !defined(EIGEN_USE_LIBXSMM)
|
||||
TF_CALL_float(REGISTER_FUSED_CPU_CONV2D);
|
||||
TF_CALL_double(REGISTER_FUSED_CPU_CONV2D);
|
||||
#endif // !USE_GEMM_FOR_CONV
|
||||
|
||||
#undef REGISTER_FUSED_CPU_CONV2D
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
#define DECLARE_FUNCTOR_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
|
||||
@ -992,23 +979,14 @@ namespace functor {
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
|
||||
extern template struct PadInput<GPUDevice, T, int, 4>
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
// Registration of the GPU implementations.
|
||||
#define REGISTER_FUSED_GPU_CONV2D(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_FusedConv2D").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
FusedConv2DOp<GPUDevice, T>);
|
||||
|
||||
TF_CALL_float(REGISTER_FUSED_GPU_CONV2D);
|
||||
TF_CALL_double(REGISTER_FUSED_GPU_CONV2D);
|
||||
|
||||
#undef REGISTER_FUSED_GPU_CONV2D
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
|
Loading…
x
Reference in New Issue
Block a user