[TF:XLA] Add EncodePrimitiveTypeAsDataType into tf2xla/type_util.{h,cc}.

Fix bug where unsigned XLA types were mapped to signed TensorFlow types.

For a long time I resisted the existence of this method in tf2xla, but it makes perfect sense to have it for use cases where you want a way to encode XLA data in TensorFlow (XRT), not to encode TensorFlow data in XLA (TF2XLA).

PiperOrigin-RevId: 256274286
This commit is contained in:
Peter Hawkins 2019-07-02 17:41:16 -07:00 committed by TensorFlower Gardener
parent 9189deb241
commit a1b8e4cc50
3 changed files with 39 additions and 7 deletions

View File

@ -264,6 +264,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:span",
],
)

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
@ -79,4 +80,31 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
}
}
xla::StatusOr<DataType> EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) {
static const absl::flat_hash_map<xla::PrimitiveType, DataType>&
data_type_map = *new absl::flat_hash_map<xla::PrimitiveType, DataType>({
{xla::PRED, DT_BOOL},
{xla::BF16, DT_BFLOAT16},
{xla::F16, DT_HALF},
{xla::F32, DT_FLOAT},
{xla::F64, DT_DOUBLE},
{xla::C64, DT_COMPLEX64},
{xla::S8, DT_INT8},
{xla::S16, DT_INT16},
{xla::S32, DT_INT32},
{xla::S64, DT_INT64},
{xla::U8, DT_UINT8},
{xla::U16, DT_UINT16},
{xla::U32, DT_UINT32},
{xla::U64, DT_UINT64},
});
auto it = data_type_map.find(type);
if (it == data_type_map.end()) {
return errors::InvalidArgument(
"Unsupported type in PrimitiveTypeToDataType ", type);
}
return it->second;
}
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status.h"
@ -25,13 +26,15 @@ namespace tensorflow {
// Converts a Tensorflow DataType to an XLA PrimitiveType.
Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type);
// N.B.: there is intentionally no function to convert an XLA PrimitiveType to
// a TensorFlow DataType. The mapping from TF types to XLA types is not
// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the
// inverse would not be a well-defined function. If you find that you want the
// inverse mapping, then most likely you should be preserving the original
// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow
// type.
// Converts an XLA PrimitiveType to a TensorFlow DataType.
// Caution: The mapping from TF types to XLA types is not one-to-one: for
// example, both DT_INT8 and DT_QINT8 map to xla::S8. So the inverse is not a
// uniquely defined function. This is fine if you want a way to encode an XLA
// object as a TensorFlow object (e.g., in XRT); whereas if you started with a
// TensorFlow object in the first place, you most likely should preserve the
// original TensorFlow type, rather than trying to convert an XLA type back into
// a TensorFlow type.
xla::StatusOr<DataType> EncodePrimitiveTypeAsDataType(xla::PrimitiveType type);
} // namespace tensorflow