From a3ca926fb20b54b08737f367ee3583b584f87861 Mon Sep 17 00:00:00 2001 From: Sung Jin Hwang Date: Mon, 15 Feb 2021 22:52:16 -0800 Subject: [PATCH] Avoid string copies during Substr broadcasting. Substr op implements broadcasting by allocating temporary tensors of the match shapes. When the input string tensor contains large string tensors, this procedure may blow up memory easily. PiperOrigin-RevId: 357652876 Change-Id: I58723a2e8d653a72b0cfe32d0093c8da69cb94da --- tensorflow/core/kernels/substr_op.cc | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index ab83efda2a2..8ca14c4de6a 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -151,15 +151,6 @@ class SubstrOp : public OpKernel { auto pos_shaped = pos_tensor.shaped(bcast.y_reshape()); auto len_shaped = len_tensor.shaped(bcast.y_reshape()); - // Allocate temporary buffer for broadcasted input tensor - Tensor input_buffer; - OP_REQUIRES_OK(context, context->allocate_temp( - DT_STRING, output_shape, &input_buffer)); - TTypes::Tensor input_bcast = - input_buffer.shaped(bcast.result_shape()); - input_bcast = - input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast())); - // Allocate temporary buffer for broadcasted position tensor Tensor pos_buffer; OP_REQUIRES_OK(context, @@ -182,7 +173,7 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { - StringPiece in(input_bcast(i)); + StringPiece in(input(input.dimension(0) > 1 ? i : 0)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); T byte_pos = pos; @@ -197,8 +188,7 @@ class SubstrOp : public OpKernel { case CharUnit::BYTE: byte_pos = AdjustedPosIndex(byte_pos, in); OP_REQUIRES( - context, - FastBoundsCheck(byte_pos, input_bcast(i).size() + 1), + context, FastBoundsCheck(byte_pos, in.size() + 1), errors::InvalidArgument("pos ", pos, " out of range for ", "string b'", in, "' at index ", i)); } @@ -214,15 +204,6 @@ class SubstrOp : public OpKernel { auto pos_shaped = pos_tensor.shaped(bcast.y_reshape()); auto len_shaped = len_tensor.shaped(bcast.y_reshape()); - // Allocate temporary buffer for broadcasted input tensor - Tensor input_buffer; - OP_REQUIRES_OK(context, context->allocate_temp( - DT_STRING, output_shape, &input_buffer)); - TTypes::Tensor input_bcast = - input_buffer.shaped(bcast.result_shape()); - input_bcast = - input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast())); - // Allocate temporary buffer for broadcasted position tensor Tensor pos_buffer; OP_REQUIRES_OK(context, @@ -246,7 +227,8 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { for (int j = 0; j < output_shape.dim_size(1); ++j) { - StringPiece in(input_bcast(i, j)); + StringPiece in(input(input.dimension(0) > 1 ? i : 0, + input.dimension(1) > 1 ? j : 0)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len =