[BERT] Add a TF_CreatePlaceholders API and use that for capturing redundant tensors in
cond_graph. The while loop in BERT has ~1M loop vars. Almost all of these are only used in the body function. However in order to make the signatures of the cond and body functions to match we create placeholders in the cond graph which are unused. Creating these sequentially in python caused a lot of graph building time overhead. This changes creates those placeholders in a batch in C++ and also does not do copy_handle_data. This is a temporary band-aid. Long-term we want to change the While op to have explicit captures for the cond and body. PiperOrigin-RevId: 276738206 Change-Id: I721f76f770cce9fe6f9b7619f6e3e48704a7b118
This commit is contained in:
parent
a582a54b5b
commit
d97b021fb1
@ -822,6 +822,24 @@ def TF_Reset(target, containers=None, config=None):
|
||||
$1 = &types_local;
|
||||
}
|
||||
|
||||
%unignore TF_CreatePlaceholders;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_CreatePlaceholders;
|
||||
|
||||
// Build a Python list of TF_Output and return it.
|
||||
%typemap(out) std::vector<TF_Output> tensorflow::TF_CreatePlaceholders {
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
|
||||
// Unwrap the generated SwigValueWrapper<std::vector<TF_Output>>
|
||||
const std::vector<TF_Output>& tf_outputs = $1;
|
||||
for (size_t i = 0; i < tf_outputs.size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, CreateWrappedTFOutput(tf_outputs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
%unignore TF_NewSessionRef;
|
||||
%unignore SetRequireShapeInferenceFns;
|
||||
%unignore TF_TryEvaluateConstant_wrapper;
|
||||
|
@ -637,6 +637,48 @@ void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
types.data(), status);
|
||||
}
|
||||
|
||||
void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name,
|
||||
TF_DataType dtype, TF_Output* output) {
|
||||
TF_OperationDescription* desc =
|
||||
TF_NewOperation(graph, "Placeholder", name.data());
|
||||
TF_SetAttrType(desc, "dtype", dtype);
|
||||
TF_Operation* op = TF_FinishOperation(desc, s);
|
||||
output->oper = op;
|
||||
output->index = 0;
|
||||
}
|
||||
|
||||
std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
|
||||
const char* prefix,
|
||||
TF_Status* status) {
|
||||
std::vector<TF_Output> outputs;
|
||||
dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence");
|
||||
if (dtypes == nullptr) {
|
||||
Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr"));
|
||||
return outputs;
|
||||
}
|
||||
Safe_PyObjectPtr dtypes_holder(make_safe(dtypes));
|
||||
Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes);
|
||||
outputs.reserve(len);
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i);
|
||||
if (!dtype) {
|
||||
Set_TF_Status_from_Status(status,
|
||||
errors::Internal("Could not get dtype ", i));
|
||||
return outputs;
|
||||
}
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype));
|
||||
#else
|
||||
TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype));
|
||||
#endif
|
||||
outputs.push_back(TF_Output());
|
||||
CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype,
|
||||
&outputs.back());
|
||||
if (!status->status.ok()) break;
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
|
||||
const std::vector<int64_t>& dims,
|
||||
bool unknown_shape, TF_Status* status) {
|
||||
|
@ -227,6 +227,14 @@ void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
|
||||
TF_Status* status);
|
||||
|
||||
// Creates Placeholders with specified types in the Graph.
|
||||
//
|
||||
// This is an internal API used to speed up creation of unused placeholders
|
||||
// in while_v2 cond graph and is subject to change/removal.
|
||||
std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
|
||||
const char* prefix,
|
||||
TF_Status* status);
|
||||
|
||||
// Set the shape of output. If unknown is true, `num_dims` must be set to
|
||||
// -1 and `dims` is set to nullptr.
|
||||
void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
|
||||
|
@ -24,6 +24,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -45,6 +46,7 @@ from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import while_v2_indexed_slices_rewriter
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
|
||||
@ -213,9 +215,8 @@ def while_loop(cond,
|
||||
body_graph.external_captures[:num_cond_captures])
|
||||
cond_graph_captures = object_identity.ObjectIdentitySet(
|
||||
cond_graph.external_captures)
|
||||
for body_capture in body_graph.external_captures[num_cond_captures:]:
|
||||
assert body_capture not in cond_graph_captures
|
||||
cond_graph.capture(body_capture)
|
||||
_duplicate_body_captures_in_cond(
|
||||
cond_graph, body_graph.external_captures[num_cond_captures:])
|
||||
|
||||
# Make sure that the shapes of the loop outputs are compatible with the
|
||||
# shape invariants, or the shapes of the loop vars if the invariants are not
|
||||
@ -1096,6 +1097,50 @@ def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars):
|
||||
loop_var.name, inp.dtype, out.dtype))
|
||||
|
||||
|
||||
def _build_cond_placeholders_name_prefix(cond_graph):
|
||||
return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder")
|
||||
|
||||
|
||||
def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures):
|
||||
"""Creates placeholders for body captures in cond_graph.
|
||||
|
||||
This is needed to match signatures of cond and body graphs.
|
||||
|
||||
Args:
|
||||
cond_graph: cond branch graph
|
||||
body_graph_captures: Tensors which were captured when building the
|
||||
`body_graph`.
|
||||
"""
|
||||
types = [t.dtype.as_datatype_enum for t in body_graph_captures]
|
||||
# TODO(srbs): Providing a unique prefix does not ensure that there is no
|
||||
# conflict between the placeholder names and existing nodes in the graph.
|
||||
# However passing a list of strings may not be performant.
|
||||
# Ideally we should move `Graph.unique_name` to C++ or make
|
||||
# `Graph._names_in_use` a trie so that we can find a unique prefix.
|
||||
# TODO(b/143286622): This should not be required once captures are separated
|
||||
# from regular loop vars.
|
||||
placeholders = c_api.TF_CreatePlaceholders(
|
||||
cond_graph._c_graph, types,
|
||||
compat.as_str(_build_cond_placeholders_name_prefix(cond_graph)))
|
||||
placeholder_ops = [
|
||||
_OperationWithOutputs(ph.oper, cond_graph)
|
||||
for ph in placeholders
|
||||
]
|
||||
|
||||
tensors = []
|
||||
for op, ph, dtype in zip(placeholder_ops, placeholders, types):
|
||||
tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph)
|
||||
op._outputs = [tensor]
|
||||
tensors.append(tensor)
|
||||
|
||||
# Update `cond_graph._captures` and `cond_graph.inputs` to contain the
|
||||
# newly created placeholders.
|
||||
tuples = zip(body_graph_captures, tensors)
|
||||
keys = [id(t) for t in body_graph_captures]
|
||||
cond_graph._captures.update(zip(keys, tuples))
|
||||
cond_graph.inputs.extend(tensors)
|
||||
|
||||
|
||||
def _copy_handle_data(src_tensors, tgt_tensors):
|
||||
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
|
||||
custom_gradient.copy_handle_data(src_t, tgt_t)
|
||||
@ -1151,4 +1196,29 @@ def _build_accumulator_name(tensor):
|
||||
def _is_loop_invariant(tensor, inputs, outputs):
|
||||
return tensor in inputs and tensor in outputs
|
||||
|
||||
|
||||
class _OperationWithOutputs(ops.Operation):
|
||||
"""Operation with pre-built `TF_Output`s.
|
||||
|
||||
The C API for creating the extra placeholders for the cond graph returns
|
||||
SWIG wrapped TF_Output* pointers which we can use directly for
|
||||
`Operation.outputs`. The default constructor for `Operation` does not provide
|
||||
a way of specifying pre-built output tensors and always creates them. This is
|
||||
a performance overhead. It is not clear if adding that feature to the
|
||||
`Operation` API would be generally useful so for now we just have our own
|
||||
lightweight `Operation` implementation. Note that this does not extract a
|
||||
stacktrace as well since we don't expect this operation to be used.
|
||||
|
||||
TODO(b/143286622): This should not be required once captures are separated
|
||||
from regular loop vars.
|
||||
"""
|
||||
|
||||
def __init__(self, c_op, g):
|
||||
self._c_op = c_op
|
||||
self._graph = g
|
||||
self._outputs = None # Initialized by _duplicate_body_captures_in_cond().
|
||||
self._id_value = g._add_op(self, self.name)
|
||||
self._is_stateful = False
|
||||
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user