Convert Softmax to custom op
Change the old TF2XLA bridge to convert a SoftMax op to a custom call to softmax when the custom_fake_quant_op_calls flag is enabled. This is to preserve the information for quantization. PiperOrigin-RevId: 305892297 Change-Id: Id372e7da5753dd3109d6e993c6193d155ba136bd
This commit is contained in:
parent
30ecd9e2dd
commit
dc9eec091c
@ -24,14 +24,28 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Builds a custom call to a method named 'softmax' or 'log_softmax'.
|
||||
xla::StatusOr<xla::XlaOp> BuildSoftmaxCustomCall(xla::XlaBuilder* b,
|
||||
xla::XlaOp logits, bool log) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape logits_shape, b->GetShape(logits));
|
||||
return xla::CustomCallWithLayout(b, log ? "log_softmax" : "softmax", {logits},
|
||||
logits_shape, {logits_shape});
|
||||
}
|
||||
|
||||
class SoftmaxOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
@ -55,6 +69,15 @@ class SoftmaxOp : public XlaOpKernel {
|
||||
auto logits = ctx->Input(0);
|
||||
|
||||
xla::XlaBuilder* const b = ctx->builder();
|
||||
|
||||
if (ctx->compiler()->options().allow_cpu_custom_calls &&
|
||||
ctx->compiler()->options().custom_fake_quant_op_calls) {
|
||||
xla::XlaOp custom_call_output =
|
||||
b->ReportErrorOrReturn(BuildSoftmaxCustomCall(b, logits, log_));
|
||||
ctx->SetOutput(0, custom_call_output);
|
||||
return;
|
||||
}
|
||||
|
||||
const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
|
||||
|
||||
// Find the max in each batch, resulting in a tensor of shape [batch]
|
||||
|
Loading…
Reference in New Issue
Block a user