Add lowering of Lgamma to TensorFlow dialect.

Based on the XlaClient implementation of Lgamma at
tensorflow/compiler/xla/compile/lib/math.cc.

PiperOrigin-RevId: 337562141
Change-Id: I0fa92add4d062130cbfbdc14987d5103abaae40b
This commit is contained in:
Richard Uhler 2020-10-16 13:00:15 -07:00 committed by TensorFlower Gardener
parent 8653ed4e56
commit 2e58714aee
3 changed files with 218 additions and 4 deletions

View File

@ -693,3 +693,14 @@ func @round_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Round"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @lgamma
func @lgamma(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// The lowering for lgamma is complicated, which makes it awkward to write a
// complete test for it here. Instead we test that Lgamma is at least being
// lowered here and rely on UnaryOpsTest.testFloatOps and other TensorFlow
// tests to check it is lowered correctly and with sufficient precision.
// CHECK-NOT: tf.Lgamma
%0 = "tf.Lgamma"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

View File

@ -477,6 +477,210 @@ class LowerInvertPermutationOp
}
};
// Approximates lgamma using Lanczos' approximation from
// "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical
// Analysis series B. Vol. 1:
// lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
// t(z) = z + kLanczosGamma + 1/2
// A(z) = kBaseLanczosCoeff
// + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
//
// Coefficients for the Lanczos approximation of the gamma function. The
// coefficients are uniquely determined by the choice of g and n
// (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below
// correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were
// evaluated and [7, 9] seemed to be the least sensitive to the quality of the
// log function. In particular, [5, 7] is the only choice where -1.5e-5 <=
// lgamma(2) <= 1.5e-5 for a particularly inaccurate log function.
static constexpr double kLanczosGamma = 7; // aka g
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
static constexpr std::array<double, 8> kLanczosCoefficients = {
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
771.3234287776530788486528258894, -176.61502916214059906584551354,
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};
class LowerLgammaOp : public OpRewritePattern<TF::LgammaOp> {
public:
using OpRewritePattern<TF::LgammaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::LgammaOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.x();
TensorType original_tensor_type = op.x().getType().cast<TensorType>();
// The approximation is not precise enough for float16. Do the computation
// in float32 for that case.
TensorType tensor_type = original_tensor_type;
FloatType float_type = tensor_type.getElementType().cast<FloatType>();
bool needs_cast = float_type.getWidth() < 32;
if (needs_cast) {
MLIRContext *context = rewriter.getContext();
float_type = FloatType::getF32(context);
if (original_tensor_type.hasRank()) {
tensor_type =
RankedTensorType::get(original_tensor_type.getShape(), float_type);
} else {
tensor_type = UnrankedTensorType::get(float_type);
}
input = rewriter.create<TF::CastOp>(loc, tensor_type, input);
}
// Helper lambda function for creating a ConstOp for a tensor filled with
// the given constant float value.
auto create_const_op = [&rewriter, loc, tensor_type,
float_type](double value) {
return rewriter.create<TF::ConstOp>(
loc, DenseElementsAttr::get(tensor_type,
FloatAttr::get(float_type, value)));
};
Value one_half = create_const_op(0.5);
Value one = create_const_op(1.0);
Value infinity = create_const_op(std::numeric_limits<double>::infinity());
Value pi = create_const_op(M_PI);
Value log_pi = create_const_op(std::log(M_PI));
Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2);
Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5);
Value log_lanczos_gamma_plus_one_half =
create_const_op(std::log(kLanczosGamma + 0.5));
Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff);
Value minus_input = rewriter.create<TF::NegOp>(loc, input);
Value input_minus_one = rewriter.create<TF::SubOp>(loc, input, one);
// If the input is less than 0.5 use Euler's reflection formula:
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
Value need_to_reflect = rewriter.create<TF::LessOp>(loc, input, one_half);
Type tensor_bool_type = need_to_reflect.getType();
Value z = rewriter.create<TF::SelectV2Op>(loc, need_to_reflect, minus_input,
input_minus_one);
Value x = base_lanczos_coeff;
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]);
Value index = create_const_op(static_cast<double>(i));
Value z_plus_index = rewriter.create<TF::AddV2Op>(loc, z, index);
Value z_plus_index_plus_one =
rewriter.create<TF::AddV2Op>(loc, z_plus_index, one);
Value incr = rewriter.create<TF::DivOp>(loc, lanczos_coefficient,
z_plus_index_plus_one);
x = rewriter.create<TF::AddV2Op>(loc, x, incr);
}
// To improve accuracy on platforms with less-precise log implementations,
// compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
// the device.
// log(t) = log(kLanczosGamma + 0.5 + z)
// = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
Value t = rewriter.create<TF::AddV2Op>(loc, lanczos_gamma_plus_one_half, z);
Value z_div_lanczos_gamma_plus_one_half =
rewriter.create<TF::DivOp>(loc, z, lanczos_gamma_plus_one_half);
Value log1p_z_div_lanczos_gamma_plus_one_half =
rewriter.create<TF::Log1pOp>(loc, z_div_lanczos_gamma_plus_one_half);
Value log_t =
rewriter.create<TF::AddV2Op>(loc, log_lanczos_gamma_plus_one_half,
log1p_z_div_lanczos_gamma_plus_one_half);
// Compute the final result (modulo reflection). t(z) may be large, and we
// need to be careful not to overflow to infinity in the first term of
//
// (z + 1/2) * log(t(z)) - t(z).
//
// Therefore we compute this as
//
// (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
//
// log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
Value t_div_log_t = rewriter.create<TF::DivOp>(loc, t, log_t);
Value one_half_minus_t_div_log_t =
rewriter.create<TF::SubOp>(loc, one_half, t_div_log_t);
Value z_plus_one_half_minus_t_div_log_t =
rewriter.create<TF::AddV2Op>(loc, z, one_half_minus_t_div_log_t);
Value z_plus_one_half_minus_t_div_log_t_mul_log_t =
rewriter.create<TF::MulOp>(loc, z_plus_one_half_minus_t_div_log_t,
log_t);
Value log_x = rewriter.create<TF::LogOp>(loc, x);
Value log_y_rhs = rewriter.create<TF::AddV2Op>(
loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x);
Value log_y = rewriter.create<TF::AddV2Op>(loc, log_sqrt_two_pi, log_y_rhs);
// Compute the reflected value, used when x < 0.5:
//
// lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
//
// (The abs is because lgamma is the log of the absolute value of the gamma
// function.)
//
// We have to be careful when computing the final term above. gamma(x) goes
// to +/-inf at every integer x < 0, and this is controlled by the
// sin(pi * x) term. The slope is large, so precision is particularly
// important.
//
// Because abs(sin(pi * x)) has period 1, we can equivalently use
// abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
// is more numerically accurate: It doesn't overflow to inf like pi * x can,
// and if x is an integer, it evaluates to 0 exactly, which is significant
// because we then take the log of this value, and log(0) is inf.
//
// We don't have a frac(x) primitive in XLA and computing it is tricky, but
// because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
// our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
//
// Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
// to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
// [0, 1] is symmetric across the line Y=0.5.
Value abs_input = rewriter.create<TF::AbsOp>(loc, input);
Value abs_input_floor = rewriter.create<TF::FloorOp>(loc, abs_input);
Value abs_frac_input =
rewriter.create<TF::SubOp>(loc, abs_input, abs_input_floor);
// Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
// precision of pi * abs_frac_input for values of abs_frac_input close to 1.
Value one_minus_abs_frac_input =
rewriter.create<TF::SubOp>(loc, one, abs_frac_input);
Value abs_frac_input_gt_one_half =
rewriter.create<TF::GreaterOp>(loc, abs_frac_input, one_half);
Value reduced_frac_input = rewriter.create<TF::SelectV2Op>(
loc, abs_frac_input_gt_one_half, one_minus_abs_frac_input,
abs_frac_input);
Value pi_mul_reduced_frac_input =
rewriter.create<TF::MulOp>(loc, pi, reduced_frac_input);
Value sin_pi_mul_reduced_frac_input =
rewriter.create<TF::SinOp>(loc, pi_mul_reduced_frac_input);
Value reflection_denom =
rewriter.create<TF::LogOp>(loc, sin_pi_mul_reduced_frac_input);
// Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
// then it "wins" and the result is +/-inf.
Value is_finite = rewriter.create<TF::IsFiniteOp>(loc, tensor_bool_type,
reflection_denom);
Value neg_reflection_denom =
rewriter.create<TF::NegOp>(loc, reflection_denom);
Value log_pi_minus_reflection_denom =
rewriter.create<TF::SubOp>(loc, log_pi, reflection_denom);
Value reflection_if_finite =
rewriter.create<TF::SubOp>(loc, log_pi_minus_reflection_denom, log_y);
Value reflection = rewriter.create<TF::SelectV2Op>(
loc, is_finite, reflection_if_finite, neg_reflection_denom);
Value result = rewriter.create<TF::SelectV2Op>(loc, need_to_reflect,
reflection, log_y);
// lgamma(+/-inf) = +inf.
Value is_inf = rewriter.create<TF::IsInfOp>(loc, tensor_bool_type, input);
result = rewriter.create<SelectV2Op>(loc, is_inf, infinity, result);
if (needs_cast) {
result = rewriter.create<TF::CastOp>(loc, original_tensor_type, result);
}
rewriter.replaceOp(op, result);
return success();
}
};
// Lowers Pack op to ConcatV2 op after changing shape of the inputs with
// ExpandDims op.
//
@ -777,9 +981,9 @@ class Lower_UnaryOpsComposition
void PopulateLoweringTFPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<LowerAddNOp, ConvertFakeQuantWithMinMaxVarsOp,
LowerDynamicStitchOp, LowerInvertPermutationOp, LowerPackOp,
LowerSpaceToBatchNDOp, LowerSparseMatMulOp,
Lower_UnaryOpsComposition>(context);
LowerDynamicStitchOp, LowerInvertPermutationOp,
LowerLgammaOp, LowerPackOp, LowerSpaceToBatchNDOp,
LowerSparseMatMulOp, Lower_UnaryOpsComposition>(context);
populateWithGenerated(context, *patterns);
}

View File

@ -161,7 +161,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::LeftShiftOp>(),
TypeID::get<TF::LessEqualOp>(),
TypeID::get<TF::LessOp>(),
TypeID::get<TF::LgammaOp>(),
TypeID::get<TF::ListDiffOp>(),
TypeID::get<TF::LogicalAndOp>(),
TypeID::get<TF::LogicalNotOp>(),