[TF:XLA] Implement XlaSvdOp
PiperOrigin-RevId: 236936547
This commit is contained in:
parent
e26ab7b9e1
commit
9baeb353e1
@ -247,12 +247,23 @@ tf_xla_py_test(
|
|||||||
name = "self_adjoint_eig_op_test",
|
name = "self_adjoint_eig_op_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["self_adjoint_eig_op_test.py"],
|
srcs = ["self_adjoint_eig_op_test.py"],
|
||||||
# TODO(kuny): remove it after b/124377352 is fixed.
|
tags = ["optonly"],
|
||||||
disabled_backends = [
|
deps = [
|
||||||
"cpu",
|
":xla_test",
|
||||||
"gpu",
|
"//tensorflow/python:array_ops",
|
||||||
"cpu_ondemand",
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:map_fn",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "svd_op_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["svd_op_test.py"],
|
||||||
tags = ["optonly"],
|
tags = ["optonly"],
|
||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
|
81
tensorflow/compiler/tests/svd_op_test.py
Normal file
81
tensorflow/compiler/tests/svd_op_test.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.ops.svd."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import linalg_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _compute_usvt(self, s, u, v):
|
||||||
|
m = u.shape[-1]
|
||||||
|
n = v.shape[-1]
|
||||||
|
if m <= n:
|
||||||
|
v = v[..., :m]
|
||||||
|
else:
|
||||||
|
u = u[..., :n]
|
||||||
|
|
||||||
|
return np.matmul(u * s[..., None, :], np.swapaxes(v, -1, -2))
|
||||||
|
|
||||||
|
def _testSvdCorrectness(self, dtype, shape):
|
||||||
|
np.random.seed(1)
|
||||||
|
x_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype)
|
||||||
|
m, n = shape[-2], shape[-1]
|
||||||
|
_, s_np, _ = np.linalg.svd(x_np)
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
x_tf = array_ops.placeholder(dtype)
|
||||||
|
with self.test_scope():
|
||||||
|
s, u, v = linalg_ops.svd(x_tf, full_matrices=True)
|
||||||
|
s_val, u_val, v_val = sess.run([s, u, v], feed_dict={x_tf: x_np})
|
||||||
|
u_diff = np.matmul(u_val, np.swapaxes(u_val, -1, -2)) - np.eye(m)
|
||||||
|
v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n)
|
||||||
|
# Check u_val and v_val are orthogonal matrices.
|
||||||
|
self.assertLess(np.linalg.norm(u_diff), 1e-2)
|
||||||
|
self.assertLess(np.linalg.norm(v_diff), 1e-2)
|
||||||
|
# Check that the singular values are correct, i.e., close to the ones from
|
||||||
|
# numpy.lingal.svd.
|
||||||
|
self.assertLess(np.linalg.norm(s_val - s_np), 1e-2)
|
||||||
|
# The tolerance is set based on our tests on numpy's svd. As our tests
|
||||||
|
# have batch dimensions and all our operations are on float32, we set the
|
||||||
|
# tolerance a bit larger. Numpy's svd calls LAPACK's svd, which operates
|
||||||
|
# on double precision.
|
||||||
|
self.assertLess(
|
||||||
|
np.linalg.norm(self._compute_usvt(s_val, u_val, v_val) - x_np), 2e-2)
|
||||||
|
|
||||||
|
SIZES = [1, 2, 5, 10, 32, 64]
|
||||||
|
DTYPES = [np.float32]
|
||||||
|
PARAMS = itertools.product(SIZES, DTYPES)
|
||||||
|
|
||||||
|
@parameterized.parameters(*PARAMS)
|
||||||
|
def testSvd(self, n, dtype):
|
||||||
|
for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10):
|
||||||
|
self._testSvdCorrectness(dtype, batch_dims + (n, n))
|
||||||
|
self._testSvdCorrectness(dtype, batch_dims + (2 * n, n))
|
||||||
|
self._testSvdCorrectness(dtype, batch_dims + (n, 2 * n))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -110,6 +110,7 @@ tf_kernel_library(
|
|||||||
"xla_reduce_op.cc",
|
"xla_reduce_op.cc",
|
||||||
"xla_select_and_scatter_op.cc",
|
"xla_select_and_scatter_op.cc",
|
||||||
"xla_self_adjoint_eig_op.cc",
|
"xla_self_adjoint_eig_op.cc",
|
||||||
|
"xla_svd_op.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"index_ops.h",
|
"index_ops.h",
|
||||||
@ -149,7 +150,9 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/xla/client/lib:qr",
|
"//tensorflow/compiler/xla/client/lib:qr",
|
||||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:svd",
|
||||||
"//tensorflow/core:bitwise_ops_op_lib",
|
"//tensorflow/core:bitwise_ops_op_lib",
|
||||||
"//tensorflow/core:control_flow_ops_op_lib",
|
"//tensorflow/core:control_flow_ops_op_lib",
|
||||||
"//tensorflow/core:data_flow_ops_op_lib",
|
"//tensorflow/core:data_flow_ops_op_lib",
|
||||||
|
@ -35,10 +35,9 @@ class XlaConvOp : public XlaOpKernel {
|
|||||||
string precision_config_attr;
|
string precision_config_attr;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
context, context->GetAttr("precision_config", &precision_config_attr));
|
context, context->GetAttr("precision_config", &precision_config_attr));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context,
|
||||||
context,
|
precision_config_.ParsePartialFromString(precision_config_attr),
|
||||||
precision_config_.ParsePartialFromString(precision_config_attr),
|
errors::InvalidArgument("Error parsing precison config."));
|
||||||
errors::InvalidArgument("Error parsing convolution dimension numbers"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* context) override {
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
|
95
tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc
Normal file
95
tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/svd.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class XlaSvdOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit XlaSvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
|
||||||
|
string precision_config_attr;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->GetAttr("precision_config", &precision_config_attr));
|
||||||
|
OP_REQUIRES(ctx,
|
||||||
|
precision_config_.ParsePartialFromString(precision_config_attr),
|
||||||
|
errors::InvalidArgument("Error parsing precison config."));
|
||||||
|
if (precision_config_.operand_precision_size() == 0) {
|
||||||
|
precision_config_.mutable_operand_precision()->Add(
|
||||||
|
xla::PrecisionConfig::HIGHEST);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
auto result = xla::SVD(ctx->Input(0), max_iter_, epsilon_,
|
||||||
|
precision_config_.operand_precision(0));
|
||||||
|
ctx->SetOutput(0, result.d);
|
||||||
|
ctx->SetOutput(1, result.u);
|
||||||
|
ctx->SetOutput(2, result.v);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32 max_iter_;
|
||||||
|
float epsilon_;
|
||||||
|
xla::PrecisionConfig precision_config_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SvdOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit SvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("compute_uv", &compute_uv_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
|
||||||
|
}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const TensorShape input_shape = ctx->InputShape("input");
|
||||||
|
int m = input_shape.dim_size(input_shape.dims() - 2);
|
||||||
|
int n = input_shape.dim_size(input_shape.dims() - 1);
|
||||||
|
// This is based on heuristics that approx log(n) sweep updates are needed.
|
||||||
|
// Note: the heuristics provides no theoretical guarantee, max_iter=100 and
|
||||||
|
// epsilon should be used to determine exit condition.
|
||||||
|
int max_iter = 2 * tensorflow::Log2Ceiling(std::max(m, n));
|
||||||
|
auto result = xla::SVD(ctx->Input(0), max_iter, 1e-6);
|
||||||
|
ctx->SetOutput(0, result.d);
|
||||||
|
if (compute_uv_) {
|
||||||
|
int p = std::min(m, n);
|
||||||
|
if (!full_matrices_) {
|
||||||
|
if (p < m) {
|
||||||
|
result.u = xla::SliceInMinorDims(result.u, {0, 0}, {m, p});
|
||||||
|
}
|
||||||
|
if (p < n) {
|
||||||
|
result.v = xla::SliceInMinorDims(result.v, {0, 0}, {n, p});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx->SetOutput(1, result.u);
|
||||||
|
ctx->SetOutput(2, result.v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool compute_uv_;
|
||||||
|
bool full_matrices_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("XlaSvd").TypeConstraint("T", kFloatTypes), XlaSvdOp);
|
||||||
|
REGISTER_XLA_OP(Name("Svd").TypeConstraint("T", kFloatTypes), SvdOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -91,6 +91,40 @@ v: The column v[..., :, i] is the normalized eigenvector corresponding to the
|
|||||||
eigenvalue w[..., i].
|
eigenvalue w[..., i].
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("XlaSvd")
|
||||||
|
.Input("a: T")
|
||||||
|
.Attr("max_iter: int")
|
||||||
|
.Attr("epsilon: float")
|
||||||
|
.Attr("precision_config: string")
|
||||||
|
.Output("s: T")
|
||||||
|
.Output("u: T")
|
||||||
|
.Output("v: T")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
|
.Attr("T: numbertype")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Computes the eigen decomposition of a batch of self-adjoint matrices
|
||||||
|
(Note: Only real inputs are supported).
|
||||||
|
|
||||||
|
Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
|
||||||
|
tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
|
||||||
|
|
||||||
|
a: the input tensor.
|
||||||
|
|
||||||
|
max_iter: maximum number of sweep update, i.e., the whole lower triangular
|
||||||
|
part or upper triangular part based on parameter lower. Heuristically, it has
|
||||||
|
been argued that approximatly log(min (M, N)) sweeps are needed in practice
|
||||||
|
(Ref: Golub & van Loan "Matrix Computation").
|
||||||
|
|
||||||
|
epsilon: the tolerance ratio.
|
||||||
|
|
||||||
|
precision_config: a serialized xla::PrecisionConfig proto.
|
||||||
|
|
||||||
|
s: Singular values. The values are sorted in reverse order of magnitude, so
|
||||||
|
s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
|
||||||
|
u: Left singular vectors.
|
||||||
|
v: Right singular vectors.
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("XlaConv")
|
REGISTER_OP("XlaConv")
|
||||||
.Input("lhs: T")
|
.Input("lhs: T")
|
||||||
.Input("rhs: T")
|
.Input("rhs: T")
|
||||||
|
@ -295,6 +295,13 @@ def self_adjoint_eig(a, lower, max_iter, epsilon):
|
|||||||
return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
|
return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
|
||||||
|
|
||||||
|
|
||||||
|
def svd(a, max_iter, epsilon, precision_config=None):
|
||||||
|
precision_config_proto = ""
|
||||||
|
if precision_config:
|
||||||
|
precision_config_proto = precision_config.SerializeToString()
|
||||||
|
return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
|
||||||
|
|
||||||
|
|
||||||
dynamic_slice = gen_xla_ops.xla_dynamic_slice
|
dynamic_slice = gen_xla_ops.xla_dynamic_slice
|
||||||
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
|
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
|
||||||
|
|
||||||
|
@ -750,19 +750,18 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
|
|||||||
result.v = Mul(result.v, sign, broadcast_dims);
|
result.v = Mul(result.v, sign, broadcast_dims);
|
||||||
|
|
||||||
d = BroadcastInDim(d, dimensions, broadcast_dims);
|
d = BroadcastInDim(d, dimensions, broadcast_dims);
|
||||||
auto zero = Zero(builder, S32);
|
|
||||||
|
|
||||||
// As m >= n, only first m columns vectors are needed to be permuted, and the
|
// As m >= n, only first m columns vectors are needed to be permuted, and the
|
||||||
// rest of m - n vectors are appended after the sorting is done.
|
// rest of m - n vectors are appended after the sorting is done.
|
||||||
XlaOp sort_u_result =
|
XlaOp sort_u_result =
|
||||||
Sort({-d, DynamicSliceInMinorDims(result.u, {zero, zero}, {m, n})},
|
Sort({-d, SliceInMinorDims(result.u, {0, 0}, {m, n})},
|
||||||
CreateScalarLtComputation(
|
CreateScalarLtComputation(
|
||||||
{shape.element_type(), shape.element_type()}, builder),
|
{shape.element_type(), shape.element_type()}, builder),
|
||||||
num_dims - 1);
|
num_dims - 1);
|
||||||
|
|
||||||
// TODO(kuny): using CreateScalarGtComputation after b/124862300 is fixed.
|
// TODO(kuny): using CreateScalarGtComputation after b/124862300 is fixed.
|
||||||
XlaOp sort_v_result =
|
XlaOp sort_v_result =
|
||||||
Sort({DynamicSliceInMinorDims(-d, {zero, zero}, {n, n}), result.v},
|
Sort({SliceInMinorDims(-d, {0, 0}, {n, n}), result.v},
|
||||||
CreateScalarLtComputation(
|
CreateScalarLtComputation(
|
||||||
{shape.element_type(), shape.element_type()}, builder),
|
{shape.element_type(), shape.element_type()}, builder),
|
||||||
num_dims - 1);
|
num_dims - 1);
|
||||||
@ -779,12 +778,10 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
|
|||||||
broadcast_dims);
|
broadcast_dims);
|
||||||
|
|
||||||
// Append the rest of m - n vectors.
|
// Append the rest of m - n vectors.
|
||||||
result.u =
|
result.u = ConcatInDim(builder,
|
||||||
ConcatInDim(builder,
|
{GetTupleElement(sort_u_result, 1),
|
||||||
{GetTupleElement(sort_u_result, 1),
|
SliceInMinorDims(result.u, {0, n}, {m, m})},
|
||||||
DynamicSliceInMinorDims(
|
num_dims - 1);
|
||||||
result.u, {zero, ScalarLike(zero, n)}, {m, m - n})},
|
|
||||||
num_dims - 1);
|
|
||||||
result.u = Mul(
|
result.u = Mul(
|
||||||
result.u,
|
result.u,
|
||||||
Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0),
|
Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0),
|
||||||
|
@ -3348,7 +3348,7 @@ cuda_py_test(
|
|||||||
"no_rocm", # flaky test
|
"no_rocm", # flaky test
|
||||||
"no_windows",
|
"no_windows",
|
||||||
],
|
],
|
||||||
# TODO(kuny): Add xla_enable_strict_auto_jit = True after b/124377352 is fixed.
|
# b/127344411: xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3385,7 +3385,7 @@ cuda_py_test(
|
|||||||
"no_oss", # b/117185141.
|
"no_oss", # b/117185141.
|
||||||
"nomsan", # TODO(b/117236102): Re-enable in msan build.
|
"nomsan", # TODO(b/117236102): Re-enable in msan build.
|
||||||
],
|
],
|
||||||
xla_enable_strict_auto_jit = True,
|
# b/127344411: xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -3405,7 +3405,7 @@ cuda_py_test(
|
|||||||
"no_windows_gpu",
|
"no_windows_gpu",
|
||||||
"nomsan",
|
"nomsan",
|
||||||
],
|
],
|
||||||
xla_enable_strict_auto_jit = True,
|
# b/127344411: xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
|
@ -89,7 +89,7 @@ def _GetNormOpTest(dtype_, shape_, ord_, axis_, keep_dims_, use_static_shape_):
|
|||||||
if ((not is_matrix_norm and ord_ == "fro") or
|
if ((not is_matrix_norm and ord_ == "fro") or
|
||||||
(is_matrix_norm and is_fancy_p_norm)):
|
(is_matrix_norm and is_fancy_p_norm)):
|
||||||
self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
|
self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
|
||||||
if ord_ == 'euclidean' or (axis_ is None and len(shape) > 2):
|
if ord_ == "euclidean" or (axis_ is None and len(shape) > 2):
|
||||||
self.skipTest("Not supported by numpy.linalg.norm")
|
self.skipTest("Not supported by numpy.linalg.norm")
|
||||||
matrix = np.random.randn(*shape_).astype(dtype_)
|
matrix = np.random.randn(*shape_).astype(dtype_)
|
||||||
if dtype_ in (np.complex64, np.complex128):
|
if dtype_ in (np.complex64, np.complex128):
|
||||||
|
Loading…
Reference in New Issue
Block a user