[tfdbg] Check numerics callback fix: Replace _op with for multi-input ops

- Some ops (e.g., tf.pad()) checks the `.op.inputs` property of resulting tensors.
  Some of these ops involve >1 inputs, which is broken by the fact that the
  CheckNumerics op that instrument the tensor has only one input.
  - Add a unit test with `tf.pad` to cover the case.
- Exclude v1 control flow ops.
  - Add a unit test with the keras LSTM layer under the v1 graph-mode  op to cover the case.

PiperOrigin-RevId: 270819564
This commit is contained in:
Shanqing Cai 2019-09-23 20:28:13 -07:00 committed by TensorFlower Gardener
parent c9c2fcaf4c
commit af39a076c2
4 changed files with 104 additions and 8 deletions

View File

@ -50,9 +50,14 @@ def limit_string_length(string, max_len=50):
_CHECK_NUMERICS_CALLBACK_SKIP_OPS = (
# TODO(b/139668453): The following skipped ops are related to a limitation
# in the op callback.
b"Enter",
b"Exit",
b"Identity",
b"If",
b"Merge",
b"NextIteration",
b"StatelessIf",
b"Switch",
b"While",
)

View File

@ -271,6 +271,24 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase):
history = model.fit(xs, ys, epochs=epochs, verbose=0)
self.assertEqual(len(history.history["loss"]), epochs)
@test_util.run_in_graph_and_eager_modes
def testKerasModelWithRNNHealthyPredictAndFitCalls(self):
"""Test a simple healthy keras recurrent model works under the callback."""
check_numerics_callback.enable_check_numerics()
model = models.Sequential()
model.add(layers.LSTM(1, input_shape=(2, 4)))
model.compile(loss="mse", optimizer="rmsprop")
xs = np.zeros([8, 2, 4], dtype=np.float32)
ys = np.zeros([8, 1], dtype=np.float32)
model.predict(xs)
epochs = 3
history = model.fit(xs, ys, epochs=epochs, verbose=0)
self.assertEqual(len(history.history["loss"]), epochs)
@test_util.run_in_graph_and_eager_modes
def testKerasModelUnhealthyPredictAndFitCallsWithLargeLearningRate(self):
"""Test keras model training crashes with Infinity is caught by callback."""

View File

