From 05943492144d146e999132b124b8df484c177e1f Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Sun, 18 Aug 2019 16:10:05 -0700 Subject: [PATCH] Convert fp16 elements attribute to tensorflow tensor When fp16 elements attribute is imported from the tensorflow tensor, it is imported as OpaqueElementsAttr. When fp16 elements attribute is imported from the mlir text format with dense elements attribute, it is imported as DenseFPElementsAttr. Thus the exporter method should handle both cases. The result binary value can be verified by: ``` c1 = tf.constant(1.0, dtype=tf.float16) c2 = tf.constant(2.0, dtype=tf.float16) sess = tf.compat.v1.Session() with sess.as_default() as sess: print(sess.graph_def) ``` PiperOrigin-RevId: 264057638 --- .../tests/mlir2graphdef/convert_tensor.mlir | 16 +++++++++++++++ .../mlir/tensorflow/utils/convert_tensor.cc | 20 +++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir new file mode 100644 index 00000000000..52e4c529815 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main() -> (tensor<1x2xf16>, tensor<2xf16>) { + %0:2 = "_tf.Const"() {device = "", name = "foo", dtype = "tfdtype$DT_HALF", value = dense<1.0> : tensor<1x2xf16>} : () -> (tensor<1x2xf16>, !_tf.control) + %1:2 = "_tf.Const"() {device = "", name = "bar", dtype = "tfdtype$DT_HALF", value = dense<[1.0, 2.0]> : tensor<2xf16>} : () -> (tensor<2xf16>, !_tf.control) + return %0#0, %1#0 : tensor<1x2xf16>, tensor<2xf16> + +// CHECK: node { +// CHECK-NEXT: name: "foo" +// CHECK-NEXT: op: "Const" +// CHECK: half_val: 15360 +// CHECK: name: "bar" +// CHECK-NEXT: op: "Const" +// CHECK: half_val: 15360 +// CHECK: half_val: 16384 +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index d85659c01f6..df19e169d3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -175,6 +175,23 @@ Status ConvertFloatElementsAttr(const ElementsAttr attr, return ConvertOpaqueElementsAttr(attr, output_tensor); } +// Converts an MLIR elements attribute to a TensorFlow tensor proto +// with the half_val field updated. +Status ConvertHalfElementsAttr(const ElementsAttr attr, + TensorProto* output_tensor) { + if (auto elts = attr.dyn_cast()) { + if (elts.isSplat()) { + output_tensor->add_half_val( + (*elts.begin()).bitcastToAPInt().getSExtValue()); + } else { + for (auto value : elts.getFloatValues()) + output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue()); + } + return Status::OK(); + } + return ConvertOpaqueElementsAttr(attr, output_tensor); +} + // Converts an MLIR elements attribute to a TensorFlow tensor proto // with the int_val field updated. Status ConvertIntElementsAttr(const mlir::ElementsAttr attr, @@ -231,6 +248,9 @@ Status ConvertToTensorProto(const ElementsAttr attr, switch (output_dtype) { case DT_FLOAT: return ConvertFloatElementsAttr(attr, output_tensor); + case DT_HALF: + // Handles both DenseFPElementsAttr and OpaqueElementsAttr. + return ConvertHalfElementsAttr(attr, output_tensor); case DT_QUINT8: case DT_UINT8: case DT_INT8: