Fix ODR violation by splitting a duplicate definition of a function template into separate overloads.
PiperOrigin-RevId: 239293552
This commit is contained in:
parent
14fbe25f3e
commit
6ec89893ce
@ -26,9 +26,21 @@ namespace tensorflow {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Device-specific naive implementation for tile.
|
||||
template <typename Device, typename T>
|
||||
void TileSimple(const Device& d, Tensor* out, const Tensor& in);
|
||||
// Device-specific naive implementation for Tile.
|
||||
|
||||
template <typename T>
|
||||
void TileSimple(const Eigen::ThreadPoolDevice& d, Tensor* out,
|
||||
const Tensor& in);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
template <typename T>
|
||||
void TileSimple(const Eigen::GpuDevice& d, Tensor* out, const Tensor& in);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename T>
|
||||
void TileSimple(const Eigen::SyclDevice& d, Tensor* out, const Tensor& in);
|
||||
#endif
|
||||
|
||||
template <typename Device, typename T, typename Tmultiples, int NDIM>
|
||||
void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
|
||||
@ -99,7 +111,7 @@ struct Tile {
|
||||
broadcast_array);
|
||||
break;
|
||||
default:
|
||||
internal::TileSimple<Device, T>(d, out, in);
|
||||
internal::TileSimple<T>(d, out, in);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -21,11 +21,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/tile_functor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
template <typename Device, typename T>
|
||||
void TileSimple(const Device& d, Tensor* out, const Tensor& in) {
|
||||
void TileSimpleImpl(const Device& d, Tensor* out, const Tensor& in) {
|
||||
const int ndims = in.dims();
|
||||
const int64 nelem = out->NumElements();
|
||||
gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
|
||||
@ -44,7 +44,21 @@ void TileSimple(const Device& d, Tensor* out, const Tensor& in) {
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void TileSimple(const Eigen::ThreadPoolDevice& d, Tensor* out,
|
||||
const Tensor& in) {
|
||||
return TileSimpleImpl<Eigen::ThreadPoolDevice, T>(d, out, in);
|
||||
}
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename T>
|
||||
void TileSimple(const Eigen::SyclDevice& d, Tensor* out, const Tensor& in) {
|
||||
return TileSimpleImpl<Eigen::SyclDevice, T>(d, out, in);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace internal
|
||||
|
||||
namespace functor {
|
||||
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/kernels/tile_functor.h"
|
||||
@ -47,8 +46,8 @@ __global__ void TileKernel(int nthreads, const T* src, const int32* buf,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
void TileSimple(const Device& d, Tensor* out, const Tensor& in) {
|
||||
template <typename T>
|
||||
void TileSimple(const Eigen::GpuDevice& d, Tensor* out, const Tensor& in) {
|
||||
// Ensures we can use 32-bit index.
|
||||
const int64 in_nelem = in.NumElements();
|
||||
CHECK_LT(in_nelem, kint32max) << "Tensor too large to transpose on GPU";
|
||||
@ -85,6 +84,7 @@ void TileSimple(const Device& d, Tensor* out, const Tensor& in) {
|
||||
|
||||
} // end namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_GPU_H_
|
||||
|
Loading…
Reference in New Issue
Block a user