From a1b8e4cc50ab20970095874043f814761427519b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 2 Jul 2019 17:41:16 -0700 Subject: [PATCH] [TF:XLA] Add EncodePrimitiveTypeAsDataType into tf2xla/type_util.{h,cc}. Fix bug where unsigned XLA types were mapped to signed TensorFlow types. For a long time I resisted the existence of this method in tf2xla, but it makes perfect sense to have it for use cases where you want a way to encode XLA data in TensorFlow (XRT), not to encode TensorFlow data in XLA (TF2XLA). PiperOrigin-RevId: 256274286 --- tensorflow/compiler/tf2xla/BUILD | 1 + tensorflow/compiler/tf2xla/type_util.cc | 28 +++++++++++++++++++++++++ tensorflow/compiler/tf2xla/type_util.h | 17 ++++++++------- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5a148b95d00..ec853c09cfb 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -264,6 +264,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 732f957d732..4275a4402f2 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" @@ -79,4 +80,31 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { } } +xla::StatusOr EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) { + static const absl::flat_hash_map& + data_type_map = *new absl::flat_hash_map({ + {xla::PRED, DT_BOOL}, + {xla::BF16, DT_BFLOAT16}, + {xla::F16, DT_HALF}, + {xla::F32, DT_FLOAT}, + {xla::F64, DT_DOUBLE}, + {xla::C64, DT_COMPLEX64}, + {xla::S8, DT_INT8}, + {xla::S16, DT_INT16}, + {xla::S32, DT_INT32}, + {xla::S64, DT_INT64}, + {xla::U8, DT_UINT8}, + {xla::U16, DT_UINT16}, + {xla::U32, DT_UINT32}, + {xla::U64, DT_UINT64}, + }); + + auto it = data_type_map.find(type); + if (it == data_type_map.end()) { + return errors::InvalidArgument( + "Unsupported type in PrimitiveTypeToDataType ", type); + } + return it->second; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h index 6354216eee7..2e93986d777 100644 --- a/tensorflow/compiler/tf2xla/type_util.h +++ b/tensorflow/compiler/tf2xla/type_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -25,13 +26,15 @@ namespace tensorflow { // Converts a Tensorflow DataType to an XLA PrimitiveType. Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type); -// N.B.: there is intentionally no function to convert an XLA PrimitiveType to -// a TensorFlow DataType. The mapping from TF types to XLA types is not -// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the -// inverse would not be a well-defined function. If you find that you want the -// inverse mapping, then most likely you should be preserving the original -// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow -// type. +// Converts an XLA PrimitiveType to a TensorFlow DataType. +// Caution: The mapping from TF types to XLA types is not one-to-one: for +// example, both DT_INT8 and DT_QINT8 map to xla::S8. So the inverse is not a +// uniquely defined function. This is fine if you want a way to encode an XLA +// object as a TensorFlow object (e.g., in XRT); whereas if you started with a +// TensorFlow object in the first place, you most likely should preserve the +// original TensorFlow type, rather than trying to convert an XLA type back into +// a TensorFlow type. +xla::StatusOr EncodePrimitiveTypeAsDataType(xla::PrimitiveType type); } // namespace tensorflow