parent
72c4105a28
commit
6f4734086e
@ -139,6 +139,16 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
|
||||||
|
Tensor linspace(DataTypeToEnum<T>::v(), shape);
|
||||||
|
auto linspace_flat = linspace.flat<T>();
|
||||||
|
for (int64 i = 0; i < depth; ++i) {
|
||||||
|
linspace_flat(i) = i;
|
||||||
|
}
|
||||||
|
return linspace;
|
||||||
|
}
|
||||||
|
|
||||||
xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
|
xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
|
||||||
int axis) {
|
int axis) {
|
||||||
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
|
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
|
||||||
@ -162,17 +172,33 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
|
|||||||
// Build a Tensor populated with values 0, 1, 2, ... depth.
|
// Build a Tensor populated with values 0, 1, 2, ... depth.
|
||||||
std::vector<int64> linspace_dims(output_dims, 1);
|
std::vector<int64> linspace_dims(output_dims, 1);
|
||||||
linspace_dims[axis] = depth;
|
linspace_dims[axis] = depth;
|
||||||
xla::PrimitiveType linspace_type;
|
TensorShape linspace_shape(linspace_dims);
|
||||||
TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &linspace_type));
|
Tensor linspace;
|
||||||
auto linspace_shape = xla::ShapeUtil::MakeShape(linspace_type, linspace_dims);
|
switch (index_type) {
|
||||||
xla::XlaOp linspace = xla::Iota(builder, linspace_shape, axis);
|
case DT_UINT8:
|
||||||
|
linspace = MakeLinspaceTensor<uint8>(linspace_shape, depth);
|
||||||
|
break;
|
||||||
|
case DT_INT32:
|
||||||
|
linspace = MakeLinspaceTensor<int32>(linspace_shape, depth);
|
||||||
|
break;
|
||||||
|
case DT_INT64:
|
||||||
|
linspace = MakeLinspaceTensor<int64>(linspace_shape, depth);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return errors::InvalidArgument("Invalid argument type ",
|
||||||
|
DataTypeString(index_type));
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::BorrowingLiteral linspace_literal;
|
||||||
|
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
|
||||||
|
|
||||||
// Broadcast the linspace constant across the indices along the new axis,
|
// Broadcast the linspace constant across the indices along the new axis,
|
||||||
// and test equality at each position.
|
// and test equality at each position.
|
||||||
std::vector<int64> broadcast_dims(indices_shape.dims());
|
std::vector<int64> broadcast_dims(indices_shape.dims());
|
||||||
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
|
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
|
||||||
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
|
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
|
||||||
xla::XlaOp one_hot_bool = xla::Eq(indices, linspace, broadcast_dims);
|
xla::XlaOp one_hot_bool = xla::Eq(
|
||||||
|
indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims);
|
||||||
|
|
||||||
// Selects the user-provided off_value and on_value values.
|
// Selects the user-provided off_value and on_value values.
|
||||||
*one_hot = xla::Select(one_hot_bool,
|
*one_hot = xla::Select(one_hot_bool,
|
||||||
|
Loading…
Reference in New Issue
Block a user