@ -49,13 +49,17 @@ _ADD_OP = b"AddV2"
_ASSIGN_ADD_VARIABLE_OP = b"AssignAddVariableOp"
_CONSTANT_OP = b"Const"
_COS_OP = b"Cos"
_ENTER_OP = b"Enter"
_EXIT_OP = b"Exit"
_GREATER_OP = b"Greater"
_IDENTITY_OP = b"Identity"
_IF_OP = b"If"
_LESS_OP = b"Less"
_LOG_OP = b"Log"
_MERGE_OP = b"Merge"
_MATMUL_OP = b"MatMul"
_MUL_OP = b"Mul"
_NEXT_ITERATION_OP = b"NextIteration"
_PLACEHOLDER_OP = b"Placeholder"
_POW_OP = b"Pow"
_READ_VARIALBE_OP = b"ReadVariableOp"
@ -64,6 +68,7 @@ _SPARSE_TENSOR_DENSE_MATMUL_OP = b"SparseTensorDenseMatMul"
_SQRT_OP = b"Sqrt"
_SQUARE_OP = b"Square"
_STATELESS_IF_OP = b"StatelessIf"
_SWTICH_OP = b"Switch"
_UNIQUE_OP = b"Unique"
_VAR_HANDLE_OP = b"VarHandleOp"
_WHILE_OP = b"While"
@ -71,8 +76,9 @@ _WHILE_OP = b"While"
class _NumpyFunctionCallback(object):
def __init__(self, instrument_graph_ops=True):
def __init__(self, instrument_graph_ops=True, float_only=False):
self.instrument_graph_ops = instrument_graph_ops
self._float_only = float_only
self.reset()
def callback(self, op_type, inputs, attrs, outputs, op_name=None, graph=None):
@ -102,7 +108,8 @@ class _NumpyFunctionCallback(object):
instrumented_outputs = []
for output in outputs:
if compat.as_bytes(op_type) in (
_IF_OP, _STATELESS_IF_OP, _WHILE_OP, _IDENTITY_OP,
_ENTER_OP, _EXIT_OP, _IF_OP, _MERGE_OP, _NEXT_ITERATION_OP,
_STATELESS_IF_OP, _SWTICH_OP, _WHILE_OP, _IDENTITY_OP,
_VAR_HANDLE_OP):
# TODO(cais): Overriding the output of StatelessIf, If and While ops
# currently fails with error. Investigate (b/139668453).
@ -110,7 +117,6 @@ class _NumpyFunctionCallback(object):
# by tf.function/AutoGraph for marshalling outputs.
instrumented_output = output
else:
def record(ndarray_value):
if compat.as_bytes(op_name) not in self.graph_internal_ndarrays:
self.graph_internal_ndarrays[compat.as_bytes(op_name)] = []
@ -118,6 +124,9 @@ class _NumpyFunctionCallback(object):
ndarray_value)
return ndarray_value
if self._float_only and not output.dtype.is_floating:
instrumented_output = output
else:
instrumented_output = script_ops.numpy_function(
record, [output], output.dtype)
instrumented_output.set_shape(output.shape)
@ -397,6 +406,49 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
self.assertEqual(len(log_op_outputs), 1)
self.assertAllClose(log_op_outputs[0], [0.0, np.log(2.0)])
@test_util.run_in_graph_and_eager_modes
def testPadOp(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def my_pad(x, padding):
return array_ops.pad(x, padding)
x = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
paddings = [[1, 1], [2, 2]]
y = my_pad(x, paddings)
expected_output = np.array([
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 0, 0],
[0, 0, 3, 4, 0, 0],
[0, 0, 0, 0, 0, 0],
], dtype=np.float32)
self.assertAllClose(y, expected_output)
self.assertAllClose(
instrument.graph_internal_ndarrays[b"Pad"][0], expected_output)
@test_util.run_in_graph_and_eager_modes
def testKerasLSTMPredict(self):
instrument = _NumpyFunctionCallback(float_only=True)
op_callbacks.add_op_callback(instrument.callback)
model = keras.Sequential()
model.add(keras.layers.LSTM(1, input_shape=(2, 4)))
model.compile(loss="mse", optimizer="sgd")
xs = np.zeros([8, 2, 4], dtype=np.float32)
ys = model.predict(xs)
self.assertAllClose(ys, np.zeros([8, 1]))
# We avoid asserting on the internal details of the LSTM implementation.
# Instead, we just assert that some graph-internal execution states are
# recorded by the callback.
self.assertTrue(instrument.graph_internal_ndarrays)
@test_util.run_in_graph_and_eager_modes
def testSimpleGraphConstructionWithCallbackReturningNone(self):
"""Test that callbacks that return None works."""

View File

@ -2952,8 +2952,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl
# Restore shape information where possible.
if not context.executing_eagerly():
paddings_constant = tensor_util.constant_value(
result.op.inputs[1], partial=True)
paddings_constant = _get_paddings_constant(paddings)
input_shape = result.op.inputs[0].shape
if (input_shape.ndims is not None and
not result.shape.is_fully_defined() and paddings_constant is not None):
@ -2968,6 +2967,28 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl
return result
def _get_paddings_constant(paddings):
"""Helper to get the constant values of the paddings arg to pad().
Used under V1 graph mode to facilitate computation of the shape of the output
tensor of `pad()`.
Args:
paddings: The same paddings arg as passed to pad(). Can be a Tensor, or
a nested list or tuple of Tensor and/or numbers.
Returns:
A nested list or numbers or `None`, in which `None` indicates unknown
padding size.
"""
if isinstance(paddings, ops.Tensor):
return tensor_util.constant_value(paddings, partial=True)
elif isinstance(paddings, (list, tuple)):
return [_get_paddings_constant(x) for x in paddings]
else:
return paddings
@tf_export("meshgrid")
def meshgrid(*args, **kwargs):
"""Broadcasts parameters for evaluation on an N-D grid.