Add stateless_random_crop
to tf.image API; it is a deterministic version of tf.image.random_crop
. Given the same seed, stateless_random_crop
guarantees the same results independent of how many times it is called, and independent of global seed settings.
PiperOrigin-RevId: 325938094 Change-Id: Iad3132e097d71513193304d8aad45a5585656c53
This commit is contained in:
parent
6ea0d3d925
commit
ef20eb2110
@ -63,11 +63,6 @@ class RandomCropOp : public OpKernel {
|
|||||||
if ((target_height == height) && (target_width == width)) {
|
if ((target_height == height) && (target_width == width)) {
|
||||||
*output = context->input(0);
|
*output = context->input(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(shlens): Implement edge case to guarantee output size dimensions.
|
|
||||||
// Edge case. The target dimensions are larger then the image, so
|
|
||||||
// zero-pad the image. This guarantees that the image will *always*
|
|
||||||
// be [target_height, target_width] in size.
|
|
||||||
OP_REQUIRES(context, width >= target_width,
|
OP_REQUIRES(context, width >= target_width,
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
"width must be >= target_width: width = ", width,
|
"width must be >= target_width: width = ", width,
|
||||||
|
@ -77,5 +77,88 @@ class RandomCropTest(test.TestCase):
|
|||||||
self.assertAllClose(counts, mean, atol=four_stddev)
|
self.assertAllClose(counts, mean, atol=four_stddev)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
class StatelessRandomCropTest(test.TestCase):
|
||||||
|
|
||||||
|
def testNoOp(self):
|
||||||
|
# No random cropping is performed since the size is value.shape.
|
||||||
|
for shape in (2, 1, 1), (2, 1, 3), (4, 5, 3):
|
||||||
|
value = np.arange(0, np.prod(shape), dtype=np.int32).reshape(shape)
|
||||||
|
crop = random_ops.stateless_random_crop(value, shape, seed=(1, 2))
|
||||||
|
self.evaluate(crop)
|
||||||
|
self.assertAllEqual(crop, value)
|
||||||
|
|
||||||
|
def testContains(self):
|
||||||
|
with test_util.use_gpu():
|
||||||
|
shape = (3, 5, 7)
|
||||||
|
target = (2, 3, 4)
|
||||||
|
value = np.random.randint(1000000, size=shape)
|
||||||
|
iterations = 10
|
||||||
|
value_set = set(
|
||||||
|
tuple(value[i:i + 2, j:j + 3, k:k + 4].ravel()) # pylint: disable=g-complex-comprehension
|
||||||
|
for i in range(2) for j in range(3) for k in range(4))
|
||||||
|
test_seeds = [
|
||||||
|
tuple(map(lambda x, i=i: x + 1 * i, t))
|
||||||
|
for (i, t) in enumerate((1, 2) for _ in range(iterations))
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check that the result is valid by making sure that it is one of all
|
||||||
|
# possible values for randomly cropping `value` with `target` shape.
|
||||||
|
for seed in test_seeds:
|
||||||
|
crop = random_ops.stateless_random_crop(value, size=target, seed=seed)
|
||||||
|
y = self.evaluate(crop)
|
||||||
|
self.assertAllEqual(y.shape, target)
|
||||||
|
self.assertIn(tuple(y.ravel()), value_set)
|
||||||
|
|
||||||
|
# TODO(b/162345082): stateless random op generates different random number
|
||||||
|
# with xla_gpu. Update tests such that there is a single ground truth result
|
||||||
|
# to test against.
|
||||||
|
def testRandomization(self):
|
||||||
|
with test_util.use_gpu():
|
||||||
|
shape = [5, 4, 1]
|
||||||
|
size = np.prod(shape)
|
||||||
|
single = [1, 1, 1]
|
||||||
|
value = np.arange(size).reshape(shape)
|
||||||
|
iterations = 5
|
||||||
|
num_samples = 5
|
||||||
|
|
||||||
|
# Test that the same result is returned given the same seed is provided
|
||||||
|
# for each round.
|
||||||
|
test_seed = (1, 2)
|
||||||
|
observations = [[] for _ in range(iterations)]
|
||||||
|
for observation in observations:
|
||||||
|
crop = random_ops.stateless_random_crop(value, single, seed=test_seed)
|
||||||
|
counts = np.zeros(size, dtype=np.int32)
|
||||||
|
for _ in range(num_samples):
|
||||||
|
y = self.evaluate(crop)
|
||||||
|
self.assertAllEqual(y.shape, single)
|
||||||
|
counts[y] += 1
|
||||||
|
|
||||||
|
observation.append(counts)
|
||||||
|
|
||||||
|
for i in range(1, iterations):
|
||||||
|
self.assertAllEqual(observations[0], observations[i])
|
||||||
|
|
||||||
|
# Test that the same sequence of results are returned given the same
|
||||||
|
# sequence of seeds provided.
|
||||||
|
test_seeds = [
|
||||||
|
tuple(map(lambda x, i=i: x + 1 * i, t))
|
||||||
|
for (i, t) in enumerate((1, 2) for _ in range(iterations))
|
||||||
|
]
|
||||||
|
observations = [[] for _ in range(iterations)]
|
||||||
|
for observation in observations:
|
||||||
|
counts = np.zeros(size, dtype=np.int32)
|
||||||
|
for seed in test_seeds:
|
||||||
|
crop = random_ops.stateless_random_crop(
|
||||||
|
value, single, seed=seed)
|
||||||
|
y = self.evaluate(crop)
|
||||||
|
self.assertAllEqual(y.shape, single)
|
||||||
|
counts[y] += 1
|
||||||
|
|
||||||
|
observation.append(counts)
|
||||||
|
|
||||||
|
for i in range(1, iterations):
|
||||||
|
self.assertAllEqual(observations[0], observations[i])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_random_ops
|
from tensorflow.python.ops import gen_random_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import stateless_random_ops
|
||||||
|
|
||||||
# go/tf-wildcard-import
|
# go/tf-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
@ -373,9 +374,6 @@ def random_crop(value, size, seed=None, name=None):
|
|||||||
Returns:
|
Returns:
|
||||||
A cropped tensor of the same rank as `value` and shape `size`.
|
A cropped tensor of the same rank as `value` and shape `size`.
|
||||||
"""
|
"""
|
||||||
# TODO(shlens): Implement edge case to guarantee output size dimensions.
|
|
||||||
# If size > value.shape, zero pad the result so that it always has shape
|
|
||||||
# exactly size.
|
|
||||||
with ops.name_scope(name, "random_crop", [value, size]) as name:
|
with ops.name_scope(name, "random_crop", [value, size]) as name:
|
||||||
value = ops.convert_to_tensor(value, name="value")
|
value = ops.convert_to_tensor(value, name="value")
|
||||||
size = ops.convert_to_tensor(size, dtype=dtypes.int32, name="size")
|
size = ops.convert_to_tensor(size, dtype=dtypes.int32, name="size")
|
||||||
@ -394,6 +392,59 @@ def random_crop(value, size, seed=None, name=None):
|
|||||||
return array_ops.slice(value, offset, size, name=name)
|
return array_ops.slice(value, offset, size, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("image.stateless_random_crop", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
|
def stateless_random_crop(value, size, seed, name=None):
|
||||||
|
"""Randomly crops a tensor to a given size in a deterministic manner.
|
||||||
|
|
||||||
|
Slices a shape `size` portion out of `value` at a uniformly chosen offset.
|
||||||
|
Requires `value.shape >= size`.
|
||||||
|
|
||||||
|
If a dimension should not be cropped, pass the full size of that dimension.
|
||||||
|
For example, RGB images can be cropped with
|
||||||
|
`size = [crop_height, crop_width, 3]`.
|
||||||
|
|
||||||
|
Guarantees the same results given the same `seed` independent of how many
|
||||||
|
times the function is called, and independent of global seed settings (e.g.
|
||||||
|
`tf.random.set_seed`).
|
||||||
|
|
||||||
|
Usage Example:
|
||||||
|
|
||||||
|
>>> image = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
|
||||||
|
>>> seed = (1, 2)
|
||||||
|
>>> tf.image.stateless_random_crop(value=image, size=(1, 2, 3), seed=seed)
|
||||||
|
<tf.Tensor: shape=(1, 2, 3), dtype=int32, numpy=
|
||||||
|
array([[[1, 2, 3],
|
||||||
|
[4, 5, 6]]], dtype=int32)>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Input tensor to crop.
|
||||||
|
size: 1-D tensor with size the rank of `value`.
|
||||||
|
seed: A shape [2] Tensor, the seed to the random number generator. Must have
|
||||||
|
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A cropped tensor of the same rank as `value` and shape `size`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, "random_crop", [value, size]) as name:
|
||||||
|
value = ops.convert_to_tensor(value, name="value")
|
||||||
|
size = ops.convert_to_tensor(size, dtype=dtypes.int32, name="size")
|
||||||
|
shape = array_ops.shape(value)
|
||||||
|
check = control_flow_ops.Assert(
|
||||||
|
math_ops.reduce_all(shape >= size),
|
||||||
|
["Need value.shape >= size, got ", shape, size],
|
||||||
|
summarize=1000)
|
||||||
|
shape = control_flow_ops.with_dependencies([check], shape)
|
||||||
|
limit = shape - size + 1
|
||||||
|
offset = stateless_random_ops.stateless_random_uniform(
|
||||||
|
array_ops.shape(shape),
|
||||||
|
dtype=size.dtype,
|
||||||
|
maxval=size.dtype.max,
|
||||||
|
seed=seed) % limit
|
||||||
|
return array_ops.slice(value, offset, size, name=name)
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["random.multinomial", "multinomial"])
|
@tf_export(v1=["random.multinomial", "multinomial"])
|
||||||
@dispatch.add_dispatch_support
|
@dispatch.add_dispatch_support
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
|
@ -240,6 +240,10 @@ tf_module {
|
|||||||
name: "stateless_random_contrast"
|
name: "stateless_random_contrast"
|
||||||
argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "stateless_random_crop"
|
||||||
|
argspec: "args=[\'value\', \'size\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "stateless_random_flip_left_right"
|
name: "stateless_random_flip_left_right"
|
||||||
argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
Reference in New Issue
Block a user