From 3e33d444c6c0861d448755627982284736ed01a0 Mon Sep 17 00:00:00 2001
From: Geoffrey Irving <geoffreyi@google.com>
Date: Mon, 8 Feb 2016 12:02:44 -0800
Subject: [PATCH] Generalize tf.image.random_crop to dimension-independent
 tf.random_crop

The C++ 3-D-only RandomCrop op is now deprecated at GraphDef version 8, replaced
with a python tf.random_crop that works for any dimension.  This will allow
random_crop to be used for other purposes.

Unfortunately, tf.image.random_crop took 2 sizes rather than 3 for 3-D tensors.
The new tf.random_crop always takes n sizes for rank n tensors; pass 3 as the
last element if you want to not crop a last dimension of size 3.
Change: 114135451
---
 RELEASE.md                                    |  4 +
 tensorflow/core/kernels/random_crop_op.cc     |  1 +
 .../core/kernels/random_crop_op_test.cc       | 75 ------------------
 tensorflow/core/public/version.h              |  3 +-
 .../models/image/cifar10/cifar10_input.py     |  2 +-
 tensorflow/python/BUILD                       |  1 +
 .../python/kernel_tests/random_crop_test.py   | 77 +++++++++++++++++++
 tensorflow/python/ops/constant_op.py          |  1 +
 tensorflow/python/ops/image_ops.py            | 41 ----------
 tensorflow/python/ops/image_ops_test.py       | 51 ------------
 tensorflow/python/ops/random_ops.py           | 40 ++++++++++
 11 files changed, 127 insertions(+), 169 deletions(-)
 delete mode 100644 tensorflow/core/kernels/random_crop_op_test.cc
 create mode 100644 tensorflow/python/kernel_tests/random_crop_test.py

diff --git a/RELEASE.md b/RELEASE.md
index fd46bcdff36..b8146610c80 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -33,6 +33,10 @@
   maintained for short-term compatibility but will be removed.
 * The non-public `nn.rnn` and the various `nn.seq2seq` methods now return
   just the final state instead of the list of all states.
+* `tf.image.random_crop(image, [height, width])` is now
+  `tf.random_crop(image, [height, width, depth])`, and `tf.random_crop` works
+  for any rank (not just 3-D images).  The C++ `RandomCrop` op has been replaced
+  with pure Python.
 
 
 ## Bug fixes
diff --git a/tensorflow/core/kernels/random_crop_op.cc b/tensorflow/core/kernels/random_crop_op.cc
index e27de3b4e04..80b4041e05f 100644
--- a/tensorflow/core/kernels/random_crop_op.cc
+++ b/tensorflow/core/kernels/random_crop_op.cc
@@ -28,6 +28,7 @@ template <typename T>
 class RandomCropOp : public OpKernel {
  public:
   explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) {
+    OP_DEPRECATED(context, 8, "Random crop is now pure Python");
     OP_REQUIRES_OK(context, generator_.Init(context));
   }
 
