Convert fp16 elements attribute to tensorflow tensor
When fp16 elements attribute is imported from the tensorflow tensor, it is imported as OpaqueElementsAttr. When fp16 elements attribute is imported from the mlir text format with dense elements attribute, it is imported as DenseFPElementsAttr. Thus the exporter method should handle both cases. The result binary value can be verified by: ``` c1 = tf.constant(1.0, dtype=tf.float16) c2 = tf.constant(2.0, dtype=tf.float16) sess = tf.compat.v1.Session() with sess.as_default() as sess: print(sess.graph_def) ``` PiperOrigin-RevId: 264057638
This commit is contained in:
parent
4ca0fd8d36
commit
0594349214
@ -0,0 +1,16 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main() -> (tensor<1x2xf16>, tensor<2xf16>) {
|
||||
%0:2 = "_tf.Const"() {device = "", name = "foo", dtype = "tfdtype$DT_HALF", value = dense<1.0> : tensor<1x2xf16>} : () -> (tensor<1x2xf16>, !_tf.control)
|
||||
%1:2 = "_tf.Const"() {device = "", name = "bar", dtype = "tfdtype$DT_HALF", value = dense<[1.0, 2.0]> : tensor<2xf16>} : () -> (tensor<2xf16>, !_tf.control)
|
||||
return %0#0, %1#0 : tensor<1x2xf16>, tensor<2xf16>
|
||||
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "foo"
|
||||
// CHECK-NEXT: op: "Const"
|
||||
// CHECK: half_val: 15360
|
||||
// CHECK: name: "bar"
|
||||
// CHECK-NEXT: op: "Const"
|
||||
// CHECK: half_val: 15360
|
||||
// CHECK: half_val: 16384
|
||||
}
|
@ -175,6 +175,23 @@ Status ConvertFloatElementsAttr(const ElementsAttr attr,
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the half_val field updated.
|
||||
Status ConvertHalfElementsAttr(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_half_val(
|
||||
(*elts.begin()).bitcastToAPInt().getSExtValue());
|
||||
} else {
|
||||
for (auto value : elts.getFloatValues())
|
||||
output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the int_val field updated.
|
||||
Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
|
||||
@ -231,6 +248,9 @@ Status ConvertToTensorProto(const ElementsAttr attr,
|
||||
switch (output_dtype) {
|
||||
case DT_FLOAT:
|
||||
return ConvertFloatElementsAttr(attr, output_tensor);
|
||||
case DT_HALF:
|
||||
// Handles both DenseFPElementsAttr and OpaqueElementsAttr.
|
||||
return ConvertHalfElementsAttr(attr, output_tensor);
|
||||
case DT_QUINT8:
|
||||
case DT_UINT8:
|
||||
case DT_INT8:
|
||||
|
Loading…
x
Reference in New Issue
Block a user