parent
72c4105a28
commit
6f4734086e
@ -139,6 +139,16 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
|
||||
Tensor linspace(DataTypeToEnum<T>::v(), shape);
|
||||
auto linspace_flat = linspace.flat<T>();
|
||||
for (int64 i = 0; i < depth; ++i) {
|
||||
linspace_flat(i) = i;
|
||||
}
|
||||
return linspace;
|
||||
}
|
||||
|
||||
xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
|
||||
int axis) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
|
||||
@ -162,17 +172,33 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
|
||||
// Build a Tensor populated with values 0, 1, 2, ... depth.
|
||||
std::vector<int64> linspace_dims(output_dims, 1);
|
||||
linspace_dims[axis] = depth;
|
||||
xla::PrimitiveType linspace_type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &linspace_type));
|
||||
auto linspace_shape = xla::ShapeUtil::MakeShape(linspace_type, linspace_dims);
|
||||
xla::XlaOp linspace = xla::Iota(builder, linspace_shape, axis);
|
||||
TensorShape linspace_shape(linspace_dims);
|
||||
Tensor linspace;
|
||||
switch (index_type) {
|
||||
case DT_UINT8:
|
||||
linspace = MakeLinspaceTensor<uint8>(linspace_shape, depth);
|
||||
break;
|
||||
case DT_INT32:
|
||||
linspace = MakeLinspaceTensor<int32>(linspace_shape, depth);
|
||||
break;
|
||||
case DT_INT64:
|
||||
linspace = MakeLinspaceTensor<int64>(linspace_shape, depth);
|
||||
break;
|
||||
default:
|
||||
return errors::InvalidArgument("Invalid argument type ",
|
||||
DataTypeString(index_type));
|
||||
}
|
||||
|
||||
xla::BorrowingLiteral linspace_literal;
|
||||
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
|
||||
|
||||
// Broadcast the linspace constant across the indices along the new axis,
|
||||
// and test equality at each position.
|
||||
std::vector<int64> broadcast_dims(indices_shape.dims());
|
||||
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
|
||||
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
|
||||
xla::XlaOp one_hot_bool = xla::Eq(indices, linspace, broadcast_dims);
|
||||
xla::XlaOp one_hot_bool = xla::Eq(
|
||||
indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims);
|
||||
|
||||
// Selects the user-provided off_value and on_value values.
|
||||
*one_hot = xla::Select(one_hot_bool,
|
||||
|
Loading…
Reference in New Issue
Block a user