[TF:XLA] Add EncodePrimitiveTypeAsDataType into tf2xla/type_util.{h,cc}.
Fix bug where unsigned XLA types were mapped to signed TensorFlow types. For a long time I resisted the existence of this method in tf2xla, but it makes perfect sense to have it for use cases where you want a way to encode XLA data in TensorFlow (XRT), not to encode TensorFlow data in XLA (TF2XLA). PiperOrigin-RevId: 256274286
This commit is contained in:
parent
9189deb241
commit
a1b8e4cc50
@ -264,6 +264,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
@ -79,4 +80,31 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
|
||||
}
|
||||
}
|
||||
|
||||
xla::StatusOr<DataType> EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) {
|
||||
static const absl::flat_hash_map<xla::PrimitiveType, DataType>&
|
||||
data_type_map = *new absl::flat_hash_map<xla::PrimitiveType, DataType>({
|
||||
{xla::PRED, DT_BOOL},
|
||||
{xla::BF16, DT_BFLOAT16},
|
||||
{xla::F16, DT_HALF},
|
||||
{xla::F32, DT_FLOAT},
|
||||
{xla::F64, DT_DOUBLE},
|
||||
{xla::C64, DT_COMPLEX64},
|
||||
{xla::S8, DT_INT8},
|
||||
{xla::S16, DT_INT16},
|
||||
{xla::S32, DT_INT32},
|
||||
{xla::S64, DT_INT64},
|
||||
{xla::U8, DT_UINT8},
|
||||
{xla::U16, DT_UINT16},
|
||||
{xla::U32, DT_UINT32},
|
||||
{xla::U64, DT_UINT64},
|
||||
});
|
||||
|
||||
auto it = data_type_map.find(type);
|
||||
if (it == data_type_map.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Unsupported type in PrimitiveTypeToDataType ", type);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -25,13 +26,15 @@ namespace tensorflow {
|
||||
// Converts a Tensorflow DataType to an XLA PrimitiveType.
|
||||
Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type);
|
||||
|
||||
// N.B.: there is intentionally no function to convert an XLA PrimitiveType to
|
||||
// a TensorFlow DataType. The mapping from TF types to XLA types is not
|
||||
// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the
|
||||
// inverse would not be a well-defined function. If you find that you want the
|
||||
// inverse mapping, then most likely you should be preserving the original
|
||||
// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow
|
||||
// type.
|
||||
// Converts an XLA PrimitiveType to a TensorFlow DataType.
|
||||
// Caution: The mapping from TF types to XLA types is not one-to-one: for
|
||||
// example, both DT_INT8 and DT_QINT8 map to xla::S8. So the inverse is not a
|
||||
// uniquely defined function. This is fine if you want a way to encode an XLA
|
||||
// object as a TensorFlow object (e.g., in XRT); whereas if you started with a
|
||||
// TensorFlow object in the first place, you most likely should preserve the
|
||||
// original TensorFlow type, rather than trying to convert an XLA type back into
|
||||
// a TensorFlow type.
|
||||
xla::StatusOr<DataType> EncodePrimitiveTypeAsDataType(xla::PrimitiveType type);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user