From af39a076c2fa3c3ebad5fd71fe2f8f3e63934c98 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Mon, 23 Sep 2019 20:28:13 -0700 Subject: [PATCH] [tfdbg] Check numerics callback fix: Replace _op with for multi-input ops - Some ops (e.g., tf.pad()) checks the `.op.inputs` property of resulting tensors. Some of these ops involve >1 inputs, which is broken by the fact that the CheckNumerics op that instrument the tensor has only one input. - Add a unit test with `tf.pad` to cover the case. - Exclude v1 control flow ops. - Add a unit test with the keras LSTM layer under the v1 graph-mode op to cover the case. PiperOrigin-RevId: 270819564 --- .../debug/lib/check_numerics_callback.py | 5 ++ .../debug/lib/check_numerics_callback_test.py | 18 ++++++ .../python/framework/op_callbacks_test.py | 64 +++++++++++++++++-- tensorflow/python/ops/array_ops.py | 25 +++++++- 4 files changed, 104 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/debug/lib/check_numerics_callback.py b/tensorflow/python/debug/lib/check_numerics_callback.py index fc2124af2f2..5260015850f 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback.py +++ b/tensorflow/python/debug/lib/check_numerics_callback.py @@ -50,9 +50,14 @@ def limit_string_length(string, max_len=50): _CHECK_NUMERICS_CALLBACK_SKIP_OPS = ( # TODO(b/139668453): The following skipped ops are related to a limitation # in the op callback. + b"Enter", + b"Exit", b"Identity", b"If", + b"Merge", + b"NextIteration", b"StatelessIf", + b"Switch", b"While", ) diff --git a/tensorflow/python/debug/lib/check_numerics_callback_test.py b/tensorflow/python/debug/lib/check_numerics_callback_test.py index 457d7b292dc..3f565d21c36 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback_test.py +++ b/tensorflow/python/debug/lib/check_numerics_callback_test.py @@ -271,6 +271,24 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase): history = model.fit(xs, ys, epochs=epochs, verbose=0) self.assertEqual(len(history.history["loss"]), epochs) + @test_util.run_in_graph_and_eager_modes + def testKerasModelWithRNNHealthyPredictAndFitCalls(self): + """Test a simple healthy keras recurrent model works under the callback.""" + check_numerics_callback.enable_check_numerics() + + model = models.Sequential() + model.add(layers.LSTM(1, input_shape=(2, 4))) + model.compile(loss="mse", optimizer="rmsprop") + + xs = np.zeros([8, 2, 4], dtype=np.float32) + ys = np.zeros([8, 1], dtype=np.float32) + + model.predict(xs) + + epochs = 3 + history = model.fit(xs, ys, epochs=epochs, verbose=0) + self.assertEqual(len(history.history["loss"]), epochs) + @test_util.run_in_graph_and_eager_modes def testKerasModelUnhealthyPredictAndFitCallsWithLargeLearningRate(self): """Test keras model training crashes with Infinity is caught by callback.""" diff --git a/tensorflow/python/framework/op_callbacks_test.py b/tensorflow/python/framework/op_callbacks_test.py index 0027c74ba7f..c55b9720a3b 100644 --- a/tensorflow/python/framework/op_callbacks_test.py +++ b/tensorflow/python/framework/op_callbacks_test.py @@ -49,13 +49,17 @@ _ADD_OP = b"AddV2" _ASSIGN_ADD_VARIABLE_OP = b"AssignAddVariableOp" _CONSTANT_OP = b"Const" _COS_OP = b"Cos" +_ENTER_OP = b"Enter" +_EXIT_OP = b"Exit" _GREATER_OP = b"Greater" _IDENTITY_OP = b"Identity" _IF_OP = b"If" _LESS_OP = b"Less" _LOG_OP = b"Log" +_MERGE_OP = b"Merge" _MATMUL_OP = b"MatMul" _MUL_OP = b"Mul" +_NEXT_ITERATION_OP = b"NextIteration" _PLACEHOLDER_OP = b"Placeholder" _POW_OP = b"Pow" _READ_VARIALBE_OP = b"ReadVariableOp" @@ -64,6 +68,7 @@ _SPARSE_TENSOR_DENSE_MATMUL_OP = b"SparseTensorDenseMatMul" _SQRT_OP = b"Sqrt" _SQUARE_OP = b"Square" _STATELESS_IF_OP = b"StatelessIf" +_SWTICH_OP = b"Switch" _UNIQUE_OP = b"Unique" _VAR_HANDLE_OP = b"VarHandleOp" _WHILE_OP = b"While" @@ -71,8 +76,9 @@ _WHILE_OP = b"While" class _NumpyFunctionCallback(object): - def __init__(self, instrument_graph_ops=True): + def __init__(self, instrument_graph_ops=True, float_only=False): self.instrument_graph_ops = instrument_graph_ops + self._float_only = float_only self.reset() def callback(self, op_type, inputs, attrs, outputs, op_name=None, graph=None): @@ -102,7 +108,8 @@ class _NumpyFunctionCallback(object): instrumented_outputs = [] for output in outputs: if compat.as_bytes(op_type) in ( - _IF_OP, _STATELESS_IF_OP, _WHILE_OP, _IDENTITY_OP, + _ENTER_OP, _EXIT_OP, _IF_OP, _MERGE_OP, _NEXT_ITERATION_OP, + _STATELESS_IF_OP, _SWTICH_OP, _WHILE_OP, _IDENTITY_OP, _VAR_HANDLE_OP): # TODO(cais): Overriding the output of StatelessIf, If and While ops # currently fails with error. Investigate (b/139668453). @@ -110,7 +117,6 @@ class _NumpyFunctionCallback(object): # by tf.function/AutoGraph for marshalling outputs. instrumented_output = output else: - def record(ndarray_value): if compat.as_bytes(op_name) not in self.graph_internal_ndarrays: self.graph_internal_ndarrays[compat.as_bytes(op_name)] = [] @@ -118,9 +124,12 @@ class _NumpyFunctionCallback(object): ndarray_value) return ndarray_value - instrumented_output = script_ops.numpy_function( - record, [output], output.dtype) - instrumented_output.set_shape(output.shape) + if self._float_only and not output.dtype.is_floating: + instrumented_output = output + else: + instrumented_output = script_ops.numpy_function( + record, [output], output.dtype) + instrumented_output.set_shape(output.shape) instrumented_outputs.append(instrumented_output) return instrumented_outputs @@ -397,6 +406,49 @@ class OpCallbacksTest(test_util.TensorFlowTestCase): self.assertEqual(len(log_op_outputs), 1) self.assertAllClose(log_op_outputs[0], [0.0, np.log(2.0)]) + @test_util.run_in_graph_and_eager_modes + def testPadOp(self): + instrument = _NumpyFunctionCallback() + + op_callbacks.add_op_callback(instrument.callback) + + @def_function.function + def my_pad(x, padding): + return array_ops.pad(x, padding) + + x = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32) + paddings = [[1, 1], [2, 2]] + + y = my_pad(x, paddings) + expected_output = np.array([ + [0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 0, 0], + [0, 0, 3, 4, 0, 0], + [0, 0, 0, 0, 0, 0], + ], dtype=np.float32) + self.assertAllClose(y, expected_output) + self.assertAllClose( + instrument.graph_internal_ndarrays[b"Pad"][0], expected_output) + + @test_util.run_in_graph_and_eager_modes + def testKerasLSTMPredict(self): + instrument = _NumpyFunctionCallback(float_only=True) + + op_callbacks.add_op_callback(instrument.callback) + + model = keras.Sequential() + model.add(keras.layers.LSTM(1, input_shape=(2, 4))) + model.compile(loss="mse", optimizer="sgd") + + xs = np.zeros([8, 2, 4], dtype=np.float32) + ys = model.predict(xs) + + self.assertAllClose(ys, np.zeros([8, 1])) + # We avoid asserting on the internal details of the LSTM implementation. + # Instead, we just assert that some graph-internal execution states are + # recorded by the callback. + self.assertTrue(instrument.graph_internal_ndarrays) + @test_util.run_in_graph_and_eager_modes def testSimpleGraphConstructionWithCallbackReturningNone(self): """Test that callbacks that return None works.""" diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 3084efeaee7..94a5a1b0674 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2952,8 +2952,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl # Restore shape information where possible. if not context.executing_eagerly(): - paddings_constant = tensor_util.constant_value( - result.op.inputs[1], partial=True) + paddings_constant = _get_paddings_constant(paddings) input_shape = result.op.inputs[0].shape if (input_shape.ndims is not None and not result.shape.is_fully_defined() and paddings_constant is not None): @@ -2968,6 +2967,28 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl return result +def _get_paddings_constant(paddings): + """Helper to get the constant values of the paddings arg to pad(). + + Used under V1 graph mode to facilitate computation of the shape of the output + tensor of `pad()`. + + Args: + paddings: The same paddings arg as passed to pad(). Can be a Tensor, or + a nested list or tuple of Tensor and/or numbers. + + Returns: + A nested list or numbers or `None`, in which `None` indicates unknown + padding size. + """ + if isinstance(paddings, ops.Tensor): + return tensor_util.constant_value(paddings, partial=True) + elif isinstance(paddings, (list, tuple)): + return [_get_paddings_constant(x) for x in paddings] + else: + return paddings + + @tf_export("meshgrid") def meshgrid(*args, **kwargs): """Broadcasts parameters for evaluation on an N-D grid.