[BERT] Add a TF_CreatePlaceholders API and use that for capturing redundant tensors in

cond_graph.
The while loop in BERT has ~1M loop vars. Almost all of these are only used in the body function. However in order to make the signatures of the cond and body functions to match we create placeholders in the cond graph which are unused. Creating these sequentially in python caused a lot of graph building time overhead. This changes creates those placeholders in a batch in C++ and also does not do copy_handle_data.
This is a temporary band-aid. Long-term we want to change the While op to have explicit captures for the cond and body.

PiperOrigin-RevId: 276738206
Change-Id: I721f76f770cce9fe6f9b7619f6e3e48704a7b118
This commit is contained in:
Saurabh Saxena 2019-10-25 12:27:38 -07:00 committed by TensorFlower Gardener
parent a582a54b5b
commit d97b021fb1
4 changed files with 141 additions and 3 deletions

View File

@ -822,6 +822,24 @@ def TF_Reset(target, containers=None, config=None):
$1 = &types_local;
}
%unignore TF_CreatePlaceholders;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_CreatePlaceholders;
// Build a Python list of TF_Output and return it.
%typemap(out) std::vector<TF_Output> tensorflow::TF_CreatePlaceholders {
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
// Unwrap the generated SwigValueWrapper<std::vector<TF_Output>>
const std::vector<TF_Output>& tf_outputs = $1;
for (size_t i = 0; i < tf_outputs.size(); ++i) {
PyList_SET_ITEM($result, i, CreateWrappedTFOutput(tf_outputs[i]));
}
}
%unignore TF_NewSessionRef;
%unignore SetRequireShapeInferenceFns;
%unignore TF_TryEvaluateConstant_wrapper;

View File

@ -637,6 +637,48 @@ void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
types.data(), status);
}
void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name,
TF_DataType dtype, TF_Output* output) {
TF_OperationDescription* desc =
TF_NewOperation(graph, "Placeholder", name.data());
TF_SetAttrType(desc, "dtype", dtype);
TF_Operation* op = TF_FinishOperation(desc, s);
output->oper = op;
output->index = 0;
}
std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
const char* prefix,
TF_Status* status) {
std::vector<TF_Output> outputs;
dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence");
if (dtypes == nullptr) {
Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr"));
return outputs;
}
Safe_PyObjectPtr dtypes_holder(make_safe(dtypes));
Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes);
outputs.reserve(len);
for (size_t i = 0; i < len; i++) {
PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i);
if (!dtype) {
Set_TF_Status_from_Status(status,
errors::Internal("Could not get dtype ", i));
return outputs;
}
#if PY_MAJOR_VERSION >= 3
TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype));
#else
TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype));
#endif
outputs.push_back(TF_Output());
CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype,
&outputs.back());
if (!status->status.ok()) break;
}
return outputs;
}
void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
const std::vector<int64_t>& dims,
bool unknown_shape, TF_Status* status) {

View File

@ -227,6 +227,14 @@ void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
TF_Status* status);
// Creates Placeholders with specified types in the Graph.
//
// This is an internal API used to speed up creation of unused placeholders
// in while_v2 cond graph and is subject to change/removal.
std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes,
const char* prefix,
TF_Status* status);
// Set the shape of output. If unknown is true, `num_dims` must be set to
// -1 and `dims` is set to nullptr.
void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,

View File

@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.eager import backprop_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -45,6 +46,7 @@ from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import while_v2_indexed_slices_rewriter
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
@ -213,9 +215,8 @@ def while_loop(cond,
body_graph.external_captures[:num_cond_captures])
cond_graph_captures = object_identity.ObjectIdentitySet(
cond_graph.external_captures)
for body_capture in body_graph.external_captures[num_cond_captures:]:
assert body_capture not in cond_graph_captures
cond_graph.capture(body_capture)
_duplicate_body_captures_in_cond(
cond_graph, body_graph.external_captures[num_cond_captures:])
# Make sure that the shapes of the loop outputs are compatible with the
# shape invariants, or the shapes of the loop vars if the invariants are not
@ -1096,6 +1097,50 @@ def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars):
loop_var.name, inp.dtype, out.dtype))
def _build_cond_placeholders_name_prefix(cond_graph):
return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder")
def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures):
"""Creates placeholders for body captures in cond_graph.
This is needed to match signatures of cond and body graphs.
Args:
cond_graph: cond branch graph
body_graph_captures: Tensors which were captured when building the
`body_graph`.
"""
types = [t.dtype.as_datatype_enum for t in body_graph_captures]
# TODO(srbs): Providing a unique prefix does not ensure that there is no
# conflict between the placeholder names and existing nodes in the graph.
# However passing a list of strings may not be performant.
# Ideally we should move `Graph.unique_name` to C++ or make
# `Graph._names_in_use` a trie so that we can find a unique prefix.
# TODO(b/143286622): This should not be required once captures are separated
# from regular loop vars.
placeholders = c_api.TF_CreatePlaceholders(
cond_graph._c_graph, types,
compat.as_str(_build_cond_placeholders_name_prefix(cond_graph)))
placeholder_ops = [
_OperationWithOutputs(ph.oper, cond_graph)
for ph in placeholders
]
tensors = []
for op, ph, dtype in zip(placeholder_ops, placeholders, types):
tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph)
op._outputs = [tensor]
tensors.append(tensor)
# Update `cond_graph._captures` and `cond_graph.inputs` to contain the
# newly created placeholders.
tuples = zip(body_graph_captures, tensors)
keys = [id(t) for t in body_graph_captures]
cond_graph._captures.update(zip(keys, tuples))
cond_graph.inputs.extend(tensors)
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
custom_gradient.copy_handle_data(src_t, tgt_t)
@ -1151,4 +1196,29 @@ def _build_accumulator_name(tensor):
def _is_loop_invariant(tensor, inputs, outputs):
return tensor in inputs and tensor in outputs
class _OperationWithOutputs(ops.Operation):
"""Operation with pre-built `TF_Output`s.
The C API for creating the extra placeholders for the cond graph returns
SWIG wrapped TF_Output* pointers which we can use directly for
`Operation.outputs`. The default constructor for `Operation` does not provide
a way of specifying pre-built output tensors and always creates them. This is
a performance overhead. It is not clear if adding that feature to the
`Operation` API would be generally useful so for now we just have our own
lightweight `Operation` implementation. Note that this does not extract a
stacktrace as well since we don't expect this operation to be used.
TODO(b/143286622): This should not be required once captures are separated
from regular loop vars.
"""
def __init__(self, c_op, g):
self._c_op = c_op
self._graph = g
self._outputs = None # Initialized by _duplicate_body_captures_in_cond().
self._id_value = g._add_op(self, self.name)
self._is_stateful = False
# pylint: enable=protected-access