[CastOp] Avoid creating and destroying a temporary tensor in the common case.
Previously, we would copy the input tensor to a temporary to admit the possibility of having to bitcast a quantized input before running the cast function. This change avoids the need for a temporary in the common (non-quantized) case. PiperOrigin-RevId: 299146069 Change-Id: Id45e86394179c4241fa2eab21ea74eeb7bf9feba
This commit is contained in:
parent
a0fc80b922
commit
cbbbe1dc02
@ -96,20 +96,21 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
|
||||
const Tensor& inp = ctx->input(0);
|
||||
if (work_ == nullptr) {
|
||||
ctx->set_output(0, inp);
|
||||
} else {
|
||||
} else if (external_src_dtype_ != src_dtype_ ||
|
||||
external_dst_dtype_ != dst_dtype_) {
|
||||
Tensor in;
|
||||
if (external_src_dtype_ != src_dtype_) {
|
||||
// If the type is a quantized type we need to do a bitcast since the
|
||||
// src_dtype_ is different from external_src_type_.
|
||||
OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
|
||||
} else {
|
||||
in = inp;
|
||||
}
|
||||
// If the type is a quantized type we need to do a bitcast since the
|
||||
// src_dtype_ is different from external_src_type_.
|
||||
OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
|
||||
out->set_dtype(dst_dtype_);
|
||||
work_(ctx, in, out, use_truncation_);
|
||||
out->set_dtype(external_dst_dtype_);
|
||||
} else {
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
|
||||
work_(ctx, inp, out, use_truncation_);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user