Introduce a few new literal conversion functions for various floating point types
PiperOrigin-RevId: 291720187 Change-Id: I07cded6d91eb6aa9101a01705108ab33d68e7d78
This commit is contained in:
parent
2e98e89091
commit
e1a66cff30
tensorflow/compiler/xla
@ -93,16 +93,31 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
return ConvertType<bfloat16, float>(bf16_literal);
|
||||
}
|
||||
|
||||
/* static */ Literal LiteralUtil::ConvertBF16ToF64(
|
||||
const LiteralSlice& bf16_literal) {
|
||||
return ConvertType<bfloat16, double>(bf16_literal);
|
||||
}
|
||||
|
||||
/* static */ Literal LiteralUtil::ConvertF32ToBF16(
|
||||
const LiteralSlice& f32_literal) {
|
||||
return ConvertType<float, bfloat16>(f32_literal);
|
||||
}
|
||||
|
||||
/* static */ Literal LiteralUtil::ConvertF32ToF64(
|
||||
const LiteralSlice& f32_literal) {
|
||||
return ConvertType<float, double>(f32_literal);
|
||||
}
|
||||
|
||||
/* static */ Literal LiteralUtil::ConvertF64ToBF16(
|
||||
const LiteralSlice& f64_literal) {
|
||||
return ConvertType<double, bfloat16>(f64_literal);
|
||||
}
|
||||
|
||||
/* static */ Literal LiteralUtil::ConvertF64ToF32(
|
||||
const LiteralSlice& f64_literal) {
|
||||
return ConvertType<double, float>(f64_literal);
|
||||
}
|
||||
|
||||
/* static */ Literal LiteralUtil::CreateToken() {
|
||||
return Literal(ShapeUtil::MakeTokenShape());
|
||||
}
|
||||
|
@ -218,16 +218,31 @@ class LiteralUtil {
|
||||
// recursively converts its elements.
|
||||
static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
|
||||
|
||||
// If the given literal's data type is bfloat16, converts it to a double
|
||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||
// recursively converts its elements.
|
||||
static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal);
|
||||
|
||||
// If the given literal's data type is float, converts it to a bfloat16
|
||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||
// recursively converts its elements.
|
||||
static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
|
||||
|
||||
// If the given literal's data type is float, converts it to a double
|
||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||
// recursively converts its elements.
|
||||
static Literal ConvertF32ToF64(const LiteralSlice& f32_literal);
|
||||
|
||||
// If the given literal's data type is double, converts it to a bfloat16
|
||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||
// recursively converts its elements.
|
||||
static Literal ConvertF64ToBF16(const LiteralSlice& f64_literal);
|
||||
|
||||
// If the given literal's data type is double, converts it to a bfloat16
|
||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||
// recursively converts its elements.
|
||||
static Literal ConvertF64ToF32(const LiteralSlice& f64_literal);
|
||||
|
||||
// Creates a literal with a new shape with the given new dimensions using the
|
||||
// data in the given input literal. For reshaping purposes the (flat) data
|
||||
// buffer of the input literal is assumed to have the given minor_to_major
|
||||
|
Loading…
Reference in New Issue
Block a user