diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9a34cd8c6ae..24626099106 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -139,16 +139,6 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, return Status::OK(); } -template -static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { - Tensor linspace(DataTypeToEnum::v(), shape); - auto linspace_flat = linspace.flat(); - for (int64 i = 0; i < depth; ++i) { - linspace_flat(i) = i; - } - return linspace; -} - xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis) { return ArgMinMax(input, output_type, axis, /*is_min=*/false); @@ -172,33 +162,17 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, // Build a Tensor populated with values 0, 1, 2, ... depth. std::vector linspace_dims(output_dims, 1); linspace_dims[axis] = depth; - TensorShape linspace_shape(linspace_dims); - Tensor linspace; - switch (index_type) { - case DT_UINT8: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT32: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT64: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(index_type)); - } - - xla::BorrowingLiteral linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); + xla::PrimitiveType linspace_type; + TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &linspace_type)); + auto linspace_shape = xla::ShapeUtil::MakeShape(linspace_type, linspace_dims); + xla::XlaOp linspace = xla::Iota(builder, linspace_shape, axis); // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = xla::Eq( - indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); + xla::XlaOp one_hot_bool = xla::Eq(indices, linspace, broadcast_dims); // Selects the user-provided off_value and on_value values. *one_hot = xla::Select(one_hot_bool,