Add a few tests for upgrade script, fix tf.argmin/tf.argmax to rename

dimension to axis.

PiperOrigin-RevId: 223061849
This commit is contained in:
Anna R 2018-11-27 14:47:43 -08:00 committed by TensorFlower Gardener
parent a69210ede1
commit 891bf2d9fa
4 changed files with 79 additions and 4 deletions

View File

@ -124,6 +124,16 @@ genrule(
tools = [":tf_upgrade_v2"],
)
py_test(
name = "test_file_v1_10",
size = "small",
srcs = ["testdata/test_file_v1_10.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "test_file_v2_0",
size = "small",

View File

@ -25,10 +25,47 @@ from tensorflow.python.platform import test as test_lib
class TestUpgrade(test_util.TensorFlowTestCase):
"""Test various APIs that have been changed in 2.0."""
def setUp(self):
tf.enable_eager_execution()
def testRenames(self):
with self.cached_session():
self.assertAllClose(1.04719755, tf.acos(0.5).eval())
self.assertAllClose(0.5, tf.rsqrt(4.0).eval())
self.assertAllClose(1.04719755, tf.acos(0.5))
self.assertAllClose(0.5, tf.rsqrt(4.0))
def testSerializeSparseTensor(self):
sp_input = tf.SparseTensor(
indices=tf.constant([[1]], dtype=tf.int64),
values=tf.constant([2], dtype=tf.int64),
dense_shape=[2])
with self.cached_session():
serialized_sp = tf.serialize_sparse(sp_input, 'serialize_name', tf.string)
self.assertEqual((3,), serialized_sp.shape)
self.assertTrue(serialized_sp[0].numpy()) # check non-empty
def testSerializeManySparse(self):
sp_input = tf.SparseTensor(
indices=tf.constant([[0, 1]], dtype=tf.int64),
values=tf.constant([2], dtype=tf.int64),
dense_shape=[1, 2])
with self.cached_session():
serialized_sp = tf.serialize_many_sparse(
sp_input, 'serialize_name', tf.string)
self.assertEqual((1, 3), serialized_sp.shape)
def testArgMaxMin(self):
self.assertAllClose(
[1],
tf.argmax([[1, 3, 2]], name='abc', dimension=1))
self.assertAllClose(
[0, 0, 0],
tf.argmax([[1, 3, 2]], dimension=0))
self.assertAllClose(
[0],
tf.argmin([[1, 3, 2]], name='abc', dimension=1))
if __name__ == "__main__":
test_lib.main()

View File

@ -31,6 +31,12 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Maps from a function name to a dictionary that describes how to
# map from an old argument keyword to the new argument keyword.
self.function_keyword_renames = {
"tf.argmin": {
"dimension": "axis",
},
"tf.argmax": {
"dimension": "axis",
},
"tf.image.crop_and_resize": {
"box_ind": "box_indices",
},
@ -408,8 +414,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
self.function_reorders = {
"tf.io.serialize_sparse": ["sp_input", "name", "out_type"],
"tf.io.serialize_many_sparse": ["sp_input", "name", "out_type"],
"tf.argmax": ["input", "axis", "name", "dimension", "output_type"],
"tf.argmin": ["input", "axis", "name", "dimension", "output_type"],
"tf.argmax": ["input", "axis", "name", "axis", "output_type"],
"tf.argmin": ["input", "axis", "name", "axis", "output_type"],
"tf.batch_to_space": ["input", "crops", "block_size", "name"],
"tf.boolean_mask": ["tensor", "mask", "name", "axis"],
"tf.convert_to_tensor": ["value", "dtype", "name", "preferred_dtype"],

View File

@ -219,6 +219,28 @@ class TestUpgrade(test_util.TensorFlowTestCase):
"rename the function export_savedmodel() to export_saved_model()",
report)
def testArgmin(self):
text = "tf.argmin(input, name=n, dimension=1, output_type=type)"
expected_text = "tf.argmin(input=input, name=n, axis=1, output_type=type)"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
text = "tf.argmin(input, 0)"
expected_text = "tf.argmin(input=input, axis=0)"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
def testArgmax(self):
text = "tf.argmax(input, name=n, dimension=1, output_type=type)"
expected_text = "tf.argmax(input=input, name=n, axis=1, output_type=type)"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
text = "tf.argmax(input, 0)"
expected_text = "tf.argmax(input=input, axis=0)"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
class TestUpgradeFiles(test_util.TensorFlowTestCase):