Added ability to generate full range floating points from MakeFakeArgs func.
The ability to generate full range floating points has been turned into an option for the MakeFakeArguments function. Full range floating data points now have a small chance to instead pick a specific special floating point value as opposed to using the usual distribution. PiperOrigin-RevId: 257880189
This commit is contained in:
parent
296a0da4fc
commit
f5bfc07b79
@ -38,6 +38,48 @@ void PopulateWithRandomFloatingPointData(Literal* literal,
|
||||
}
|
||||
}
|
||||
|
||||
// Populates a floating point literal with random floating points sampled from a
|
||||
// uniform-log distribution spanning approximately the entire range of the
|
||||
// representable floating point.
|
||||
template <typename FloatT>
|
||||
void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
|
||||
std::minstd_rand0* engine) {
|
||||
constexpr float kSpecialValueProbability = 1e-6;
|
||||
constexpr float kSpecialValues[] = {+0.F,
|
||||
-0.F,
|
||||
1.F,
|
||||
-1.F,
|
||||
std::numeric_limits<float>::infinity(),
|
||||
-std::numeric_limits<float>::infinity()};
|
||||
constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float);
|
||||
std::uniform_real_distribution<float> special_value_gen(0, 1);
|
||||
|
||||
// Generates floating points with a log-uniform distribution. This causes the
|
||||
// exponent of the floating point to have a uniform distribution.
|
||||
int min_exp, max_exp;
|
||||
if (std::is_same<FloatT, bfloat16>()) {
|
||||
min_exp = std::numeric_limits<float>::min_exponent;
|
||||
max_exp = std::numeric_limits<float>::max_exponent;
|
||||
} else {
|
||||
min_exp = std::numeric_limits<FloatT>::min_exponent;
|
||||
max_exp = std::numeric_limits<FloatT>::max_exponent;
|
||||
}
|
||||
std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);
|
||||
|
||||
for (FloatT& value : literal->data<FloatT>()) {
|
||||
// Each special value has a kSpecialValueProbability chance to be generated
|
||||
// instead of sampling using the normal distributions.
|
||||
if (special_value_gen(*engine) <
|
||||
kSpecialValueProbability * kNumSpecialValues) {
|
||||
value =
|
||||
static_cast<FloatT>(kSpecialValues[(*engine)() % kNumSpecialValues]);
|
||||
} else {
|
||||
float sign = ((*engine)() % 2 == 0) ? 1 : -1;
|
||||
value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FloatT>
|
||||
void PopulateWithIntNext(Literal* literal);
|
||||
|
||||
@ -101,12 +143,14 @@ void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
|
||||
|
||||
template <typename FloatT>
|
||||
void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
|
||||
bool no_duplicates) {
|
||||
bool no_duplicates, bool use_large_range) {
|
||||
CHECK(engine != nullptr);
|
||||
CHECK_EQ(literal->shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<FloatT>());
|
||||
if (no_duplicates) {
|
||||
PopulateWithNoDuplicateData<FloatT>(literal, engine);
|
||||
} else if (use_large_range) {
|
||||
PopulateWithRandomFullRangeFloatingPointData<FloatT>(literal, engine);
|
||||
} else {
|
||||
PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
|
||||
}
|
||||
@ -114,7 +158,7 @@ void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
|
||||
|
||||
template <typename ComplexT>
|
||||
void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
|
||||
bool no_duplicates) {
|
||||
bool no_duplicates, bool use_large_range) {
|
||||
using InnerFloatT = typename ComplexT::value_type;
|
||||
CHECK(engine != nullptr);
|
||||
CHECK_EQ(result->shape().element_type(),
|
||||
@ -124,9 +168,10 @@ void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
|
||||
Literal real_lit(floating_point_shape);
|
||||
Literal imaginary_lit(floating_point_shape);
|
||||
|
||||
PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates);
|
||||
PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates,
|
||||
use_large_range);
|
||||
PopulateWithFloatingPointData<InnerFloatT>(&imaginary_lit, engine,
|
||||
no_duplicates);
|
||||
no_duplicates, use_large_range);
|
||||
|
||||
absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
|
||||
absl::Span<const InnerFloatT> imaginary_data =
|
||||
@ -140,12 +185,15 @@ void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
|
||||
template <>
|
||||
void PopulateWithFloatingPointData<half>(Literal* literal,
|
||||
std::minstd_rand0* engine,
|
||||
bool no_duplicates) {
|
||||
bool no_duplicates,
|
||||
bool use_large_range) {
|
||||
CHECK(engine != nullptr);
|
||||
CHECK_EQ(literal->shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<half>());
|
||||
if (no_duplicates) {
|
||||
PopulateWithNoDuplicateData<half>(literal, engine);
|
||||
} else if (use_large_range) {
|
||||
PopulateWithRandomFullRangeFloatingPointData<half>(literal, engine);
|
||||
} else {
|
||||
PopulateWithRandomFloatingPointData<half, float>(literal, engine);
|
||||
}
|
||||
@ -154,12 +202,15 @@ void PopulateWithFloatingPointData<half>(Literal* literal,
|
||||
template <>
|
||||
void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
|
||||
std::minstd_rand0* engine,
|
||||
bool no_duplicates) {
|
||||
bool no_duplicates,
|
||||
bool use_large_range) {
|
||||
CHECK(engine != nullptr);
|
||||
CHECK_EQ(literal->shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<bfloat16>());
|
||||
if (no_duplicates) {
|
||||
PopulateWithNoDuplicateData<bfloat16>(literal, engine);
|
||||
} else if (use_large_range) {
|
||||
PopulateWithRandomFullRangeFloatingPointData<bfloat16>(literal, engine);
|
||||
} else {
|
||||
PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
|
||||
}
|
||||
@ -193,13 +244,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
|
||||
// elements exceeds the number of different values supported by the type.
|
||||
StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
|
||||
std::minstd_rand0* engine,
|
||||
bool no_duplicates) {
|
||||
bool no_duplicates,
|
||||
bool use_large_range) {
|
||||
if (shape.IsTuple()) {
|
||||
std::vector<Literal> elements;
|
||||
for (const Shape& element_shape : shape.tuple_shapes()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Literal element,
|
||||
MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
|
||||
TF_ASSIGN_OR_RETURN(Literal element, MakeFakeLiteralInternal(
|
||||
element_shape, engine,
|
||||
no_duplicates, use_large_range));
|
||||
elements.push_back(std::move(element));
|
||||
}
|
||||
return LiteralUtil::MakeTupleOwned(std::move(elements));
|
||||
@ -215,16 +267,20 @@ StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
|
||||
Literal literal(new_shape);
|
||||
switch (shape.element_type()) {
|
||||
case BF16:
|
||||
PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates);
|
||||
PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates,
|
||||
use_large_range);
|
||||
break;
|
||||
case F16:
|
||||
PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates);
|
||||
PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates,
|
||||
use_large_range);
|
||||
break;
|
||||
case F32:
|
||||
PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates);
|
||||
PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates,
|
||||
use_large_range);
|
||||
break;
|
||||
case F64:
|
||||
PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates);
|
||||
PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates,
|
||||
use_large_range);
|
||||
break;
|
||||
case S8:
|
||||
PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
|
||||
@ -251,10 +307,12 @@ StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
|
||||
PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
|
||||
break;
|
||||
case C64:
|
||||
PopulateWithComplexData<complex64>(&literal, engine, no_duplicates);
|
||||
PopulateWithComplexData<complex64>(&literal, engine, no_duplicates,
|
||||
use_large_range);
|
||||
break;
|
||||
case C128:
|
||||
PopulateWithComplexData<complex128>(&literal, engine, no_duplicates);
|
||||
PopulateWithComplexData<complex128>(&literal, engine, no_duplicates,
|
||||
use_large_range);
|
||||
break;
|
||||
case PRED: {
|
||||
std::uniform_int_distribution<int> generator(0, 1);
|
||||
@ -447,7 +505,8 @@ std::vector<HloInstruction*> FindConstrainedUses(
|
||||
// zero in the case of init_values for reductions).
|
||||
StatusOr<Literal> CreateLiteralForConstrainedUses(
|
||||
const absl::Span<HloInstruction* const> constrained_uses,
|
||||
const HloInstruction& param, std::minstd_rand0* engine) {
|
||||
const HloInstruction& param, std::minstd_rand0* engine,
|
||||
bool use_large_range) {
|
||||
int64 index_bound = INT64_MAX;
|
||||
bool no_duplicates = false;
|
||||
bool needs_constant = false;
|
||||
@ -531,10 +590,12 @@ StatusOr<Literal> CreateLiteralForConstrainedUses(
|
||||
// We want the identity element for the computation, but we don't really
|
||||
// know what it is - so any value we generate will be just as wrong.
|
||||
return MakeFakeLiteralInternal(param.shape(), engine,
|
||||
/*no_duplicates=*/false);
|
||||
/*no_duplicates=*/false,
|
||||
use_large_range);
|
||||
}
|
||||
} else {
|
||||
return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates);
|
||||
return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates,
|
||||
use_large_range);
|
||||
}
|
||||
}
|
||||
|
||||
@ -542,34 +603,41 @@ StatusOr<Literal> CreateLiteralForConstrainedUses(
|
||||
// special case literal must be created, or if we can generate fake data.
|
||||
StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
|
||||
const HloInstruction& param,
|
||||
std::minstd_rand0* engine) {
|
||||
std::minstd_rand0* engine,
|
||||
bool use_large_range) {
|
||||
const auto constrained_uses = FindConstrainedUses(dataflow, param);
|
||||
return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
|
||||
return CreateLiteralForConstrainedUses(constrained_uses, param, engine,
|
||||
use_large_range);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) {
|
||||
StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random,
|
||||
bool use_large_range) {
|
||||
auto engine =
|
||||
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
|
||||
return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
|
||||
return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false,
|
||||
use_large_range);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
|
||||
bool pseudo_random) {
|
||||
bool pseudo_random,
|
||||
bool use_large_range) {
|
||||
auto engine =
|
||||
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
|
||||
return MakeFakeArguments(module, engine.get());
|
||||
return MakeFakeArguments(module, engine.get(), use_large_range);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
|
||||
std::minstd_rand0* engine) {
|
||||
std::minstd_rand0* engine,
|
||||
bool use_large_range) {
|
||||
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
|
||||
const auto params = module->entry_computation()->parameter_instructions();
|
||||
std::vector<Literal> arguments(params.size());
|
||||
for (int i = 0; i < params.size(); ++i) {
|
||||
arguments[i] =
|
||||
MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
|
||||
MakeConstrainedArgument(*dataflow, *params[i], engine, use_large_range)
|
||||
.ValueOrDie();
|
||||
}
|
||||
return std::move(arguments);
|
||||
}
|
||||
|
||||
@ -54,34 +54,11 @@ class PseudorandomGenerator {
|
||||
std::mt19937 generator_;
|
||||
};
|
||||
|
||||
// Populates a floating point literal with random floating points sampled from a
|
||||
// uniform-log distribution spanning approximately the entire range of the
|
||||
// representable floating point.
|
||||
template <typename FloatT>
|
||||
void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
|
||||
std::minstd_rand0* engine) {
|
||||
// Generates floating points with a log-uniform distribution. This causes the
|
||||
// exponent of the floating point to have a uniform distribution.
|
||||
int min_exp, max_exp;
|
||||
if (std::is_same<FloatT, bfloat16>()) {
|
||||
min_exp = std::numeric_limits<float>::min_exponent;
|
||||
max_exp = std::numeric_limits<float>::max_exponent;
|
||||
} else {
|
||||
min_exp = std::numeric_limits<FloatT>::min_exponent;
|
||||
max_exp = std::numeric_limits<FloatT>::max_exponent;
|
||||
}
|
||||
std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);
|
||||
for (FloatT& value : literal->data<FloatT>()) {
|
||||
float sign = ((*engine)() % 2 == 0) ? 1 : -1;
|
||||
value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
|
||||
}
|
||||
}
|
||||
|
||||
// Generates fake data in a literal of the given shape, or returns an error
|
||||
// status if the element type is currently unhandled for fake data
|
||||
// generation. See below for documentation of pseudo_random.
|
||||
StatusOr<Literal> MakeFakeLiteral(const Shape& shape,
|
||||
bool pseudo_random = true);
|
||||
// generation. See below for documentation of pseudo_random and use_large_range.
|
||||
StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random = true,
|
||||
bool use_large_range = false);
|
||||
|
||||
// Generates a vector of arguments containing fake data. The number, shape and
|
||||
// layout of the arguments is appropriate for given HLO module.
|
||||
@ -104,17 +81,25 @@ StatusOr<Literal> MakeFakeLiteral(const Shape& shape,
|
||||
// will be generated in a faster way that yields less interesting data, e.g. the
|
||||
// values may all be just the same value.
|
||||
//
|
||||
// If use_large_range is false, the generated floating point numbers will be
|
||||
// sampled from a small range of possible values. If use_large_range is true,
|
||||
// the generated floating point numbers will be sampled from a uniform-log
|
||||
// distribution of most possible floats, with a small chance to instead be
|
||||
// sampled from a list of special floating point values (such as 0, inf, etc.).
|
||||
//
|
||||
// TODO(b/79942829): Make interesting argument generation fast enough that using
|
||||
// pseudo_random does not save any noticeable amount of time so that the
|
||||
// parameter can be removed.
|
||||
StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
|
||||
bool pseudo_random = true);
|
||||
bool pseudo_random = true,
|
||||
bool use_large_range = false);
|
||||
|
||||
// Overload which accepts a random number generator. This enables generation of
|
||||
// different random values with sequential calls to MakeFakeArguments by reusing
|
||||
// the same generator.
|
||||
StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
|
||||
std::minstd_rand0* engine);
|
||||
std::minstd_rand0* engine,
|
||||
bool use_large_range = false);
|
||||
|
||||
// Check that a given module satisfies various constraints before trying to
|
||||
// execute it.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user