[CastOp] Avoid creating and destroying a temporary tensor in the common case.

Previously, we would copy the input tensor to a temporary to admit the possibility of having to bitcast a quantized input before running the cast function. This change avoids the need for a temporary in the common (non-quantized) case.

PiperOrigin-RevId: 299146069
Change-Id: Id45e86394179c4241fa2eab21ea74eeb7bf9feba
This commit is contained in:
Derek Murray 2020-03-05 11:13:07 -08:00 committed by TensorFlower Gardener
parent a0fc80b922
commit cbbbe1dc02

View File

@ -96,20 +96,21 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
const Tensor& inp = ctx->input(0);
if (work_ == nullptr) {
ctx->set_output(0, inp);
} else {
} else if (external_src_dtype_ != src_dtype_ ||
external_dst_dtype_ != dst_dtype_) {
Tensor in;
if (external_src_dtype_ != src_dtype_) {
// If the type is a quantized type we need to do a bitcast since the
// src_dtype_ is different from external_src_type_.
OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
} else {
in = inp;
}
// If the type is a quantized type we need to do a bitcast since the
// src_dtype_ is different from external_src_type_.
OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->set_dtype(dst_dtype_);
work_(ctx, in, out, use_truncation_);
out->set_dtype(external_dst_dtype_);
} else {
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
work_(ctx, inp, out, use_truncation_);
}
}