[tfdbg2] Ensure that op_callbacks capture Placeholders for tf.functions

- The Placeholder ops created for input args to tf.functions use a separate
  code path from the one currently covered by op_callbacks. The code path is
  in graph_only_ops.py. This CL adds the op_callbacks invocation in that module.
- Unit tests are added.
- Some existing unit tests are to accommodate the newly-tracked Placeholder ops.

PiperOrigin-RevId: 290661147
Change-Id: I6352134a42473392e08258c215ae9db91812b604
This commit is contained in:
Shanqing Cai 2020-01-20 17:16:38 -08:00 committed by TensorFlower Gardener
parent b26e1efece
commit 8ff9774650
7 changed files with 255 additions and 65 deletions

View File

@ -225,6 +225,11 @@ class CheckNumericsCallback(object):
def __init__(self, stack_height_limit, path_length_limit):
self._stack_height_limit = stack_height_limit
self._path_length_limit = path_length_limit
# A dict mapping Placeholder tensors to their instrumenting debug tensors.
# Used only under V1 graph mode, where we can't rely on auto control
# dependency to execute the debug tensors and hence need to attach the debug
# tensors as control dependencies of the ops that consume the Placeholder.
self._placeholder_to_debug_tensor = dict()
def callback(self,
op_type,
@ -243,6 +248,11 @@ class CheckNumericsCallback(object):
if graph:
# Under graph mode. Insert check_numerics op.
instrumented_outputs = []
if is_v1_graph_mode:
for input_tensor in inputs:
if input_tensor in self._placeholder_to_debug_tensor and outputs:
outputs[0].op._add_control_input( # pylint: disable=protected-access
self._placeholder_to_debug_tensor[input_tensor].op)
for slot, output in enumerate(outputs):
if (output.dtype.is_floating and
(op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
@ -262,8 +272,8 @@ class CheckNumericsCallback(object):
graph=graph,
traceback=output.op.traceback))
_CHECK_NUMERICS_INPUT_LOOKUP[graph][checked_output.name] = output
instrumented_outputs.append(
checked_output if is_v1_graph_mode else output)
instrumented_outputs.append(self._get_output_tensor(
op_type_bytes, output, checked_output, is_v1_graph_mode))
else:
instrumented_outputs.append(output)
return instrumented_outputs
@ -283,6 +293,40 @@ class CheckNumericsCallback(object):
stack_height_limit=self._stack_height_limit,
path_length_limit=self._path_length_limit))
def _get_output_tensor(self,
op_type,
tensor,
checked_tensor,
is_v1_graph_mode):
"""Determine what tensor to output from callback.
Args:
op_type: Type of the op that outputs the original symbolic tensor, as
`bytes`.
tensor: The original output symbolic tensor.
checked_tensor: The debugger-instrumented, numerics-checking tensor.
is_v1_graph_mode: Whether the debugged proggram is running under V1 graph
mode.
Returns:
A symbolic tensor to be returned by the dumping op_callback.
"""
if is_v1_graph_mode:
# Placeholders need special treatment under V1 graph mode. The
# callback can't simply override the Placeholder tensor to the debug
# tensor, as that would cause the Placeholder op to lack a value.
# The debug tensor is remembered and will be attached as control
# inputs to ops that consumer the Placeholders later.
if op_type == b"Placeholder":
self._placeholder_to_debug_tensor[tensor] = checked_tensor
return tensor
else:
return checked_tensor
else:
# Under non-v1 graph mode, rely on auto control dependency to run the
# checked tensor.
return tensor
@tf_export("debugging.enable_check_numerics")
def enable_check_numerics(stack_height_limit=30,

View File

@ -399,7 +399,10 @@ class DebuggedGraph(object):
graph_op_creation_digest: A GraphOpCreationDigest data object describing
the creation of an op inside this graph.
"""
assert graph_op_creation_digest.op_name not in self._op_by_name
if graph_op_creation_digest.op_name in self._op_by_name:
raise ValueError(
"Duplicate op name: %s (op type: %s)" %
(graph_op_creation_digest.op_name, graph_op_creation_digest.op_type))
self._op_by_name[
graph_op_creation_digest.op_name] = graph_op_creation_digest

View File

@ -102,6 +102,11 @@ class _DumpingCallback(object):
self._stack_frame_to_id_lock = threading.Lock()
self._context_lock = threading.Lock()
self._symbolic_tensor_counter_lock = threading.Lock()
# A dict mapping Placeholder tensors to their instrumenting debug tensors.
# Used only under V1 graph mode, where we can't rely on auto control
# dependency to execute the debug tensors and hence need to attach the debug
# tensors as control dependencies of the ops that consume the Placeholder.
self._placeholder_to_debug_tensor = dict()
self._writer = None
def function_callback(self, function):
@ -256,6 +261,40 @@ class _DumpingCallback(object):
host_name=self._hostname, stack_frame_ids=stack_frame_ids)
return code_location
def _process_v1_graph_mode_tensor(self,
op_type,
tensor,
debug_tensor,
tensor_debug_mode):
"""For V1 graph mode, determine what tensor to output from callback.
Args:
op_type: Type of the op that outputs the original symbolic tensor.
tensor: The original output symbolic tensor.
debug_tensor: The debugger-instrumented tensor.
tensor_debug_mode: Debug mode used, a tfdbg TensorDebugMode enum.
Returns:
A symbolic tensor to be returned by the dumping op_callback.
"""
# Placeholders need special treatment under V1 graph mode. The
# callback can't simply override the Placeholder tensor to a debug tensor,
# as that would cause the Placeholder op to lack a value.
if op_type in ("Placeholder", "PlaceholderWithDefault"):
self._placeholder_to_debug_tensor[tensor] = debug_tensor
return tensor
else:
# TODO(cais): Evaluate performance optimization options. For the
# `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a
# control dependency of `tensor.op` without an additional identity op.
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
return debug_tensor
else:
identity = array_ops.identity(tensor)
identity.op._add_control_input( # pylint: disable=protected-access
debug_tensor.op)
return identity
def _instrument_symbolic_tensors(self,
tensors,
op_type,
@ -287,8 +326,6 @@ class _DumpingCallback(object):
automatic control dependencies (see `auto_control_deps.py`) instead of
tensor overriding.
"""
# TODO(b/144441464, b/144440920, b/144440922): Make use of it.
tensor_debug_mode = self._tensor_debug_mode
debug_urls = ["file://%s" % self._dump_root]
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
@ -297,16 +334,16 @@ class _DumpingCallback(object):
for output_slot, tensor in enumerate(tensors):
if (not self._should_dump_tensor(op_type, tensor.dtype) or
not tensor.dtype.is_numpy_compatible):
# Instrumenting DT_VARIANT and DT_RESOURCE type tensors under
# V1 graph mode is known to have issues. TODO(cais): Investigate.
if is_v1_graph_mode:
instrumented_tensors.append(tensor)
continue
if is_v1_graph_mode and not tensor.dtype.is_numpy_compatible:
# Avoid instrumenting Placeholder under is_v1_graph_mode. Doing that
# would cause runtime complaint about Placeholders not being fed.
instrumented_tensors.append(tensor)
continue
# Except in V1 graph mode + control flow, debug_identity_v2 trigger auto
# control dependency because it's a stateful op.
# Except in V1 graph mode + control flow, debug_identity_v2 triggers
# auto control dependency because it's a stateful op.
debug_tensor = gen_debug_ops.debug_identity_v2(
# Use an empty (shape=[0]) float32 tensor for the NO_TENSOR mode
# as a low-overhead placeholder, since no actual tensor value is
@ -318,13 +355,8 @@ class _DumpingCallback(object):
tensor_debug_mode=self._tensor_debug_mode,
debug_urls=debug_urls)
if is_v1_graph_mode:
# TODO(cais): Evaluate performance optimization options. For the
# `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a
# control dependency of `tensor.op` without an additional identity op.
identity = array_ops.identity(tensor)
identity.op._add_control_input( # pylint: disable=protected-access
debug_tensor.op)
instrumented_tensors.append(identity)
instrumented_tensors.append(self._process_v1_graph_mode_tensor(
op_type, tensor, debug_tensor, tensor_debug_mode))
return instrumented_tensors
elif tensor_debug_mode in (debug_event_pb2.TensorDebugMode.CURT_HEALTH,
debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
@ -355,10 +387,8 @@ class _DumpingCallback(object):
tensor_debug_mode=self._tensor_debug_mode,
debug_urls=debug_urls)
if is_v1_graph_mode:
identity = array_ops.identity(tensor)
identity.op._add_control_input( # pylint: disable=protected-access
debug_tensor.op)
instrumented_tensors.append(identity)
instrumented_tensors.append(self._process_v1_graph_mode_tensor(
op_type, tensor, debug_tensor, tensor_debug_mode))
return instrumented_tensors
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
for output_slot, tensor in enumerate(tensors):
@ -377,7 +407,8 @@ class _DumpingCallback(object):
tensor_debug_mode=self._tensor_debug_mode,
debug_urls=debug_urls)
if is_v1_graph_mode:
instrumented_tensors.append(debug_tensor)
instrumented_tensors.append(self._process_v1_graph_mode_tensor(
op_type, tensor, debug_tensor, tensor_debug_mode))
return instrumented_tensors
else:
raise NotImplementedError(
@ -487,9 +518,21 @@ class _DumpingCallback(object):
writer = self.get_writer()
if graph:
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
context_id = self._get_context_id(graph) # Innermost context ID.
assert op_name is not None
output_tensor_ids = self._get_symbolic_tensor_ids(len(outputs))
if op_type in ("Placeholder", "PlaceholderWithDefault"):
# In some cases, the op name of a Placeholder op in a graph
# can be duplicate (e.g., with the name "resource").
# When this happens, we give the op an debugger-generated name
# in order to prevent problems and check failures down the pipe.
op_name = "%s_%d" % (op_name, self._symbolic_tensor_counter)
if is_v1_graph_mode:
for input_tensor in inputs:
# TODO(cais):
if input_tensor in self._placeholder_to_debug_tensor and outputs:
outputs[0].op._add_control_input( # pylint: disable=protected-access
self._placeholder_to_debug_tensor[input_tensor].op)
graph_op_creation = debug_event_pb2.GraphOpCreation(
op_type=op_type,
op_name=op_name,

View File

@ -270,7 +270,9 @@ class TracingCallbackTest(
reader.update()
graph_exec_traces = reader.graph_execution_traces()
executed_op_types = [trace.op_type for trace in graph_exec_traces]
self.assertCountEqual(executed_op_types, ["AddV2", "Sub", "RealDiv"])
self.assertCountEqual(
executed_op_types,
["Placeholder", "Placeholder", "AddV2", "Sub", "RealDiv"])
if tensor_debug_mode == "CURT_HEALTH":
for trace in graph_exec_traces:
# 1st element: tensor_id, should be >= 0.
@ -330,7 +332,9 @@ class TracingCallbackTest(
reader.update()
graph_exec_traces = reader.graph_execution_traces()
executed_op_types = [trace.op_type for trace in graph_exec_traces]
self.assertEqual(executed_op_types, ["LogicalAnd", "LogicalNot"])
self.assertEqual(
executed_op_types,
["Placeholder", "Placeholder", "LogicalAnd", "LogicalNot"])
for trace in graph_exec_traces:
tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
self.assertGreaterEqual(tensor_id, 0)
@ -424,6 +428,7 @@ class TracingCallbackTest(
set(reader.device_name_map().values()))
# Verify the recorded graph-building history.
placeholder_op_digests = reader.graph_op_digests(op_type="Placeholder")
add_op_digests = reader.graph_op_digests(op_type="AddV2")
self.assertLen(add_op_digests, 2)
self.assertEqual(
@ -449,30 +454,57 @@ class TracingCallbackTest(
graph_exec_traces = reader.graph_execution_traces()
executed_op_types = [digest.op_type for digest in graph_exec_traces]
self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2", "Sin"])
self.assertEqual(
executed_op_types,
["Placeholder", "Placeholder", "Placeholder", "Placeholder",
"AddV2", "Log", "AddV2", "Sin"])
placeholder_traces = graph_exec_traces[:4]
non_placeholder_traces = graph_exec_traces[4:]
# Verify the graph ID stack of each op.
# 1st AddV2 op.
# The outer function's 1st Placeholder.
self.assertEqual(
reader.graph_by_id(graph_exec_traces[0].graph_ids[-1]).name,
reader.graph_by_id(placeholder_traces[0].graph_ids[-1]).name,
"sin1p_log_sum")
# The outer function's 2nd Placeholder.
self.assertEqual(
reader.graph_by_id(placeholder_traces[1].graph_ids[-1]).name,
"sin1p_log_sum")
# The inner function's 1st Placeholder.
self.assertEqual(
reader.graph_by_id(placeholder_traces[2].graph_ids[-1]).name,
"log_sum")
self.assertEqual(
reader.graph_by_id(graph_exec_traces[0].graph_ids[-2]).name,
reader.graph_by_id(placeholder_traces[2].graph_ids[-2]).name,
"sin1p_log_sum")
# The inner function's 2nd Placeholder.
self.assertEqual(
reader.graph_by_id(placeholder_traces[3].graph_ids[-1]).name,
"log_sum")
self.assertEqual(
reader.graph_by_id(placeholder_traces[3].graph_ids[-2]).name,
"sin1p_log_sum")
# 1st AddV2 op.
self.assertEqual(
reader.graph_by_id(non_placeholder_traces[0].graph_ids[-1]).name,
"log_sum")
self.assertEqual(
reader.graph_by_id(non_placeholder_traces[0].graph_ids[-2]).name,
"sin1p_log_sum")
# Log op.
self.assertEqual(
reader.graph_by_id(graph_exec_traces[1].graph_ids[-1]).name,
reader.graph_by_id(non_placeholder_traces[1].graph_ids[-1]).name,
"log_sum")
self.assertEqual(
reader.graph_by_id(graph_exec_traces[1].graph_ids[-2]).name,
reader.graph_by_id(non_placeholder_traces[1].graph_ids[-2]).name,
"sin1p_log_sum")
# 2nd AddV2 op.
self.assertEqual(
reader.graph_by_id(graph_exec_traces[2].graph_ids[-1]).name,
reader.graph_by_id(non_placeholder_traces[2].graph_ids[-1]).name,
"sin1p_log_sum")
# Sin op.
self.assertEqual(
reader.graph_by_id(graph_exec_traces[3].graph_ids[-1]).name,
reader.graph_by_id(non_placeholder_traces[3].graph_ids[-1]).name,
"sin1p_log_sum")
if tensor_debug_mode == "NO_TENSOR":
@ -485,37 +517,61 @@ class TracingCallbackTest(
# In each case, the 1st element of debug_tensor_value is the ID of the
# symbolic tenosr and the 2nd element is a zero indicating there is no
# inf or nan.
self.assertAllClose(
graph_exec_traces[0].debug_tensor_value,
[add_op_digests[0].output_tensor_ids[0], 0.0]) # 1st AddV2 op.
self.assertAllClose(
graph_exec_traces[1].debug_tensor_value,
[log_op_digests[0].output_tensor_ids[0], 0.0]) # Log op.
self.assertAllClose(
graph_exec_traces[2].debug_tensor_value,
[add_op_digests[1].output_tensor_ids[0], 0.0]) # 2nd AddV2 op.
self.assertAllClose(
graph_exec_traces[3].debug_tensor_value,
[sin_op_digests[0].output_tensor_ids[0], 0.0]) # Sin op.
self.assertAllClose( # 1st outer placeholder.
placeholder_traces[0].debug_tensor_value,
[placeholder_op_digests[0].output_tensor_ids[0], 0.0])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[1].debug_tensor_value,
[placeholder_op_digests[1].output_tensor_ids[0], 0.0])
self.assertAllClose( # 1st inner placeholder.
placeholder_traces[2].debug_tensor_value,
[placeholder_op_digests[2].output_tensor_ids[0], 0.0])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[3].debug_tensor_value,
[placeholder_op_digests[3].output_tensor_ids[0], 0.0])
self.assertAllClose( # 1st AddV2 op.
non_placeholder_traces[0].debug_tensor_value,
[add_op_digests[0].output_tensor_ids[0], 0.0])
self.assertAllClose( # Log op.
non_placeholder_traces[1].debug_tensor_value,
[log_op_digests[0].output_tensor_ids[0], 0.0])
self.assertAllClose( # 2nd AddV2 op.
non_placeholder_traces[2].debug_tensor_value,
[add_op_digests[1].output_tensor_ids[0], 0.0])
self.assertAllClose( # Sin op.
non_placeholder_traces[3].debug_tensor_value,
[sin_op_digests[0].output_tensor_ids[0], 0.0])
elif tensor_debug_mode == "CONCISE_HEALTH":
# 1st element: tensor_id, should be >= 0.
# 1st element: tensor_id.
# 2nd element: element count. Remaining elements: all zero because there
# is no -inf, inf or nan.
self.assertAllClose( # 1st outer placeholder.
placeholder_traces[0].debug_tensor_value,
[placeholder_op_digests[0].output_tensor_ids[0], 1., 0., 0., 0.])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[1].debug_tensor_value,
[placeholder_op_digests[1].output_tensor_ids[0], 1., 0., 0., 0.])
self.assertAllClose( # 1st inner placeholder.
placeholder_traces[2].debug_tensor_value,
[placeholder_op_digests[2].output_tensor_ids[0], 1., 0., 0., 0.])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[3].debug_tensor_value,
[placeholder_op_digests[3].output_tensor_ids[0], 1., 0., 0., 0.])
# 1st AddV2 op.
self.assertAllClose(
graph_exec_traces[0].debug_tensor_value,
non_placeholder_traces[0].debug_tensor_value,
[add_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
# Log op.
self.assertAllClose(
graph_exec_traces[1].debug_tensor_value,
non_placeholder_traces[1].debug_tensor_value,
[log_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
# 2nd AddV2 op.
self.assertAllClose(
graph_exec_traces[2].debug_tensor_value,
non_placeholder_traces[2].debug_tensor_value,
[add_op_digests[1].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
# Sin op.
self.assertAllClose(
graph_exec_traces[3].debug_tensor_value,
non_placeholder_traces[3].debug_tensor_value,
[sin_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
elif tensor_debug_mode == "SHAPE":
# 1st element: tensor_id.
@ -523,32 +579,59 @@ class TracingCallbackTest(
# 3rd element: rank (scalar).
# 4th element: element count (1).
# Remaining elements: shape padded to fixed length (6).
self.assertAllClose( # 1st outer placeholder.
placeholder_traces[0].debug_tensor_value,
[placeholder_op_digests[0].output_tensor_ids[0],
1, 0, 1, 0, 0, 0, 0, 0, 0])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[1].debug_tensor_value,
[placeholder_op_digests[1].output_tensor_ids[0],
1, 0, 1, 0, 0, 0, 0, 0, 0])
self.assertAllClose( # 1st inner placeholder.
placeholder_traces[2].debug_tensor_value,
[placeholder_op_digests[2].output_tensor_ids[0],
1, 0, 1, 0, 0, 0, 0, 0, 0])
self.assertAllClose( # 2nd outer placeholder.
placeholder_traces[3].debug_tensor_value,
[placeholder_op_digests[3].output_tensor_ids[0],
1, 0, 1, 0, 0, 0, 0, 0, 0])
# 1st AddV2 op.
self.assertAllClose(
graph_exec_traces[0].debug_tensor_value,
non_placeholder_traces[0].debug_tensor_value,
[add_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
# Log op.
self.assertAllClose(
graph_exec_traces[1].debug_tensor_value,
non_placeholder_traces[1].debug_tensor_value,
[log_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
# 2nd AddV2 op.
self.assertAllClose(
graph_exec_traces[2].debug_tensor_value,
non_placeholder_traces[2].debug_tensor_value,
[add_op_digests[1].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
# Sin op.
self.assertAllClose(
graph_exec_traces[3].debug_tensor_value,
non_placeholder_traces[3].debug_tensor_value,
[sin_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
else: # FULL_TENSOR.
full_tensor_values = [
placeholder_full_tensor_values = [
reader.graph_execution_trace_to_tensor_value(trace)
for trace in graph_exec_traces]
self.assertAllClose(full_tensor_values[0], 5.0) # 1st AddV2 op.
self.assertAllClose(full_tensor_values[1], np.log(5.0)) # Log op.
for trace in placeholder_traces]
self.assertAllClose(placeholder_full_tensor_values[0], x) # Input x.
self.assertAllClose(placeholder_full_tensor_values[1], y) # Input y.
self.assertAllClose(placeholder_full_tensor_values[2], x) # Input x.
self.assertAllClose(placeholder_full_tensor_values[3], y) # Input y.
non_placeholder_full_tensor_values = [
reader.graph_execution_trace_to_tensor_value(trace)
for trace in non_placeholder_traces]
self.assertAllClose(
full_tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op.
non_placeholder_full_tensor_values[0], 5.0) # 1st AddV2 op.
self.assertAllClose(
full_tensor_values[3], np.sin(np.log(5.0) + 1.0)) # Sin op.
non_placeholder_full_tensor_values[1], np.log(5.0)) # Log op.
self.assertAllClose(
non_placeholder_full_tensor_values[2],
np.log(5.0) + 1.0) # 2nd AddV2 op.
self.assertAllClose(
non_placeholder_full_tensor_values[3],
np.sin(np.log(5.0) + 1.0)) # Sin op.
def testCapturingExecutedGraphIdsOfTwoCompilationsOfSameFunction(self):
"""Test correct executed IDs of two FuncGraphs from the same Py function."""
@ -738,9 +821,11 @@ class TracingCallbackTest(
with debug_events_reader.DebugDataReader(self.dump_root) as reader:
reader.update()
graph_exec_digests = reader.graph_execution_traces(digest=True)
executed_op_types = [digest.op_type for digest in graph_exec_digests]
executed_op_types = [digest.op_type for digest in graph_exec_digests
if digest.op_type != "Placeholder"]
tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
for digest in graph_exec_digests]
for digest in graph_exec_digests
if digest.op_type != "Placeholder"]
if tensor_dtypes == [dtypes.float32] and not op_regex:
self.assertEqual(executed_op_types, ["Unique", "Sum"])

View File

@ -443,6 +443,7 @@ py_library(
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python:op_callbacks",
"//tensorflow/python:tensor_shape",
],
)

View File

@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -33,8 +34,17 @@ def graph_placeholder(dtype, shape, name=None):
shape = tensor_shape.TensorShape(shape)
shape = attr_value_pb2.AttrValue(shape=shape.as_proto())
g = ops.get_default_graph()
attrs = {"dtype": dtype_value, "shape": shape}
op = g._create_op_internal( # pylint: disable=protected-access
"Placeholder", [], [dtype], input_types=[],
attrs={"dtype": dtype_value, "shape": shape}, name=name)
attrs=attrs, name=name)
result, = op.outputs
if op_callbacks.should_invoke_op_callbacks():
# TODO(b/147670703): Once the special-op creation code paths
# are unified. Remove this `if` block.
callback_outputs = op_callbacks.invoke_op_callbacks(
"Placeholder", tuple(), attrs, tuple(op.outputs),
op_name=name, graph=g)
if callback_outputs is not None:
result, = callback_outputs
return result

View File

@ -110,7 +110,7 @@ class _NumpyFunctionCallback(object):
if compat.as_bytes(op_type) in (
_ENTER_OP, _EXIT_OP, _IF_OP, _MERGE_OP, _NEXT_ITERATION_OP,
_STATELESS_IF_OP, _SWTICH_OP, _WHILE_OP, _IDENTITY_OP,
_VAR_HANDLE_OP):
_VAR_HANDLE_OP, _PLACEHOLDER_OP):
# TODO(cais): Overriding the output of StatelessIf, If and While ops
# currently fails with error. Investigate (b/139668453).
# Avoid instrumenting Identity ops as well, as they are inserted
@ -218,6 +218,7 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
# Assert that there is no cross-talk between the main thread
# and the created thread.
self.assertIn(_PLACEHOLDER_OP, instrument_1.graph_op_types)
self.assertIn(_LOG_OP, instrument_1.graph_op_types)
self.assertIn(_SQRT_OP, instrument_1.graph_op_types)
self.assertNotIn(_SIN_OP, instrument_1.graph_op_types)
@ -739,8 +740,11 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes
def testOverrideDTypeInFuncGraph(self):
def to_float64(op_type, inputs, attrs, outputs, op_name=None, graph=None):
del op_type, inputs, attrs, op_name, graph # Unused.
return [math_ops.cast(output, dtypes.float64) for output in outputs]
del inputs, attrs, op_name, graph # Unused.
if op_type == "Placeholder":
return outputs
else:
return [math_ops.cast(output, dtypes.float64) for output in outputs]
op_callbacks.add_op_callback(to_float64)