Change tf_upgrade_v2.py script to wrap labels argument with tf.stop_gradients in
tf.softmax_cross_entropy_with_logits call. PiperOrigin-RevId: 228260462
This commit is contained in:
parent
b9762fed18
commit
a253b9eab5
@ -69,6 +69,7 @@ py_library(
|
||||
":ast_edits",
|
||||
":renames_v2",
|
||||
":reorders_v2",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -65,6 +65,7 @@ reorders = {
|
||||
'tf.nn.moments': ['x', 'axes', 'shift', 'name', 'keep_dims'],
|
||||
'tf.nn.pool': ['input', 'window_shape', 'pooling_type', 'padding', 'dilation_rate', 'strides', 'name', 'data_format'],
|
||||
'tf.nn.separable_conv2d': ['input', 'depthwise_filter', 'pointwise_filter', 'strides', 'padding', 'rate', 'name', 'data_format'],
|
||||
'tf.nn.softmax_cross_entropy_with_logits': ['_sentinel', 'labels', 'logits', 'dim', 'name'],
|
||||
'tf.nn.space_to_batch': ['input', 'paddings', 'block_size', 'name'],
|
||||
'tf.nn.space_to_depth': ['input', 'block_size', 'name', 'data_format'],
|
||||
'tf.nn.weighted_moments': ['x', 'axes', 'frequency_weights', 'name', 'keep_dims'],
|
||||
|
@ -70,6 +70,15 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
[0],
|
||||
tf.argmin([[1, 3, 2]], name='abc', dimension=1))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testSoftmaxCrossEntropyWithLogits(self):
|
||||
out = tf.nn.softmax_cross_entropy_with_logits(
|
||||
logits=[0.1, 0.8], labels=[0, 1])
|
||||
self.assertAllClose(out, 0.40318608)
|
||||
out = tf.nn.softmax_cross_entropy_with_logits_v2(
|
||||
logits=[0.1, 0.8], labels=[0, 1])
|
||||
self.assertAllClose(out, 0.40318608)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import ast
|
||||
|
||||
import pasta
|
||||
import six
|
||||
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import renames_v2
|
||||
@ -94,6 +95,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.convert_to_tensor": {
|
||||
"preferred_dtype": "dtype_hint"
|
||||
},
|
||||
"tf.nn.softmax_cross_entropy_with_logits": {
|
||||
"dim": "axis",
|
||||
"_sentinel": None,
|
||||
},
|
||||
"tf.nn.softmax_cross_entropy_with_logits_v2": {
|
||||
"dim": "axis"
|
||||
},
|
||||
@ -682,6 +687,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.norm",
|
||||
"tf.reverse_sequence",
|
||||
"tf.sparse_split",
|
||||
# tf.nn.softmax_cross_entropy_with_logits *must* be called with
|
||||
# keyword arguments. Add keyword arguments in rare case when they
|
||||
# are not specified.
|
||||
"tf.nn.softmax_cross_entropy_with_logits",
|
||||
}
|
||||
|
||||
# Functions that were reordered should be changed to the new keyword args
|
||||
@ -714,6 +723,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.to_float": self._cast_transformer,
|
||||
"tf.to_int32": self._cast_transformer,
|
||||
"tf.to_int64": self._cast_transformer,
|
||||
"tf.nn.softmax_cross_entropy_with_logits":
|
||||
self._softmax_cross_entropy_with_logits_transformer,
|
||||
}
|
||||
|
||||
decay_function_comment = (
|
||||
@ -950,10 +961,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"'deterministic' arguments. Now it takes a single 'seed' arg. If "
|
||||
"'seed' is zero, the execution is random and deterministic "
|
||||
"otherwise",
|
||||
"tf.nn.softmax_cross_entropy_with_logits":
|
||||
"tf.nn.softmax_cross_entropy_with_logits behavior has changed. "
|
||||
"'labels' needs to be wrapped with tf.stop_gradient to keep the "
|
||||
"old behavior. Also, 'dim' argument has been renamed to 'axis'.",
|
||||
"tf.test.assert_equal_graph_def":
|
||||
"tf.assert_equal_graph_def no longer takes 'checkpoint_v2' "
|
||||
"argument. 'checkpoint_v2' now defaults to True.",
|
||||
@ -1228,6 +1235,36 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
dtype_str)))
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def _softmax_cross_entropy_with_logits_transformer(
|
||||
parent, node, full_name, name, logs, errors):
|
||||
def _wrap_label(parent, old_value):
|
||||
"""Wrap labels with tf.stop_gradient."""
|
||||
if six.PY3:
|
||||
new_value = ast.Call(
|
||||
ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
|
||||
[old_value], [])
|
||||
else:
|
||||
new_value = ast.Call(
|
||||
ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
|
||||
[old_value], [], None, None)
|
||||
|
||||
# This copies the prefix and suffix on old_value to new_value.
|
||||
pasta.ast_utils.replace_child(parent, old_value, new_value)
|
||||
ast.copy_location(new_value, old_value)
|
||||
|
||||
# Check if we have a labels keyword arg
|
||||
for karg in node.keywords:
|
||||
if karg.arg == "labels":
|
||||
logs.append((node.lineno, node.col_offset,
|
||||
"Changing labels arg of "
|
||||
"tf.nn.softmax_cross_entropy_with_logits to "
|
||||
"tf.stop_gradient(labels). Please check this "
|
||||
"transformation.\n"))
|
||||
_wrap_label(karg, karg.value)
|
||||
return node
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def _batch_gather_transformer(parent, node, full_name, name, logs, errors):
|
||||
# Check if the call already has a batch_dims argument
|
||||
|
@ -684,26 +684,32 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
def testSoftMaxCrossEntropyWithLogitsV2(self):
|
||||
text = "tf.nn.softmax_cross_entropy_with_logits_v2(labels, logits, dim=2)"
|
||||
text = (
|
||||
"tf.nn.softmax_cross_entropy_with_logits_v2("
|
||||
"labels=labels, logits=logits, dim=2)")
|
||||
expected_text = (
|
||||
"tf.nn.softmax_cross_entropy_with_logits(labels, logits, axis=2)")
|
||||
"tf.nn.softmax_cross_entropy_with_logits("
|
||||
"labels=labels, logits=logits, axis=2)")
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
self.assertFalse(errors)
|
||||
|
||||
def testSoftMaxCrossEntropyWithLogits(self):
|
||||
text = "tf.nn.softmax_cross_entropy_with_logits(labels, logits, dim=2)"
|
||||
text = ("tf.nn.softmax_cross_entropy_with_logits("
|
||||
"labels=labels, logits=logits, dim=2)")
|
||||
expected_text = (
|
||||
"tf.nn.softmax_cross_entropy_with_logits(labels, logits, dim=2)")
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
"tf.nn.softmax_cross_entropy_with_logits("
|
||||
"labels=tf.stop_gradient(labels), logits=logits, axis=2)")
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn(
|
||||
"tf.nn.softmax_cross_entropy_with_logits requires manual check",
|
||||
errors[0])
|
||||
self.assertIn(
|
||||
"tf.nn.softmax_cross_entropy_with_logits behavior has changed. ",
|
||||
report)
|
||||
|
||||
text = ("tf.nn.softmax_cross_entropy_with_logits("
|
||||
"labels=foo(bar))")
|
||||
expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
|
||||
"labels=tf.stop_gradient(foo(bar)))")
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testSparseMatmul(self):
|
||||
text = ("tf.sparse_matmul(a, b, c, d, e, f, g)\n")
|
||||
|
Loading…
Reference in New Issue
Block a user