Avoid string copies during Substr broadcasting.

Substr op implements broadcasting by allocating temporary tensors of the match
shapes. When the input string tensor contains large string tensors, this
procedure may blow up memory easily.

PiperOrigin-RevId: 357652876
Change-Id: I58723a2e8d653a72b0cfe32d0093c8da69cb94da
This commit is contained in:
Sung Jin Hwang 2021-02-15 22:52:16 -08:00 committed by TensorFlower Gardener
parent df608cf4ef
commit a3ca926fb2

View File

@ -151,15 +151,6 @@ class SubstrOp : public OpKernel {
auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape());
auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape());
// Allocate temporary buffer for broadcasted input tensor
Tensor input_buffer;
OP_REQUIRES_OK(context, context->allocate_temp(
DT_STRING, output_shape, &input_buffer));
TTypes<tstring, 1>::Tensor input_bcast =
input_buffer.shaped<tstring, 1>(bcast.result_shape());
input_bcast =
input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast()));
// Allocate temporary buffer for broadcasted position tensor
Tensor pos_buffer;
OP_REQUIRES_OK(context,
@ -182,7 +173,7 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
StringPiece in(input_bcast(i));
StringPiece in(input(input.dimension(0) > 1 ? i : 0));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
T byte_pos = pos;
@ -197,8 +188,7 @@ class SubstrOp : public OpKernel {
case CharUnit::BYTE:
byte_pos = AdjustedPosIndex(byte_pos, in);
OP_REQUIRES(
context,
FastBoundsCheck(byte_pos, input_bcast(i).size() + 1),
context, FastBoundsCheck(byte_pos, in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for ",
"string b'", in, "' at index ", i));
}
@ -214,15 +204,6 @@ class SubstrOp : public OpKernel {
auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape());
auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape());
// Allocate temporary buffer for broadcasted input tensor
Tensor input_buffer;
OP_REQUIRES_OK(context, context->allocate_temp(
DT_STRING, output_shape, &input_buffer));
TTypes<tstring, 2>::Tensor input_bcast =
input_buffer.shaped<tstring, 2>(bcast.result_shape());
input_bcast =
input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast()));
// Allocate temporary buffer for broadcasted position tensor
Tensor pos_buffer;
OP_REQUIRES_OK(context,
@ -246,7 +227,8 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
for (int j = 0; j < output_shape.dim_size(1); ++j) {
StringPiece in(input_bcast(i, j));
StringPiece in(input(input.dimension(0) > 1 ? i : 0,
input.dimension(1) > 1 ? j : 0));
const T pos =
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =