Avoid string copies during Substr broadcasting.
Substr op implements broadcasting by allocating temporary tensors of the match shapes. When the input string tensor contains large string tensors, this procedure may blow up memory easily. PiperOrigin-RevId: 357652876 Change-Id: I58723a2e8d653a72b0cfe32d0093c8da69cb94da
This commit is contained in:
parent
df608cf4ef
commit
a3ca926fb2
@ -151,15 +151,6 @@ class SubstrOp : public OpKernel {
|
||||
auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape());
|
||||
auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape());
|
||||
|
||||
// Allocate temporary buffer for broadcasted input tensor
|
||||
Tensor input_buffer;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DT_STRING, output_shape, &input_buffer));
|
||||
TTypes<tstring, 1>::Tensor input_bcast =
|
||||
input_buffer.shaped<tstring, 1>(bcast.result_shape());
|
||||
input_bcast =
|
||||
input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast()));
|
||||
|
||||
// Allocate temporary buffer for broadcasted position tensor
|
||||
Tensor pos_buffer;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -182,7 +173,7 @@ class SubstrOp : public OpKernel {
|
||||
|
||||
// Iterate through broadcasted tensors and perform substr
|
||||
for (int i = 0; i < output_shape.dim_size(0); ++i) {
|
||||
StringPiece in(input_bcast(i));
|
||||
StringPiece in(input(input.dimension(0) > 1 ? i : 0));
|
||||
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
|
||||
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
|
||||
T byte_pos = pos;
|
||||
@ -197,8 +188,7 @@ class SubstrOp : public OpKernel {
|
||||
case CharUnit::BYTE:
|
||||
byte_pos = AdjustedPosIndex(byte_pos, in);
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(byte_pos, input_bcast(i).size() + 1),
|
||||
context, FastBoundsCheck(byte_pos, in.size() + 1),
|
||||
errors::InvalidArgument("pos ", pos, " out of range for ",
|
||||
"string b'", in, "' at index ", i));
|
||||
}
|
||||
@ -214,15 +204,6 @@ class SubstrOp : public OpKernel {
|
||||
auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape());
|
||||
auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape());
|
||||
|
||||
// Allocate temporary buffer for broadcasted input tensor
|
||||
Tensor input_buffer;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DT_STRING, output_shape, &input_buffer));
|
||||
TTypes<tstring, 2>::Tensor input_bcast =
|
||||
input_buffer.shaped<tstring, 2>(bcast.result_shape());
|
||||
input_bcast =
|
||||
input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast()));
|
||||
|
||||
// Allocate temporary buffer for broadcasted position tensor
|
||||
Tensor pos_buffer;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -246,7 +227,8 @@ class SubstrOp : public OpKernel {
|
||||
// Iterate through broadcasted tensors and perform substr
|
||||
for (int i = 0; i < output_shape.dim_size(0); ++i) {
|
||||
for (int j = 0; j < output_shape.dim_size(1); ++j) {
|
||||
StringPiece in(input_bcast(i, j));
|
||||
StringPiece in(input(input.dimension(0) > 1 ? i : 0,
|
||||
input.dimension(1) > 1 ? j : 0));
|
||||
const T pos =
|
||||
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
|
||||
const T len =
|
||||
|
Loading…
x
Reference in New Issue
Block a user