diff --git a/tensorflow/core/kernels/random_crop_op_test.cc b/tensorflow/core/kernels/random_crop_op_test.cc
deleted file mode 100644
index 18bdbd4b6b5..00000000000
--- a/tensorflow/core/kernels/random_crop_op_test.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright 2015 Google Inc. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/fake_input.h"
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/kernels/ops_testutil.h"
-#include "tensorflow/core/kernels/ops_util.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-
-class RandomCropOpTest : public OpsTestBase {
- protected:
-  RandomCropOpTest() {
-    RequireDefaultOps();
-    TF_EXPECT_OK(NodeDefBuilder("random_crop_op", "RandomCrop")
-                     .Input(FakeInput(DT_UINT8))
-                     .Input(FakeInput(DT_INT64))
-                     .Attr("T", DT_UINT8)
-                     .Finalize(node_def()));
-    TF_EXPECT_OK(InitOp());
-  }
-};
-
-TEST_F(RandomCropOpTest, Basic) {
-  AddInputFromArray<uint8>(TensorShape({1, 2, 1}), {2, 2});
-  AddInputFromArray<int64>(TensorShape({2}), {1, 1});
-  TF_ASSERT_OK(RunOpKernel());
-
-  Tensor expected(allocator(), DT_UINT8, TensorShape({1, 1, 1}));
-  test::FillValues<uint8>(&expected, {2});
-  test::ExpectTensorEqual<uint8>(expected, *GetOutput(0));
-}
-
-TEST_F(RandomCropOpTest, SameSizeOneChannel) {
-  AddInputFromArray<uint8>(TensorShape({2, 1, 1}), {1, 2});
-  AddInputFromArray<int64>(TensorShape({2}), {2, 1});
-  TF_ASSERT_OK(RunOpKernel());
-
-  Tensor expected(allocator(), DT_UINT8, TensorShape({2, 1, 1}));
-  test::FillValues<uint8>(&expected, {1, 2});
-  test::ExpectTensorEqual<uint8>(expected, *GetOutput(0));
-}
-
-TEST_F(RandomCropOpTest, SameSizeMultiChannel) {
-  AddInputFromArray<uint8>(TensorShape({2, 1, 3}), {1, 2, 3, 4, 5, 6});
-  AddInputFromArray<int64>(TensorShape({2}), {2, 1});
-  TF_ASSERT_OK(RunOpKernel());
-
-  Tensor expected(allocator(), DT_UINT8, TensorShape({2, 1, 3}));
-  test::FillValues<uint8>(&expected, {1, 2, 3, 4, 5, 6});
-  test::ExpectTensorEqual<uint8>(expected, *GetOutput(0));
-}
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 9ec1aa9b7f2..39cf5b06e9b 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -61,8 +61,9 @@ limitations under the License.
 // 5. Graphs are wholly-validated during Session::Create() (7jan2016).
 // 6. TensorFlow is scalar strict within Google (27jan2016).
 // 7. Remove TopK in favor of TopKV2 (5feb2016).
+// 8. Replace RandomCrop from C++ with pure Python (5feb2016).
 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 7
+#define TF_GRAPH_DEF_VERSION 8
 
 #endif  // TENSORFLOW_CORE_PUBLIC_VERSION_H_
diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py
index ffe8facd27d..f7d7083d735 100644
--- a/tensorflow/models/image/cifar10/cifar10_input.py
+++ b/tensorflow/models/image/cifar10/cifar10_input.py
@@ -161,7 +161,7 @@ def distorted_inputs(data_dir, batch_size):
   # distortions applied to the image.
 
   # Randomly crop a [height, width] section of the image.
