Fix GPU build failure
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
e75409c2fe
commit
33cd7b88ac
@ -443,6 +443,38 @@ struct BatchSelectFunctor<CPUDevice, T> {
|
||||
d.parallelFor(batch, cost, work);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T, int NDIMS>
|
||||
struct BCastSelectFunctorBase {
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<T, NDIMS>::Tensor output_tensor,
|
||||
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
|
||||
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
|
||||
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
|
||||
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
|
||||
.select(then_tensor.broadcast(then_bcast),
|
||||
else_tensor.broadcast(else_bcast));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int NDIMS>
|
||||
struct BCastSelectFunctor<CPUDevice, T, NDIMS>
|
||||
: BCastSelectFunctorBase<CPUDevice, T, NDIMS> {};
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename T>
|
||||
struct BatchSelectFunctor<SYCLDevice, T>
|
||||
: BatchSelectFunctorBase<SYCLDevice, T> {};
|
||||
|
||||
template <typename T, int NDIMS>
|
||||
struct BCastSelectFunctor<SYCLDevice, T, NDIMS>
|
||||
: BCastSelectFunctorBase<SYCLDevice, T, NDIMS> {};
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1217,11 +1217,7 @@ struct BCastSelectFunctor {
|
||||
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
|
||||
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
|
||||
.select(then_tensor.broadcast(then_bcast),
|
||||
else_tensor.broadcast(else_bcast));
|
||||
}
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
Loading…
Reference in New Issue
Block a user