Convert fp16 elements attribute to tensorflow tensor

When fp16 elements attribute is imported from the tensorflow tensor, it is
imported as OpaqueElementsAttr. When fp16 elements attribute is imported from
the mlir text format with dense elements attribute, it is imported as
DenseFPElementsAttr. Thus the exporter method should handle both cases.

The result binary value can be verified by:
```
c1 = tf.constant(1.0, dtype=tf.float16)
c2 = tf.constant(2.0, dtype=tf.float16)
sess = tf.compat.v1.Session()
with sess.as_default() as sess:
  print(sess.graph_def)
```
PiperOrigin-RevId: 264057638
This commit is contained in:
Feng Liu 2019-08-18 16:10:05 -07:00 committed by TensorFlower Gardener
parent 4ca0fd8d36
commit 0594349214
2 changed files with 36 additions and 0 deletions

View File

@ -0,0 +1,16 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main() -> (tensor<1x2xf16>, tensor<2xf16>) {
%0:2 = "_tf.Const"() {device = "", name = "foo", dtype = "tfdtype$DT_HALF", value = dense<1.0> : tensor<1x2xf16>} : () -> (tensor<1x2xf16>, !_tf.control)
%1:2 = "_tf.Const"() {device = "", name = "bar", dtype = "tfdtype$DT_HALF", value = dense<[1.0, 2.0]> : tensor<2xf16>} : () -> (tensor<2xf16>, !_tf.control)
return %0#0, %1#0 : tensor<1x2xf16>, tensor<2xf16>
// CHECK: node {
// CHECK-NEXT: name: "foo"
// CHECK-NEXT: op: "Const"
// CHECK: half_val: 15360
// CHECK: name: "bar"
// CHECK-NEXT: op: "Const"
// CHECK: half_val: 15360
// CHECK: half_val: 16384
}

View File

@ -175,6 +175,23 @@ Status ConvertFloatElementsAttr(const ElementsAttr attr,
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the half_val field updated.
Status ConvertHalfElementsAttr(const ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_half_val(
(*elts.begin()).bitcastToAPInt().getSExtValue());
} else {
for (auto value : elts.getFloatValues())
output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the int_val field updated.
Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
@ -231,6 +248,9 @@ Status ConvertToTensorProto(const ElementsAttr attr,
switch (output_dtype) {
case DT_FLOAT:
return ConvertFloatElementsAttr(attr, output_tensor);
case DT_HALF:
// Handles both DenseFPElementsAttr and OpaqueElementsAttr.
return ConvertHalfElementsAttr(attr, output_tensor);
case DT_QUINT8:
case DT_UINT8:
case DT_INT8: