From cbbbe1dc0277dc11712e9e55844fca0dc7e9c373 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 5 Mar 2020 11:13:07 -0800 Subject: [PATCH] [CastOp] Avoid creating and destroying a temporary tensor in the common case. Previously, we would copy the input tensor to a temporary to admit the possibility of having to bitcast a quantized input before running the cast function. This change avoids the need for a temporary in the common (non-quantized) case. PiperOrigin-RevId: 299146069 Change-Id: Id45e86394179c4241fa2eab21ea74eeb7bf9feba --- tensorflow/core/kernels/cast_op.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 8db2ad8e5ab..e8c428a80d0 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -96,20 +96,21 @@ void CastOpBase::Compute(OpKernelContext* ctx) { const Tensor& inp = ctx->input(0); if (work_ == nullptr) { ctx->set_output(0, inp); - } else { + } else if (external_src_dtype_ != src_dtype_ || + external_dst_dtype_ != dst_dtype_) { Tensor in; - if (external_src_dtype_ != src_dtype_) { - // If the type is a quantized type we need to do a bitcast since the - // src_dtype_ is different from external_src_type_. - OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape())); - } else { - in = inp; - } + // If the type is a quantized type we need to do a bitcast since the + // src_dtype_ is different from external_src_type_. + OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape())); Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); out->set_dtype(dst_dtype_); work_(ctx, in, out, use_truncation_); out->set_dtype(external_dst_dtype_); + } else { + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); + work_(ctx, inp, out, use_truncation_); } }