Fix ODR violation by splitting a duplicate definition of a function template into separate overloads.

PiperOrigin-RevId: 239293552
This commit is contained in:
A. Unique TensorFlower 2019-03-19 16:26:42 -07:00 committed by TensorFlower Gardener
parent 14fbe25f3e
commit 6ec89893ce
3 changed files with 36 additions and 10 deletions

View File

@ -26,9 +26,21 @@ namespace tensorflow {
namespace internal { namespace internal {
// Device-specific naive implementation for tile. // Device-specific naive implementation for Tile.
template <typename Device, typename T>
void TileSimple(const Device& d, Tensor* out, const Tensor& in); template <typename T>
void TileSimple(const Eigen::ThreadPoolDevice& d, Tensor* out,
const Tensor& in);
#if GOOGLE_CUDA
template <typename T>
void TileSimple(const Eigen::GpuDevice& d, Tensor* out, const Tensor& in);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
void TileSimple(const Eigen::SyclDevice& d, Tensor* out, const Tensor& in);
#endif
template <typename Device, typename T, typename Tmultiples, int NDIM> template <typename Device, typename T, typename Tmultiples, int NDIM>
void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in, void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
@ -99,7 +111,7 @@ struct Tile {
broadcast_array); broadcast_array);
break; break;
default: default:
internal::TileSimple<Device, T>(d, out, in); internal::TileSimple<T>(d, out, in);
break; break;
} }
} }

View File

@ -21,11 +21,11 @@ limitations under the License.
#include "tensorflow/core/kernels/tile_functor.h" #include "tensorflow/core/kernels/tile_functor.h"
namespace tensorflow { namespace tensorflow {
namespace internal { namespace internal {
namespace {
template <typename Device, typename T> template <typename Device, typename T>
void TileSimple(const Device& d, Tensor* out, const Tensor& in) { void TileSimpleImpl(const Device& d, Tensor* out, const Tensor& in) {
const int ndims = in.dims(); const int ndims = in.dims();
const int64 nelem = out->NumElements(); const int64 nelem = out->NumElements();
gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape()); gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
@ -44,7 +44,21 @@ void TileSimple(const Device& d, Tensor* out, const Tensor& in) {
} }
} }
} // end namespace internal } // namespace
template <typename T>
void TileSimple(const Eigen::ThreadPoolDevice& d, Tensor* out,
const Tensor& in) {
return TileSimpleImpl<Eigen::ThreadPoolDevice, T>(d, out, in);
}
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
void TileSimple(const Eigen::SyclDevice& d, Tensor* out, const Tensor& in) {
return TileSimpleImpl<Eigen::SyclDevice, T>(d, out, in);
}
#endif
} // namespace internal
namespace functor { namespace functor {

View File

@ -21,7 +21,6 @@ limitations under the License.
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/tile_functor.h" #include "tensorflow/core/kernels/tile_functor.h"
@ -47,8 +46,8 @@ __global__ void TileKernel(int nthreads, const T* src, const int32* buf,
} }
} }
template <typename Device, typename T> template <typename T>
void TileSimple(const Device& d, Tensor* out, const Tensor& in) { void TileSimple(const Eigen::GpuDevice& d, Tensor* out, const Tensor& in) {
// Ensures we can use 32-bit index. // Ensures we can use 32-bit index.
const int64 in_nelem = in.NumElements(); const int64 in_nelem = in.NumElements();
CHECK_LT(in_nelem, kint32max) << "Tensor too large to transpose on GPU"; CHECK_LT(in_nelem, kint32max) << "Tensor too large to transpose on GPU";
@ -85,6 +84,7 @@ void TileSimple(const Device& d, Tensor* out, const Tensor& in) {
} // end namespace internal } // end namespace internal
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
#endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_GPU_H_ #endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_GPU_H_