STT-tensorflow/tensorflow/python/framework/op_callbacks_test.py
Shanqing Cai e6f22ee5f4 [tfdbg2] Ensure Const ops in graphs are captured by op_callbacks
Details of the changes:
- In the Python API of tensorflow, Const ops are created by calling
  `_create_op_internal()` from constant_op.py. This differs from how most other ops
  are created, and is similar to Placeholder ops, which are already instrumented
  by tfdbg2' op_callbacks. In this CL, we add a op_callback hook to the code in
  constant_op.py to allow instrumentation of Const ops.
  that makes that call.
- In `_ConstantValue()` in tensor_util.py, add a special case for `CheckNumericsV2` op,
  so the `constant_value()` does not treat the `CheckNumericsV2` op as the constant
  tensor value. Similarly, add special cases for `Identity` and `DebugIdentityV2`.
- In `dumping_callback_test.py`, replace use of a deprecated Dataset API
  (`make_one_shot_iterator()`) with non-deprecated API (`iter()` and `next()`)
- Make other necessary changes to tfdbg2's tests to accommodate the Const ops
  which were previously not instrumented, but are now.
- Increase the shard_count of learning/brain/python/debug/tpu_callbacks_test.py to 6
  to avoid timeouts under the instrumented number of instrumented ops.

PiperOrigin-RevId: 307723353
Change-Id: Iecdbfcb439f6e04fc12c1503ad5339d42703e8bc
2020-04-21 18:42:31 -07:00

