Integrate LLVM at https://github.com/llvm/llvm-project/commit/92cb0ce8f814
PiperOrigin-RevId: 315290349 Change-Id: I2405c3505b6a860dd32f32d754d1a6da3f3acd29
This commit is contained in:
parent
ad6ccc651c
commit
b8bd7b3483
@ -89,12 +89,11 @@ StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
|
|||||||
|
|
||||||
ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
|
ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
|
||||||
RankedTensorType type) {
|
RankedTensorType type) {
|
||||||
auto flat = input_tensor.flat<bfloat16>();
|
auto buffer = llvm::makeArrayRef(static_cast<char*>(input_tensor.data()),
|
||||||
llvm::SmallVector<llvm::APFloat, 4> floats;
|
input_tensor.TotalBytes());
|
||||||
floats.reserve(flat.size());
|
return mlir::DenseElementsAttr::getFromRawBuffer(
|
||||||
for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size()))
|
type, buffer,
|
||||||
floats.push_back(llvm::APFloat(static_cast<double>(v)));
|
/*isSplatBuffer=*/type.getNumElements() == 1);
|
||||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(floats));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
|
ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
|
||||||
@ -280,16 +279,11 @@ void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
|
|||||||
|
|
||||||
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
|
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
|
||||||
protobuf::RepeatedField<int>* output) {
|
protobuf::RepeatedField<int>* output) {
|
||||||
// Bfloat16 is internally represented as `double` in MLIR.
|
|
||||||
if (attr.isSplat()) {
|
if (attr.isSplat()) {
|
||||||
double v = attr.getSplatValue<double>();
|
output->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
|
||||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
|
||||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
|
||||||
} else {
|
} else {
|
||||||
for (auto v : attr.getValues<double>()) {
|
for (const llvm::APFloat value : attr.getFloatValues())
|
||||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
output->Add(value.bitcastToAPInt().getSExtValue());
|
||||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,8 +44,7 @@ template <typename CppType>
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::APFloat ConvertToAPFloat(bfloat16 val) {
|
mlir::APFloat ConvertToAPFloat(bfloat16 val) {
|
||||||
// bfloat16 values are stored as double in MLIR.
|
return llvm::APFloat(llvm::APFloat::BFloat(), llvm::APInt(16, val.value));
|
||||||
return llvm::APFloat(static_cast<double>(val));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::APFloat ConvertToAPFloat(half val) {
|
mlir::APFloat ConvertToAPFloat(half val) {
|
||||||
|
@ -979,10 +979,10 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
|||||||
values.reserve(attr.getNumElements());
|
values.reserve(attr.getNumElements());
|
||||||
for (APFloat val : attr.getValues<APFloat>()) {
|
for (APFloat val : attr.getValues<APFloat>()) {
|
||||||
bool loses_info = false;
|
bool loses_info = false;
|
||||||
CHECK_EQ(val.convert(llvm::APFloat::IEEEsingle(),
|
TF_RET_CHECK(val.convert(llvm::APFloat::IEEEsingle(),
|
||||||
llvm::APFloat::rmTowardZero, &loses_info),
|
llvm::APFloat::rmTowardZero,
|
||||||
llvm::APFloat::opOK);
|
&loses_info) == llvm::APFloat::opOK);
|
||||||
CHECK(!loses_info);
|
TF_RET_CHECK(!loses_info);
|
||||||
values.push_back(xla::half(val.convertToFloat()));
|
values.push_back(xla::half(val.convertToFloat()));
|
||||||
}
|
}
|
||||||
xla::Array<xla::half> source_data(shape.dimensions());
|
xla::Array<xla::half> source_data(shape.dimensions());
|
||||||
@ -992,10 +992,15 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
|||||||
case xla::PrimitiveType::BF16: {
|
case xla::PrimitiveType::BF16: {
|
||||||
xla::Array<double> source_data(shape.dimensions());
|
xla::Array<double> source_data(shape.dimensions());
|
||||||
auto attr_values = attr.getValues<APFloat>();
|
auto attr_values = attr.getValues<APFloat>();
|
||||||
std::vector<double> values_double(source_data.num_elements());
|
std::vector<double> values_double;
|
||||||
for (auto index_and_value : llvm::enumerate(attr_values)) {
|
values_double.reserve(source_data.num_elements());
|
||||||
values_double[index_and_value.index()] =
|
for (APFloat val : attr_values) {
|
||||||
index_and_value.value().convertToDouble();
|
bool loses_info = false;
|
||||||
|
TF_RET_CHECK(val.convert(llvm::APFloat::IEEEdouble(),
|
||||||
|
llvm::APFloat::rmTowardZero,
|
||||||
|
&loses_info) == llvm::APFloat::opOK);
|
||||||
|
TF_RET_CHECK(!loses_info);
|
||||||
|
values_double.push_back(val.convertToDouble());
|
||||||
}
|
}
|
||||||
source_data.SetValues(values_double);
|
source_data.SetValues(values_double);
|
||||||
return xla::LiteralUtil::ConvertF64ToBF16(
|
return xla::LiteralUtil::ConvertF64ToBF16(
|
||||||
|
@ -191,7 +191,7 @@ func @const_f32_bf16() -> tensor<bf16> {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @const_bf16_f64
|
// CHECK-LABEL: func @const_bf16_f64
|
||||||
func @const_bf16_f64() -> tensor<f64> {
|
func @const_bf16_f64() -> tensor<f64> {
|
||||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f64>
|
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
|
||||||
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
|
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
|
||||||
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
|
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
|
||||||
// CHECK-NEXT: return [[CST]]
|
// CHECK-NEXT: return [[CST]]
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "absl/base/casts.h"
|
#include "absl/base/casts.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "llvm/ADT/Triple.h"
|
||||||
#include "llvm/IR/DerivedTypes.h"
|
#include "llvm/IR/DerivedTypes.h"
|
||||||
#include "llvm/IR/GlobalValue.h"
|
#include "llvm/IR/GlobalValue.h"
|
||||||
#include "llvm/IR/GlobalVariable.h"
|
#include "llvm/IR/GlobalVariable.h"
|
||||||
|
20
third_party/mlir/BUILD
vendored
20
third_party/mlir/BUILD
vendored
@ -686,25 +686,6 @@ gentbl(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
gentbl(
|
|
||||||
name = "MLIRShapeCanonicalizationIncGen",
|
|
||||||
strip_include_prefix = "include/mlir/Dialect/Shape",
|
|
||||||
tbl_outs = [
|
|
||||||
(
|
|
||||||
"-gen-rewriters",
|
|
||||||
"include/mlir/Dialect/Shape/IR/ShapeCanonicalization.inc",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tblgen = ":mlir-tblgen",
|
|
||||||
td_file = "lib/Dialect/Shape/IR/ShapeCanonicalization.td",
|
|
||||||
td_srcs = [
|
|
||||||
":StdOpsTdFiles",
|
|
||||||
"include/mlir/Dialect/Shape/IR/ShapeBase.td",
|
|
||||||
"include/mlir/Dialect/Shape/IR/ShapeOps.td",
|
|
||||||
"include/mlir/Interfaces/InferTypeOpInterface.td",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "Shape",
|
name = "Shape",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
@ -723,7 +704,6 @@ cc_library(
|
|||||||
":Dialect",
|
":Dialect",
|
||||||
":IR",
|
":IR",
|
||||||
":InferTypeOpInterface",
|
":InferTypeOpInterface",
|
||||||
":MLIRShapeCanonicalizationIncGen",
|
|
||||||
":ShapeOpsIncGen",
|
":ShapeOpsIncGen",
|
||||||
":SideEffects",
|
":SideEffects",
|
||||||
":Support",
|
":Support",
|
||||||
|
Loading…
Reference in New Issue
Block a user