Split up conv_ops_fused kernels.
This improves build times by allowing the double, float, and half implementations to build in parallel. PiperOrigin-RevId: 235576953
This commit is contained in:
parent
c715e350ca
commit
1c6f10152f
@ -43,7 +43,9 @@ tensorflow/core/kernels/conv_grad_input_ops.cc
|
|||||||
tensorflow/core/kernels/conv_grad_ops.cc
|
tensorflow/core/kernels/conv_grad_ops.cc
|
||||||
tensorflow/core/kernels/conv_ops.cc
|
tensorflow/core/kernels/conv_ops.cc
|
||||||
tensorflow/core/kernels/conv_ops_3d.cc
|
tensorflow/core/kernels/conv_ops_3d.cc
|
||||||
tensorflow/core/kernels/conv_ops_fused.cc
|
tensorflow/core/kernels/conv_ops_fused_double.cc
|
||||||
|
tensorflow/core/kernels/conv_ops_fused_float.cc
|
||||||
|
tensorflow/core/kernels/conv_ops_fused_half.cc
|
||||||
tensorflow/core/kernels/conv_ops_using_gemm.cc
|
tensorflow/core/kernels/conv_ops_using_gemm.cc
|
||||||
tensorflow/core/kernels/crop_and_resize_op.cc
|
tensorflow/core/kernels/crop_and_resize_op.cc
|
||||||
tensorflow/core/kernels/ctc_decoder_ops.cc
|
tensorflow/core/kernels/ctc_decoder_ops.cc
|
||||||
|
@ -5624,7 +5624,10 @@ filegroup(
|
|||||||
"conv_grad_ops.h",
|
"conv_grad_ops.h",
|
||||||
"conv_ops.cc",
|
"conv_ops.cc",
|
||||||
"conv_ops_3d.cc",
|
"conv_ops_3d.cc",
|
||||||
"conv_ops_fused.cc",
|
"conv_ops_fused_double.cc",
|
||||||
|
"conv_ops_fused_float.cc",
|
||||||
|
"conv_ops_fused_half.cc",
|
||||||
|
"conv_ops_fused_impl.h",
|
||||||
"conv_ops_using_gemm.cc",
|
"conv_ops_using_gemm.cc",
|
||||||
"crop_and_resize_op.cc",
|
"crop_and_resize_op.cc",
|
||||||
"crop_and_resize_op.h",
|
"crop_and_resize_op.h",
|
||||||
|
39
tensorflow/core/kernels/conv_ops_fused_double.cc
Normal file
39
tensorflow/core/kernels/conv_ops_fused_double.cc
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// If we're using the alternative GEMM-based implementation of Conv2D for the
|
||||||
|
// CPU implementation, don't register this EigenTensor-based version.
|
||||||
|
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
|
||||||
|
// contractions with non-default contraction output kernels.
|
||||||
|
#if !defined(USE_GEMM_FOR_CONV) && !defined(EIGEN_USE_LIBXSMM)
|
||||||
|
TF_CALL_double(REGISTER_FUSED_CPU_CONV2D);
|
||||||
|
#endif // !USE_GEMM_FOR_CONV
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
DECLARE_FUNCTOR_GPU_SPEC(double);
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
TF_CALL_double(REGISTER_FUSED_GPU_CONV2D);
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
39
tensorflow/core/kernels/conv_ops_fused_float.cc
Normal file
39
tensorflow/core/kernels/conv_ops_fused_float.cc
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// If we're using the alternative GEMM-based implementation of Conv2D for the
|
||||||
|
// CPU implementation, don't register this EigenTensor-based version.
|
||||||
|
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
|
||||||
|
// contractions with non-default contraction output kernels.
|
||||||
|
#if !defined(USE_GEMM_FOR_CONV) && !defined(EIGEN_USE_LIBXSMM)
|
||||||
|
TF_CALL_float(REGISTER_FUSED_CPU_CONV2D);
|
||||||
|
#endif // !USE_GEMM_FOR_CONV
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
DECLARE_FUNCTOR_GPU_SPEC(float);
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
TF_CALL_float(REGISTER_FUSED_GPU_CONV2D);
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
29
tensorflow/core/kernels/conv_ops_fused_half.cc
Normal file
29
tensorflow/core/kernels/conv_ops_fused_half.cc
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
DECLARE_FUNCTOR_GPU_SPEC(Eigen::half);
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -28,6 +28,9 @@ limitations under the License.
|
|||||||
//
|
//
|
||||||
// NOTE: GPU only supports fusion of Conv2D + BiasAdd + <optional Relu>.
|
// NOTE: GPU only supports fusion of Conv2D + BiasAdd + <optional Relu>.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
|
||||||
|
|
||||||
#define USE_EIGEN_TENSOR
|
#define USE_EIGEN_TENSOR
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
@ -63,7 +66,6 @@ namespace tensorflow {
|
|||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
namespace {
|
|
||||||
// Supported Conv2D fusions. Not all of them supported on all type of devices.
|
// Supported Conv2D fusions. Not all of them supported on all type of devices.
|
||||||
enum class FusedComputationType {
|
enum class FusedComputationType {
|
||||||
// NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports
|
// NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports
|
||||||
@ -463,12 +465,12 @@ class FusedConvParameters : public ConvParameters {
|
|||||||
se::dnn::ActivationMode activation_mode_;
|
se::dnn::ActivationMode activation_mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool operator==(const FusedConvParameters& lhs,
|
inline bool operator==(const FusedConvParameters& lhs,
|
||||||
const FusedConvParameters& rhs) {
|
const FusedConvParameters& rhs) {
|
||||||
return lhs.get_data_as_tuple() == rhs.get_data_as_tuple();
|
return lhs.get_data_as_tuple() == rhs.get_data_as_tuple();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator!=(const FusedConvParameters& lhs,
|
inline bool operator!=(const FusedConvParameters& lhs,
|
||||||
const FusedConvParameters& rhs) {
|
const FusedConvParameters& rhs) {
|
||||||
return !(lhs == rhs);
|
return !(lhs == rhs);
|
||||||
}
|
}
|
||||||
@ -482,7 +484,7 @@ using AutoTuneFusedConv =
|
|||||||
AutoTuneSingleton<FusedConvAutoTuneGroup, FusedConvParameters,
|
AutoTuneSingleton<FusedConvAutoTuneGroup, FusedConvParameters,
|
||||||
se::dnn::AlgorithmConfig>;
|
se::dnn::AlgorithmConfig>;
|
||||||
|
|
||||||
int64 ConvolveScratchSize() {
|
inline int64 ConvolveScratchSize() {
|
||||||
static int64 convolve_scratch_size = GetDnnWorkspaceLimit(
|
static int64 convolve_scratch_size = GetDnnWorkspaceLimit(
|
||||||
// default value is in bytes despite the name of the environment variable
|
// default value is in bytes despite the name of the environment variable
|
||||||
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
|
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
|
||||||
@ -822,8 +824,6 @@ struct LaunchFusedConv2DOp<GPUDevice, T> {
|
|||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class FusedConv2DOp : public OpKernel {
|
class FusedConv2DOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
@ -962,22 +962,9 @@ class FusedConv2DOp : public OpKernel {
|
|||||||
Name("_FusedConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("_FusedConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
FusedConv2DOp<CPUDevice, T>);
|
FusedConv2DOp<CPUDevice, T>);
|
||||||
|
|
||||||
// If we're using the alternative GEMM-based implementation of Conv2D for the
|
|
||||||
// CPU implementation, don't register this EigenTensor-based version.
|
|
||||||
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
|
|
||||||
// contractions with non-default contraction output kernels.
|
|
||||||
#if !defined(USE_GEMM_FOR_CONV) && !defined(EIGEN_USE_LIBXSMM)
|
|
||||||
TF_CALL_float(REGISTER_FUSED_CPU_CONV2D);
|
|
||||||
TF_CALL_double(REGISTER_FUSED_CPU_CONV2D);
|
|
||||||
#endif // !USE_GEMM_FOR_CONV
|
|
||||||
|
|
||||||
#undef REGISTER_FUSED_CPU_CONV2D
|
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
// Forward declarations of the functor specializations for GPU.
|
#define DECLARE_FUNCTOR_GPU_SPEC(T) \
|
||||||
namespace functor {
|
|
||||||
#define DECLARE_GPU_SPEC(T) \
|
|
||||||
template <> \
|
template <> \
|
||||||
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
|
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
|
||||||
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
|
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
|
||||||
@ -992,23 +979,14 @@ namespace functor {
|
|||||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
|
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
|
||||||
extern template struct PadInput<GPUDevice, T, int, 4>
|
extern template struct PadInput<GPUDevice, T, int, 4>
|
||||||
|
|
||||||
DECLARE_GPU_SPEC(float);
|
|
||||||
DECLARE_GPU_SPEC(Eigen::half);
|
|
||||||
DECLARE_GPU_SPEC(double);
|
|
||||||
#undef DECLARE_GPU_SPEC
|
|
||||||
} // namespace functor
|
|
||||||
|
|
||||||
// Registration of the GPU implementations.
|
// Registration of the GPU implementations.
|
||||||
#define REGISTER_FUSED_GPU_CONV2D(T) \
|
#define REGISTER_FUSED_GPU_CONV2D(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("_FusedConv2D").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
Name("_FusedConv2D").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||||
FusedConv2DOp<GPUDevice, T>);
|
FusedConv2DOp<GPUDevice, T>);
|
||||||
|
|
||||||
TF_CALL_float(REGISTER_FUSED_GPU_CONV2D);
|
|
||||||
TF_CALL_double(REGISTER_FUSED_GPU_CONV2D);
|
|
||||||
|
|
||||||
#undef REGISTER_FUSED_GPU_CONV2D
|
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
|
Loading…
x
Reference in New Issue
Block a user