Add bfloat16 support to more image grad ops.

Bfloat16 support for CPU/GPU was added to ResizeNearestNeighborGrad. Bfloat16 support for CPU was added to ResizeBilinearGrad.

PiperOrigin-RevId: 350264170
Change-Id: If014d60a33c644addc955fc29b3a43589ea12903
This commit is contained in:
Reed Wanderman-Milne 2021-01-05 18:59:18 -08:00 committed by TensorFlower Gardener
parent fe70900e00
commit 2d94d904df
4 changed files with 35 additions and 27 deletions

View File

@ -10035,7 +10035,7 @@ def TF_ResizeNearestNeighborGradOp : TF_Op<"ResizeNearestNeighborGrad", [NoSideE
let summary = "Computes the gradient of nearest neighbor interpolation.";
let arguments = (ins
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$grads,
TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$grads,
TF_Int32Tensor:$size,
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
@ -10043,7 +10043,7 @@ def TF_ResizeNearestNeighborGradOp : TF_Op<"ResizeNearestNeighborGrad", [NoSideE
);
let results = (outs
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$output
TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;

View File

@ -286,21 +286,22 @@ void resize_image(typename TTypes<T, 4>::ConstTensor images,
}
}
template <typename Device>
struct CastFloatToHalf {
// Casts from float16 to T.
template <typename Device, typename T>
struct CastFloatTo {
void operator()(const Device& d, typename TTypes<float>::ConstFlat input,
typename TTypes<Eigen::half>::Flat output) {
output.device(d) = input.template cast<Eigen::half>();
typename TTypes<T>::Flat output) {
output.device(d) = input.template cast<T>();
}
};
template <>
struct CastFloatToHalf<GPUDevice> {
template <typename T>
struct CastFloatTo<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<float>::ConstFlat input,
typename TTypes<Eigen::half>::Flat output) {
typename TTypes<T>::Flat output) {
// Use existing cast functor instead of directly casting Eigen tensor, as
// otherwise we need to instantiate the cast function in a .cu.cc file
functor::CastFunctor<GPUDevice, Eigen::half, float> cast;
functor::CastFunctor<GPUDevice, T, float> cast;
cast(d, output, input);
}
};
@ -380,17 +381,18 @@ class ResizeBilinearOpGrad : public OpKernel {
TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
if (!std::is_same<T, Eigen::half>::value) {
if (!std::is_same<T, Eigen::half>::value &&
!std::is_same<T, Eigen::bfloat16>::value) {
typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
functor::ResizeBilinearGrad<Device, T>()(
context->eigen_device<Device>(), input_grad, st.height_scale,
st.width_scale, half_pixel_centers_, output_grad);
} else {
// Accumulate output to float instead of half tensor, since float
// Accumulate output to float instead of half/bfloat16 tensor, since float
// accumulation is more numerically stable and GPU half implementation is
// slow.
// TODO(b/165759037): Create optimized and numerically stable half
// implementation
// TODO(b/165759037): Create optimized and numerically stable half and
// bfloat16 implementation
Tensor output_grad;
OP_REQUIRES_OK(context, context->allocate_temp(
DT_FLOAT, st.output->shape(), &output_grad));
@ -398,9 +400,9 @@ class ResizeBilinearOpGrad : public OpKernel {
context->eigen_device<Device>(), input_grad, st.height_scale,
st.width_scale, half_pixel_centers_, output_grad.tensor<float, 4>());
const Tensor& output_grad_const = output_grad;
CastFloatToHalf<Device>{}(context->template eigen_device<Device>(),
output_grad_const.template flat<float>(),
st.output->template flat<Eigen::half>());
CastFloatTo<Device, T>{}(context->template eigen_device<Device>(),
output_grad_const.template flat<float>(),
st.output->template flat<T>());
}
}
@ -509,6 +511,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
TF_CALL_half(REGISTER_GRAD_KERNEL);
TF_CALL_float(REGISTER_GRAD_KERNEL);
TF_CALL_double(REGISTER_GRAD_KERNEL);
TF_CALL_bfloat16(REGISTER_GRAD_KERNEL);
#undef REGISTER_GRAD_KERNEL

View File

@ -387,7 +387,7 @@ REGISTER_OP("ResizeNearestNeighborGrad")
.Input("grads: T")
.Input("size: int32")
.Output("output: T")
.Attr("T: {uint8, int8, int32, half, float, double}")
.Attr("T: {uint8, int8, int32, half, float, double, bfloat16}")
.Attr("align_corners: bool = false")
.Attr("half_pixel_centers: bool = false")
.SetShapeFn([](InferenceContext* c) {

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_image_ops
@ -37,7 +38,7 @@ from tensorflow.python.platform import test
'align_corners=False not supported by XLA')
class ResizeNearestNeighborOpTestBase(test.TestCase):
TYPES = [np.float32, np.float64]
TYPES = [np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype]
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
@ -67,7 +68,8 @@ class ResizeNearestNeighborOpTestBase(test.TestCase):
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
*gradient_checker_v2.compute_gradient(
resize_nn, [input_tensor], delta=1 / 8))
self.assertLess(err, 1e-3)
def testGradFromResizeToSmallerInBothDims(self):
@ -83,7 +85,8 @@ class ResizeNearestNeighborOpTestBase(test.TestCase):
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
*gradient_checker_v2.compute_gradient(
resize_nn, [input_tensor], delta=1 / 8))
self.assertLess(err, 1e-3)
def testCompareGpuVsCpu(self):
@ -101,12 +104,12 @@ class ResizeNearestNeighborOpTestBase(test.TestCase):
with self.cached_session(use_gpu=False):
input_tensor = constant_op.constant(x, shape=in_shape)
grad_cpu = gradient_checker_v2.compute_gradient(
resize_nn, [input_tensor])
resize_nn, [input_tensor], delta=1 / 8)
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
grad_gpu = gradient_checker_v2.compute_gradient(
resize_nn, [input_tensor])
resize_nn, [input_tensor], delta=1 / 8)
self.assertAllClose(grad_cpu, grad_gpu, rtol=1e-5, atol=1e-5)
@ -199,12 +202,14 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
for use_gpu in [False, True]:
for dtype in [np.float16, np.float32, np.float64]:
for dtype in [
np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
]:
jacob_a, jacob_n = self._getJacobians(
in_shape, out_shape, dtype=dtype, use_gpu=use_gpu)
if dtype == np.float16:
# Compare fp16 analytical gradients to fp32 numerical gradients,
# since fp16 numerical gradients are too imprecise unless great
if dtype in (np.float16, dtypes.bfloat16.as_numpy_dtype):
# Compare fp16/bf16 analytical gradients to fp32 numerical gradients,
# since fp16/bf16 numerical gradients are too imprecise unless great
# care is taken with choosing the inputs and the delta. This is
# a weaker, but pragmatic, check (in particular, it does not test
# the op itself, only its gradient).