Add `offset` argument to `Rescaling`.
PiperOrigin-RevId: 313104348 Change-Id: I5472da4856a6040e74286a5dc174a5897b8955df
This commit is contained in:
parent
83ed5aad57
commit
291125835e
|
@ -292,16 +292,11 @@ class RandomCrop(Layer):
|
||||||
|
|
||||||
@keras_export('keras.layers.experimental.preprocessing.Rescaling')
|
@keras_export('keras.layers.experimental.preprocessing.Rescaling')
|
||||||
class Rescaling(Layer):
|
class Rescaling(Layer):
|
||||||
"""Multiply inputs by `scale` and adds `offset`.
|
"""Multiply inputs by `scale`.
|
||||||
|
|
||||||
For instance:
|
For instance, to rescale an input in the `[0, 255]` range
|
||||||
|
|
||||||
1. To rescale an input in the `[0, 255]` range
|
|
||||||
to be in the `[0, 1]` range, you would pass `scale=1./255`.
|
to be in the `[0, 1]` range, you would pass `scale=1./255`.
|
||||||
|
|
||||||
2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,
|
|
||||||
you would pass `scale=1./127.5, offset=-1`.
|
|
||||||
|
|
||||||
The rescaling is applied both during training and inference.
|
The rescaling is applied both during training and inference.
|
||||||
|
|
||||||
Input shape:
|
Input shape:
|
||||||
|
@ -312,20 +307,16 @@ class Rescaling(Layer):
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
scale: Float, the scale to apply to the inputs.
|
scale: Float, the scale to apply to the inputs.
|
||||||
offset: Float, the offset to apply to the inputs.
|
|
||||||
name: A string, the name of the layer.
|
name: A string, the name of the layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scale, offset=0., name=None, **kwargs):
|
def __init__(self, scale, name=None, **kwargs):
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.offset = offset
|
|
||||||
super(Rescaling, self).__init__(name=name, **kwargs)
|
super(Rescaling, self).__init__(name=name, **kwargs)
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
dtype = self._compute_dtype
|
dtype = self._compute_dtype
|
||||||
scale = math_ops.cast(self.scale, dtype)
|
return math_ops.cast(inputs, dtype) * math_ops.cast(self.scale, dtype)
|
||||||
offset = math_ops.cast(self.offset, dtype)
|
|
||||||
return math_ops.cast(inputs, dtype) * scale + offset
|
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
return input_shape
|
return input_shape
|
||||||
|
@ -333,7 +324,6 @@ class Rescaling(Layer):
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {
|
config = {
|
||||||
'scale': self.scale,
|
'scale': self.scale,
|
||||||
'offset': self.offset,
|
|
||||||
}
|
}
|
||||||
base_config = super(Rescaling, self).get_config()
|
base_config = super(Rescaling, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
|
@ -306,7 +306,7 @@ class RescalingTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
def test_rescaling_base(self):
|
def test_rescaling_base(self):
|
||||||
kwargs = {'scale': 1./127.5, 'offset': -1.}
|
kwargs = {'scale': 0.004}
|
||||||
testing_utils.layer_test(
|
testing_utils.layer_test(
|
||||||
image_preprocessing.Rescaling,
|
image_preprocessing.Rescaling,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
@ -315,18 +315,18 @@ class RescalingTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
@tf_test_util.run_v2_only
|
@tf_test_util.run_v2_only
|
||||||
def test_rescaling_correctness_float(self):
|
def test_rescaling_correctness_float(self):
|
||||||
layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1.)
|
layer = image_preprocessing.Rescaling(0.004)
|
||||||
inputs = random_ops.random_uniform((2, 4, 5, 3))
|
inputs = random_ops.random_uniform((2, 4, 5, 3))
|
||||||
outputs = layer(inputs)
|
outputs = layer(inputs)
|
||||||
self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
|
self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004)
|
||||||
|
|
||||||
@tf_test_util.run_v2_only
|
@tf_test_util.run_v2_only
|
||||||
def test_rescaling_correctness_int(self):
|
def test_rescaling_correctness_int(self):
|
||||||
layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1)
|
layer = image_preprocessing.Rescaling(0.004)
|
||||||
inputs = random_ops.random_uniform((2, 4, 5, 3), 0, 100, dtype='int32')
|
inputs = random_ops.random_uniform((2, 4, 5, 3), 0, 100, dtype='int32')
|
||||||
outputs = layer(inputs)
|
outputs = layer(inputs)
|
||||||
self.assertEqual(outputs.dtype.name, 'float32')
|
self.assertEqual(outputs.dtype.name, 'float32')
|
||||||
self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
|
self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004)
|
||||||
|
|
||||||
def test_config_with_custom_name(self):
|
def test_config_with_custom_name(self):
|
||||||
layer = image_preprocessing.Rescaling(0.5, name='rescaling')
|
layer = image_preprocessing.Rescaling(0.5, name='rescaling')
|
||||||
|
|
|
@ -113,7 +113,7 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], "
|
argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_loss"
|
name: "add_loss"
|
||||||
|
|
|
@ -113,7 +113,7 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], "
|
argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_loss"
|
name: "add_loss"
|
||||||
|
|
Loading…
Reference in New Issue