[tf.data] Add missing std::vector::reserve() calls.

PiperOrigin-RevId: 288319987
Change-Id: I07bbec3f0bf00223505226285a5ab0c20dac5569
This commit is contained in:
Derek Murray 2020-01-06 09:57:11 -08:00 committed by TensorFlower Gardener
parent 54381429c2
commit feffae326d
6 changed files with 8 additions and 0 deletions

View File

@ -186,6 +186,7 @@ class BatchDatasetOp::Dataset : public DatasetBase {
// overload that supports zero-copy, and might make sense in an
// optimization pass.
const size_t num_tuple_components = batch_elements[0].size();
out_tensors->reserve(num_tuple_components);
const int64 num_batch_elements = batch_elements.size();
for (size_t component_index = 0; component_index < num_tuple_components;
++component_index) {

View File

@ -110,6 +110,7 @@ Status RunShortCircuit(const ShortCircuitInfo& info,
std::vector<Tensor>* rets) {
VLOG(3) << "Running function " << func->func().name() << " short circuit";
size_t num_args = args.size();
rets->reserve(info.indices.size());
for (size_t i = 0; i < info.indices.size(); ++i) {
if (info.indices[i] < num_args) {
rets->push_back(args[info.indices[i]]);
@ -125,6 +126,7 @@ Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) {
VLOG(3) << "Running function " << func->func().name() << " short circuit";
size_t num_args = args.size();
rets->reserve(info.indices.size());
for (size_t i = 0; i < info.indices.size(); ++i) {
if (info.indices[i] < num_args) {
if (info.can_move[i]) {

View File

@ -466,6 +466,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
const size_t num_components = return_values->size();
result->output.reserve(num_components);
for (size_t i = 0; i < num_components; ++i) {
TensorShape component_shape({dataset()->batch_size_});
component_shape.AppendShape(return_values->at(i).shape());

View File

@ -86,6 +86,7 @@ class RandomDatasetOp::Dataset : public DatasetBase {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
out_tensors->reserve(1);
mutex_lock l(mu_);
out_tensors->emplace_back(ctx->allocator({}), DT_INT64, TensorShape({}));
out_tensors->back().scalar<int64>()() = Random();

View File

@ -38,6 +38,8 @@ class TensorDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
: DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
dtypes_.reserve(tensors_.size());
shapes_.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
dtypes_.push_back(t.dtype());
shapes_.emplace_back(t.shape().dim_sizes());

View File

@ -100,6 +100,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
out_tensors->reserve(1);
mutex_lock l(mu_);
do {
// We are currently processing a file, so try to read the next record.