PiperOrigin-RevId: 315290349
Change-Id: I2405c3505b6a860dd32f32d754d1a6da3f3acd29
This commit is contained in:
A. Unique TensorFlower 2020-06-08 09:30:37 -07:00 committed by TensorFlower Gardener
parent ad6ccc651c
commit b8bd7b3483
6 changed files with 24 additions and 45 deletions

View File

@ -89,12 +89,11 @@ StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor, ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
RankedTensorType type) { RankedTensorType type) {
auto flat = input_tensor.flat<bfloat16>(); auto buffer = llvm::makeArrayRef(static_cast<char*>(input_tensor.data()),
llvm::SmallVector<llvm::APFloat, 4> floats; input_tensor.TotalBytes());
floats.reserve(flat.size()); return mlir::DenseElementsAttr::getFromRawBuffer(
for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size())) type, buffer,
floats.push_back(llvm::APFloat(static_cast<double>(v))); /*isSplatBuffer=*/type.getNumElements() == 1);
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(floats));
} }
ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) { ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
@ -280,16 +279,11 @@ void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr, void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
protobuf::RepeatedField<int>* output) { protobuf::RepeatedField<int>* output) {
// Bfloat16 is internally represented as `double` in MLIR.
if (attr.isSplat()) { if (attr.isSplat()) {
double v = attr.getSplatValue<double>(); output->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
bfloat16 bf16_val = static_cast<bfloat16>(v);
output->Add(absl::bit_cast<int16>(bf16_val));
} else { } else {
for (auto v : attr.getValues<double>()) { for (const llvm::APFloat value : attr.getFloatValues())
bfloat16 bf16_val = static_cast<bfloat16>(v); output->Add(value.bitcastToAPInt().getSExtValue());
output->Add(absl::bit_cast<int16>(bf16_val));
}
} }
} }

View File

@ -44,8 +44,7 @@ template <typename CppType>
} }
mlir::APFloat ConvertToAPFloat(bfloat16 val) { mlir::APFloat ConvertToAPFloat(bfloat16 val) {
// bfloat16 values are stored as double in MLIR. return llvm::APFloat(llvm::APFloat::BFloat(), llvm::APInt(16, val.value));
return llvm::APFloat(static_cast<double>(val));
} }
mlir::APFloat ConvertToAPFloat(half val) { mlir::APFloat ConvertToAPFloat(half val) {

View File

@ -979,10 +979,10 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
values.reserve(attr.getNumElements()); values.reserve(attr.getNumElements());
for (APFloat val : attr.getValues<APFloat>()) { for (APFloat val : attr.getValues<APFloat>()) {
bool loses_info = false; bool loses_info = false;
CHECK_EQ(val.convert(llvm::APFloat::IEEEsingle(), TF_RET_CHECK(val.convert(llvm::APFloat::IEEEsingle(),
llvm::APFloat::rmTowardZero, &loses_info), llvm::APFloat::rmTowardZero,
llvm::APFloat::opOK); &loses_info) == llvm::APFloat::opOK);
CHECK(!loses_info); TF_RET_CHECK(!loses_info);
values.push_back(xla::half(val.convertToFloat())); values.push_back(xla::half(val.convertToFloat()));
} }
xla::Array<xla::half> source_data(shape.dimensions()); xla::Array<xla::half> source_data(shape.dimensions());
@ -992,10 +992,15 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
case xla::PrimitiveType::BF16: { case xla::PrimitiveType::BF16: {
xla::Array<double> source_data(shape.dimensions()); xla::Array<double> source_data(shape.dimensions());
auto attr_values = attr.getValues<APFloat>(); auto attr_values = attr.getValues<APFloat>();
std::vector<double> values_double(source_data.num_elements()); std::vector<double> values_double;
for (auto index_and_value : llvm::enumerate(attr_values)) { values_double.reserve(source_data.num_elements());
values_double[index_and_value.index()] = for (APFloat val : attr_values) {
index_and_value.value().convertToDouble(); bool loses_info = false;
TF_RET_CHECK(val.convert(llvm::APFloat::IEEEdouble(),
llvm::APFloat::rmTowardZero,
&loses_info) == llvm::APFloat::opOK);
TF_RET_CHECK(!loses_info);
values_double.push_back(val.convertToDouble());
} }
source_data.SetValues(values_double); source_data.SetValues(values_double);
return xla::LiteralUtil::ConvertF64ToBF16( return xla::LiteralUtil::ConvertF64ToBF16(

View File

@ -191,7 +191,7 @@ func @const_f32_bf16() -> tensor<bf16> {
// CHECK-LABEL: func @const_bf16_f64 // CHECK-LABEL: func @const_bf16_f64
func @const_bf16_f64() -> tensor<f64> { func @const_bf16_f64() -> tensor<f64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f64> // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
%cst = xla_hlo.constant dense<4.2> : tensor<bf16> %cst = xla_hlo.constant dense<4.2> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64> %0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
// CHECK-NEXT: return [[CST]] // CHECK-NEXT: return [[CST]]

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/base/casts.h" #include "absl/base/casts.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/GlobalVariable.h"

View File

@ -686,25 +686,6 @@ gentbl(
], ],
) )
gentbl(
name = "MLIRShapeCanonicalizationIncGen",
strip_include_prefix = "include/mlir/Dialect/Shape",
tbl_outs = [
(
"-gen-rewriters",
"include/mlir/Dialect/Shape/IR/ShapeCanonicalization.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "lib/Dialect/Shape/IR/ShapeCanonicalization.td",
td_srcs = [
":StdOpsTdFiles",
"include/mlir/Dialect/Shape/IR/ShapeBase.td",
"include/mlir/Dialect/Shape/IR/ShapeOps.td",
"include/mlir/Interfaces/InferTypeOpInterface.td",
],
)
cc_library( cc_library(
name = "Shape", name = "Shape",
srcs = glob( srcs = glob(
@ -723,7 +704,6 @@ cc_library(
":Dialect", ":Dialect",
":IR", ":IR",
":InferTypeOpInterface", ":InferTypeOpInterface",
":MLIRShapeCanonicalizationIncGen",
":ShapeOpsIncGen", ":ShapeOpsIncGen",
":SideEffects", ":SideEffects",
":Support", ":Support",