Add lowering of Lgamma to TensorFlow dialect.
Based on the XlaClient implementation of Lgamma at tensorflow/compiler/xla/compile/lib/math.cc. PiperOrigin-RevId: 337562141 Change-Id: I0fa92add4d062130cbfbdc14987d5103abaae40b
This commit is contained in:
parent
8653ed4e56
commit
2e58714aee
@ -693,3 +693,14 @@ func @round_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Round"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @lgamma
|
||||
func @lgamma(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// The lowering for lgamma is complicated, which makes it awkward to write a
|
||||
// complete test for it here. Instead we test that Lgamma is at least being
|
||||
// lowered here and rely on UnaryOpsTest.testFloatOps and other TensorFlow
|
||||
// tests to check it is lowered correctly and with sufficient precision.
|
||||
// CHECK-NOT: tf.Lgamma
|
||||
%0 = "tf.Lgamma"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -477,6 +477,210 @@ class LowerInvertPermutationOp
|
||||
}
|
||||
};
|
||||
|
||||
// Approximates lgamma using Lanczos' approximation from
|
||||
// "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical
|
||||
// Analysis series B. Vol. 1:
|
||||
// lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
|
||||
// t(z) = z + kLanczosGamma + 1/2
|
||||
// A(z) = kBaseLanczosCoeff
|
||||
// + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
//
|
||||
// Coefficients for the Lanczos approximation of the gamma function. The
|
||||
// coefficients are uniquely determined by the choice of g and n
|
||||
// (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below
|
||||
// correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were
|
||||
// evaluated and [7, 9] seemed to be the least sensitive to the quality of the
|
||||
// log function. In particular, [5, 7] is the only choice where -1.5e-5 <=
|
||||
// lgamma(2) <= 1.5e-5 for a particularly inaccurate log function.
|
||||
static constexpr double kLanczosGamma = 7; // aka g
|
||||
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
|
||||
static constexpr std::array<double, 8> kLanczosCoefficients = {
|
||||
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
|
||||
771.3234287776530788486528258894, -176.61502916214059906584551354,
|
||||
12.507343278686904814458936853, -0.13857109526572011689554707,
|
||||
9.984369578019570859563e-6, 1.50563273514931155834e-7};
|
||||
|
||||
class LowerLgammaOp : public OpRewritePattern<TF::LgammaOp> {
|
||||
public:
|
||||
using OpRewritePattern<TF::LgammaOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::LgammaOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.x();
|
||||
TensorType original_tensor_type = op.x().getType().cast<TensorType>();
|
||||
|
||||
// The approximation is not precise enough for float16. Do the computation
|
||||
// in float32 for that case.
|
||||
TensorType tensor_type = original_tensor_type;
|
||||
FloatType float_type = tensor_type.getElementType().cast<FloatType>();
|
||||
bool needs_cast = float_type.getWidth() < 32;
|
||||
if (needs_cast) {
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
float_type = FloatType::getF32(context);
|
||||
if (original_tensor_type.hasRank()) {
|
||||
tensor_type =
|
||||
RankedTensorType::get(original_tensor_type.getShape(), float_type);
|
||||
} else {
|
||||
tensor_type = UnrankedTensorType::get(float_type);
|
||||
}
|
||||
input = rewriter.create<TF::CastOp>(loc, tensor_type, input);
|
||||
}
|
||||
|
||||
// Helper lambda function for creating a ConstOp for a tensor filled with
|
||||
// the given constant float value.
|
||||
auto create_const_op = [&rewriter, loc, tensor_type,
|
||||
float_type](double value) {
|
||||
return rewriter.create<TF::ConstOp>(
|
||||
loc, DenseElementsAttr::get(tensor_type,
|
||||
FloatAttr::get(float_type, value)));
|
||||
};
|
||||
|
||||
Value one_half = create_const_op(0.5);
|
||||
Value one = create_const_op(1.0);
|
||||
Value infinity = create_const_op(std::numeric_limits<double>::infinity());
|
||||
Value pi = create_const_op(M_PI);
|
||||
Value log_pi = create_const_op(std::log(M_PI));
|
||||
Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2);
|
||||
Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5);
|
||||
Value log_lanczos_gamma_plus_one_half =
|
||||
create_const_op(std::log(kLanczosGamma + 0.5));
|
||||
Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff);
|
||||
|
||||
Value minus_input = rewriter.create<TF::NegOp>(loc, input);
|
||||
Value input_minus_one = rewriter.create<TF::SubOp>(loc, input, one);
|
||||
|
||||
// If the input is less than 0.5 use Euler's reflection formula:
|
||||
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
|
||||
Value need_to_reflect = rewriter.create<TF::LessOp>(loc, input, one_half);
|
||||
Type tensor_bool_type = need_to_reflect.getType();
|
||||
Value z = rewriter.create<TF::SelectV2Op>(loc, need_to_reflect, minus_input,
|
||||
input_minus_one);
|
||||
|
||||
Value x = base_lanczos_coeff;
|
||||
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||
Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]);
|
||||
Value index = create_const_op(static_cast<double>(i));
|
||||
Value z_plus_index = rewriter.create<TF::AddV2Op>(loc, z, index);
|
||||
Value z_plus_index_plus_one =
|
||||
rewriter.create<TF::AddV2Op>(loc, z_plus_index, one);
|
||||
Value incr = rewriter.create<TF::DivOp>(loc, lanczos_coefficient,
|
||||
z_plus_index_plus_one);
|
||||
x = rewriter.create<TF::AddV2Op>(loc, x, incr);
|
||||
}
|
||||
|
||||
// To improve accuracy on platforms with less-precise log implementations,
|
||||
// compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
|
||||
// the device.
|
||||
// log(t) = log(kLanczosGamma + 0.5 + z)
|
||||
// = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
|
||||
Value t = rewriter.create<TF::AddV2Op>(loc, lanczos_gamma_plus_one_half, z);
|
||||
Value z_div_lanczos_gamma_plus_one_half =
|
||||
rewriter.create<TF::DivOp>(loc, z, lanczos_gamma_plus_one_half);
|
||||
Value log1p_z_div_lanczos_gamma_plus_one_half =
|
||||
rewriter.create<TF::Log1pOp>(loc, z_div_lanczos_gamma_plus_one_half);
|
||||
Value log_t =
|
||||
rewriter.create<TF::AddV2Op>(loc, log_lanczos_gamma_plus_one_half,
|
||||
log1p_z_div_lanczos_gamma_plus_one_half);
|
||||
|
||||
// Compute the final result (modulo reflection). t(z) may be large, and we
|
||||
// need to be careful not to overflow to infinity in the first term of
|
||||
//
|
||||
// (z + 1/2) * log(t(z)) - t(z).
|
||||
//
|
||||
// Therefore we compute this as
|
||||
//
|
||||
// (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
|
||||
//
|
||||
// log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
|
||||
Value t_div_log_t = rewriter.create<TF::DivOp>(loc, t, log_t);
|
||||
Value one_half_minus_t_div_log_t =
|
||||
rewriter.create<TF::SubOp>(loc, one_half, t_div_log_t);
|
||||
Value z_plus_one_half_minus_t_div_log_t =
|
||||
rewriter.create<TF::AddV2Op>(loc, z, one_half_minus_t_div_log_t);
|
||||
Value z_plus_one_half_minus_t_div_log_t_mul_log_t =
|
||||
rewriter.create<TF::MulOp>(loc, z_plus_one_half_minus_t_div_log_t,
|
||||
log_t);
|
||||
Value log_x = rewriter.create<TF::LogOp>(loc, x);
|
||||
Value log_y_rhs = rewriter.create<TF::AddV2Op>(
|
||||
loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x);
|
||||
Value log_y = rewriter.create<TF::AddV2Op>(loc, log_sqrt_two_pi, log_y_rhs);
|
||||
|
||||
// Compute the reflected value, used when x < 0.5:
|
||||
//
|
||||
// lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
|
||||
//
|
||||
// (The abs is because lgamma is the log of the absolute value of the gamma
|
||||
// function.)
|
||||
//
|
||||
// We have to be careful when computing the final term above. gamma(x) goes
|
||||
// to +/-inf at every integer x < 0, and this is controlled by the
|
||||
// sin(pi * x) term. The slope is large, so precision is particularly
|
||||
// important.
|
||||
//
|
||||
// Because abs(sin(pi * x)) has period 1, we can equivalently use
|
||||
// abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
|
||||
// is more numerically accurate: It doesn't overflow to inf like pi * x can,
|
||||
// and if x is an integer, it evaluates to 0 exactly, which is significant
|
||||
// because we then take the log of this value, and log(0) is inf.
|
||||
//
|
||||
// We don't have a frac(x) primitive in XLA and computing it is tricky, but
|
||||
// because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
|
||||
// our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
|
||||
//
|
||||
// Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
|
||||
// to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
|
||||
// [0, 1] is symmetric across the line Y=0.5.
|
||||
Value abs_input = rewriter.create<TF::AbsOp>(loc, input);
|
||||
Value abs_input_floor = rewriter.create<TF::FloorOp>(loc, abs_input);
|
||||
Value abs_frac_input =
|
||||
rewriter.create<TF::SubOp>(loc, abs_input, abs_input_floor);
|
||||
|
||||
// Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
|
||||
// precision of pi * abs_frac_input for values of abs_frac_input close to 1.
|
||||
Value one_minus_abs_frac_input =
|
||||
rewriter.create<TF::SubOp>(loc, one, abs_frac_input);
|
||||
Value abs_frac_input_gt_one_half =
|
||||
rewriter.create<TF::GreaterOp>(loc, abs_frac_input, one_half);
|
||||
Value reduced_frac_input = rewriter.create<TF::SelectV2Op>(
|
||||
loc, abs_frac_input_gt_one_half, one_minus_abs_frac_input,
|
||||
abs_frac_input);
|
||||
Value pi_mul_reduced_frac_input =
|
||||
rewriter.create<TF::MulOp>(loc, pi, reduced_frac_input);
|
||||
Value sin_pi_mul_reduced_frac_input =
|
||||
rewriter.create<TF::SinOp>(loc, pi_mul_reduced_frac_input);
|
||||
Value reflection_denom =
|
||||
rewriter.create<TF::LogOp>(loc, sin_pi_mul_reduced_frac_input);
|
||||
|
||||
// Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
|
||||
// then it "wins" and the result is +/-inf.
|
||||
Value is_finite = rewriter.create<TF::IsFiniteOp>(loc, tensor_bool_type,
|
||||
reflection_denom);
|
||||
Value neg_reflection_denom =
|
||||
rewriter.create<TF::NegOp>(loc, reflection_denom);
|
||||
Value log_pi_minus_reflection_denom =
|
||||
rewriter.create<TF::SubOp>(loc, log_pi, reflection_denom);
|
||||
Value reflection_if_finite =
|
||||
rewriter.create<TF::SubOp>(loc, log_pi_minus_reflection_denom, log_y);
|
||||
Value reflection = rewriter.create<TF::SelectV2Op>(
|
||||
loc, is_finite, reflection_if_finite, neg_reflection_denom);
|
||||
|
||||
Value result = rewriter.create<TF::SelectV2Op>(loc, need_to_reflect,
|
||||
reflection, log_y);
|
||||
|
||||
// lgamma(+/-inf) = +inf.
|
||||
Value is_inf = rewriter.create<TF::IsInfOp>(loc, tensor_bool_type, input);
|
||||
result = rewriter.create<SelectV2Op>(loc, is_inf, infinity, result);
|
||||
|
||||
if (needs_cast) {
|
||||
result = rewriter.create<TF::CastOp>(loc, original_tensor_type, result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Lowers Pack op to ConcatV2 op after changing shape of the inputs with
|
||||
// ExpandDims op.
|
||||
//
|
||||
@ -777,9 +981,9 @@ class Lower_UnaryOpsComposition
|
||||
void PopulateLoweringTFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<LowerAddNOp, ConvertFakeQuantWithMinMaxVarsOp,
|
||||
LowerDynamicStitchOp, LowerInvertPermutationOp, LowerPackOp,
|
||||
LowerSpaceToBatchNDOp, LowerSparseMatMulOp,
|
||||
Lower_UnaryOpsComposition>(context);
|
||||
LowerDynamicStitchOp, LowerInvertPermutationOp,
|
||||
LowerLgammaOp, LowerPackOp, LowerSpaceToBatchNDOp,
|
||||
LowerSparseMatMulOp, Lower_UnaryOpsComposition>(context);
|
||||
populateWithGenerated(context, *patterns);
|
||||
}
|
||||
|
||||
|
@ -161,7 +161,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::LeftShiftOp>(),
|
||||
TypeID::get<TF::LessEqualOp>(),
|
||||
TypeID::get<TF::LessOp>(),
|
||||
TypeID::get<TF::LgammaOp>(),
|
||||
TypeID::get<TF::ListDiffOp>(),
|
||||
TypeID::get<TF::LogicalAndOp>(),
|
||||
TypeID::get<TF::LogicalNotOp>(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user