From 3e33d444c6c0861d448755627982284736ed01a0 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving <geoffreyi@google.com> Date: Mon, 8 Feb 2016 12:02:44 -0800 Subject: [PATCH] Generalize tf.image.random_crop to dimension-independent tf.random_crop The C++ 3-D-only RandomCrop op is now deprecated at GraphDef version 8, replaced with a python tf.random_crop that works for any dimension. This will allow random_crop to be used for other purposes. Unfortunately, tf.image.random_crop took 2 sizes rather than 3 for 3-D tensors. The new tf.random_crop always takes n sizes for rank n tensors; pass 3 as the last element if you want to not crop a last dimension of size 3. Change: 114135451 --- RELEASE.md | 4 + tensorflow/core/kernels/random_crop_op.cc | 1 + .../core/kernels/random_crop_op_test.cc | 75 ------------------ tensorflow/core/public/version.h | 3 +- .../models/image/cifar10/cifar10_input.py | 2 +- tensorflow/python/BUILD | 1 + .../python/kernel_tests/random_crop_test.py | 77 +++++++++++++++++++ tensorflow/python/ops/constant_op.py | 1 + tensorflow/python/ops/image_ops.py | 41 ---------- tensorflow/python/ops/image_ops_test.py | 51 ------------ tensorflow/python/ops/random_ops.py | 40 ++++++++++ 11 files changed, 127 insertions(+), 169 deletions(-) delete mode 100644 tensorflow/core/kernels/random_crop_op_test.cc create mode 100644 tensorflow/python/kernel_tests/random_crop_test.py diff --git a/RELEASE.md b/RELEASE.md index fd46bcdff36..b8146610c80 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -33,6 +33,10 @@ maintained for short-term compatibility but will be removed. * The non-public `nn.rnn` and the various `nn.seq2seq` methods now return just the final state instead of the list of all states. +* `tf.image.random_crop(image, [height, width])` is now + `tf.random_crop(image, [height, width, depth])`, and `tf.random_crop` works + for any rank (not just 3-D images). The C++ `RandomCrop` op has been replaced + with pure Python. ## Bug fixes diff --git a/tensorflow/core/kernels/random_crop_op.cc b/tensorflow/core/kernels/random_crop_op.cc index e27de3b4e04..80b4041e05f 100644 --- a/tensorflow/core/kernels/random_crop_op.cc +++ b/tensorflow/core/kernels/random_crop_op.cc @@ -28,6 +28,7 @@ template <typename T> class RandomCropOp : public OpKernel { public: explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) { + OP_DEPRECATED(context, 8, "Random crop is now pure Python"); OP_REQUIRES_OK(context, generator_.Init(context)); } diff --git a/tensorflow/core/kernels/random_crop_op_test.cc b/tensorflow/core/kernels/random_crop_op_test.cc deleted file mode 100644 index 18bdbd4b6b5..00000000000 --- a/tensorflow/core/kernels/random_crop_op_test.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2015 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/kernels/ops_util.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { - -class RandomCropOpTest : public OpsTestBase { - protected: - RandomCropOpTest() { - RequireDefaultOps(); - TF_EXPECT_OK(NodeDefBuilder("random_crop_op", "RandomCrop") - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_INT64)) - .Attr("T", DT_UINT8) - .Finalize(node_def())); - TF_EXPECT_OK(InitOp()); - } -}; - -TEST_F(RandomCropOpTest, Basic) { - AddInputFromArray<uint8>(TensorShape({1, 2, 1}), {2, 2}); - AddInputFromArray<int64>(TensorShape({2}), {1, 1}); - TF_ASSERT_OK(RunOpKernel()); - - Tensor expected(allocator(), DT_UINT8, TensorShape({1, 1, 1})); - test::FillValues<uint8>(&expected, {2}); - test::ExpectTensorEqual<uint8>(expected, *GetOutput(0)); -} - -TEST_F(RandomCropOpTest, SameSizeOneChannel) { - AddInputFromArray<uint8>(TensorShape({2, 1, 1}), {1, 2}); - AddInputFromArray<int64>(TensorShape({2}), {2, 1}); - TF_ASSERT_OK(RunOpKernel()); - - Tensor expected(allocator(), DT_UINT8, TensorShape({2, 1, 1})); - test::FillValues<uint8>(&expected, {1, 2}); - test::ExpectTensorEqual<uint8>(expected, *GetOutput(0)); -} - -TEST_F(RandomCropOpTest, SameSizeMultiChannel) { - AddInputFromArray<uint8>(TensorShape({2, 1, 3}), {1, 2, 3, 4, 5, 6}); - AddInputFromArray<int64>(TensorShape({2}), {2, 1}); - TF_ASSERT_OK(RunOpKernel()); - - Tensor expected(allocator(), DT_UINT8, TensorShape({2, 1, 3})); - test::FillValues<uint8>(&expected, {1, 2, 3, 4, 5, 6}); - test::ExpectTensorEqual<uint8>(expected, *GetOutput(0)); -} - -} // namespace tensorflow diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 9ec1aa9b7f2..39cf5b06e9b 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -61,8 +61,9 @@ limitations under the License. // 5. Graphs are wholly-validated during Session::Create() (7jan2016). // 6. TensorFlow is scalar strict within Google (27jan2016). // 7. Remove TopK in favor of TopKV2 (5feb2016). +// 8. Replace RandomCrop from C++ with pure Python (5feb2016). #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 7 +#define TF_GRAPH_DEF_VERSION 8 #endif // TENSORFLOW_CORE_PUBLIC_VERSION_H_ diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py index ffe8facd27d..f7d7083d735 100644 --- a/tensorflow/models/image/cifar10/cifar10_input.py +++ b/tensorflow/models/image/cifar10/cifar10_input.py @@ -161,7 +161,7 @@ def distorted_inputs(data_dir, batch_size): # distortions applied to the image. # Randomly crop a [height, width] section of the image. - distorted_image = tf.image.random_crop(reshaped_image, [height, width]) + distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) # Randomly flip the image horizontally. distorted_image = tf.image.random_flip_left_right(distorted_image) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 979897c985a..ee4dae35453 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -509,6 +509,7 @@ tf_gen_op_wrapper_py( tf_gen_op_wrapper_py( name = "image_ops", hidden = [ + "RandomCrop", "ResizeBilinearGrad", "ResizeNearestNeighborGrad", "AdjustContrastv2", diff --git a/tensorflow/python/kernel_tests/random_crop_test.py b/tensorflow/python/kernel_tests/random_crop_test.py new file mode 100644 index 00000000000..b682b22b88b --- /dev/null +++ b/tensorflow/python/kernel_tests/random_crop_test.py @@ -0,0 +1,77 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for random_crop.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class RandomCropTest(tf.test.TestCase): + + def testNoOp(self): + # No random cropping is performed since the size is value.shape. + for shape in (2, 1, 1), (2, 1, 3), (4, 5, 3): + value = np.arange(0, np.prod(shape), dtype=np.int32).reshape(shape) + with self.test_session(): + crop = tf.random_crop(value, shape).eval() + self.assertAllEqual(crop, value) + + def testContains(self): + with self.test_session(): + shape = (3, 5, 7) + target = (2, 3, 4) + value = np.random.randint(1000000, size=shape) + value_set = set(tuple(value[i:i + 2, j:j + 3, k:k + 4].ravel()) + for i in range(2) for j in range(3) for k in range(4)) + crop = tf.random_crop(value, size=target) + for _ in range(20): + y = crop.eval() + self.assertAllEqual(y.shape, target) + self.assertTrue(tuple(y.ravel()) in value_set) + + def testRandomization(self): + # Run 1x1 crop num_samples times in an image and ensure that one finds each + # pixel 1/size of the time. + num_samples = 1000 + shape = [5, 4, 1] + size = np.prod(shape) + single = [1, 1, 1] + value = np.arange(size).reshape(shape) + + with self.test_session(): + crop = tf.random_crop(value, single, seed=7) + counts = np.zeros(size, dtype=np.int32) + for _ in range(num_samples): + y = crop.eval() + self.assertAllEqual(y.shape, single) + counts[y] += 1 + + # Calculate the mean and 4 * standard deviation. + mean = np.repeat(num_samples / size, size) + four_stddev = 4.0 * np.sqrt(mean) + + # Ensure that each entry is observed in 1/size of the samples + # within 4 standard deviations. + self.assertAllClose(counts, mean, atol=four_stddev) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py index a5f1a9c8ee1..5d04389c731 100644 --- a/tensorflow/python/ops/constant_op.py +++ b/tensorflow/python/ops/constant_op.py @@ -91,6 +91,7 @@ print(sess.run(var)) @@truncated_normal @@random_uniform @@random_shuffle +@@random_crop @@set_random_seed """ diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 2eeef95d99d..d278b53d709 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -70,7 +70,6 @@ resized_image = tf.image.resize_images(image, 299, 299) @@pad_to_bounding_box @@crop_to_bounding_box -@@random_crop @@extract_glimpse ## Flipping and Transposing @@ -156,7 +155,6 @@ import tensorflow.python.platform from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -832,45 +830,6 @@ def _ImageEncodeShape(op): return [tensor_shape.scalar()] -@ops.RegisterShape('RandomCrop') -def _random_cropShape(op): - """Shape function for the random_crop op.""" - input_shape = op.inputs[0].get_shape().with_rank(3) - unused_size_shape = op.inputs[1].get_shape().merge_with( - tensor_shape.vector(2)) - size = tensor_util.constant_value(op.inputs[1]) - if size is not None: - height = size[0] - width = size[1] - else: - height = None - width = None - channels = input_shape[2] - return [tensor_shape.TensorShape([height, width, channels])] - - -def random_crop(image, size, seed=None, name=None): - """Randomly crops `image` to size `[target_height, target_width]`. - - The offset of the output within `image` is uniformly random. `image` always - fully contains the result. - - Args: - image: 3-D tensor of shape `[height, width, channels]` - size: 1-D tensor with two elements, specifying target `[height, width]` - seed: A Python integer. Used to create a random seed. See - [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) - for behavior. - name: A name for this operation (optional). - - Returns: - A cropped 3-D tensor of shape `[target_height, target_width, channels]`. - """ - seed1, seed2 = random_seed.get_seed(seed) - return gen_image_ops.random_crop(image, size, seed=seed1, seed2=seed2, - name=name) - - def saturate_cast(image, dtype): """Performs a safe cast of image data to `dtype`. diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index d09004556ca..bb19902960c 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -443,57 +443,6 @@ class AdjustBrightnessTest(test_util.TensorFlowTestCase): self._testBrightness(x_np, y_np, delta=-10. / 255.) -class RandomCropTest(test_util.TensorFlowTestCase): - - def testNoOp(self): - # No random cropping is performed since the target width and height - # are match the image dimensions. - height = 4 - width = 5 - x_shape = [height, width, 3] - x_np = np.arange(0, np.prod(x_shape), dtype=np.int32).reshape(x_shape) - target_shape_np = np.array([height, width], dtype=np.int64) - - with self.test_session(): - x = constant_op.constant(x_np, shape=x_shape) - target_shape = constant_op.constant(target_shape_np, shape=[2]) - y = image_ops.random_crop(x, target_shape) - y_tf = y.eval() - self.assertAllEqual(y_tf, x_np) - - def testRandomization(self): - # Run 1x1 crop num_samples times in an image and ensure that one finds each - # pixel 1/num_pixels of the time. - num_samples = 1000 - height = 5 - width = 4 - - num_pixels = height * width - data = np.arange(num_pixels).reshape([height, width, 1]) - x_np = np.array(data).astype(np.int32) - - target_shape_np = np.array([1, 1], dtype=np.int64) - - y = [] - with self.test_session(): - x = constant_op.constant(x_np, shape=x_np.shape) - target_shape = constant_op.constant(target_shape_np, shape=[2]) - y_tf = image_ops.random_crop(x, target_shape) - for _ in xrange(num_samples): - y_np = y_tf.eval() - self.assertAllEqual(y_np.shape, [1, 1, 1]) - y.extend(y_np.flatten()) - - # Calculate the mean and 4 * standard deviation. - mean = [num_samples / num_pixels] * num_pixels - four_stddev = 4.0 * np.sqrt(mean) - - # Ensure that each entry is observed in 1/num_pixels of the samples - # within 4 standard deviations. - counts = np.bincount(y) - self.assertAllClose(counts, mean, atol=four_stddev) - - class PerImageWhiteningTest(test_util.TensorFlowTestCase): def _NumpyPerImageWhitening(self, x): diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 1416e4a8c14..15cfcebc291 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -24,8 +24,11 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops from tensorflow.python.ops import common_shapes +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_random_ops +from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_random_ops import * @@ -209,6 +212,43 @@ def random_shuffle(value, seed=None, name=None): name=name) +def random_crop(value, size, seed=None, name=None): + """Randomly crops a tensor to a given size. + + Slices a shape `size` portion out of `value` at a uniformly chosen offset. + Requires `value.shape >= size`. + + If a dimension should not be cropped, pass the full size of that dimension. + For example, RGB images can be cropped with + `size = [crop_height, crop_width, 3]`. + + Args: + value: Input tensor to crop. + size: 1-D tensor with size the rank of `value`. + seed: Python integer. Used to create a random seed. See + [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) + for behavior. + name: A name for this operation (optional). + + Returns: + A cropped tensor of the same rank as `value` and shape `size`. + """ + # TODO(shlens): Implement edge case to guarantee output size dimensions. + # If size > value.shape, zero pad the result so that it always has shape + # exactly size. + with ops.op_scope([value, size], name, "random_crop") as name: + value = ops.convert_to_tensor(value, name="value") + size = ops.convert_to_tensor(size, dtype=dtypes.int32, name="size") + shape = array_ops.shape(value) + check = logging_ops.Assert(math_ops.reduce_all(shape >= size), + ["Need value.shape >= size, got ", shape, size]) + shape = control_flow_ops.with_dependencies([check], shape) + limit = shape - size + 1 + offset = random_uniform(array_ops.shape(shape), dtype=size.dtype, + maxval=size.dtype.max, seed=seed) % limit + return array_ops.slice(value, offset, size, name=name) + + ops.NoGradient("RandomUniform")