846 lines
31 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for op_callback."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.util import compat
# Keep all the hard-coded op type strings in one place so they are easy to
# change all at once in the face of any possible future op type name changes.
_ADD_OP = b"AddV2"
_ASSIGN_ADD_VARIABLE_OP = b"AssignAddVariableOp"
_CONSTANT_OP = b"Const"
_COS_OP = b"Cos"
_ENTER_OP = b"Enter"
_EXIT_OP = b"Exit"
_GREATER_OP = b"Greater"
_IDENTITY_OP = b"Identity"
_IF_OP = b"If"
_LESS_OP = b"Less"
_LOG_OP = b"Log"
_MERGE_OP = b"Merge"
_MATMUL_OP = b"MatMul"
_MUL_OP = b"Mul"
_NEXT_ITERATION_OP = b"NextIteration"
_PLACEHOLDER_OP = b"Placeholder"
_POW_OP = b"Pow"
_READ_VARIABLE_OP = b"ReadVariableOp"
_SIN_OP = b"Sin"
_SPARSE_TENSOR_DENSE_MATMUL_OP = b"SparseTensorDenseMatMul"
_SQRT_OP = b"Sqrt"
_SQUARE_OP = b"Square"
_STATELESS_IF_OP = b"StatelessIf"
_SWITCH_OP = b"Switch"
_UNIQUE_OP = b"Unique"
_VAR_HANDLE_OP = b"VarHandleOp"
_WHILE_OP = b"While"
class _NumpyFunctionCallback(object):
def __init__(self, instrument_graph_ops=True, float_only=False):
self.instrument_graph_ops = instrument_graph_ops
self._float_only = float_only
self.reset()
def callback(self, op_type, inputs, attrs, outputs, op_name=None, graph=None):
is_eager = not graph
if is_eager:
self.eager_op_types.append(
compat.as_bytes(op_type) if op_type else op_type)
self.eager_op_names.append(
compat.as_bytes(op_name) if op_name else op_name)
self.eager_attrs.append(attrs)
self.eager_graphs.append(graph)
self.eager_inputs.append(inputs)
else:
self.graph_op_types.append(
compat.as_bytes(op_type) if op_type else op_type)
self.graph_op_names.append(
compat.as_bytes(op_name) if op_name else op_name)
self.graph_attrs.append(attrs)
self.graph_graphs.append(graph)
self.graph_graph_versions.append(graph.version)
self.graph_inputs.append(inputs)
if not self.instrument_graph_ops:
return outputs
# Instrument the graph with numpy_function.
instrumented_outputs = []
for output in outputs:
if compat.as_bytes(op_type) in (_ENTER_OP, _EXIT_OP, _IF_OP, _MERGE_OP,
_NEXT_ITERATION_OP, _STATELESS_IF_OP,
_SWITCH_OP, _WHILE_OP, _IDENTITY_OP,
_VAR_HANDLE_OP, _PLACEHOLDER_OP,
_CONSTANT_OP):
# TODO(cais): Overriding the output of StatelessIf, If and While ops
# currently fails with error. Investigate (b/139668453).
# Avoid instrumenting Identity ops as well, as they are inserted
# by tf.function/AutoGraph for marshalling outputs.
instrumented_output = output
else:
def record(ndarray_value):
if compat.as_bytes(op_name) not in self.graph_internal_ndarrays:
self.graph_internal_ndarrays[compat.as_bytes(op_name)] = []
self.graph_internal_ndarrays[compat.as_bytes(op_name)].append(
ndarray_value)
return ndarray_value
if self._float_only and not output.dtype.is_floating:
instrumented_output = output
else:
instrumented_output = script_ops.numpy_function(
record, [output], output.dtype)
instrumented_output.set_shape(output.shape)
instrumented_outputs.append(instrumented_output)
return instrumented_outputs
def reset(self):
self.eager_op_types = []
self.eager_op_names = []
self.eager_attrs = []
self.eager_graphs = []
self.eager_inputs = []
self.graph_op_types = []
self.graph_op_names = []
self.graph_attrs = []
self.graph_graphs = []
self.graph_graph_versions = []
self.graph_inputs = []
# A dict mapping tensor name (e.g., "MatMut_10") to a list of ndarrays.
# The list is the history of the tensor's computation result inside
# `tf.Graph`s (`FuncGraph`s).
# For an op with multiple output tensors, the outputs are interleaved in
# the list.
self.graph_internal_ndarrays = {}
class OpCallbacksTest(test_util.TensorFlowTestCase):
def tearDown(self):
op_callbacks.clear_op_callbacks()
super(OpCallbacksTest, self).tearDown()
def testSingleThreadedStack(self):
ctx = context.context()
instrument_0 = _NumpyFunctionCallback()
instrument_1 = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument_0.callback)
self.assertEqual(1, len(ctx.op_callbacks))
self.assertIn(instrument_0.callback, ctx.op_callbacks)
op_callbacks.add_op_callback(instrument_1.callback)
self.assertEqual(2, len(ctx.op_callbacks))
self.assertIn(instrument_0.callback, ctx.op_callbacks)
self.assertIn(instrument_1.callback, ctx.op_callbacks)
op_callbacks.remove_op_callback(instrument_1.callback)
self.assertEqual(1, len(ctx.op_callbacks))
self.assertIn(instrument_0.callback, ctx.op_callbacks)
op_callbacks.remove_op_callback(instrument_0.callback)
self.assertEqual(0, len(ctx.op_callbacks))
def testMultiThreadedStacks(self):
# Instrument for the main thread.
instrument_0 = _NumpyFunctionCallback()
# Instrument for the to-be-created thread.
instrument_1 = _NumpyFunctionCallback()
def thread1_job():
op_callbacks.add_op_callback(instrument_1.callback)
@def_function.function
def func1(x):
return math_ops.sqrt(math_ops.log(x))
x = constant_op.constant(4.0)
self.assertAllClose(func1(x), np.sqrt(np.log(4.0)))
thread1 = threading.Thread(target=thread1_job)
# Start job on separate thread.
thread1.start()
# Run something on the main thread.
op_callbacks.add_op_callback(instrument_0.callback)
@def_function.function
def func0(x):
return math_ops.square(math_ops.sin(x))
x = constant_op.constant(4.0)
self.assertAllClose(func0(x), np.square(np.sin(4.0)))
thread1.join()
# Assert that there is no cross-talk between the main thread
# and the created thread.
self.assertIn(_PLACEHOLDER_OP, instrument_1.graph_op_types)
self.assertIn(_LOG_OP, instrument_1.graph_op_types)
self.assertIn(_SQRT_OP, instrument_1.graph_op_types)
self.assertNotIn(_SIN_OP, instrument_1.graph_op_types)
self.assertNotIn(_SQUARE_OP, instrument_1.graph_op_types)
self.assertNotIn(_LOG_OP, instrument_0.graph_op_types)
self.assertNotIn(_SQRT_OP, instrument_0.graph_op_types)
self.assertIn(_SIN_OP, instrument_0.graph_op_types)
self.assertIn(_SQUARE_OP, instrument_0.graph_op_types)
def testEagerOpExecution(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
x = constant_op.constant(6.0)
y = math_ops.square(math_ops.log(x))
self.assertAllClose(y, np.square(np.log(6.0)))
self.assertEqual(instrument.eager_op_types, [_LOG_OP, _SQUARE_OP])
# Op names are unavailable under eager mode.
self.assertEqual(instrument.eager_op_names, [None, None])
self.assertEqual(instrument.eager_graphs, [None, None])
self.assertEqual(len(instrument.eager_inputs), 2)
self.assertEqual(len(instrument.eager_inputs[0]), 1)
self.assertIsInstance(instrument.eager_inputs[0], tuple)
self.assertEqual(instrument.eager_inputs[0][0], x)
self.assertEqual(len(instrument.eager_inputs[1]), 1)
self.assertIsInstance(instrument.eager_inputs[1], tuple)
self.assertAllClose(instrument.eager_inputs[1][0], np.log(6.0))
self.assertFalse(instrument.graph_op_types)
self.assertFalse(instrument.graph_op_names)
self.assertFalse(instrument.graph_attrs)
self.assertFalse(instrument.graph_graphs)
self.assertFalse(instrument.graph_inputs)
def testMultiThreadedEagerOpExecution(self):
# Instrument for the main thread.
instrument_0 = _NumpyFunctionCallback()
# Instrument for the to-be-created thread.
instrument_1 = _NumpyFunctionCallback()
def thread_1_job():
op_callbacks.add_op_callback(instrument_1.callback)
x = constant_op.constant(6.0)
y = math_ops.square(math_ops.log(x))
op_callbacks.remove_op_callback(instrument_1.callback)
return y
thread_1 = threading.Thread(target=thread_1_job)
thread_1.start()
# While thread_1 is ongoing, do something on the main thread.
op_callbacks.add_op_callback(instrument_0.callback)
x = constant_op.constant(2.0)
y = math_ops.cos(x)
self.assertAllClose(y, np.cos(2.0))
op_callbacks.remove_op_callback(instrument_0.callback)
thread_1.join()
self.assertEqual(instrument_0.eager_op_types, [_COS_OP])
self.assertEqual(instrument_0.eager_op_names, [None])
self.assertEqual(instrument_1.eager_op_types, [_LOG_OP, _SQUARE_OP])
self.assertEqual(instrument_1.eager_op_names, [None, None])
def testEagerFunctionExecution(self):
instrument = _NumpyFunctionCallback()
@def_function.function
def square_log(x):
return math_ops.square(math_ops.log(x))
# Call the function once, so that the graph construction won't show up
# in the callback.
x_float32 = constant_op.constant(6.0, dtype=dtypes.float32)
x_float64 = constant_op.constant(6.0, dtype=dtypes.float64)
square_log(x_float32)
square_log(x_float64)
op_callbacks.add_op_callback(instrument.callback)
y = square_log(x_float32)
self.assertAllClose(y, np.square(np.log(6.0)))
y = square_log(x_float64)
self.assertAllClose(y, np.square(np.log(6.0)))
self.assertEqual(instrument.eager_op_names, [None, None])
self.assertFalse(instrument.graph_op_types)
self.assertFalse(instrument.graph_op_names)
self.assertFalse(instrument.graph_inputs)
# Each of the two dtypes should be associated with its own FuncGraph.
self.assertIn(
square_log.get_concrete_function(x_float32).name,
instrument.eager_op_types)
self.assertIn(
square_log.get_concrete_function(x_float64).name,
instrument.eager_op_types)
self.assertEqual(len(instrument.eager_inputs), 2)
self.assertIsInstance(instrument.eager_inputs[0], tuple)
self.assertEqual(instrument.eager_inputs[0][0], x_float32)
self.assertIsInstance(instrument.eager_inputs[1], tuple)
self.assertEqual(instrument.eager_inputs[1][0], x_float64)
def testMultiThreadedEagerFunctionExecution(self):
# Instrument for the main thread.
instrument_0 = _NumpyFunctionCallback()
# Instrument for the to-be-created thread.
instrument_1 = _NumpyFunctionCallback()
@def_function.function
def square_log(x):
return math_ops.square(math_ops.log(x))
# Call the function once, so that the graph construction won't show up
# in the callback.
x_float32 = constant_op.constant(6.0, dtype=dtypes.float32)
x_float64 = constant_op.constant(6.0, dtype=dtypes.float64)
square_log(x_float32)
square_log(x_float64)
def thread_1_job():
op_callbacks.add_op_callback(instrument_1.callback)
square_log(x_float32)
thread_1 = threading.Thread(target=thread_1_job)
thread_1.start()
# In the meantime, run some computation on the main thread.
op_callbacks.add_op_callback(instrument_0.callback)
square_log(x_float64)
thread_1.join()
# Each of the two dtypes should be associated with its own FuncGraph.
self.assertIn(
square_log.get_concrete_function(x_float64).name,
instrument_0.eager_op_types)
self.assertEqual(instrument_0.eager_op_names, [None])
self.assertFalse(instrument_0.graph_op_types)
self.assertIn(
square_log.get_concrete_function(x_float32).name,
instrument_1.eager_op_types)
self.assertEqual(instrument_1.eager_op_names, [None])
self.assertFalse(instrument_1.graph_op_types)
@test_util.run_in_graph_and_eager_modes
def testSimpleGraphConstructionScopeOutsideFunction(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def log_2plus_unique_x(x):
unique_values, unique_pos = array_ops.unique(x)
return math_ops.log(2.0 + unique_values), unique_pos
x = constant_op.constant([-1.0, -1.0, 0.0], dtype=dtypes.float32)
y1, y2 = log_2plus_unique_x(x)
self.assertAllClose(y1, [0.0, np.log(2.0)])
self.assertAllClose(y2, [0, 0, 1])
self.assertIn(_UNIQUE_OP, instrument.graph_op_types)
self.assertIn(_ADD_OP, instrument.graph_op_types)
self.assertIn(_LOG_OP, instrument.graph_op_types)
self.assertEqual(
len(instrument.graph_op_names), len(instrument.graph_op_types))
# Check the graph internal ndarrays recorded at runtime.
unique_op_outputs = instrument.graph_internal_ndarrays[_UNIQUE_OP]
if context.executing_eagerly():
# b/140810696: The run_in_graph_and_eager_modes decorator runs
# Session.run() twice. We can't assert on the number of outputs in
# that case.
self.assertEqual(len(unique_op_outputs), 2)
self.assertAllClose(unique_op_outputs[0], [-1.0, 0.0])
self.assertAllClose(unique_op_outputs[1], [0, 0, 1])
add_op_outputs = instrument.graph_internal_ndarrays[b"add"]
if context.executing_eagerly():
self.assertEqual(len(add_op_outputs), 1)
self.assertAllClose(add_op_outputs[0], [1.0, 2.0])
log_op_outputs = instrument.graph_internal_ndarrays[_LOG_OP]
if context.executing_eagerly():
self.assertEqual(len(log_op_outputs), 1)
self.assertAllClose(log_op_outputs[0], [0.0, np.log(2.0)])
@test_util.run_in_graph_and_eager_modes
def testPadOp(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def my_pad(x, padding):
return array_ops.pad(x, padding)
x = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
paddings = [[1, 1], [2, 2]]
y = my_pad(x, paddings)
expected_output = np.array([
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 0, 0],
[0, 0, 3, 4, 0, 0],
[0, 0, 0, 0, 0, 0],
], dtype=np.float32)
self.assertAllClose(y, expected_output)
self.assertAllClose(
instrument.graph_internal_ndarrays[b"Pad"][0], expected_output)
@test_util.run_in_graph_and_eager_modes
def testSimpleGraphConstructionWithCallbackReturningNone(self):
"""Test that callbacks that return None works."""
op_types = []
def no_return_callback(op_type,
inputs,
attrs,
outputs,
op_name=None,
graph=None):
del inputs, attrs, outputs, op_name, graph # Unused.
op_types.append(compat.as_bytes(op_type))
op_callbacks.add_op_callback(no_return_callback)
@def_function.function
def log1p(x):
return math_ops.log(1.0 + x)
x = constant_op.constant(3.0)
y = log1p(x)
self.assertAllClose(y, np.log(4.0))
self.assertIn(_ADD_OP, op_types)
self.assertIn(_LOG_OP, op_types)
@test_util.run_in_graph_and_eager_modes
def testGraphConstructionInputsAndGraphAreCapturedCorrectly(self):
instrument = _NumpyFunctionCallback(instrument_graph_ops=False)
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def log_2plus_unique_x(x):
unique_values, unique_pos = array_ops.unique(x)
return math_ops.log(2.0 + unique_values), unique_pos
x = constant_op.constant([-1.0, -1.0, 0.0], dtype=dtypes.float32)
y1, y2 = log_2plus_unique_x(x)
self.assertAllClose(y1, [0.0, np.log(2.0)])
self.assertAllClose(y2, [0, 0, 1])
# Check the recorded input tensors.
self.assertEqual(
len(instrument.graph_inputs), len(instrument.graph_op_types))
unique_inputs = instrument.graph_inputs[instrument.graph_op_types.index(
_UNIQUE_OP)]
self.assertIsInstance(unique_inputs, tuple)
self.assertEqual(len(unique_inputs), 1)
self.assertEqual(
compat.as_bytes(unique_inputs[0].op.op_def.name), _PLACEHOLDER_OP)
add_inputs = instrument.graph_inputs[instrument.graph_op_types.index(
_ADD_OP)]
self.assertIsInstance(add_inputs, tuple)
self.assertEqual(len(add_inputs), 2)
self.assertEqual(
compat.as_bytes(add_inputs[0].op.op_def.name), _CONSTANT_OP)
self.assertEqual(compat.as_bytes(add_inputs[1].op.op_def.name), _UNIQUE_OP)
log_inputs = instrument.graph_inputs[instrument.graph_op_types.index(
_LOG_OP)]
self.assertIsInstance(log_inputs, tuple)
self.assertEqual(len(log_inputs), 1)
self.assertEqual(compat.as_bytes(log_inputs[0].op.op_def.name), _ADD_OP)
# Check the recorded graphs.
self.assertEqual(
len(instrument.graph_graphs), len(instrument.graph_op_types))
self.assertGreater(len(instrument.graph_graph_versions), 1)
if context.executing_eagerly():
for i in range(len(instrument.graph_graph_versions) - 1):
self.assertGreater(instrument.graph_graph_versions[i + 1],
instrument.graph_graph_versions[i])
@test_util.run_in_graph_and_eager_modes
def testEagerGraphOpConstructionSimpleGraphScopeInsideFunction(self):
instrument = _NumpyFunctionCallback()
@def_function.function
def log_2plus_unique_x(x):
op_callbacks.add_op_callback(instrument.callback)
unique_values, _ = array_ops.unique(x)
y = math_ops.log(2.0 + unique_values)
op_callbacks.remove_op_callback(instrument.callback)
return math_ops.sin(y)
x = constant_op.constant([-1.0, -1.0, 0.0], dtype=dtypes.float32)
output = log_2plus_unique_x(x)
self.assertAllClose(output, np.sin([0.0, np.log(2.0)]))
# The following ops should have been captured by the callback
# because they were constructed within the scope of `op_callback()`.
self.assertIn(_UNIQUE_OP, instrument.graph_op_types)
self.assertIn(_ADD_OP, instrument.graph_op_types)
self.assertIn(_LOG_OP, instrument.graph_op_types)
# The "Sin" op should not have been captured, because it was constructed
# outside the scope of `op_callback()`.
self.assertNotIn(_SIN_OP, instrument.graph_op_types)
self.assertEqual(
len(instrument.graph_op_names), len(instrument.graph_op_types))
# Check the graph internal ndarrays recorded at runtime.
unique_op_outputs = instrument.graph_internal_ndarrays[_UNIQUE_OP]
self.assertEqual(len(unique_op_outputs), 2)
self.assertAllClose(unique_op_outputs[0], [-1.0, 0.0])
self.assertAllClose(unique_op_outputs[1], [0, 0, 1])
add_op_outputs = instrument.graph_internal_ndarrays[b"add"]
self.assertEqual(len(add_op_outputs), 1)
self.assertAllClose(add_op_outputs[0], [1.0, 2.0])
log_op_outputs = instrument.graph_internal_ndarrays[_LOG_OP]
self.assertEqual(len(log_op_outputs), 1)
self.assertAllClose(log_op_outputs[0], [0.0, np.log(2.0)])
def testEagerOpAttributesAreCapture(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
m = constant_op.constant([[1.0, -1.0], [0.0, 1.0]])
x = constant_op.constant([[-2.0], [3.0]])
y = math_ops.matmul(m, x, transpose_a=True, transpose_b=False)
self.assertAllClose(y, [[-2.0], [5.0]])
self.assertEqual(len(instrument.eager_attrs), 1)
self.assertIsInstance(instrument.eager_attrs[0], tuple)
self.assertEqual(
instrument.eager_attrs[0][instrument.eager_attrs[0].index("transpose_a")
+ 1], True)
self.assertEqual(
instrument.eager_attrs[0][instrument.eager_attrs[0].index("transpose_b")
+ 1], False)
self.assertEqual(len(instrument.graph_attrs), 0)
@test_util.run_in_graph_and_eager_modes
def testGraphOpAttributesAreCapture(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def my_matmul(m, x):
return math_ops.matmul(m, x, transpose_a=True, transpose_b=False)
m = constant_op.constant([[1.0, -1.0], [0.0, 1.0]])
x = constant_op.constant([[-2.0], [3.0]])
y = my_matmul(m, x)
self.assertAllClose(y, [[-2.0], [5.0]])
index = instrument.graph_op_types.index(_MATMUL_OP)
self.assertIsInstance(instrument.graph_attrs[index], tuple)
self.assertEqual(
instrument.graph_attrs[index][
instrument.graph_attrs[index].index("transpose_a") + 1].b, True)
self.assertEqual(
instrument.graph_attrs[index][
instrument.graph_attrs[index].index("transpose_b") + 1].b, False)
if context.executing_eagerly():
self.assertEqual(len(instrument.eager_attrs), 1)
self.assertIsInstance(instrument.eager_attrs[0], tuple)
@test_util.run_in_graph_and_eager_modes
def testEagerGraphOpConstructionIfControlFlow(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def my_function_with_cond(x):
if math_ops.greater(x, 0.0):
return x**2.0
else:
return x**3.0
x = constant_op.constant(-4.0)
self.assertAllClose(my_function_with_cond(x), -64.0)
self.assertIn(_IF_OP, instrument.graph_op_types)
self.assertIn(_GREATER_OP, instrument.graph_op_types)
self.assertIn(_POW_OP, instrument.graph_op_types)
self.assertEqual(
len(instrument.graph_op_names), len(instrument.graph_op_types))
# Check the graph internal ndarrays recorded at runtime.
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
self.assertEqual(len(greater_op_outputs), 1)
self.assertAllClose(greater_op_outputs[0], False)
# This was needed for backwards compatibility with TF2 Estimators which
# rely on variable names.
prefix = b"cond/" if context.executing_eagerly() else b""
pow_op_outputs = instrument.graph_internal_ndarrays[b"%spow" % prefix]
self.assertEqual(len(pow_op_outputs), 1)
self.assertAllClose(pow_op_outputs[0], -64.0)
def testEagerGraphOpConstructionWhileLoopControlFlow(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def my_function_with_while(counter, lim, accum):
while math_ops.less(counter, lim):
accum.assign_add(accum)
counter.assign_add(1.0)
counter = variables.Variable(0.0)
lim = constant_op.constant(4.0, dtype=dtypes.float32)
accum = variables.Variable(1.0)
my_function_with_while(counter, lim, accum)
self.assertAllClose(accum.read_value(), 16.0)
self.assertIn(_WHILE_OP, instrument.graph_op_types)
self.assertIn(_LESS_OP, instrument.graph_op_types)
self.assertIn(_ASSIGN_ADD_VARIABLE_OP, instrument.graph_op_types)
self.assertEqual(
len(instrument.graph_op_names), len(instrument.graph_op_types))
# Check the graph internal ndarrays recorded at runtime.
read_variable_op_outputs = instrument.graph_internal_ndarrays[
b"while/" + _READ_VARIABLE_OP]
self.assertAllClose(read_variable_op_outputs, [1.0, 2.0, 4.0, 8.0])
less_op_outputs = instrument.graph_internal_ndarrays[b"while/" + _LESS_OP]
self.assertAllClose(less_op_outputs, [True, True, True, True, False])
# TODO(cais): The following isn't decorated with
# `@test_util.run_in_graph_and_eager_modes` because of some apparent
# between `Dataset.map()` and `numpy_function()` used by
# `_NumpyFunctionCallback`. Maybe investigate.
def testDatasetMapTest(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
tensor = constant_op.constant(
[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0])
def map_fn(x):
return math_ops.log(math_ops.square(x) + 1)
dataset = dataset_ops.Dataset.from_tensor_slices(tensor).batch(2).map(
map_fn)
iterator = dataset_ops.make_one_shot_iterator(dataset)
self.assertAllClose(iterator.next(), np.log([1.25, 2]))
self.assertAllClose(iterator.next(), np.log([3.25, 5]))
self.assertIn(_SQUARE_OP, instrument.graph_op_types)
self.assertIn(_ADD_OP, instrument.graph_op_types)
self.assertIn(_LOG_OP, instrument.graph_op_types)
self.assertEqual(
len(instrument.eager_op_types), len(instrument.eager_op_names))
def testSparseTensorEagerExecution(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
indices = [[1, 2], [2, 0], [3, 4]]
values = [0.0, 8.0, -2.0]
shape = [4, 5]
sp = sparse_tensor.SparseTensorValue(indices, values, shape)
w = ops.convert_to_tensor(np.ones([5, 1], np.float32))
y = sparse_ops.sparse_tensor_dense_matmul(sp, w)
self.assertAllClose(y, [[0.0], [0.0], [8.0], [-2.0]])
self.assertIn(_SPARSE_TENSOR_DENSE_MATMUL_OP, instrument.eager_op_types)
self.assertFalse(instrument.graph_op_types)
@test_util.run_in_graph_and_eager_modes
def testSparseTensorFuncGraph(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def dense_matmul(sp, w):
return sparse_ops.sparse_tensor_dense_matmul(sp, w)
indices = [[1, 2], [2, 0], [3, 4]]
values = [0.0, 8.0, -2.0]
shape = [4, 5]
sp = sparse_tensor.SparseTensorValue(indices, values, shape)
w = ops.convert_to_tensor(np.ones([5, 1], np.float32))
y = dense_matmul(sp, w)
self.assertAllClose(y, [[0.0], [0.0], [8.0], [-2.0]])
self.assertIn(_SPARSE_TENSOR_DENSE_MATMUL_OP, instrument.graph_op_types)
if context.executing_eagerly():
self.assertIn(
dense_matmul.get_concrete_function(sp, w).name,
instrument.eager_op_types)
# Check the graph internal ndarrays recorded at runtime.
sparse_matmul_outputs = instrument.graph_internal_ndarrays[
_SPARSE_TENSOR_DENSE_MATMUL_OP + b"/" + _SPARSE_TENSOR_DENSE_MATMUL_OP]
if context.executing_eagerly():
self.assertEqual(len(sparse_matmul_outputs), 1)
self.assertAllClose(sparse_matmul_outputs[0], [[0.0], [0.0], [8.0], [-2.0]])
@test_util.run_in_graph_and_eager_modes
def testOverrideDTypeInFuncGraph(self):
def to_float64(op_type, inputs, attrs, outputs, op_name=None, graph=None):
del inputs, attrs, op_name, graph # Unused.
if op_type in ("Const", "Placeholder"):
return outputs
else:
return [math_ops.cast(output, dtypes.float64) for output in outputs]
op_callbacks.add_op_callback(to_float64)
@def_function.function
def add_1_times_2(x):
return (x + 1.0) * 2.0
x = constant_op.constant(3.0, dtype=dtypes.float32)
y = add_1_times_2(x)
self.assertEqual(y.dtype, dtypes.float64)
self.assertAllClose(y, 8.0)
def testNoOutputOpUnderEagerExecution(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
x = constant_op.constant(10.0)
y = constant_op.constant(20.0)
z = x + y
w = control_flow_ops.group([z])
self.assertIsNone(w)
self.assertEqual(instrument.eager_op_types, [_ADD_OP])
def testOpCallbackCapturesConstTensors(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def times_two_plus_three(x):
return x * 2.0 + 3.0
self.assertAllClose(times_two_plus_three(constant_op.constant(10.0)), 23.0)
self.assertEqual(instrument.graph_op_types.count(b"Const"), 2)
@test_util.run_in_graph_and_eager_modes
def testOpCallbackWorksWithGradientTape(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
v = variables.Variable(3.0, dtype=dtypes.float32)
if not context.executing_eagerly():
self.evaluate(v.initializer)
@def_function.function
def get_gradients():
with backprop.GradientTape() as tape:
loss = math_ops.sin(math_ops.square(v))
gradients = tape.gradient(loss, v)
return gradients
gradients = get_gradients()
# Applying the chain rule.
self.assertAllClose(gradients, np.cos(3.0 * 3.0) * 3.0 * 2.0)
self.assertIn(_SQUARE_OP, instrument.graph_op_types)
self.assertIn(_SIN_OP, instrument.graph_op_types)
# The mul and cos ops are created for backprop.
self.assertIn(_MUL_OP, instrument.graph_op_types)
self.assertIn(_COS_OP, instrument.graph_op_types)
# Check the ndarrays from runtime.
cos_op_outputs = instrument.graph_internal_ndarrays[b"gradient_tape/" +
_COS_OP]
self.assertEqual(len(cos_op_outputs), 1)
self.assertAllClose(cos_op_outputs[0], np.cos(3.0 * 3.0))
class OpCallbacksErrorConditionsTest(test_util.TensorFlowTestCase):
def tearDown(self):
op_callbacks.clear_op_callbacks()
super(OpCallbacksErrorConditionsTest, self).tearDown()
def testNonCallableObjectArgErrors(self):
with self.assertRaisesRegex(ValueError, r"is expected to be callable"):
op_callbacks.add_op_callback(1337)
def testRemoveUnregisteredCallbackLeadsToError(self):
instrument = _NumpyFunctionCallback()
with self.assertRaisesRegex(KeyError, r"has not been registered"):
op_callbacks.remove_op_callback(instrument.callback)
def testRemovingCallbackTwiceLeadsToError(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
op_callbacks.remove_op_callback(instrument.callback)
with self.assertRaisesRegex(KeyError, r"has not been registered"):
op_callbacks.remove_op_callback(instrument.callback)
def testOverridingWithWrongNumberOfTensorOutputsErrors(self):
def wrong_outputs_callback(op_type,
inputs,
attrs,
outputs,
op_name=None,
graph=None):
del op_type, inputs, attrs, op_name, graph # Unused.
return outputs[0], math_ops.negative(outputs[0])
op_callbacks.add_op_callback(wrong_outputs_callback)
@def_function.function
def log1p(x):
return math_ops.log(1.0 + x)
x = constant_op.constant(3.0)
with self.assertRaisesRegex(
ValueError,
r"returned 2 tensors, .* does not match .* \(1\)"):
log1p(x)
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()