Merge pull request #45613 from yongtang:45324-tf.image.central_crop
PiperOrigin-RevId: 357599670 Change-Id: I4789db092979eaea931f18733ae5dbe88a356fbf
This commit is contained in:
commit
d40c29a9f9
@ -901,10 +901,17 @@ def central_crop(image, central_fraction):
|
||||
"""
|
||||
with ops.name_scope(None, 'central_crop', [image]):
|
||||
image = ops.convert_to_tensor(image, name='image')
|
||||
if central_fraction <= 0.0 or central_fraction > 1.0:
|
||||
raise ValueError('central_fraction must be within (0, 1]')
|
||||
if central_fraction == 1.0:
|
||||
return image
|
||||
central_fraction_static = tensor_util.constant_value(central_fraction)
|
||||
if central_fraction_static is not None:
|
||||
if central_fraction_static <= 0.0 or central_fraction_static > 1.0:
|
||||
raise ValueError('central_fraction must be within (0, 1]')
|
||||
if central_fraction_static == 1.0:
|
||||
return image
|
||||
else:
|
||||
assert_ops = _assert(
|
||||
math_ops.logical_or(central_fraction > 0.0, central_fraction <= 1.0),
|
||||
ValueError, 'central_fraction must be within (0, 1]')
|
||||
image = control_flow_ops.with_dependencies(assert_ops, image)
|
||||
|
||||
_AssertAtLeast3DImage(image)
|
||||
rank = image.get_shape().ndims
|
||||
@ -932,24 +939,29 @@ def central_crop(image, central_fraction):
|
||||
img_w, dynamic_w = _get_dim(image, 2)
|
||||
img_d = image.get_shape()[3]
|
||||
|
||||
dynamic_h = dynamic_h or (central_fraction_static is None)
|
||||
dynamic_w = dynamic_w or (central_fraction_static is None)
|
||||
|
||||
# Compute the bounding boxes for the crop. The type and value of the
|
||||
# bounding boxes depend on the `image` tensor's rank and whether / not the
|
||||
# dimensions are statically defined.
|
||||
if dynamic_h:
|
||||
img_hd = math_ops.cast(img_h, dtypes.float64)
|
||||
bbox_h_start = math_ops.cast((img_hd - img_hd * central_fraction) / 2,
|
||||
dtypes.int32)
|
||||
bbox_h_start = math_ops.cast(
|
||||
(img_hd - img_hd * math_ops.cast(central_fraction, dtypes.float64)) /
|
||||
2, dtypes.int32)
|
||||
else:
|
||||
img_hd = float(img_h)
|
||||
bbox_h_start = int((img_hd - img_hd * central_fraction) / 2)
|
||||
bbox_h_start = int((img_hd - img_hd * central_fraction_static) / 2)
|
||||
|
||||
if dynamic_w:
|
||||
img_wd = math_ops.cast(img_w, dtypes.float64)
|
||||
bbox_w_start = math_ops.cast((img_wd - img_wd * central_fraction) / 2,
|
||||
dtypes.int32)
|
||||
bbox_w_start = math_ops.cast(
|
||||
(img_wd - img_wd * math_ops.cast(central_fraction, dtypes.float64)) /
|
||||
2, dtypes.int32)
|
||||
else:
|
||||
img_wd = float(img_w)
|
||||
bbox_w_start = int((img_wd - img_wd * central_fraction) / 2)
|
||||
bbox_w_start = int((img_wd - img_wd * central_fraction_static) / 2)
|
||||
|
||||
bbox_h_size = img_h - bbox_h_start * 2
|
||||
bbox_w_size = img_w - bbox_w_start * 2
|
||||
|
@ -2003,6 +2003,21 @@ class CentralCropTest(test_util.TensorFlowTestCase):
|
||||
y = image_ops.central_crop(x_np, 1.0)
|
||||
self.assertTrue(y.op.name.startswith("central_crop"))
|
||||
|
||||
def testCentralFractionTensor(self):
|
||||
# Test case for GitHub issue 45324.
|
||||
x_shape = [240, 320, 3]
|
||||
y_shape = [80, 106, 3]
|
||||
|
||||
@def_function.function(autograph=False)
|
||||
def f(x, central_fraction):
|
||||
return image_ops.central_crop(x, central_fraction)
|
||||
|
||||
x_np = np.zeros(x_shape, dtype=np.int32)
|
||||
y_np = np.zeros(y_shape, dtype=np.int32)
|
||||
y_tf = self.evaluate(f(x_np, constant_op.constant(0.33)))
|
||||
self.assertAllEqual(y_tf, y_np)
|
||||
self.assertAllEqual(y_tf.shape, y_np.shape)
|
||||
|
||||
|
||||
class PadToBoundingBoxTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user