PiperOrigin-RevId: 315290349
Change-Id: I2405c3505b6a860dd32f32d754d1a6da3f3acd29
This commit is contained in:
A. Unique TensorFlower 2020-06-08 09:30:37 -07:00 committed by TensorFlower Gardener
parent ad6ccc651c
commit b8bd7b3483
6 changed files with 24 additions and 45 deletions

View File

@ -89,12 +89,11 @@ StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
RankedTensorType type) {
auto flat = input_tensor.flat<bfloat16>();
llvm::SmallVector<llvm::APFloat, 4> floats;
floats.reserve(flat.size());
for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size()))
floats.push_back(llvm::APFloat(static_cast<double>(v)));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(floats));
auto buffer = llvm::makeArrayRef(static_cast<char*>(input_tensor.data()),
input_tensor.TotalBytes());
return mlir::DenseElementsAttr::getFromRawBuffer(
type, buffer,
/*isSplatBuffer=*/type.getNumElements() == 1);
}
ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
@ -280,16 +279,11 @@ void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
protobuf::RepeatedField<int>* output) {
// Bfloat16 is internally represented as `double` in MLIR.
if (attr.isSplat()) {
double v = attr.getSplatValue<double>();
bfloat16 bf16_val = static_cast<bfloat16>(v);
output->Add(absl::bit_cast<int16>(bf16_val));
output->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
} else {
for (auto v : attr.getValues<double>()) {
bfloat16 bf16_val = static_cast<bfloat16>(v);
output->Add(absl::bit_cast<int16>(bf16_val));
}
for (const llvm::APFloat value : attr.getFloatValues())
output->Add(value.bitcastToAPInt().getSExtValue());
}
}

View File

@ -44,8 +44,7 @@ template <typename CppType>
}
mlir::APFloat ConvertToAPFloat(bfloat16 val) {
// bfloat16 values are stored as double in MLIR.
return llvm::APFloat(static_cast<double>(val));
return llvm::APFloat(llvm::APFloat::BFloat(), llvm::APInt(16, val.value));
}
mlir::APFloat ConvertToAPFloat(half val) {

View File

@ -979,10 +979,10 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
values.reserve(attr.getNumElements());
for (APFloat val : attr.getValues<APFloat>()) {
bool loses_info = false;
CHECK_EQ(val.convert(llvm::APFloat::IEEEsingle(),
llvm::APFloat::rmTowardZero, &loses_info),
llvm::APFloat::opOK);
CHECK(!loses_info);
TF_RET_CHECK(val.convert(llvm::APFloat::IEEEsingle(),
llvm::APFloat::rmTowardZero,
&loses_info) == llvm::APFloat::opOK);
TF_RET_CHECK(!loses_info);
values.push_back(xla::half(val.convertToFloat()));
}
xla::Array<xla::half> source_data(shape.dimensions());
@ -992,10 +992,15 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
case xla::PrimitiveType::BF16: {
xla::Array<double> source_data(shape.dimensions());
auto attr_values = attr.getValues<APFloat>();
std::vector<double> values_double(source_data.num_elements());
for (auto index_and_value : llvm::enumerate(attr_values)) {
values_double[index_and_value.index()] =
index_and_value.value().convertToDouble();
std::vector<double> values_double;
values_double.reserve(source_data.num_elements());
for (APFloat val : attr_values) {
bool loses_info = false;
TF_RET_CHECK(val.convert(llvm::APFloat::IEEEdouble(),
llvm::APFloat::rmTowardZero,
&loses_info) == llvm::APFloat::opOK);
TF_RET_CHECK(!loses_info);
values_double.push_back(val.convertToDouble());
}
source_data.SetValues(values_double);
return xla::LiteralUtil::ConvertF64ToBF16(

View File

@ -191,7 +191,7 @@ func @const_f32_bf16() -> tensor<bf16> {
// CHECK-LABEL: func @const_bf16_f64
func @const_bf16_f64() -> tensor<f64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f64>
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
// CHECK-NEXT: return [[CST]]

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/base/casts.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"

View File

@ -686,25 +686,6 @@ gentbl(
],
)
gentbl(
name = "MLIRShapeCanonicalizationIncGen",
strip_include_prefix = "include/mlir/Dialect/Shape",
tbl_outs = [
(
"-gen-rewriters",
"include/mlir/Dialect/Shape/IR/ShapeCanonicalization.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "lib/Dialect/Shape/IR/ShapeCanonicalization.td",
td_srcs = [
":StdOpsTdFiles",
"include/mlir/Dialect/Shape/IR/ShapeBase.td",
"include/mlir/Dialect/Shape/IR/ShapeOps.td",
"include/mlir/Interfaces/InferTypeOpInterface.td",
],
)
cc_library(
name = "Shape",
srcs = glob(
@ -723,7 +704,6 @@ cc_library(
":Dialect",
":IR",
":InferTypeOpInterface",
":MLIRShapeCanonicalizationIncGen",
":ShapeOpsIncGen",
":SideEffects",
":Support",