diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 628a8a8ee4d..4abef7b6ec5 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -822,6 +822,24 @@ def TF_Reset(target, containers=None, config=None): $1 = &types_local; } +%unignore TF_CreatePlaceholders; +// See comment for "%noexception TF_SessionRun_wrapper;" +%noexception TF_CreatePlaceholders; + +// Build a Python list of TF_Output and return it. +%typemap(out) std::vector<TF_Output> tensorflow::TF_CreatePlaceholders { + $result = PyList_New($1.size()); + if (!$result) { + SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); + } + + // Unwrap the generated SwigValueWrapper<std::vector<TF_Output>> + const std::vector<TF_Output>& tf_outputs = $1; + for (size_t i = 0; i < tf_outputs.size(); ++i) { + PyList_SET_ITEM($result, i, CreateWrappedTFOutput(tf_outputs[i])); + } +} + %unignore TF_NewSessionRef; %unignore SetRequireShapeInferenceFns; %unignore TF_TryEvaluateConstant_wrapper; diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 1c7c42aea1b..78a1613c86c 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -637,6 +637,48 @@ void TF_GraphSetOutputHandleShapesAndTypes_wrapper( types.data(), status); } +void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name, + TF_DataType dtype, TF_Output* output) { + TF_OperationDescription* desc = + TF_NewOperation(graph, "Placeholder", name.data()); + TF_SetAttrType(desc, "dtype", dtype); + TF_Operation* op = TF_FinishOperation(desc, s); + output->oper = op; + output->index = 0; +} + +std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes, + const char* prefix, + TF_Status* status) { + std::vector<TF_Output> outputs; + dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence"); + if (dtypes == nullptr) { + Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr")); + return outputs; + } + Safe_PyObjectPtr dtypes_holder(make_safe(dtypes)); + Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes); + outputs.reserve(len); + for (size_t i = 0; i < len; i++) { + PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i); + if (!dtype) { + Set_TF_Status_from_Status(status, + errors::Internal("Could not get dtype ", i)); + return outputs; + } +#if PY_MAJOR_VERSION >= 3 + TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype)); +#else + TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype)); +#endif + outputs.push_back(TF_Output()); + CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype, + &outputs.back()); + if (!status->status.ok()) break; + } + return outputs; +} + void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, const std::vector<int64_t>& dims, bool unknown_shape, TF_Status* status) { diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 6e2f501a39a..fc57add79e1 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -227,6 +227,14 @@ void TF_GraphSetOutputHandleShapesAndTypes_wrapper( const std::vector<int>& ranks, const std::vector<TF_DataType>& types, TF_Status* status); +// Creates Placeholders with specified types in the Graph. +// +// This is an internal API used to speed up creation of unused placeholders +// in while_v2 cond graph and is subject to change/removal. +std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes, + const char* prefix, + TF_Status* status); + // Set the shape of output. If unknown is true, `num_dims` must be set to // -1 and `dims` is set to nullptr. void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 859eeefc185..d240885c1a9 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -24,6 +24,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.eager import backprop_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -45,6 +46,7 @@ from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import while_v2_indexed_slices_rewriter +from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import object_identity @@ -213,9 +215,8 @@ def while_loop(cond, body_graph.external_captures[:num_cond_captures]) cond_graph_captures = object_identity.ObjectIdentitySet( cond_graph.external_captures) - for body_capture in body_graph.external_captures[num_cond_captures:]: - assert body_capture not in cond_graph_captures - cond_graph.capture(body_capture) + _duplicate_body_captures_in_cond( + cond_graph, body_graph.external_captures[num_cond_captures:]) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not @@ -1096,6 +1097,50 @@ def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars): loop_var.name, inp.dtype, out.dtype)) +def _build_cond_placeholders_name_prefix(cond_graph): + return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder") + + +def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures): + """Creates placeholders for body captures in cond_graph. + + This is needed to match signatures of cond and body graphs. + + Args: + cond_graph: cond branch graph + body_graph_captures: Tensors which were captured when building the + `body_graph`. + """ + types = [t.dtype.as_datatype_enum for t in body_graph_captures] + # TODO(srbs): Providing a unique prefix does not ensure that there is no + # conflict between the placeholder names and existing nodes in the graph. + # However passing a list of strings may not be performant. + # Ideally we should move `Graph.unique_name` to C++ or make + # `Graph._names_in_use` a trie so that we can find a unique prefix. + # TODO(b/143286622): This should not be required once captures are separated + # from regular loop vars. + placeholders = c_api.TF_CreatePlaceholders( + cond_graph._c_graph, types, + compat.as_str(_build_cond_placeholders_name_prefix(cond_graph))) + placeholder_ops = [ + _OperationWithOutputs(ph.oper, cond_graph) + for ph in placeholders + ] + + tensors = [] + for op, ph, dtype in zip(placeholder_ops, placeholders, types): + tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph) + op._outputs = [tensor] + tensors.append(tensor) + + # Update `cond_graph._captures` and `cond_graph.inputs` to contain the + # newly created placeholders. + tuples = zip(body_graph_captures, tensors) + keys = [id(t) for t in body_graph_captures] + cond_graph._captures.update(zip(keys, tuples)) + cond_graph.inputs.extend(tensors) + + def _copy_handle_data(src_tensors, tgt_tensors): for src_t, tgt_t in zip(src_tensors, tgt_tensors): custom_gradient.copy_handle_data(src_t, tgt_t) @@ -1151,4 +1196,29 @@ def _build_accumulator_name(tensor): def _is_loop_invariant(tensor, inputs, outputs): return tensor in inputs and tensor in outputs + +class _OperationWithOutputs(ops.Operation): + """Operation with pre-built `TF_Output`s. + + The C API for creating the extra placeholders for the cond graph returns + SWIG wrapped TF_Output* pointers which we can use directly for + `Operation.outputs`. The default constructor for `Operation` does not provide + a way of specifying pre-built output tensors and always creates them. This is + a performance overhead. It is not clear if adding that feature to the + `Operation` API would be generally useful so for now we just have our own + lightweight `Operation` implementation. Note that this does not extract a + stacktrace as well since we don't expect this operation to be used. + + TODO(b/143286622): This should not be required once captures are separated + from regular loop vars. + """ + + def __init__(self, c_op, g): + self._c_op = c_op + self._graph = g + self._outputs = None # Initialized by _duplicate_body_captures_in_cond(). + self._id_value = g._add_op(self, self.name) + self._is_stateful = False + + # pylint: enable=protected-access