Change tf_upgrade_v2.py script to wrap labels argument with tf.stop_gradients in

tf.softmax_cross_entropy_with_logits call.

PiperOrigin-RevId: 228260462
This commit is contained in:
Anna R 2019-01-07 16:54:22 -08:00 committed by TensorFlower Gardener
parent b9762fed18
commit a253b9eab5
5 changed files with 69 additions and 15 deletions

View File

@ -69,6 +69,7 @@ py_library(
":ast_edits",
":renames_v2",
":reorders_v2",
"@six_archive//:six",
],
)

View File

@ -65,6 +65,7 @@ reorders = {
'tf.nn.moments': ['x', 'axes', 'shift', 'name', 'keep_dims'],
'tf.nn.pool': ['input', 'window_shape', 'pooling_type', 'padding', 'dilation_rate', 'strides', 'name', 'data_format'],
'tf.nn.separable_conv2d': ['input', 'depthwise_filter', 'pointwise_filter', 'strides', 'padding', 'rate', 'name', 'data_format'],
'tf.nn.softmax_cross_entropy_with_logits': ['_sentinel', 'labels', 'logits', 'dim', 'name'],
'tf.nn.space_to_batch': ['input', 'paddings', 'block_size', 'name'],
'tf.nn.space_to_depth': ['input', 'block_size', 'name', 'data_format'],
'tf.nn.weighted_moments': ['x', 'axes', 'frequency_weights', 'name', 'keep_dims'],

View File

@ -70,6 +70,15 @@ class TestUpgrade(test_util.TensorFlowTestCase):
[0],
tf.argmin([[1, 3, 2]], name='abc', dimension=1))
@test_util.run_v1_only("b/120545219")
def testSoftmaxCrossEntropyWithLogits(self):
out = tf.nn.softmax_cross_entropy_with_logits(
logits=[0.1, 0.8], labels=[0, 1])
self.assertAllClose(out, 0.40318608)
out = tf.nn.softmax_cross_entropy_with_logits_v2(
logits=[0.1, 0.8], labels=[0, 1])
self.assertAllClose(out, 0.40318608)
if __name__ == "__main__":
test_lib.main()

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import ast
import pasta
import six
from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import renames_v2
@ -94,6 +95,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.convert_to_tensor": {
"preferred_dtype": "dtype_hint"
},
"tf.nn.softmax_cross_entropy_with_logits": {
"dim": "axis",
"_sentinel": None,
},
"tf.nn.softmax_cross_entropy_with_logits_v2": {
"dim": "axis"
},
@ -682,6 +687,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.norm",
"tf.reverse_sequence",
"tf.sparse_split",
# tf.nn.softmax_cross_entropy_with_logits *must* be called with
# keyword arguments. Add keyword arguments in rare case when they
# are not specified.
"tf.nn.softmax_cross_entropy_with_logits",
}
# Functions that were reordered should be changed to the new keyword args
@ -714,6 +723,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.to_float": self._cast_transformer,
"tf.to_int32": self._cast_transformer,
"tf.to_int64": self._cast_transformer,
"tf.nn.softmax_cross_entropy_with_logits":
self._softmax_cross_entropy_with_logits_transformer,
}
decay_function_comment = (
@ -950,10 +961,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"'deterministic' arguments. Now it takes a single 'seed' arg. If "
"'seed' is zero, the execution is random and deterministic "
"otherwise",
"tf.nn.softmax_cross_entropy_with_logits":
"tf.nn.softmax_cross_entropy_with_logits behavior has changed. "
"'labels' needs to be wrapped with tf.stop_gradient to keep the "
"old behavior. Also, 'dim' argument has been renamed to 'axis'.",
"tf.test.assert_equal_graph_def":
"tf.assert_equal_graph_def no longer takes 'checkpoint_v2' "
"argument. 'checkpoint_v2' now defaults to True.",
@ -1228,6 +1235,36 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
dtype_str)))
return node
@staticmethod
def _softmax_cross_entropy_with_logits_transformer(
parent, node, full_name, name, logs, errors):
def _wrap_label(parent, old_value):
"""Wrap labels with tf.stop_gradient."""
if six.PY3:
new_value = ast.Call(
ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
[old_value], [])
else:
new_value = ast.Call(
ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
[old_value], [], None, None)
# This copies the prefix and suffix on old_value to new_value.
pasta.ast_utils.replace_child(parent, old_value, new_value)
ast.copy_location(new_value, old_value)
# Check if we have a labels keyword arg
for karg in node.keywords:
if karg.arg == "labels":
logs.append((node.lineno, node.col_offset,
"Changing labels arg of "
"tf.nn.softmax_cross_entropy_with_logits to "
"tf.stop_gradient(labels). Please check this "
"transformation.\n"))
_wrap_label(karg, karg.value)
return node
return node
@staticmethod
def _batch_gather_transformer(parent, node, full_name, name, logs, errors):
# Check if the call already has a batch_dims argument

View File

@ -684,26 +684,32 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
self.assertEqual(new_text, expected_text)
def testSoftMaxCrossEntropyWithLogitsV2(self):
text = "tf.nn.softmax_cross_entropy_with_logits_v2(labels, logits, dim=2)"
text = (
"tf.nn.softmax_cross_entropy_with_logits_v2("
"labels=labels, logits=logits, dim=2)")
expected_text = (
"tf.nn.softmax_cross_entropy_with_logits(labels, logits, axis=2)")
"tf.nn.softmax_cross_entropy_with_logits("
"labels=labels, logits=logits, axis=2)")
_, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
self.assertFalse(errors)
def testSoftMaxCrossEntropyWithLogits(self):
text = "tf.nn.softmax_cross_entropy_with_logits(labels, logits, dim=2)"
text = ("tf.nn.softmax_cross_entropy_with_logits("
"labels=labels, logits=logits, dim=2)")
expected_text = (
"tf.nn.softmax_cross_entropy_with_logits(labels, logits, dim=2)")
_, report, errors, new_text = self._upgrade(text)
"tf.nn.softmax_cross_entropy_with_logits("
"labels=tf.stop_gradient(labels), logits=logits, axis=2)")
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
self.assertIn(
"tf.nn.softmax_cross_entropy_with_logits requires manual check",
errors[0])
self.assertIn(
"tf.nn.softmax_cross_entropy_with_logits behavior has changed. ",
report)
text = ("tf.nn.softmax_cross_entropy_with_logits("
"labels=foo(bar))")
expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
"labels=tf.stop_gradient(foo(bar)))")
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
def testSparseMatmul(self):
text = ("tf.sparse_matmul(a, b, c, d, e, f, g)\n")