Add a few tests for upgrade script, fix tf.argmin/tf.argmax to rename
dimension to axis. PiperOrigin-RevId: 223061849
This commit is contained in:
parent
a69210ede1
commit
891bf2d9fa
tensorflow/tools/compatibility
@ -124,6 +124,16 @@ genrule(
|
|||||||
tools = [":tf_upgrade_v2"],
|
tools = [":tf_upgrade_v2"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_file_v1_10",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["testdata/test_file_v1_10.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_file_v2_0",
|
name = "test_file_v2_0",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -25,10 +25,47 @@ from tensorflow.python.platform import test as test_lib
|
|||||||
class TestUpgrade(test_util.TensorFlowTestCase):
|
class TestUpgrade(test_util.TensorFlowTestCase):
|
||||||
"""Test various APIs that have been changed in 2.0."""
|
"""Test various APIs that have been changed in 2.0."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
tf.enable_eager_execution()
|
||||||
|
|
||||||
def testRenames(self):
|
def testRenames(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllClose(1.04719755, tf.acos(0.5).eval())
|
self.assertAllClose(1.04719755, tf.acos(0.5))
|
||||||
self.assertAllClose(0.5, tf.rsqrt(4.0).eval())
|
self.assertAllClose(0.5, tf.rsqrt(4.0))
|
||||||
|
|
||||||
|
def testSerializeSparseTensor(self):
|
||||||
|
sp_input = tf.SparseTensor(
|
||||||
|
indices=tf.constant([[1]], dtype=tf.int64),
|
||||||
|
values=tf.constant([2], dtype=tf.int64),
|
||||||
|
dense_shape=[2])
|
||||||
|
|
||||||
|
with self.cached_session():
|
||||||
|
serialized_sp = tf.serialize_sparse(sp_input, 'serialize_name', tf.string)
|
||||||
|
self.assertEqual((3,), serialized_sp.shape)
|
||||||
|
self.assertTrue(serialized_sp[0].numpy()) # check non-empty
|
||||||
|
|
||||||
|
def testSerializeManySparse(self):
|
||||||
|
sp_input = tf.SparseTensor(
|
||||||
|
indices=tf.constant([[0, 1]], dtype=tf.int64),
|
||||||
|
values=tf.constant([2], dtype=tf.int64),
|
||||||
|
dense_shape=[1, 2])
|
||||||
|
|
||||||
|
with self.cached_session():
|
||||||
|
serialized_sp = tf.serialize_many_sparse(
|
||||||
|
sp_input, 'serialize_name', tf.string)
|
||||||
|
self.assertEqual((1, 3), serialized_sp.shape)
|
||||||
|
|
||||||
|
def testArgMaxMin(self):
|
||||||
|
self.assertAllClose(
|
||||||
|
[1],
|
||||||
|
tf.argmax([[1, 3, 2]], name='abc', dimension=1))
|
||||||
|
self.assertAllClose(
|
||||||
|
[0, 0, 0],
|
||||||
|
tf.argmax([[1, 3, 2]], dimension=0))
|
||||||
|
self.assertAllClose(
|
||||||
|
[0],
|
||||||
|
tf.argmin([[1, 3, 2]], name='abc', dimension=1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_lib.main()
|
test_lib.main()
|
||||||
|
@ -31,6 +31,12 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
# Maps from a function name to a dictionary that describes how to
|
# Maps from a function name to a dictionary that describes how to
|
||||||
# map from an old argument keyword to the new argument keyword.
|
# map from an old argument keyword to the new argument keyword.
|
||||||
self.function_keyword_renames = {
|
self.function_keyword_renames = {
|
||||||
|
"tf.argmin": {
|
||||||
|
"dimension": "axis",
|
||||||
|
},
|
||||||
|
"tf.argmax": {
|
||||||
|
"dimension": "axis",
|
||||||
|
},
|
||||||
"tf.image.crop_and_resize": {
|
"tf.image.crop_and_resize": {
|
||||||
"box_ind": "box_indices",
|
"box_ind": "box_indices",
|
||||||
},
|
},
|
||||||
@ -408,8 +414,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
self.function_reorders = {
|
self.function_reorders = {
|
||||||
"tf.io.serialize_sparse": ["sp_input", "name", "out_type"],
|
"tf.io.serialize_sparse": ["sp_input", "name", "out_type"],
|
||||||
"tf.io.serialize_many_sparse": ["sp_input", "name", "out_type"],
|
"tf.io.serialize_many_sparse": ["sp_input", "name", "out_type"],
|
||||||
"tf.argmax": ["input", "axis", "name", "dimension", "output_type"],
|
"tf.argmax": ["input", "axis", "name", "axis", "output_type"],
|
||||||
"tf.argmin": ["input", "axis", "name", "dimension", "output_type"],
|
"tf.argmin": ["input", "axis", "name", "axis", "output_type"],
|
||||||
"tf.batch_to_space": ["input", "crops", "block_size", "name"],
|
"tf.batch_to_space": ["input", "crops", "block_size", "name"],
|
||||||
"tf.boolean_mask": ["tensor", "mask", "name", "axis"],
|
"tf.boolean_mask": ["tensor", "mask", "name", "axis"],
|
||||||
"tf.convert_to_tensor": ["value", "dtype", "name", "preferred_dtype"],
|
"tf.convert_to_tensor": ["value", "dtype", "name", "preferred_dtype"],
|
||||||
|
@ -219,6 +219,28 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
|||||||
"rename the function export_savedmodel() to export_saved_model()",
|
"rename the function export_savedmodel() to export_saved_model()",
|
||||||
report)
|
report)
|
||||||
|
|
||||||
|
def testArgmin(self):
|
||||||
|
text = "tf.argmin(input, name=n, dimension=1, output_type=type)"
|
||||||
|
expected_text = "tf.argmin(input=input, name=n, axis=1, output_type=type)"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
|
||||||
|
text = "tf.argmin(input, 0)"
|
||||||
|
expected_text = "tf.argmin(input=input, axis=0)"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
|
||||||
|
def testArgmax(self):
|
||||||
|
text = "tf.argmax(input, name=n, dimension=1, output_type=type)"
|
||||||
|
expected_text = "tf.argmax(input=input, name=n, axis=1, output_type=type)"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
|
||||||
|
text = "tf.argmax(input, 0)"
|
||||||
|
expected_text = "tf.argmax(input=input, axis=0)"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
|
||||||
|
|
||||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user