Automated rollback of commit 9a8e0c228d

PiperOrigin-RevId: 219205308
This commit is contained in:
A. Unique TensorFlower 2018-10-29 15:34:54 -07:00 committed by TensorFlower Gardener
parent 72c4105a28
commit 6f4734086e

View File

@ -139,6 +139,16 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
return Status::OK();
}
template <typename T>
static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
Tensor linspace(DataTypeToEnum<T>::v(), shape);
auto linspace_flat = linspace.flat<T>();
for (int64 i = 0; i < depth; ++i) {
linspace_flat(i) = i;
}
return linspace;
}
xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
@ -162,17 +172,33 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
// Build a Tensor populated with values 0, 1, 2, ... depth.
std::vector<int64> linspace_dims(output_dims, 1);
linspace_dims[axis] = depth;
xla::PrimitiveType linspace_type;
TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &linspace_type));
auto linspace_shape = xla::ShapeUtil::MakeShape(linspace_type, linspace_dims);
xla::XlaOp linspace = xla::Iota(builder, linspace_shape, axis);
TensorShape linspace_shape(linspace_dims);
Tensor linspace;
switch (index_type) {
case DT_UINT8:
linspace = MakeLinspaceTensor<uint8>(linspace_shape, depth);
break;
case DT_INT32:
linspace = MakeLinspaceTensor<int32>(linspace_shape, depth);
break;
case DT_INT64:
linspace = MakeLinspaceTensor<int64>(linspace_shape, depth);
break;
default:
return errors::InvalidArgument("Invalid argument type ",
DataTypeString(index_type));
}
xla::BorrowingLiteral linspace_literal;
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
// Broadcast the linspace constant across the indices along the new axis,
// and test equality at each position.
std::vector<int64> broadcast_dims(indices_shape.dims());
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
xla::XlaOp one_hot_bool = xla::Eq(indices, linspace, broadcast_dims);
xla::XlaOp one_hot_bool = xla::Eq(
indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims);
// Selects the user-provided off_value and on_value values.
*one_hot = xla::Select(one_hot_bool,