Add `offset` argument to `Rescaling`.

PiperOrigin-RevId: 313104348
Change-Id: I5472da4856a6040e74286a5dc174a5897b8955df
This commit is contained in:
A. Unique TensorFlower 2020-05-25 14:48:04 -07:00 committed by TensorFlower Gardener
parent 83ed5aad57
commit 291125835e
4 changed files with 11 additions and 21 deletions

View File

@ -292,16 +292,11 @@ class RandomCrop(Layer):
@keras_export('keras.layers.experimental.preprocessing.Rescaling')
class Rescaling(Layer):
"""Multiply inputs by `scale` and adds `offset`.
"""Multiply inputs by `scale`.
For instance:
1. To rescale an input in the `[0, 255]` range
For instance, to rescale an input in the `[0, 255]` range
to be in the `[0, 1]` range, you would pass `scale=1./255`.
2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,
you would pass `scale=1./127.5, offset=-1`.
The rescaling is applied both during training and inference.
Input shape:
@ -312,20 +307,16 @@ class Rescaling(Layer):
Arguments:
scale: Float, the scale to apply to the inputs.
offset: Float, the offset to apply to the inputs.
name: A string, the name of the layer.
"""
def __init__(self, scale, offset=0., name=None, **kwargs):
def __init__(self, scale, name=None, **kwargs):
self.scale = scale
self.offset = offset
super(Rescaling, self).__init__(name=name, **kwargs)
def call(self, inputs):
dtype = self._compute_dtype
scale = math_ops.cast(self.scale, dtype)
offset = math_ops.cast(self.offset, dtype)
return math_ops.cast(inputs, dtype) * scale + offset
return math_ops.cast(inputs, dtype) * math_ops.cast(self.scale, dtype)
def compute_output_shape(self, input_shape):
return input_shape
@ -333,7 +324,6 @@ class Rescaling(Layer):
def get_config(self):
config = {
'scale': self.scale,
'offset': self.offset,
}
base_config = super(Rescaling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View File

@ -306,7 +306,7 @@ class RescalingTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_rescaling_base(self):
kwargs = {'scale': 1./127.5, 'offset': -1.}
kwargs = {'scale': 0.004}
testing_utils.layer_test(
image_preprocessing.Rescaling,
kwargs=kwargs,
@ -315,18 +315,18 @@ class RescalingTest(keras_parameterized.TestCase):
@tf_test_util.run_v2_only
def test_rescaling_correctness_float(self):
layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1.)
layer = image_preprocessing.Rescaling(0.004)
inputs = random_ops.random_uniform((2, 4, 5, 3))
outputs = layer(inputs)
self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004)
@tf_test_util.run_v2_only
def test_rescaling_correctness_int(self):
layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1)
layer = image_preprocessing.Rescaling(0.004)
inputs = random_ops.random_uniform((2, 4, 5, 3), 0, 100, dtype='int32')
outputs = layer(inputs)
self.assertEqual(outputs.dtype.name, 'float32')
self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1)
self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004)
def test_config_with_custom_name(self):
layer = image_preprocessing.Rescaling(0.5, name='rescaling')

View File

@ -113,7 +113,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], "
argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_loss"

View File

@ -113,7 +113,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], "
argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_loss"