Add `offset` argument to `Rescaling`.
PiperOrigin-RevId: 313101675 Change-Id: Id59e6dcbe4f038d627c7d71fdf4dfeb58e8e05cd
This commit is contained in:
parent
55c1176fe2
commit
83ed5aad57
|
@ -292,11 +292,16 @@ class RandomCrop(Layer):
|
|||
|
||||
@keras_export('keras.layers.experimental.preprocessing.Rescaling')
|
||||
class Rescaling(Layer):
|
||||
"""Multiply inputs by `scale`.
|
||||
"""Multiply inputs by `scale` and adds `offset`.
|
||||
|
||||
For instance, to rescale an input in the `[0, 255]` range
|
||||
For instance:
|
||||
|
||||
1. To rescale an input in the `[0, 255]` range
|
||||
to be in the `[0, 1]` range, you would pass `scale=1./255`.
|
||||
|
||||
2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,
|
||||
you would pass `scale=1./127.5, offset=-1`.
|
||||
|
||||
The rescaling is applied both during training and inference.
|
||||
|
||||
Input shape:
|
||||
|
@ -307,16 +312,20 @@ class Rescaling(Layer):
|
|||
|
||||
Arguments:
|
||||
scale: Float, the scale to apply to the inputs.
|
||||
offset: Float, the offset to apply to the inputs.
|
||||
name: A string, the name of the layer.
|
||||
"""
|
||||
|
||||
def __init__(self, scale, name=None, **kwargs):
|
||||
def __init__(self, scale, offset=0., name=None, **kwargs):
|
||||
self.scale = scale
|
||||
self.offset = offset
|
||||
super(Rescaling, self).__init__(name=name, **kwargs)
|
||||
|
||||
def call(self, inputs):
|
||||
dtype = self._compute_dtype
|
||||
return math_ops.cast(inputs, dtype) * math_ops.cast(self.scale, dtype)
|
||||
scale = math_ops.cast(self.scale, dtype)
|
||||
offset = math_ops.cast(self.offset, dtype)
|
||||
return math_ops.cast(inputs, dtype) * scale + offset
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
@ -324,6 +333,7 @@ class Rescaling(Layer):
|
|||
def get_config(self):
|
||||
config = {
|
||||
'scale': self.scale,
|
||||
'offset': self.offset,
|
||||
}
|
||||
base_config = super(Rescaling, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
|
|
@ -306,7 +306,7 @@ class RescalingTest(keras_parameterized.TestCase):
|
|||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_rescaling_base(self):
|
||||
kwargs = {'scale': 0.004}
|
||||
kwargs = {'scale': 1./127.5, 'offset': -1.}
|
||||
testing_utils.layer_test(
|
||||
image_preprocessing.Rescaling,
|
||||
kwargs=kwargs,
|
||||
|
@ -315,18 +315,18 @@ class RescalingTest(keras_parameterized.TestCase):
|
|||
|
||||
@tf_test_util.run_v2_only
|
||||
def test_rescaling_correctness_float(self):
|
||||
layer = image_preprocessing.Rescaling(0.004)
|
||||
layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1.)
|
||||
inputs = random_ops.random_uniform((2, 4, 5, 3))
|
||||
outputs = layer(inputs)
|
||||
self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004)
|
||||
self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
|
||||
|
||||
@tf_test_util.run_v2_only
|
||||
def test_rescaling_correctness_int(self):
|
||||
layer = image_preprocessing.Rescaling(0.004)
|
||||
layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1)
|
||||
inputs = random_ops.random_uniform((2, 4, 5, 3), 0, 100, dtype='int32')
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(outputs.dtype.name, 'float32')
|
||||
self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004)
|
||||
self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
|
||||
|
||||
def test_config_with_custom_name(self):
|
||||
layer = image_preprocessing.Rescaling(0.5, name='rescaling')
|
||||
|
|
|
@ -113,7 +113,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -113,7 +113,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
Loading…
Reference in New Issue