From 6f4734086ebbcb2c37b3cacfd5e30a69c1fde0c1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 29 Oct 2018 15:34:54 -0700 Subject: [PATCH] Automated rollback of commit 9a8e0c228dcf6a98d0f35b3737be9a07a43961ea PiperOrigin-RevId: 219205308 --- tensorflow/compiler/tf2xla/xla_helpers.cc | 36 +++++++++++++++++++---- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 24626099106..9a34cd8c6ae 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -139,6 +139,16 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, return Status::OK(); } +template <typename T> +static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { + Tensor linspace(DataTypeToEnum<T>::v(), shape); + auto linspace_flat = linspace.flat<T>(); + for (int64 i = 0; i < depth; ++i) { + linspace_flat(i) = i; + } + return linspace; +} + xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis) { return ArgMinMax(input, output_type, axis, /*is_min=*/false); @@ -162,17 +172,33 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, // Build a Tensor populated with values 0, 1, 2, ... depth. std::vector<int64> linspace_dims(output_dims, 1); linspace_dims[axis] = depth; - xla::PrimitiveType linspace_type; - TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &linspace_type)); - auto linspace_shape = xla::ShapeUtil::MakeShape(linspace_type, linspace_dims); - xla::XlaOp linspace = xla::Iota(builder, linspace_shape, axis); + TensorShape linspace_shape(linspace_dims); + Tensor linspace; + switch (index_type) { + case DT_UINT8: + linspace = MakeLinspaceTensor<uint8>(linspace_shape, depth); + break; + case DT_INT32: + linspace = MakeLinspaceTensor<int32>(linspace_shape, depth); + break; + case DT_INT64: + linspace = MakeLinspaceTensor<int64>(linspace_shape, depth); + break; + default: + return errors::InvalidArgument("Invalid argument type ", + DataTypeString(index_type)); + } + + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector<int64> broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = xla::Eq(indices, linspace, broadcast_dims); + xla::XlaOp one_hot_bool = xla::Eq( + indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); // Selects the user-provided off_value and on_value values. *one_hot = xla::Select(one_hot_bool,