From e1a66cff303ebdbc25280b9ad6f749504ec95534 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 27 Jan 2020 07:44:39 -0800 Subject: [PATCH] Introduce a few new literal conversion functions for various floating point types PiperOrigin-RevId: 291720187 Change-Id: I07cded6d91eb6aa9101a01705108ab33d68e7d78 --- tensorflow/compiler/xla/literal_util.cc | 15 +++++++++++++++ tensorflow/compiler/xla/literal_util.h | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index e342e7a9bdb..4304c207cad 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -93,16 +93,31 @@ Literal ConvertType(LiteralSlice literal) { return ConvertType<bfloat16, float>(bf16_literal); } +/* static */ Literal LiteralUtil::ConvertBF16ToF64( + const LiteralSlice& bf16_literal) { + return ConvertType<bfloat16, double>(bf16_literal); +} + /* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType<float, bfloat16>(f32_literal); } +/* static */ Literal LiteralUtil::ConvertF32ToF64( + const LiteralSlice& f32_literal) { + return ConvertType<float, double>(f32_literal); +} + /* static */ Literal LiteralUtil::ConvertF64ToBF16( const LiteralSlice& f64_literal) { return ConvertType<double, bfloat16>(f64_literal); } +/* static */ Literal LiteralUtil::ConvertF64ToF32( + const LiteralSlice& f64_literal) { + return ConvertType<double, float>(f64_literal); +} + /* static */ Literal LiteralUtil::CreateToken() { return Literal(ShapeUtil::MakeTokenShape()); } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index b22b71a2ec0..e9e4f74f47b 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -218,16 +218,31 @@ class LiteralUtil { // recursively converts its elements. static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); + // If the given literal's data type is bfloat16, converts it to a double + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal); + // If the given literal's data type is float, converts it to a bfloat16 // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); + // If the given literal's data type is float, converts it to a double + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static Literal ConvertF32ToF64(const LiteralSlice& f32_literal); + // If the given literal's data type is double, converts it to a bfloat16 // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. static Literal ConvertF64ToBF16(const LiteralSlice& f64_literal); + // If the given literal's data type is double, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static Literal ConvertF64ToF32(const LiteralSlice& f64_literal); + // Creates a literal with a new shape with the given new dimensions using the // data in the given input literal. For reshaping purposes the (flat) data // buffer of the input literal is assumed to have the given minor_to_major