Convert Softmax to custom op

Change the old TF2XLA bridge to convert a SoftMax op to a custom call to
softmax when the custom_fake_quant_op_calls flag is enabled. This is to
preserve the information for quantization.

PiperOrigin-RevId: 305892297
Change-Id: Id372e7da5753dd3109d6e993c6193d155ba136bd
This commit is contained in:
Feng Liu 2020-04-10 09:49:36 -07:00 committed by TensorFlower Gardener
parent 30ecd9e2dd
commit dc9eec091c

View File

@ -24,14 +24,28 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace {
// Builds a custom call to a method named 'softmax' or 'log_softmax'.
xla::StatusOr<xla::XlaOp> BuildSoftmaxCustomCall(xla::XlaBuilder* b,
xla::XlaOp logits, bool log) {
TF_ASSIGN_OR_RETURN(xla::Shape logits_shape, b->GetShape(logits));
return xla::CustomCallWithLayout(b, log ? "log_softmax" : "softmax", {logits},
logits_shape, {logits_shape});
}
class SoftmaxOp : public XlaOpKernel {
public:
explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@ -55,6 +69,15 @@ class SoftmaxOp : public XlaOpKernel {
auto logits = ctx->Input(0);
xla::XlaBuilder* const b = ctx->builder();
if (ctx->compiler()->options().allow_cpu_custom_calls &&
ctx->compiler()->options().custom_fake_quant_op_calls) {
xla::XlaOp custom_call_output =
b->ReportErrorOrReturn(BuildSoftmaxCustomCall(b, logits, log_));
ctx->SetOutput(0, custom_call_output);
return;
}
const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
// Find the max in each batch, resulting in a tensor of shape [batch]