[tf.data] Add missing std::vector::reserve() calls.
PiperOrigin-RevId: 288319987 Change-Id: I07bbec3f0bf00223505226285a5ab0c20dac5569
This commit is contained in:
parent
54381429c2
commit
feffae326d
|
@ -186,6 +186,7 @@ class BatchDatasetOp::Dataset : public DatasetBase {
|
||||||
// overload that supports zero-copy, and might make sense in an
|
// overload that supports zero-copy, and might make sense in an
|
||||||
// optimization pass.
|
// optimization pass.
|
||||||
const size_t num_tuple_components = batch_elements[0].size();
|
const size_t num_tuple_components = batch_elements[0].size();
|
||||||
|
out_tensors->reserve(num_tuple_components);
|
||||||
const int64 num_batch_elements = batch_elements.size();
|
const int64 num_batch_elements = batch_elements.size();
|
||||||
for (size_t component_index = 0; component_index < num_tuple_components;
|
for (size_t component_index = 0; component_index < num_tuple_components;
|
||||||
++component_index) {
|
++component_index) {
|
||||||
|
|
|
@ -110,6 +110,7 @@ Status RunShortCircuit(const ShortCircuitInfo& info,
|
||||||
std::vector<Tensor>* rets) {
|
std::vector<Tensor>* rets) {
|
||||||
VLOG(3) << "Running function " << func->func().name() << " short circuit";
|
VLOG(3) << "Running function " << func->func().name() << " short circuit";
|
||||||
size_t num_args = args.size();
|
size_t num_args = args.size();
|
||||||
|
rets->reserve(info.indices.size());
|
||||||
for (size_t i = 0; i < info.indices.size(); ++i) {
|
for (size_t i = 0; i < info.indices.size(); ++i) {
|
||||||
if (info.indices[i] < num_args) {
|
if (info.indices[i] < num_args) {
|
||||||
rets->push_back(args[info.indices[i]]);
|
rets->push_back(args[info.indices[i]]);
|
||||||
|
@ -125,6 +126,7 @@ Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args,
|
||||||
std::vector<Tensor>* rets) {
|
std::vector<Tensor>* rets) {
|
||||||
VLOG(3) << "Running function " << func->func().name() << " short circuit";
|
VLOG(3) << "Running function " << func->func().name() << " short circuit";
|
||||||
size_t num_args = args.size();
|
size_t num_args = args.size();
|
||||||
|
rets->reserve(info.indices.size());
|
||||||
for (size_t i = 0; i < info.indices.size(); ++i) {
|
for (size_t i = 0; i < info.indices.size(); ++i) {
|
||||||
if (info.indices[i] < num_args) {
|
if (info.indices[i] < num_args) {
|
||||||
if (info.can_move[i]) {
|
if (info.can_move[i]) {
|
||||||
|
|
|
@ -466,6 +466,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
const size_t num_components = return_values->size();
|
const size_t num_components = return_values->size();
|
||||||
|
result->output.reserve(num_components);
|
||||||
for (size_t i = 0; i < num_components; ++i) {
|
for (size_t i = 0; i < num_components; ++i) {
|
||||||
TensorShape component_shape({dataset()->batch_size_});
|
TensorShape component_shape({dataset()->batch_size_});
|
||||||
component_shape.AppendShape(return_values->at(i).shape());
|
component_shape.AppendShape(return_values->at(i).shape());
|
||||||
|
|
|
@ -86,6 +86,7 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
|
out_tensors->reserve(1);
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
out_tensors->emplace_back(ctx->allocator({}), DT_INT64, TensorShape({}));
|
out_tensors->emplace_back(ctx->allocator({}), DT_INT64, TensorShape({}));
|
||||||
out_tensors->back().scalar<int64>()() = Random();
|
out_tensors->back().scalar<int64>()() = Random();
|
||||||
|
|
|
@ -38,6 +38,8 @@ class TensorDatasetOp::Dataset : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
|
Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
|
||||||
: DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
|
: DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
|
||||||
|
dtypes_.reserve(tensors_.size());
|
||||||
|
shapes_.reserve(tensors_.size());
|
||||||
for (const Tensor& t : tensors_) {
|
for (const Tensor& t : tensors_) {
|
||||||
dtypes_.push_back(t.dtype());
|
dtypes_.push_back(t.dtype());
|
||||||
shapes_.emplace_back(t.shape().dim_sizes());
|
shapes_.emplace_back(t.shape().dim_sizes());
|
||||||
|
|
|
@ -100,6 +100,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
|
out_tensors->reserve(1);
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
do {
|
do {
|
||||||
// We are currently processing a file, so try to read the next record.
|
// We are currently processing a file, so try to read the next record.
|
||||||
|
|
Loading…
Reference in New Issue