Add `offset` argument to `Rescaling`.

PiperOrigin-RevId: 313101675
Change-Id: Id59e6dcbe4f038d627c7d71fdf4dfeb58e8e05cd
This commit is contained in:
Francois Chollet 2020-05-25 14:02:01 -07:00 committed by TensorFlower Gardener
parent 55c1176fe2
commit 83ed5aad57
4 changed files with 21 additions and 11 deletions

View File

@ -292,11 +292,16 @@ class RandomCrop(Layer):
@keras_export('keras.layers.experimental.preprocessing.Rescaling')
class Rescaling(Layer):
"""Multiply inputs by `scale`.
"""Multiply inputs by `scale` and adds `offset`.
For instance, to rescale an input in the `[0, 255]` range
For instance:
1. To rescale an input in the `[0, 255]` range
to be in the `[0, 1]` range, you would pass `scale=1./255`.
2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,
you would pass `scale=1./127.5, offset=-1`.
The rescaling is applied both during training and inference.
Input shape:
@ -307,16 +312,20 @@ class Rescaling(Layer):
Arguments:
scale: Float, the scale to apply to the inputs.
offset: Float, the offset to apply to the inputs.
name: A string, the name of the layer.
"""
def __init__(self, scale, name=None, **kwargs):
def __init__(self, scale, offset=0., name=None, **kwargs):
self.scale = scale
self.offset = offset
super(Rescaling, self).__init__(name=name, **kwargs)
def call(self, inputs):
dtype = self._compute_dtype
return math_ops.cast(inputs, dtype) * math_ops.cast(self.scale, dtype)
scale = math_ops.cast(self.scale, dtype)
offset = math_ops.cast(self.offset, dtype)
return math_ops.cast(inputs, dtype) * scale + offset
def compute_output_shape(self, input_shape):
return input_shape
@ -324,6 +333,7 @@ class Rescaling(Layer):
def get_config(self):
config = {
'scale': self.scale,
'offset': self.offset,
}
base_config = super(Rescaling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View File

@ -306,7 +306,7 @@ class RescalingTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_rescaling_base(self):
kwargs = {'scale': 0.004}
kwargs = {'scale': 1./127.5, 'offset': -1.}
testing_utils.layer_test(
image_preprocessing.Rescaling,
kwargs=kwargs,
@ -315,18 +315,18 @@ class RescalingTest(keras_parameterized.TestCase):
@tf_test_util.run_v2_only
def test_rescaling_correctness_float(self):
layer = image_preprocessing.Rescaling(0.004)
layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1.)
inputs = random_ops.random_uniform((2, 4, 5, 3))
outputs = layer(inputs)
self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004)
self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
@tf_test_util.run_v2_only
def test_rescaling_correctness_int(self):
layer = image_preprocessing.Rescaling(0.004)
layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1)
inputs = random_ops.random_uniform((2, 4, 5, 3), 0, 100, dtype='int32')
outputs = layer(inputs)
self.assertEqual(outputs.dtype.name, 'float32')
self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004)
self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
def test_config_with_custom_name(self):
layer = image_preprocessing.Rescaling(0.5, name='rescaling')

View File

@ -113,7 +113,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], "
}
member_method {
name: "add_loss"

View File

@ -113,7 +113,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], "
}
member_method {
name: "add_loss"