merge fix
This commit is contained in:
parent
9dd8e7aec9
commit
7467f37092
@ -27,7 +27,6 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
_DTYPES = set(
|
||||
@ -186,29 +185,5 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
|
||||
self._test_grad([3, 4, 12, 12])
|
||||
|
||||
|
||||
def _test_grad(self, shape_to_test):
|
||||
with self.test_session():
|
||||
test_image_shape = shape_to_test
|
||||
test_image = np.random.randn(*test_image_shape)
|
||||
test_image_tensor = constant_op.constant(test_image,
|
||||
shape=test_image_shape)
|
||||
test_transform = image_ops.angles_to_projective_transforms(np.pi / 2,
|
||||
4,
|
||||
4)
|
||||
test_transform_shape = test_transform.shape
|
||||
|
||||
output_shape = test_image_shape
|
||||
output = image_ops.transform(test_image_tensor, test_transform)
|
||||
left_err = gradient_checker.compute_gradient_error(
|
||||
test_image_tensor, test_image_shape, output, output_shape,
|
||||
x_init_value=test_image)
|
||||
self.assertLess(left_err, 1e-10)
|
||||
|
||||
def test_grad(self):
|
||||
self._test_grad([16, 16])
|
||||
self._test_grad([4, 12, 12])
|
||||
self._test_grad([3, 4, 12, 12])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user