Details of the changes: - In the Python API of tensorflow, Const ops are created by calling `_create_op_internal()` from constant_op.py. This differs from how most other ops are created, and is similar to Placeholder ops, which are already instrumented by tfdbg2' op_callbacks. In this CL, we add a op_callback hook to the code in constant_op.py to allow instrumentation of Const ops. that makes that call. - In `_ConstantValue()` in tensor_util.py, add a special case for `CheckNumericsV2` op, so the `constant_value()` does not treat the `CheckNumericsV2` op as the constant tensor value. Similarly, add special cases for `Identity` and `DebugIdentityV2`. - In `dumping_callback_test.py`, replace use of a deprecated Dataset API (`make_one_shot_iterator()`) with non-deprecated API (`iter()` and `next()`) - Make other necessary changes to tfdbg2's tests to accommodate the Const ops which were previously not instrumented, but are now. - Increase the shard_count of learning/brain/python/debug/tpu_callbacks_test.py to 6 to avoid timeouts under the instrumented number of instrumented ops. PiperOrigin-RevId: 307723353 Change-Id: Iecdbfcb439f6e04fc12c1503ad5339d42703e8bc
846 lines
31 KiB
Python
846 lines
31 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Unit tests for op_callback."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import threading
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.eager import backprop
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import test
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import op_callbacks
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import script_ops
|
|
from tensorflow.python.ops import sparse_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.util import compat
|
|
|
|
|
|
# Keep all the hard-coded op type strings in one place so they are easy to
|
|
# change all at once in the face of any possible future op type name changes.
|
|
_ADD_OP = b"AddV2"
|
|
_ASSIGN_ADD_VARIABLE_OP = b"AssignAddVariableOp"
|
|
_CONSTANT_OP = b"Const"
|
|
_COS_OP = b"Cos"
|
|
_ENTER_OP = b"Enter"
|
|
_EXIT_OP = b"Exit"
|
|
_GREATER_OP = b"Greater"
|
|
_IDENTITY_OP = b"Identity"
|
|
_IF_OP = b"If"
|
|
_LESS_OP = b"Less"
|
|
_LOG_OP = b"Log"
|
|
_MERGE_OP = b"Merge"
|
|
_MATMUL_OP = b"MatMul"
|
|
_MUL_OP = b"Mul"
|
|
_NEXT_ITERATION_OP = b"NextIteration"
|
|
_PLACEHOLDER_OP = b"Placeholder"
|
|
_POW_OP = b"Pow"
|
|
_READ_VARIABLE_OP = b"ReadVariableOp"
|
|
_SIN_OP = b"Sin"
|
|
_SPARSE_TENSOR_DENSE_MATMUL_OP = b"SparseTensorDenseMatMul"
|
|
_SQRT_OP = b"Sqrt"
|
|
_SQUARE_OP = b"Square"
|
|
_STATELESS_IF_OP = b"StatelessIf"
|
|
_SWITCH_OP = b"Switch"
|
|
_UNIQUE_OP = b"Unique"
|
|
_VAR_HANDLE_OP = b"VarHandleOp"
|
|
_WHILE_OP = b"While"
|
|
|
|
|
|
class _NumpyFunctionCallback(object):
|
|
|
|
def __init__(self, instrument_graph_ops=True, float_only=False):
|
|
self.instrument_graph_ops = instrument_graph_ops
|
|
self._float_only = float_only
|
|
self.reset()
|
|
|
|
def callback(self, op_type, inputs, attrs, outputs, op_name=None, graph=None):
|
|
is_eager = not graph
|
|
if is_eager:
|
|
self.eager_op_types.append(
|
|
compat.as_bytes(op_type) if op_type else op_type)
|
|
self.eager_op_names.append(
|
|
compat.as_bytes(op_name) if op_name else op_name)
|
|
self.eager_attrs.append(attrs)
|
|
self.eager_graphs.append(graph)
|
|
self.eager_inputs.append(inputs)
|
|
else:
|
|
self.graph_op_types.append(
|
|
compat.as_bytes(op_type) if op_type else op_type)
|
|
self.graph_op_names.append(
|
|
compat.as_bytes(op_name) if op_name else op_name)
|
|
self.graph_attrs.append(attrs)
|
|
self.graph_graphs.append(graph)
|
|
self.graph_graph_versions.append(graph.version)
|
|
self.graph_inputs.append(inputs)
|
|
|
|
if not self.instrument_graph_ops:
|
|
return outputs
|
|
|
|
# Instrument the graph with numpy_function.
|
|
instrumented_outputs = []
|
|
for output in outputs:
|
|
if compat.as_bytes(op_type) in (_ENTER_OP, _EXIT_OP, _IF_OP, _MERGE_OP,
|
|
_NEXT_ITERATION_OP, _STATELESS_IF_OP,
|
|
_SWITCH_OP, _WHILE_OP, _IDENTITY_OP,
|
|
_VAR_HANDLE_OP, _PLACEHOLDER_OP,
|
|
_CONSTANT_OP):
|
|
# TODO(cais): Overriding the output of StatelessIf, If and While ops
|
|
# currently fails with error. Investigate (b/139668453).
|
|
# Avoid instrumenting Identity ops as well, as they are inserted
|
|
# by tf.function/AutoGraph for marshalling outputs.
|
|
instrumented_output = output
|
|
else:
|
|
def record(ndarray_value):
|
|
if compat.as_bytes(op_name) not in self.graph_internal_ndarrays:
|
|
self.graph_internal_ndarrays[compat.as_bytes(op_name)] = []
|
|
self.graph_internal_ndarrays[compat.as_bytes(op_name)].append(
|
|
ndarray_value)
|
|
return ndarray_value
|
|
|
|
if self._float_only and not output.dtype.is_floating:
|
|
instrumented_output = output
|
|
else:
|
|
instrumented_output = script_ops.numpy_function(
|
|
record, [output], output.dtype)
|
|
instrumented_output.set_shape(output.shape)
|
|
instrumented_outputs.append(instrumented_output)
|
|
|
|
return instrumented_outputs
|
|
|
|
def reset(self):
|
|
self.eager_op_types = []
|
|
self.eager_op_names = []
|
|
self.eager_attrs = []
|
|
self.eager_graphs = []
|
|
self.eager_inputs = []
|
|
self.graph_op_types = []
|
|
self.graph_op_names = []
|
|
self.graph_attrs = []
|
|
self.graph_graphs = []
|
|
self.graph_graph_versions = []
|
|
self.graph_inputs = []
|
|
|
|
# A dict mapping tensor name (e.g., "MatMut_10") to a list of ndarrays.
|
|
# The list is the history of the tensor's computation result inside
|
|
# `tf.Graph`s (`FuncGraph`s).
|
|
# For an op with multiple output tensors, the outputs are interleaved in
|
|
# the list.
|
|
self.graph_internal_ndarrays = {}
|
|
|
|
|
|
class OpCallbacksTest(test_util.TensorFlowTestCase):
|
|
|
|
def tearDown(self):
|
|
op_callbacks.clear_op_callbacks()
|
|
super(OpCallbacksTest, self).tearDown()
|
|
|
|
def testSingleThreadedStack(self):
|
|
ctx = context.context()
|
|
instrument_0 = _NumpyFunctionCallback()
|
|
instrument_1 = _NumpyFunctionCallback()
|
|
|
|
op_callbacks.add_op_callback(instrument_0.callback)
|
|
self.assertEqual(1, len(ctx.op_callbacks))
|
|
self.assertIn(instrument_0.callback, ctx.op_callbacks)
|
|
|
|
op_callbacks.add_op_callback(instrument_1.callback)
|
|
self.assertEqual(2, len(ctx.op_callbacks))
|
|
self.assertIn(instrument_0.callback, ctx.op_callbacks)
|
|
self.assertIn(instrument_1.callback, ctx.op_callbacks)
|
|
|
|
op_callbacks.remove_op_callback(instrument_1.callback)
|
|
self.assertEqual(1, len(ctx.op_callbacks))
|
|
self.assertIn(instrument_0.callback, ctx.op_callbacks)
|
|
|
|
op_callbacks.remove_op_callback(instrument_0.callback)
|
|
self.assertEqual(0, len(ctx.op_callbacks))
|
|
|
|
def testMultiThreadedStacks(self):
|
|
# Instrument for the main thread.
|
|
instrument_0 = _NumpyFunctionCallback()
|
|
|
|
# Instrument for the to-be-created thread.
|
|
instrument_1 = _NumpyFunctionCallback()
|
|
|
|
def thread1_job():
|
|
op_callbacks.add_op_callback(instrument_1.callback)
|
|
|
|
@def_function.function
|
|
def func1(x):
|
|
return math_ops.sqrt(math_ops.log(x))
|
|
|
|
x = constant_op.constant(4.0)
|
|
self.assertAllClose(func1(x), np.sqrt(np.log(4.0)))
|
|
|
|
thread1 = threading.Thread(target=thread1_job)
|
|
|
|
# Start job on separate thread.
|
|
thread1.start()
|
|
|
|
# Run something on the main thread.
|
|
op_callbacks.add_op_callback(instrument_0.callback)
|
|
|
|
@def_function.function
|
|
def func0(x):
|
|
return math_ops.square(math_ops.sin(x))
|
|
|
|
x = constant_op.constant(4.0)
|
|
self.assertAllClose(func0(x), np.square(np.sin(4.0)))
|
|
|
|
thread1.join()
|
|
|
|
# Assert that there is no cross-talk between the main thread
|
|
# and the created thread.
|
|
self.assertIn(_PLACEHOLDER_OP, instrument_1.graph_op_types)
|
|
self.assertIn(_LOG_OP, instrument_1.graph_op_types)
|
|
self.assertIn(_SQRT_OP, instrument_1.graph_op_types)
|
|
self.assertNotIn(_SIN_OP, instrument_1.graph_op_types)
|
|
self.assertNotIn(_SQUARE_OP, instrument_1.graph_op_types)
|
|
|
|
self.assertNotIn(_LOG_OP, instrument_0.graph_op_types)
|
|
self.assertNotIn(_SQRT_OP, instrument_0.graph_op_types)
|
|
self.assertIn(_SIN_OP, instrument_0.graph_op_types)
|
|
self.assertIn(_SQUARE_OP, instrument_0.graph_op_types)
|
|
|
|
def testEagerOpExecution(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
x = constant_op.constant(6.0)
|
|
y = math_ops.square(math_ops.log(x))
|
|
self.assertAllClose(y, np.square(np.log(6.0)))
|
|
|
|
self.assertEqual(instrument.eager_op_types, [_LOG_OP, _SQUARE_OP])
|
|
# Op names are unavailable under eager mode.
|
|
self.assertEqual(instrument.eager_op_names, [None, None])
|
|
self.assertEqual(instrument.eager_graphs, [None, None])
|
|
self.assertEqual(len(instrument.eager_inputs), 2)
|
|
self.assertEqual(len(instrument.eager_inputs[0]), 1)
|
|
self.assertIsInstance(instrument.eager_inputs[0], tuple)
|
|
self.assertEqual(instrument.eager_inputs[0][0], x)
|
|
self.assertEqual(len(instrument.eager_inputs[1]), 1)
|
|
self.assertIsInstance(instrument.eager_inputs[1], tuple)
|
|
self.assertAllClose(instrument.eager_inputs[1][0], np.log(6.0))
|
|
self.assertFalse(instrument.graph_op_types)
|
|
self.assertFalse(instrument.graph_op_names)
|
|
self.assertFalse(instrument.graph_attrs)
|
|
self.assertFalse(instrument.graph_graphs)
|
|
self.assertFalse(instrument.graph_inputs)
|
|
|
|
def testMultiThreadedEagerOpExecution(self):
|
|
# Instrument for the main thread.
|
|
instrument_0 = _NumpyFunctionCallback()
|
|
|
|
# Instrument for the to-be-created thread.
|
|
instrument_1 = _NumpyFunctionCallback()
|
|
|
|
def thread_1_job():
|
|
op_callbacks.add_op_callback(instrument_1.callback)
|
|
x = constant_op.constant(6.0)
|
|
y = math_ops.square(math_ops.log(x))
|
|
op_callbacks.remove_op_callback(instrument_1.callback)
|
|
return y
|
|
|
|
thread_1 = threading.Thread(target=thread_1_job)
|
|
thread_1.start()
|
|
|
|
# While thread_1 is ongoing, do something on the main thread.
|
|
op_callbacks.add_op_callback(instrument_0.callback)
|
|
x = constant_op.constant(2.0)
|
|
y = math_ops.cos(x)
|
|
self.assertAllClose(y, np.cos(2.0))
|
|
op_callbacks.remove_op_callback(instrument_0.callback)
|
|
|
|
thread_1.join()
|
|
|
|
self.assertEqual(instrument_0.eager_op_types, [_COS_OP])
|
|
self.assertEqual(instrument_0.eager_op_names, [None])
|
|
self.assertEqual(instrument_1.eager_op_types, [_LOG_OP, _SQUARE_OP])
|
|
self.assertEqual(instrument_1.eager_op_names, [None, None])
|
|
|
|
def testEagerFunctionExecution(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
|
|
@def_function.function
|
|
def square_log(x):
|
|
return math_ops.square(math_ops.log(x))
|
|
|
|
# Call the function once, so that the graph construction won't show up
|
|
# in the callback.
|
|
x_float32 = constant_op.constant(6.0, dtype=dtypes.float32)
|
|
x_float64 = constant_op.constant(6.0, dtype=dtypes.float64)
|
|
square_log(x_float32)
|
|
square_log(x_float64)
|
|
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
y = square_log(x_float32)
|
|
self.assertAllClose(y, np.square(np.log(6.0)))
|
|
y = square_log(x_float64)
|
|
self.assertAllClose(y, np.square(np.log(6.0)))
|
|
|
|
self.assertEqual(instrument.eager_op_names, [None, None])
|
|
self.assertFalse(instrument.graph_op_types)
|
|
self.assertFalse(instrument.graph_op_names)
|
|
self.assertFalse(instrument.graph_inputs)
|
|
|
|
# Each of the two dtypes should be associated with its own FuncGraph.
|
|
self.assertIn(
|
|
square_log.get_concrete_function(x_float32).name,
|
|
instrument.eager_op_types)
|
|
self.assertIn(
|
|
square_log.get_concrete_function(x_float64).name,
|
|
instrument.eager_op_types)
|
|
|
|
self.assertEqual(len(instrument.eager_inputs), 2)
|
|
self.assertIsInstance(instrument.eager_inputs[0], tuple)
|
|
self.assertEqual(instrument.eager_inputs[0][0], x_float32)
|
|
self.assertIsInstance(instrument.eager_inputs[1], tuple)
|
|
self.assertEqual(instrument.eager_inputs[1][0], x_float64)
|
|
|
|
def testMultiThreadedEagerFunctionExecution(self):
|
|
# Instrument for the main thread.
|
|
instrument_0 = _NumpyFunctionCallback()
|
|
|
|
# Instrument for the to-be-created thread.
|
|
instrument_1 = _NumpyFunctionCallback()
|
|
|
|
@def_function.function
|
|
def square_log(x):
|
|
return math_ops.square(math_ops.log(x))
|
|
|
|
# Call the function once, so that the graph construction won't show up
|
|
# in the callback.
|
|
x_float32 = constant_op.constant(6.0, dtype=dtypes.float32)
|
|
x_float64 = constant_op.constant(6.0, dtype=dtypes.float64)
|
|
square_log(x_float32)
|
|
square_log(x_float64)
|
|
|
|
def thread_1_job():
|
|
op_callbacks.add_op_callback(instrument_1.callback)
|
|
square_log(x_float32)
|
|
|
|
thread_1 = threading.Thread(target=thread_1_job)
|
|
thread_1.start()
|
|
|
|
# In the meantime, run some computation on the main thread.
|
|
op_callbacks.add_op_callback(instrument_0.callback)
|
|
square_log(x_float64)
|
|
|
|
thread_1.join()
|
|
|
|
# Each of the two dtypes should be associated with its own FuncGraph.
|
|
self.assertIn(
|
|
square_log.get_concrete_function(x_float64).name,
|
|
instrument_0.eager_op_types)
|
|
self.assertEqual(instrument_0.eager_op_names, [None])
|
|
self.assertFalse(instrument_0.graph_op_types)
|
|
self.assertIn(
|
|
square_log.get_concrete_function(x_float32).name,
|
|
instrument_1.eager_op_types)
|
|
self.assertEqual(instrument_1.eager_op_names, [None])
|
|
self.assertFalse(instrument_1.graph_op_types)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testSimpleGraphConstructionScopeOutsideFunction(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
@def_function.function
|
|
def log_2plus_unique_x(x):
|
|
unique_values, unique_pos = array_ops.unique(x)
|
|
return math_ops.log(2.0 + unique_values), unique_pos
|
|
|
|
x = constant_op.constant([-1.0, -1.0, 0.0], dtype=dtypes.float32)
|
|
y1, y2 = log_2plus_unique_x(x)
|
|
self.assertAllClose(y1, [0.0, np.log(2.0)])
|
|
self.assertAllClose(y2, [0, 0, 1])
|
|
|
|
self.assertIn(_UNIQUE_OP, instrument.graph_op_types)
|
|
self.assertIn(_ADD_OP, instrument.graph_op_types)
|
|
self.assertIn(_LOG_OP, instrument.graph_op_types)
|
|
self.assertEqual(
|
|
len(instrument.graph_op_names), len(instrument.graph_op_types))
|
|
|
|
# Check the graph internal ndarrays recorded at runtime.
|
|
unique_op_outputs = instrument.graph_internal_ndarrays[_UNIQUE_OP]
|
|
if context.executing_eagerly():
|
|
# b/140810696: The run_in_graph_and_eager_modes decorator runs
|
|
# Session.run() twice. We can't assert on the number of outputs in
|
|
# that case.
|
|
self.assertEqual(len(unique_op_outputs), 2)
|
|
self.assertAllClose(unique_op_outputs[0], [-1.0, 0.0])
|
|
self.assertAllClose(unique_op_outputs[1], [0, 0, 1])
|
|
add_op_outputs = instrument.graph_internal_ndarrays[b"add"]
|
|
if context.executing_eagerly():
|
|
self.assertEqual(len(add_op_outputs), 1)
|
|
self.assertAllClose(add_op_outputs[0], [1.0, 2.0])
|
|
log_op_outputs = instrument.graph_internal_ndarrays[_LOG_OP]
|
|
if context.executing_eagerly():
|
|
self.assertEqual(len(log_op_outputs), 1)
|
|
self.assertAllClose(log_op_outputs[0], [0.0, np.log(2.0)])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testPadOp(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
@def_function.function
|
|
def my_pad(x, padding):
|
|
return array_ops.pad(x, padding)
|
|
|
|
x = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
|
|
paddings = [[1, 1], [2, 2]]
|
|
|
|
y = my_pad(x, paddings)
|
|
expected_output = np.array([
|
|
[0, 0, 0, 0, 0, 0],
|
|
[0, 0, 1, 2, 0, 0],
|
|
[0, 0, 3, 4, 0, 0],
|
|
[0, 0, 0, 0, 0, 0],
|
|
], dtype=np.float32)
|
|
self.assertAllClose(y, expected_output)
|
|
self.assertAllClose(
|
|
instrument.graph_internal_ndarrays[b"Pad"][0], expected_output)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testSimpleGraphConstructionWithCallbackReturningNone(self):
|
|
"""Test that callbacks that return None works."""
|
|
op_types = []
|
|
def no_return_callback(op_type,
|
|
inputs,
|
|
attrs,
|
|
outputs,
|
|
op_name=None,
|
|
graph=None):
|
|
del inputs, attrs, outputs, op_name, graph # Unused.
|
|
op_types.append(compat.as_bytes(op_type))
|
|
|
|
op_callbacks.add_op_callback(no_return_callback)
|
|
|
|
@def_function.function
|
|
def log1p(x):
|
|
return math_ops.log(1.0 + x)
|
|
x = constant_op.constant(3.0)
|
|
y = log1p(x)
|
|
|
|
self.assertAllClose(y, np.log(4.0))
|
|
self.assertIn(_ADD_OP, op_types)
|
|
self.assertIn(_LOG_OP, op_types)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testGraphConstructionInputsAndGraphAreCapturedCorrectly(self):
|
|
instrument = _NumpyFunctionCallback(instrument_graph_ops=False)
|
|
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
@def_function.function
|
|
def log_2plus_unique_x(x):
|
|
unique_values, unique_pos = array_ops.unique(x)
|
|
return math_ops.log(2.0 + unique_values), unique_pos
|
|
|
|
x = constant_op.constant([-1.0, -1.0, 0.0], dtype=dtypes.float32)
|
|
y1, y2 = log_2plus_unique_x(x)
|
|
self.assertAllClose(y1, [0.0, np.log(2.0)])
|
|
self.assertAllClose(y2, [0, 0, 1])
|
|
|
|
# Check the recorded input tensors.
|
|
self.assertEqual(
|
|
len(instrument.graph_inputs), len(instrument.graph_op_types))
|
|
unique_inputs = instrument.graph_inputs[instrument.graph_op_types.index(
|
|
_UNIQUE_OP)]
|
|
self.assertIsInstance(unique_inputs, tuple)
|
|
self.assertEqual(len(unique_inputs), 1)
|
|
self.assertEqual(
|
|
compat.as_bytes(unique_inputs[0].op.op_def.name), _PLACEHOLDER_OP)
|
|
|
|
add_inputs = instrument.graph_inputs[instrument.graph_op_types.index(
|
|
_ADD_OP)]
|
|
self.assertIsInstance(add_inputs, tuple)
|
|
self.assertEqual(len(add_inputs), 2)
|
|
self.assertEqual(
|
|
compat.as_bytes(add_inputs[0].op.op_def.name), _CONSTANT_OP)
|
|
self.assertEqual(compat.as_bytes(add_inputs[1].op.op_def.name), _UNIQUE_OP)
|
|
|
|
log_inputs = instrument.graph_inputs[instrument.graph_op_types.index(
|
|
_LOG_OP)]
|
|
self.assertIsInstance(log_inputs, tuple)
|
|
self.assertEqual(len(log_inputs), 1)
|
|
self.assertEqual(compat.as_bytes(log_inputs[0].op.op_def.name), _ADD_OP)
|
|
|
|
# Check the recorded graphs.
|
|
self.assertEqual(
|
|
len(instrument.graph_graphs), len(instrument.graph_op_types))
|
|
self.assertGreater(len(instrument.graph_graph_versions), 1)
|
|
if context.executing_eagerly():
|
|
for i in range(len(instrument.graph_graph_versions) - 1):
|
|
self.assertGreater(instrument.graph_graph_versions[i + 1],
|
|
instrument.graph_graph_versions[i])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testEagerGraphOpConstructionSimpleGraphScopeInsideFunction(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
|
|
@def_function.function
|
|
def log_2plus_unique_x(x):
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
unique_values, _ = array_ops.unique(x)
|
|
y = math_ops.log(2.0 + unique_values)
|
|
op_callbacks.remove_op_callback(instrument.callback)
|
|
return math_ops.sin(y)
|
|
|
|
x = constant_op.constant([-1.0, -1.0, 0.0], dtype=dtypes.float32)
|
|
output = log_2plus_unique_x(x)
|
|
self.assertAllClose(output, np.sin([0.0, np.log(2.0)]))
|
|
|
|
# The following ops should have been captured by the callback
|
|
# because they were constructed within the scope of `op_callback()`.
|
|
self.assertIn(_UNIQUE_OP, instrument.graph_op_types)
|
|
self.assertIn(_ADD_OP, instrument.graph_op_types)
|
|
self.assertIn(_LOG_OP, instrument.graph_op_types)
|
|
# The "Sin" op should not have been captured, because it was constructed
|
|
# outside the scope of `op_callback()`.
|
|
self.assertNotIn(_SIN_OP, instrument.graph_op_types)
|
|
self.assertEqual(
|
|
len(instrument.graph_op_names), len(instrument.graph_op_types))
|
|
|
|
# Check the graph internal ndarrays recorded at runtime.
|
|
unique_op_outputs = instrument.graph_internal_ndarrays[_UNIQUE_OP]
|
|
self.assertEqual(len(unique_op_outputs), 2)
|
|
self.assertAllClose(unique_op_outputs[0], [-1.0, 0.0])
|
|
self.assertAllClose(unique_op_outputs[1], [0, 0, 1])
|
|
add_op_outputs = instrument.graph_internal_ndarrays[b"add"]
|
|
self.assertEqual(len(add_op_outputs), 1)
|
|
self.assertAllClose(add_op_outputs[0], [1.0, 2.0])
|
|
log_op_outputs = instrument.graph_internal_ndarrays[_LOG_OP]
|
|
self.assertEqual(len(log_op_outputs), 1)
|
|
self.assertAllClose(log_op_outputs[0], [0.0, np.log(2.0)])
|
|
|
|
def testEagerOpAttributesAreCapture(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
m = constant_op.constant([[1.0, -1.0], [0.0, 1.0]])
|
|
x = constant_op.constant([[-2.0], [3.0]])
|
|
y = math_ops.matmul(m, x, transpose_a=True, transpose_b=False)
|
|
self.assertAllClose(y, [[-2.0], [5.0]])
|
|
|
|
self.assertEqual(len(instrument.eager_attrs), 1)
|
|
self.assertIsInstance(instrument.eager_attrs[0], tuple)
|
|
self.assertEqual(
|
|
instrument.eager_attrs[0][instrument.eager_attrs[0].index("transpose_a")
|
|
+ 1], True)
|
|
self.assertEqual(
|
|
instrument.eager_attrs[0][instrument.eager_attrs[0].index("transpose_b")
|
|
+ 1], False)
|
|
self.assertEqual(len(instrument.graph_attrs), 0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testGraphOpAttributesAreCapture(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
@def_function.function
|
|
def my_matmul(m, x):
|
|
return math_ops.matmul(m, x, transpose_a=True, transpose_b=False)
|
|
|
|
m = constant_op.constant([[1.0, -1.0], [0.0, 1.0]])
|
|
x = constant_op.constant([[-2.0], [3.0]])
|
|
y = my_matmul(m, x)
|
|
self.assertAllClose(y, [[-2.0], [5.0]])
|
|
|
|
index = instrument.graph_op_types.index(_MATMUL_OP)
|
|
self.assertIsInstance(instrument.graph_attrs[index], tuple)
|
|
self.assertEqual(
|
|
instrument.graph_attrs[index][
|
|
instrument.graph_attrs[index].index("transpose_a") + 1].b, True)
|
|
self.assertEqual(
|
|
instrument.graph_attrs[index][
|
|
instrument.graph_attrs[index].index("transpose_b") + 1].b, False)
|
|
if context.executing_eagerly():
|
|
self.assertEqual(len(instrument.eager_attrs), 1)
|
|
self.assertIsInstance(instrument.eager_attrs[0], tuple)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testEagerGraphOpConstructionIfControlFlow(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
@def_function.function
|
|
def my_function_with_cond(x):
|
|
if math_ops.greater(x, 0.0):
|
|
return x**2.0
|
|
else:
|
|
return x**3.0
|
|
|
|
x = constant_op.constant(-4.0)
|
|
self.assertAllClose(my_function_with_cond(x), -64.0)
|
|
|
|
self.assertIn(_IF_OP, instrument.graph_op_types)
|
|
self.assertIn(_GREATER_OP, instrument.graph_op_types)
|
|
self.assertIn(_POW_OP, instrument.graph_op_types)
|
|
self.assertEqual(
|
|
len(instrument.graph_op_names), len(instrument.graph_op_types))
|
|
|
|
# Check the graph internal ndarrays recorded at runtime.
|
|
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
|
|
self.assertEqual(len(greater_op_outputs), 1)
|
|
self.assertAllClose(greater_op_outputs[0], False)
|
|
# This was needed for backwards compatibility with TF2 Estimators which
|
|
# rely on variable names.
|
|
prefix = b"cond/" if context.executing_eagerly() else b""
|
|
pow_op_outputs = instrument.graph_internal_ndarrays[b"%spow" % prefix]
|
|
self.assertEqual(len(pow_op_outputs), 1)
|
|
self.assertAllClose(pow_op_outputs[0], -64.0)
|
|
|
|
def testEagerGraphOpConstructionWhileLoopControlFlow(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
@def_function.function
|
|
def my_function_with_while(counter, lim, accum):
|
|
while math_ops.less(counter, lim):
|
|
accum.assign_add(accum)
|
|
counter.assign_add(1.0)
|
|
|
|
counter = variables.Variable(0.0)
|
|
lim = constant_op.constant(4.0, dtype=dtypes.float32)
|
|
accum = variables.Variable(1.0)
|
|
my_function_with_while(counter, lim, accum)
|
|
|
|
self.assertAllClose(accum.read_value(), 16.0)
|
|
self.assertIn(_WHILE_OP, instrument.graph_op_types)
|
|
self.assertIn(_LESS_OP, instrument.graph_op_types)
|
|
self.assertIn(_ASSIGN_ADD_VARIABLE_OP, instrument.graph_op_types)
|
|
self.assertEqual(
|
|
len(instrument.graph_op_names), len(instrument.graph_op_types))
|
|
|
|
# Check the graph internal ndarrays recorded at runtime.
|
|
read_variable_op_outputs = instrument.graph_internal_ndarrays[
|
|
b"while/" + _READ_VARIABLE_OP]
|
|
self.assertAllClose(read_variable_op_outputs, [1.0, 2.0, 4.0, 8.0])
|
|
less_op_outputs = instrument.graph_internal_ndarrays[b"while/" + _LESS_OP]
|
|
self.assertAllClose(less_op_outputs, [True, True, True, True, False])
|
|
|
|
# TODO(cais): The following isn't decorated with
|
|
# `@test_util.run_in_graph_and_eager_modes` because of some apparent
|
|
# between `Dataset.map()` and `numpy_function()` used by
|
|
# `_NumpyFunctionCallback`. Maybe investigate.
|
|
def testDatasetMapTest(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
tensor = constant_op.constant(
|
|
[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0])
|
|
|
|
def map_fn(x):
|
|
return math_ops.log(math_ops.square(x) + 1)
|
|
|
|
dataset = dataset_ops.Dataset.from_tensor_slices(tensor).batch(2).map(
|
|
map_fn)
|
|
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
|
|
|
self.assertAllClose(iterator.next(), np.log([1.25, 2]))
|
|
self.assertAllClose(iterator.next(), np.log([3.25, 5]))
|
|
|
|
self.assertIn(_SQUARE_OP, instrument.graph_op_types)
|
|
self.assertIn(_ADD_OP, instrument.graph_op_types)
|
|
self.assertIn(_LOG_OP, instrument.graph_op_types)
|
|
self.assertEqual(
|
|
len(instrument.eager_op_types), len(instrument.eager_op_names))
|
|
|
|
def testSparseTensorEagerExecution(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
indices = [[1, 2], [2, 0], [3, 4]]
|
|
values = [0.0, 8.0, -2.0]
|
|
shape = [4, 5]
|
|
sp = sparse_tensor.SparseTensorValue(indices, values, shape)
|
|
w = ops.convert_to_tensor(np.ones([5, 1], np.float32))
|
|
|
|
y = sparse_ops.sparse_tensor_dense_matmul(sp, w)
|
|
self.assertAllClose(y, [[0.0], [0.0], [8.0], [-2.0]])
|
|
self.assertIn(_SPARSE_TENSOR_DENSE_MATMUL_OP, instrument.eager_op_types)
|
|
self.assertFalse(instrument.graph_op_types)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testSparseTensorFuncGraph(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
@def_function.function
|
|
def dense_matmul(sp, w):
|
|
return sparse_ops.sparse_tensor_dense_matmul(sp, w)
|
|
|
|
indices = [[1, 2], [2, 0], [3, 4]]
|
|
values = [0.0, 8.0, -2.0]
|
|
shape = [4, 5]
|
|
sp = sparse_tensor.SparseTensorValue(indices, values, shape)
|
|
w = ops.convert_to_tensor(np.ones([5, 1], np.float32))
|
|
y = dense_matmul(sp, w)
|
|
self.assertAllClose(y, [[0.0], [0.0], [8.0], [-2.0]])
|
|
self.assertIn(_SPARSE_TENSOR_DENSE_MATMUL_OP, instrument.graph_op_types)
|
|
if context.executing_eagerly():
|
|
self.assertIn(
|
|
dense_matmul.get_concrete_function(sp, w).name,
|
|
instrument.eager_op_types)
|
|
|
|
# Check the graph internal ndarrays recorded at runtime.
|
|
sparse_matmul_outputs = instrument.graph_internal_ndarrays[
|
|
_SPARSE_TENSOR_DENSE_MATMUL_OP + b"/" + _SPARSE_TENSOR_DENSE_MATMUL_OP]
|
|
if context.executing_eagerly():
|
|
self.assertEqual(len(sparse_matmul_outputs), 1)
|
|
self.assertAllClose(sparse_matmul_outputs[0], [[0.0], [0.0], [8.0], [-2.0]])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testOverrideDTypeInFuncGraph(self):
|
|
def to_float64(op_type, inputs, attrs, outputs, op_name=None, graph=None):
|
|
del inputs, attrs, op_name, graph # Unused.
|
|
if op_type in ("Const", "Placeholder"):
|
|
return outputs
|
|
else:
|
|
return [math_ops.cast(output, dtypes.float64) for output in outputs]
|
|
|
|
op_callbacks.add_op_callback(to_float64)
|
|
|
|
@def_function.function
|
|
def add_1_times_2(x):
|
|
return (x + 1.0) * 2.0
|
|
|
|
x = constant_op.constant(3.0, dtype=dtypes.float32)
|
|
y = add_1_times_2(x)
|
|
self.assertEqual(y.dtype, dtypes.float64)
|
|
self.assertAllClose(y, 8.0)
|
|
|
|
def testNoOutputOpUnderEagerExecution(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
x = constant_op.constant(10.0)
|
|
y = constant_op.constant(20.0)
|
|
z = x + y
|
|
w = control_flow_ops.group([z])
|
|
self.assertIsNone(w)
|
|
self.assertEqual(instrument.eager_op_types, [_ADD_OP])
|
|
|
|
def testOpCallbackCapturesConstTensors(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
@def_function.function
|
|
def times_two_plus_three(x):
|
|
return x * 2.0 + 3.0
|
|
|
|
self.assertAllClose(times_two_plus_three(constant_op.constant(10.0)), 23.0)
|
|
self.assertEqual(instrument.graph_op_types.count(b"Const"), 2)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testOpCallbackWorksWithGradientTape(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
|
|
v = variables.Variable(3.0, dtype=dtypes.float32)
|
|
if not context.executing_eagerly():
|
|
self.evaluate(v.initializer)
|
|
|
|
@def_function.function
|
|
def get_gradients():
|
|
with backprop.GradientTape() as tape:
|
|
loss = math_ops.sin(math_ops.square(v))
|
|
gradients = tape.gradient(loss, v)
|
|
return gradients
|
|
|
|
gradients = get_gradients()
|
|
# Applying the chain rule.
|
|
self.assertAllClose(gradients, np.cos(3.0 * 3.0) * 3.0 * 2.0)
|
|
self.assertIn(_SQUARE_OP, instrument.graph_op_types)
|
|
self.assertIn(_SIN_OP, instrument.graph_op_types)
|
|
# The mul and cos ops are created for backprop.
|
|
self.assertIn(_MUL_OP, instrument.graph_op_types)
|
|
self.assertIn(_COS_OP, instrument.graph_op_types)
|
|
|
|
# Check the ndarrays from runtime.
|
|
cos_op_outputs = instrument.graph_internal_ndarrays[b"gradient_tape/" +
|
|
_COS_OP]
|
|
self.assertEqual(len(cos_op_outputs), 1)
|
|
self.assertAllClose(cos_op_outputs[0], np.cos(3.0 * 3.0))
|
|
|
|
|
|
class OpCallbacksErrorConditionsTest(test_util.TensorFlowTestCase):
|
|
|
|
def tearDown(self):
|
|
op_callbacks.clear_op_callbacks()
|
|
super(OpCallbacksErrorConditionsTest, self).tearDown()
|
|
|
|
def testNonCallableObjectArgErrors(self):
|
|
with self.assertRaisesRegex(ValueError, r"is expected to be callable"):
|
|
op_callbacks.add_op_callback(1337)
|
|
|
|
def testRemoveUnregisteredCallbackLeadsToError(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
with self.assertRaisesRegex(KeyError, r"has not been registered"):
|
|
op_callbacks.remove_op_callback(instrument.callback)
|
|
|
|
def testRemovingCallbackTwiceLeadsToError(self):
|
|
instrument = _NumpyFunctionCallback()
|
|
op_callbacks.add_op_callback(instrument.callback)
|
|
op_callbacks.remove_op_callback(instrument.callback)
|
|
with self.assertRaisesRegex(KeyError, r"has not been registered"):
|
|
op_callbacks.remove_op_callback(instrument.callback)
|
|
|
|
def testOverridingWithWrongNumberOfTensorOutputsErrors(self):
|
|
def wrong_outputs_callback(op_type,
|
|
inputs,
|
|
attrs,
|
|
outputs,
|
|
op_name=None,
|
|
graph=None):
|
|
del op_type, inputs, attrs, op_name, graph # Unused.
|
|
return outputs[0], math_ops.negative(outputs[0])
|
|
|
|
op_callbacks.add_op_callback(wrong_outputs_callback)
|
|
|
|
@def_function.function
|
|
def log1p(x):
|
|
return math_ops.log(1.0 + x)
|
|
|
|
x = constant_op.constant(3.0)
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"returned 2 tensors, .* does not match .* \(1\)"):
|
|
log1p(x)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ops.enable_eager_execution()
|
|
test.main()
|