diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 8db2ad8e5ab..e8c428a80d0 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -96,20 +96,21 @@ void CastOpBase::Compute(OpKernelContext* ctx) { const Tensor& inp = ctx->input(0); if (work_ == nullptr) { ctx->set_output(0, inp); - } else { + } else if (external_src_dtype_ != src_dtype_ || + external_dst_dtype_ != dst_dtype_) { Tensor in; - if (external_src_dtype_ != src_dtype_) { - // If the type is a quantized type we need to do a bitcast since the - // src_dtype_ is different from external_src_type_. - OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape())); - } else { - in = inp; - } + // If the type is a quantized type we need to do a bitcast since the + // src_dtype_ is different from external_src_type_. + OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape())); Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); out->set_dtype(dst_dtype_); work_(ctx, in, out, use_truncation_); out->set_dtype(external_dst_dtype_); + } else { + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); + work_(ctx, inp, out, use_truncation_); } }