Add bfloat16 support to more image grad ops.
Bfloat16 support for CPU/GPU was added to ResizeNearestNeighborGrad. Bfloat16 support for CPU was added to ResizeBilinearGrad. PiperOrigin-RevId: 350264170 Change-Id: If014d60a33c644addc955fc29b3a43589ea12903
This commit is contained in:
parent
fe70900e00
commit
2d94d904df
@ -10035,7 +10035,7 @@ def TF_ResizeNearestNeighborGradOp : TF_Op<"ResizeNearestNeighborGrad", [NoSideE
|
|||||||
let summary = "Computes the gradient of nearest neighbor interpolation.";
|
let summary = "Computes the gradient of nearest neighbor interpolation.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$grads,
|
TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$grads,
|
||||||
TF_Int32Tensor:$size,
|
TF_Int32Tensor:$size,
|
||||||
|
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
|
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
|
||||||
@ -10043,7 +10043,7 @@ def TF_ResizeNearestNeighborGradOp : TF_Op<"ResizeNearestNeighborGrad", [NoSideE
|
|||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$output
|
TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int8, TF_Uint8]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
@ -286,21 +286,22 @@ void resize_image(typename TTypes<T, 4>::ConstTensor images,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Device>
|
// Casts from float16 to T.
|
||||||
struct CastFloatToHalf {
|
template <typename Device, typename T>
|
||||||
|
struct CastFloatTo {
|
||||||
void operator()(const Device& d, typename TTypes<float>::ConstFlat input,
|
void operator()(const Device& d, typename TTypes<float>::ConstFlat input,
|
||||||
typename TTypes<Eigen::half>::Flat output) {
|
typename TTypes<T>::Flat output) {
|
||||||
output.device(d) = input.template cast<Eigen::half>();
|
output.device(d) = input.template cast<T>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename T>
|
||||||
struct CastFloatToHalf<GPUDevice> {
|
struct CastFloatTo<GPUDevice, T> {
|
||||||
void operator()(const GPUDevice& d, typename TTypes<float>::ConstFlat input,
|
void operator()(const GPUDevice& d, typename TTypes<float>::ConstFlat input,
|
||||||
typename TTypes<Eigen::half>::Flat output) {
|
typename TTypes<T>::Flat output) {
|
||||||
// Use existing cast functor instead of directly casting Eigen tensor, as
|
// Use existing cast functor instead of directly casting Eigen tensor, as
|
||||||
// otherwise we need to instantiate the cast function in a .cu.cc file
|
// otherwise we need to instantiate the cast function in a .cu.cc file
|
||||||
functor::CastFunctor<GPUDevice, Eigen::half, float> cast;
|
functor::CastFunctor<GPUDevice, T, float> cast;
|
||||||
cast(d, output, input);
|
cast(d, output, input);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -380,17 +381,18 @@ class ResizeBilinearOpGrad : public OpKernel {
|
|||||||
|
|
||||||
TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
|
TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
|
||||||
|
|
||||||
if (!std::is_same<T, Eigen::half>::value) {
|
if (!std::is_same<T, Eigen::half>::value &&
|
||||||
|
!std::is_same<T, Eigen::bfloat16>::value) {
|
||||||
typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
|
typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
|
||||||
functor::ResizeBilinearGrad<Device, T>()(
|
functor::ResizeBilinearGrad<Device, T>()(
|
||||||
context->eigen_device<Device>(), input_grad, st.height_scale,
|
context->eigen_device<Device>(), input_grad, st.height_scale,
|
||||||
st.width_scale, half_pixel_centers_, output_grad);
|
st.width_scale, half_pixel_centers_, output_grad);
|
||||||
} else {
|
} else {
|
||||||
// Accumulate output to float instead of half tensor, since float
|
// Accumulate output to float instead of half/bfloat16 tensor, since float
|
||||||
// accumulation is more numerically stable and GPU half implementation is
|
// accumulation is more numerically stable and GPU half implementation is
|
||||||
// slow.
|
// slow.
|
||||||
// TODO(b/165759037): Create optimized and numerically stable half
|
// TODO(b/165759037): Create optimized and numerically stable half and
|
||||||
// implementation
|
// bfloat16 implementation
|
||||||
Tensor output_grad;
|
Tensor output_grad;
|
||||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||||
DT_FLOAT, st.output->shape(), &output_grad));
|
DT_FLOAT, st.output->shape(), &output_grad));
|
||||||
@ -398,9 +400,9 @@ class ResizeBilinearOpGrad : public OpKernel {
|
|||||||
context->eigen_device<Device>(), input_grad, st.height_scale,
|
context->eigen_device<Device>(), input_grad, st.height_scale,
|
||||||
st.width_scale, half_pixel_centers_, output_grad.tensor<float, 4>());
|
st.width_scale, half_pixel_centers_, output_grad.tensor<float, 4>());
|
||||||
const Tensor& output_grad_const = output_grad;
|
const Tensor& output_grad_const = output_grad;
|
||||||
CastFloatToHalf<Device>{}(context->template eigen_device<Device>(),
|
CastFloatTo<Device, T>{}(context->template eigen_device<Device>(),
|
||||||
output_grad_const.template flat<float>(),
|
output_grad_const.template flat<float>(),
|
||||||
st.output->template flat<Eigen::half>());
|
st.output->template flat<T>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -509,6 +511,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
|
|||||||
TF_CALL_half(REGISTER_GRAD_KERNEL);
|
TF_CALL_half(REGISTER_GRAD_KERNEL);
|
||||||
TF_CALL_float(REGISTER_GRAD_KERNEL);
|
TF_CALL_float(REGISTER_GRAD_KERNEL);
|
||||||
TF_CALL_double(REGISTER_GRAD_KERNEL);
|
TF_CALL_double(REGISTER_GRAD_KERNEL);
|
||||||
|
TF_CALL_bfloat16(REGISTER_GRAD_KERNEL);
|
||||||
|
|
||||||
#undef REGISTER_GRAD_KERNEL
|
#undef REGISTER_GRAD_KERNEL
|
||||||
|
|
||||||
|
@ -387,7 +387,7 @@ REGISTER_OP("ResizeNearestNeighborGrad")
|
|||||||
.Input("grads: T")
|
.Input("grads: T")
|
||||||
.Input("size: int32")
|
.Input("size: int32")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: {uint8, int8, int32, half, float, double}")
|
.Attr("T: {uint8, int8, int32, half, float, double, bfloat16}")
|
||||||
.Attr("align_corners: bool = false")
|
.Attr("align_corners: bool = false")
|
||||||
.Attr("half_pixel_centers: bool = false")
|
.Attr("half_pixel_centers: bool = false")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
@ -24,6 +24,7 @@ import numpy as np
|
|||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_image_ops
|
from tensorflow.python.ops import gen_image_ops
|
||||||
@ -37,7 +38,7 @@ from tensorflow.python.platform import test
|
|||||||
'align_corners=False not supported by XLA')
|
'align_corners=False not supported by XLA')
|
||||||
class ResizeNearestNeighborOpTestBase(test.TestCase):
|
class ResizeNearestNeighborOpTestBase(test.TestCase):
|
||||||
|
|
||||||
TYPES = [np.float32, np.float64]
|
TYPES = [np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype]
|
||||||
|
|
||||||
def testShapeIsCorrectAfterOp(self):
|
def testShapeIsCorrectAfterOp(self):
|
||||||
in_shape = [1, 2, 2, 1]
|
in_shape = [1, 2, 2, 1]
|
||||||
@ -67,7 +68,8 @@ class ResizeNearestNeighborOpTestBase(test.TestCase):
|
|||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
|
*gradient_checker_v2.compute_gradient(
|
||||||
|
resize_nn, [input_tensor], delta=1 / 8))
|
||||||
self.assertLess(err, 1e-3)
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
def testGradFromResizeToSmallerInBothDims(self):
|
def testGradFromResizeToSmallerInBothDims(self):
|
||||||
@ -83,7 +85,8 @@ class ResizeNearestNeighborOpTestBase(test.TestCase):
|
|||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
|
*gradient_checker_v2.compute_gradient(
|
||||||
|
resize_nn, [input_tensor], delta=1 / 8))
|
||||||
self.assertLess(err, 1e-3)
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
def testCompareGpuVsCpu(self):
|
def testCompareGpuVsCpu(self):
|
||||||
@ -101,12 +104,12 @@ class ResizeNearestNeighborOpTestBase(test.TestCase):
|
|||||||
with self.cached_session(use_gpu=False):
|
with self.cached_session(use_gpu=False):
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
grad_cpu = gradient_checker_v2.compute_gradient(
|
grad_cpu = gradient_checker_v2.compute_gradient(
|
||||||
resize_nn, [input_tensor])
|
resize_nn, [input_tensor], delta=1 / 8)
|
||||||
|
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
grad_gpu = gradient_checker_v2.compute_gradient(
|
grad_gpu = gradient_checker_v2.compute_gradient(
|
||||||
resize_nn, [input_tensor])
|
resize_nn, [input_tensor], delta=1 / 8)
|
||||||
|
|
||||||
self.assertAllClose(grad_cpu, grad_gpu, rtol=1e-5, atol=1e-5)
|
self.assertAllClose(grad_cpu, grad_gpu, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
@ -199,12 +202,14 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
|
|||||||
in_shape = [1, 4, 6, 1]
|
in_shape = [1, 4, 6, 1]
|
||||||
out_shape = [1, 2, 3, 1]
|
out_shape = [1, 2, 3, 1]
|
||||||
for use_gpu in [False, True]:
|
for use_gpu in [False, True]:
|
||||||
for dtype in [np.float16, np.float32, np.float64]:
|
for dtype in [
|
||||||
|
np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
|
||||||
|
]:
|
||||||
jacob_a, jacob_n = self._getJacobians(
|
jacob_a, jacob_n = self._getJacobians(
|
||||||
in_shape, out_shape, dtype=dtype, use_gpu=use_gpu)
|
in_shape, out_shape, dtype=dtype, use_gpu=use_gpu)
|
||||||
if dtype == np.float16:
|
if dtype in (np.float16, dtypes.bfloat16.as_numpy_dtype):
|
||||||
# Compare fp16 analytical gradients to fp32 numerical gradients,
|
# Compare fp16/bf16 analytical gradients to fp32 numerical gradients,
|
||||||
# since fp16 numerical gradients are too imprecise unless great
|
# since fp16/bf16 numerical gradients are too imprecise unless great
|
||||||
# care is taken with choosing the inputs and the delta. This is
|
# care is taken with choosing the inputs and the delta. This is
|
||||||
# a weaker, but pragmatic, check (in particular, it does not test
|
# a weaker, but pragmatic, check (in particular, it does not test
|
||||||
# the op itself, only its gradient).
|
# the op itself, only its gradient).
|
||||||
|
Loading…
Reference in New Issue
Block a user