-  distorted_image = tf.image.random_crop(reshaped_image, [height, width])
+  distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
 
   # Randomly flip the image horizontally.
   distorted_image = tf.image.random_flip_left_right(distorted_image)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 979897c985a..ee4dae35453 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -509,6 +509,7 @@ tf_gen_op_wrapper_py(
 tf_gen_op_wrapper_py(
     name = "image_ops",
     hidden = [
+        "RandomCrop",
         "ResizeBilinearGrad",
         "ResizeNearestNeighborGrad",
         "AdjustContrastv2",
diff --git a/tensorflow/python/kernel_tests/random_crop_test.py b/tensorflow/python/kernel_tests/random_crop_test.py
new file mode 100644
index 00000000000..b682b22b88b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/random_crop_test.py
@@ -0,0 +1,77 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for random_crop."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class RandomCropTest(tf.test.TestCase):
+
+  def testNoOp(self):
+    # No random cropping is performed since the size is value.shape.
+    for shape in (2, 1, 1), (2, 1, 3), (4, 5, 3):
+      value = np.arange(0, np.prod(shape), dtype=np.int32).reshape(shape)
+      with self.test_session():
+        crop = tf.random_crop(value, shape).eval()
+        self.assertAllEqual(crop, value)
+
+  def testContains(self):
+    with self.test_session():
+      shape = (3, 5, 7)
+      target = (2, 3, 4)
+      value = np.random.randint(1000000, size=shape)
+      value_set = set(tuple(value[i:i + 2, j:j + 3, k:k + 4].ravel())
+                      for i in range(2) for j in range(3) for k in range(4))
+      crop = tf.random_crop(value, size=target)
+      for _ in range(20):
+        y = crop.eval()
+        self.assertAllEqual(y.shape, target)
+        self.assertTrue(tuple(y.ravel()) in value_set)
+
+  def testRandomization(self):
+    # Run 1x1 crop num_samples times in an image and ensure that one finds each
+    # pixel 1/size of the time.
+    num_samples = 1000
+    shape = [5, 4, 1]
+    size = np.prod(shape)
+    single = [1, 1, 1]
+    value = np.arange(size).reshape(shape)
+
+    with self.test_session():
+      crop = tf.random_crop(value, single, seed=7)
+      counts = np.zeros(size, dtype=np.int32)
+      for _ in range(num_samples):
+        y = crop.eval()
+        self.assertAllEqual(y.shape, single)
+        counts[y] += 1
+
+    # Calculate the mean and 4 * standard deviation.
+    mean = np.repeat(num_samples / size, size)
+    four_stddev = 4.0 * np.sqrt(mean)
+
+    # Ensure that each entry is observed in 1/size of the samples
+    # within 4 standard deviations.
+    self.assertAllClose(counts, mean, atol=four_stddev)
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py
index a5f1a9c8ee1..5d04389c731 100644
--- a/tensorflow/python/ops/constant_op.py
+++ b/tensorflow/python/ops/constant_op.py
@@ -91,6 +91,7 @@ print(sess.run(var))
 @@truncated_normal
 @@random_uniform
 @@random_shuffle
+@@random_crop
 @@set_random_seed
 
 """
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 2eeef95d99d..d278b53d709 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -70,7 +70,6 @@ resized_image = tf.image.resize_images(image, 299, 299)
 
 @@pad_to_bounding_box
 @@crop_to_bounding_box
-@@random_crop
 @@extract_glimpse
 
 ## Flipping and Transposing
@@ -156,7 +155,6 @@ import tensorflow.python.platform
 
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
@@ -832,45 +830,6 @@ def _ImageEncodeShape(op):
   return [tensor_shape.scalar()]
 
 
-@ops.RegisterShape('RandomCrop')
-def _random_cropShape(op):
-  """Shape function for the random_crop op."""
-  input_shape = op.inputs[0].get_shape().with_rank(3)
-  unused_size_shape = op.inputs[1].get_shape().merge_with(
-      tensor_shape.vector(2))
-  size = tensor_util.constant_value(op.inputs[1])
-  if size is not None:
-    height = size[0]
-    width = size[1]
-  else:
-    height = None
-    width = None
-  channels = input_shape[2]
-  return [tensor_shape.TensorShape([height, width, channels])]
-
-
-def random_crop(image, size, seed=None, name=None):
-  """Randomly crops `image` to size `[target_height, target_width]`.
-
-  The offset of the output within `image` is uniformly random. `image` always
-  fully contains the result.
-
-  Args:
-    image: 3-D tensor of shape `[height, width, channels]`
-    size: 1-D tensor with two elements, specifying target `[height, width]`
-    seed: A Python integer. Used to create a random seed. See
-      [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
-      for behavior.
-    name: A name for this operation (optional).
-
-  Returns:
-    A cropped 3-D tensor of shape `[target_height, target_width, channels]`.
-  """
-  seed1, seed2 = random_seed.get_seed(seed)
-  return gen_image_ops.random_crop(image, size, seed=seed1, seed2=seed2,
-                                   name=name)
-
-
 def saturate_cast(image, dtype):
   """Performs a safe cast of image data to `dtype`.
 
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index d09004556ca..bb19902960c 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -443,57 +443,6 @@ class AdjustBrightnessTest(test_util.TensorFlowTestCase):
     self._testBrightness(x_np, y_np, delta=-10. / 255.)
 
 
-class RandomCropTest(test_util.TensorFlowTestCase):
-
-  def testNoOp(self):
-    # No random cropping is performed since the target width and height
-    # are match the image dimensions.
-    height = 4
-    width = 5
-    x_shape = [height, width, 3]
-    x_np = np.arange(0, np.prod(x_shape), dtype=np.int32).reshape(x_shape)
-    target_shape_np = np.array([height, width], dtype=np.int64)
-
-    with self.test_session():
-      x = constant_op.constant(x_np, shape=x_shape)
-      target_shape = constant_op.constant(target_shape_np, shape=[2])
-      y = image_ops.random_crop(x, target_shape)
-      y_tf = y.eval()
-      self.assertAllEqual(y_tf, x_np)
-
-  def testRandomization(self):
-    # Run 1x1 crop num_samples times in an image and ensure that one finds each
-    # pixel 1/num_pixels of the time.
-    num_samples = 1000
-    height = 5
-    width = 4
-
-    num_pixels = height * width
-    data = np.arange(num_pixels).reshape([height, width, 1])
-    x_np = np.array(data).astype(np.int32)
-
-    target_shape_np = np.array([1, 1], dtype=np.int64)
-
-    y = []
-    with self.test_session():
-      x = constant_op.constant(x_np, shape=x_np.shape)
-      target_shape = constant_op.constant(target_shape_np, shape=[2])
-      y_tf = image_ops.random_crop(x, target_shape)
-      for _ in xrange(num_samples):
-        y_np = y_tf.eval()
-        self.assertAllEqual(y_np.shape, [1, 1, 1])
-        y.extend(y_np.flatten())
-
-    # Calculate the mean and 4 * standard deviation.
-    mean = [num_samples / num_pixels] * num_pixels
-    four_stddev = 4.0 * np.sqrt(mean)
-
-    # Ensure that each entry is observed in 1/num_pixels of the samples
-    # within 4 standard deviations.
-    counts = np.bincount(y)
-    self.assertAllClose(counts, mean, atol=four_stddev)
-
-
 class PerImageWhiteningTest(test_util.TensorFlowTestCase):
 
   def _NumpyPerImageWhitening(self, x):
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 1416e4a8c14..15cfcebc291 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -24,8 +24,11 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_random_ops
+from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
 # pylint: disable=wildcard-import
 from tensorflow.python.ops.gen_random_ops import *
@@ -209,6 +212,43 @@ def random_shuffle(value, seed=None, name=None):
                                         name=name)
 
 
+def random_crop(value, size, seed=None, name=None):
+  """Randomly crops a tensor to a given size.
+
+  Slices a shape `size` portion out of `value` at a uniformly chosen offset.
+  Requires `value.shape >= size`.
+
+  If a dimension should not be cropped, pass the full size of that dimension.
+  For example, RGB images can be cropped with
+  `size = [crop_height, crop_width, 3]`.
+
+  Args:
+    value: Input tensor to crop.
+    size: 1-D tensor with size the rank of `value`.
+    seed: Python integer. Used to create a random seed. See
+      [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
+      for behavior.
+    name: A name for this operation (optional).
+
+  Returns:
+    A cropped tensor of the same rank as `value` and shape `size`.
+  """
+  # TODO(shlens): Implement edge case to guarantee output size dimensions.
+  # If size > value.shape, zero pad the result so that it always has shape
+  # exactly size.
+  with ops.op_scope([value, size], name, "random_crop") as name:
+    value = ops.convert_to_tensor(value, name="value")
+    size = ops.convert_to_tensor(size, dtype=dtypes.int32, name="size")
+    shape = array_ops.shape(value)
+    check = logging_ops.Assert(math_ops.reduce_all(shape >= size),
+                               ["Need value.shape >= size, got ", shape, size])
+    shape = control_flow_ops.with_dependencies([check], shape)
+    limit = shape - size + 1
+    offset = random_uniform(array_ops.shape(shape), dtype=size.dtype,
+                            maxval=size.dtype.max, seed=seed) % limit
+    return array_ops.slice(value, offset, size, name=name)
+
+
 ops.NoGradient("RandomUniform")