Fix Windows GPU build failure in resize_blinear_op.cc.
I broke it in 67d15573a7
.
Before, I called a function under a `std::is_same<Device, CPUDevice>` condition which cannot be linked if Device is a GPUDevice. I would expect the function not to be generated if Device is a GPUDevice due to dead code elimination, but apparently it still is on Windows.
PiperOrigin-RevId: 327887418
Change-Id: Ib97e1abf1680c75dc072850cc69c761e10ac3e1e
This commit is contained in:
parent
d4e7fede79
commit
dcc2e62c8e
@ -286,6 +286,25 @@ void resize_image(typename TTypes<T, 4>::ConstTensor images,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
struct CastFloatToHalf {
|
||||
void operator()(const Device& d, typename TTypes<float>::ConstFlat input,
|
||||
typename TTypes<Eigen::half>::Flat output) {
|
||||
output.device(d) = input.template cast<Eigen::half>();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CastFloatToHalf<GPUDevice> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<float>::ConstFlat input,
|
||||
typename TTypes<Eigen::half>::Flat output) {
|
||||
// Use existing cast functor instead of directly casting Eigen tensor, as
|
||||
// otherwise we need to instantiate the cast function in a .cu.cc file
|
||||
functor::CastFunctor<GPUDevice, Eigen::half, float> cast;
|
||||
cast(d, output, input);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Partial specialization of ResizeBilinear functor for a CPUDevice.
|
||||
@ -378,19 +397,10 @@ class ResizeBilinearOpGrad : public OpKernel {
|
||||
functor::ResizeBilinearGrad<Device, float>()(
|
||||
context->eigen_device<Device>(), input_grad, st.height_scale,
|
||||
st.width_scale, half_pixel_centers_, output_grad.tensor<float, 4>());
|
||||
if (std::is_same<Device, CPUDevice>::value) {
|
||||
const Device& d = context->template eigen_device<Device>();
|
||||
st.output->template flat<Eigen::half>().device(d) =
|
||||
output_grad.template flat<float>().template cast<Eigen::half>();
|
||||
} else {
|
||||
// Use cast functor instead of directly casting Eigen tensor, as
|
||||
// otherwise we need to instantiate the cast function in a .cu.cc file
|
||||
const Tensor& output_grad_const = output_grad;
|
||||
functor::CastFunctor<Device, Eigen::half, float> cast;
|
||||
const Device& device = context->template eigen_device<Device>();
|
||||
cast(device, st.output->template flat<Eigen::half>(),
|
||||
output_grad_const.template flat<float>());
|
||||
}
|
||||
const Tensor& output_grad_const = output_grad;
|
||||
CastFloatToHalf<Device>{}(context->template eigen_device<Device>(),
|
||||
output_grad_const.template flat<float>(),
|
||||
st.output->template flat<Eigen::half>());
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user