Forwardprop: Add utilities for temporarily pushing forward accumulator state
More work toward adding a function special case. An accumulator triggers function-building, then needs to work on symbolic tensors captured by the function before returning to its original task. PiperOrigin-RevId: 265120491
This commit is contained in:
parent
aba46a95e0
commit
e9ff15d98a
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
// Language-agnostic gradient tape. Does not perform backpropagation, just
|
// Language-agnostic gradient tape. Does not perform backpropagation, just
|
||||||
// maintains the data structures required to do so.
|
// maintains the data structures required to do so.
|
||||||
|
|
||||||
|
#include <stack>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
@ -209,7 +210,9 @@ class ForwardAccumulator {
|
|||||||
// ForwardAccumulator.
|
// ForwardAccumulator.
|
||||||
explicit ForwardAccumulator(
|
explicit ForwardAccumulator(
|
||||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
|
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
|
||||||
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
|
: vspace_(vspace) {
|
||||||
|
call_state_.emplace(nullptr, false);
|
||||||
|
}
|
||||||
|
|
||||||
virtual ~ForwardAccumulator() {
|
virtual ~ForwardAccumulator() {
|
||||||
for (auto accumulated : accumulated_gradients_) {
|
for (auto accumulated : accumulated_gradients_) {
|
||||||
@ -262,11 +265,11 @@ class ForwardAccumulator {
|
|||||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||||
const std::function<void(BackwardFunction*)>& backward_function_deleter);
|
const std::function<void(BackwardFunction*)>& backward_function_deleter);
|
||||||
|
|
||||||
// Returns true if `Accumulate` is active somewhere above on the stack. This
|
// Returns true if `Accumulate` is active somewhere above on the stack and
|
||||||
// is useful for ordering ForwardAccumulators, where more deeply nested
|
// there isn't an intervening PushState. This is useful for ordering
|
||||||
// accumulators should not see computations from less deeply nested
|
// ForwardAccumulators, where more deeply nested accumulators should not see
|
||||||
// accumulators.
|
// computations from less deeply nested accumulators.
|
||||||
bool BusyAccumulating() const { return this->accumulating_; }
|
bool BusyAccumulating() const { return call_state_.top().accumulating; }
|
||||||
|
|
||||||
// Fetches the current Jacobian-vector product associated with `tensor_id`, or
|
// Fetches the current Jacobian-vector product associated with `tensor_id`, or
|
||||||
// a nullptr if none is available.
|
// a nullptr if none is available.
|
||||||
@ -282,6 +285,15 @@ class ForwardAccumulator {
|
|||||||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
|
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
|
||||||
gtl::ArraySlice<tensorflow::DataType> dtypes);
|
gtl::ArraySlice<tensorflow::DataType> dtypes);
|
||||||
|
|
||||||
|
// Temporarily push or pop transient state for this accumulator.
|
||||||
|
//
|
||||||
|
// Allows an accumulator which is currently processing an operation to
|
||||||
|
// temporarily reset its state. Without pushing and poping, accumulators
|
||||||
|
// ignore operations executed as a direct result of their own jvp
|
||||||
|
// computations.
|
||||||
|
void PushState() { call_state_.emplace(nullptr, false); }
|
||||||
|
void PopState() { call_state_.pop(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Helper for Accumulate: uses a GradientTape to compute forward gradients
|
// Helper for Accumulate: uses a GradientTape to compute forward gradients
|
||||||
// from a backward gradient function. Fills `out_grads` corresponding to
|
// from a backward gradient function. Fills `out_grads` corresponding to
|
||||||
@ -289,7 +301,7 @@ class ForwardAccumulator {
|
|||||||
//
|
//
|
||||||
// Executes the backward function in order to trace its gradient, which will
|
// Executes the backward function in order to trace its gradient, which will
|
||||||
// waste computation if executing eagerly (when graph building the unneeded
|
// waste computation if executing eagerly (when graph building the unneeded
|
||||||
// computation is pruned). Temporarily sets `backward_tape_` so that
|
// computation is pruned). Temporarily sets `backward_tape` so that
|
||||||
// Accumulate will forward op executions to the tape while the backward
|
// Accumulate will forward op executions to the tape while the backward
|
||||||
// function is running; this effectively adds the backward tape to the active
|
// function is running; this effectively adds the backward tape to the active
|
||||||
// set (but does not require complicated callbacks to the language bindings).
|
// set (but does not require complicated callbacks to the language bindings).
|
||||||
@ -305,16 +317,25 @@ class ForwardAccumulator {
|
|||||||
// Not owned; provides operations on Tensors which are currently only
|
// Not owned; provides operations on Tensors which are currently only
|
||||||
// available in language bindings (e.g. Python).
|
// available in language bindings (e.g. Python).
|
||||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
|
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
|
||||||
// Set temporarily while in the Accumulate method; if backward_tape_ is not
|
|
||||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
struct AccumulatorCallState {
|
||||||
// backward pass on its backward function.
|
AccumulatorCallState(
|
||||||
//
|
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
|
||||||
// Not owned by the ForwardAccumulator. The method which sets `backward_tape_`
|
bool accumulating)
|
||||||
// keeps ownership.
|
: backward_tape(backward_tape), accumulating(accumulating) {}
|
||||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape_;
|
// Set temporarily while in the Accumulate method; if backward_tape is not
|
||||||
// While the Accumulate method is running (accumulating_ is True), any op
|
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||||
// executions not forwarded to backward_tape_ should be ignored.
|
// backward pass on its backward function.
|
||||||
bool accumulating_;
|
//
|
||||||
|
// Not owned by the ForwardAccumulator. The method which sets
|
||||||
|
// `backward_tape` keeps ownership.
|
||||||
|
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
|
||||||
|
// While the Accumulate method is running (accumulating is True), any op
|
||||||
|
// executions not forwarded to backward_tape should be ignored.
|
||||||
|
bool accumulating;
|
||||||
|
};
|
||||||
|
std::stack<AccumulatorCallState, std::vector<AccumulatorCallState>>
|
||||||
|
call_state_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Template instantiations here
|
// Template instantiations here
|
||||||
@ -847,12 +868,12 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
|||||||
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
|
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
|
||||||
gtl::ArraySlice<int64> tensor_ids,
|
gtl::ArraySlice<int64> tensor_ids,
|
||||||
gtl::ArraySlice<tensorflow::DataType> dtypes) {
|
gtl::ArraySlice<tensorflow::DataType> dtypes) {
|
||||||
if (backward_tape_ != nullptr) {
|
if (call_state_.top().backward_tape != nullptr) {
|
||||||
// If we're forwarding Accumulate calls to backward_tape_'s RecordOperation,
|
// If we're forwarding Accumulate calls to backward_tape's RecordOperation,
|
||||||
// we should also delegate ShouldRecord.
|
// we should also delegate ShouldRecord.
|
||||||
return backward_tape_->ShouldRecord(tensor_ids, dtypes);
|
return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
|
||||||
}
|
}
|
||||||
if (accumulating_) {
|
if (call_state_.top().accumulating) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < tensor_ids.size(); ++i) {
|
for (int i = 0; i < tensor_ids.size(); ++i) {
|
||||||
@ -884,9 +905,10 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
|||||||
*/
|
*/
|
||||||
std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
|
std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
|
||||||
new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
|
new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
|
||||||
backward_tape_ = tape.get();
|
AccumulatorCallState& call_state = call_state_.top();
|
||||||
|
call_state.backward_tape = tape.get();
|
||||||
auto pop_backward_tape =
|
auto pop_backward_tape =
|
||||||
gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; });
|
gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
|
||||||
std::vector<Gradient*> forwardprop_aids;
|
std::vector<Gradient*> forwardprop_aids;
|
||||||
std::vector<int64> sources;
|
std::vector<int64> sources;
|
||||||
std::unordered_set<int64> sources_set;
|
std::unordered_set<int64> sources_set;
|
||||||
@ -961,10 +983,10 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
|||||||
const ForwardFunction<Gradient>* forward_function,
|
const ForwardFunction<Gradient>* forward_function,
|
||||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||||
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
|
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
|
||||||
if (backward_tape_ != nullptr) {
|
if (call_state_.top().backward_tape != nullptr) {
|
||||||
// If backward_tape_ is not null, then this call to Accumulate is the result
|
// If backward_tape is not null, then this call to Accumulate is the result
|
||||||
// of a still-active call to Accumulate which is running operations. We
|
// of a still-active call to Accumulate which is running operations. We
|
||||||
// forward these operations to backward_tape_ so the outer Accumulate call
|
// forward these operations to backward_tape so the outer Accumulate call
|
||||||
// can do its work.
|
// can do its work.
|
||||||
//
|
//
|
||||||
// Rather than re-entering and delegating Accumulate like this, we could
|
// Rather than re-entering and delegating Accumulate like this, we could
|
||||||
@ -972,9 +994,9 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
|||||||
// (so it can deactivate itself and activate its GradientTape). Currently
|
// (so it can deactivate itself and activate its GradientTape). Currently
|
||||||
// that is managed by the language binding and would require relatively
|
// that is managed by the language binding and would require relatively
|
||||||
// messy callbacks.
|
// messy callbacks.
|
||||||
backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id,
|
call_state_.top().backward_tape->RecordOperation(
|
||||||
input_dtypes, backward_function_getter,
|
op_type, output_tensors, input_tensor_id, input_dtypes,
|
||||||
backward_function_deleter);
|
backward_function_getter, backward_function_deleter);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
|
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
|
||||||
@ -1012,9 +1034,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
|||||||
|
|
||||||
// Avoid infinite recursion. Whichever forward function we run, it'll end up
|
// Avoid infinite recursion. Whichever forward function we run, it'll end up
|
||||||
// executing ops, and we don't want to watch those with this accumulator.
|
// executing ops, and we don't want to watch those with this accumulator.
|
||||||
accumulating_ = true;
|
call_state_.emplace(nullptr, true);
|
||||||
auto reset_accumulating =
|
auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
|
||||||
gtl::MakeCleanup([this] { this->accumulating_ = false; });
|
|
||||||
|
|
||||||
std::vector<Gradient*> forward_grads;
|
std::vector<Gradient*> forward_grads;
|
||||||
if (forward_function == nullptr) {
|
if (forward_function == nullptr) {
|
||||||
|
@ -67,6 +67,7 @@ py_library(
|
|||||||
":execute",
|
":execute",
|
||||||
":execution_callbacks",
|
":execution_callbacks",
|
||||||
":forwardprop",
|
":forwardprop",
|
||||||
|
":forwardprop_util",
|
||||||
":function",
|
":function",
|
||||||
":graph_only_ops",
|
":graph_only_ops",
|
||||||
":monitoring",
|
":monitoring",
|
||||||
@ -270,6 +271,7 @@ cuda_py_test(
|
|||||||
srcs = ["forwardprop_test.py"],
|
srcs = ["forwardprop_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":forwardprop",
|
":forwardprop",
|
||||||
|
":forwardprop_util",
|
||||||
":test",
|
":test",
|
||||||
],
|
],
|
||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
@ -529,6 +531,16 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "forwardprop_util",
|
||||||
|
srcs = ["forwardprop_util.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "benchmarks_test",
|
name = "benchmarks_test",
|
||||||
srcs = ["benchmarks_test.py"],
|
srcs = ["benchmarks_test.py"],
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python import pywrap_tensorflow
|
|||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import forwardprop
|
from tensorflow.python.eager import forwardprop
|
||||||
|
from tensorflow.python.eager import forwardprop_util
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -274,6 +275,31 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, "test_error_string"):
|
with self.assertRaisesRegexp(ValueError, "test_error_string"):
|
||||||
f(c)
|
f(c)
|
||||||
|
|
||||||
|
def testPushPopAccumulatorState(self):
|
||||||
|
# Note that this example is somewhat contrived. push_forwardprop_state is
|
||||||
|
# probably only useful in practice for building functions that compute jvps
|
||||||
|
# alongside their usual outputs.
|
||||||
|
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||||
|
|
||||||
|
@custom_gradient.custom_gradient
|
||||||
|
def f(x):
|
||||||
|
y = math_ops.sin(x.numpy())
|
||||||
|
|
||||||
|
def grad(dy):
|
||||||
|
with forwardprop_util.push_forwardprop_state():
|
||||||
|
x_copy = constant_op.constant(x.numpy())
|
||||||
|
acc.watch(x_copy, dy)
|
||||||
|
y_copy = math_ops.sin(x_copy)
|
||||||
|
return dy * acc.jvp(y_copy)
|
||||||
|
|
||||||
|
return y, grad
|
||||||
|
|
||||||
|
c = constant_op.constant(1.)
|
||||||
|
d = constant_op.constant(2.)
|
||||||
|
acc.watch(c, d)
|
||||||
|
output = f(c)
|
||||||
|
self.assertAllClose(d * math_ops.cos(c), acc.jvp(output))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
[("Order{}".format(order), order, expected)
|
[("Order{}".format(order), order, expected)
|
||||||
for order, expected in enumerate(_X11_35_DERIVATIVES)])
|
for order, expected in enumerate(_X11_35_DERIVATIVES)])
|
||||||
|
47
tensorflow/python/eager/forwardprop_util.py
Normal file
47
tensorflow/python/eager/forwardprop_util.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilities for managing forward accumulators.
|
||||||
|
|
||||||
|
A separate file from forwardprop.py so that functions can use these utilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def push_forwardprop_state():
|
||||||
|
"""Temporarily push or pop transient state for accumulators in the active set.
|
||||||
|
|
||||||
|
Allows an accumulator which is currently processing an operation to
|
||||||
|
temporarily reset its state. This is useful when building forwardprop versions
|
||||||
|
of functions, where an accumulator will trigger function building and then
|
||||||
|
must process captured symbolic tensors while building it. Without pushing and
|
||||||
|
poping, accumulators ignore operations executed as a direct result of their
|
||||||
|
own jvp computations.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None (used for its side effect).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPushState()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPopState()
|
@ -256,6 +256,17 @@ void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
|
|||||||
// `accumulator`. Returns None if no JVP is available.
|
// `accumulator`. Returns None if no JVP is available.
|
||||||
PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, PyObject* tensor);
|
PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, PyObject* tensor);
|
||||||
|
|
||||||
|
// Temporarily push or pop transient state for accumulators in the active set.
|
||||||
|
//
|
||||||
|
// Allows an accumulator which is currently processing an operation to
|
||||||
|
// temporarily reset its state. This is useful when building forwardprop
|
||||||
|
// versions of functions, where an accumulator will trigger function building
|
||||||
|
// and then must process captured symbolic tensors while building it. Without
|
||||||
|
// pushing and poping, accumulators ignore operations executed as a direct
|
||||||
|
// result of their own jvp computations.
|
||||||
|
PyObject* TFE_Py_ForwardAccumulatorPushState();
|
||||||
|
PyObject* TFE_Py_ForwardAccumulatorPopState();
|
||||||
|
|
||||||
// Collects state from all current forward accumulators related to `tensors`.
|
// Collects state from all current forward accumulators related to `tensors`.
|
||||||
//
|
//
|
||||||
// This is useful for packing JVPs as function inputs before executing a
|
// This is useful for packing JVPs as function inputs before executing a
|
||||||
|
@ -1640,6 +1640,22 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
|||||||
Py_RETURN_FALSE;
|
Py_RETURN_FALSE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* TFE_Py_ForwardAccumulatorPushState() {
|
||||||
|
auto forward_accumulators = *GetAccumulatorSet();
|
||||||
|
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
|
||||||
|
accumulator->accumulator->PushState();
|
||||||
|
}
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* TFE_Py_ForwardAccumulatorPopState() {
|
||||||
|
auto forward_accumulators = *GetAccumulatorSet();
|
||||||
|
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
|
||||||
|
accumulator->accumulator->PopState();
|
||||||
|
}
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
|
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
|
||||||
if (!TapeCouldPossiblyRecord(tensors)) {
|
if (!TapeCouldPossiblyRecord(tensors)) {
|
||||||
return GetPythonObjectFromInt(0);
|
return GetPythonObjectFromInt(0);
|
||||||
|
@ -88,6 +88,8 @@ limitations under the License.
|
|||||||
%rename("%s") TFE_Py_ForwardAccumulatorSetRemove;
|
%rename("%s") TFE_Py_ForwardAccumulatorSetRemove;
|
||||||
%rename("%s") TFE_Py_ForwardAccumulatorWatch;
|
%rename("%s") TFE_Py_ForwardAccumulatorWatch;
|
||||||
%rename("%s") TFE_Py_ForwardAccumulatorJVP;
|
%rename("%s") TFE_Py_ForwardAccumulatorJVP;
|
||||||
|
%rename("%s") TFE_Py_ForwardAccumulatorPushState;
|
||||||
|
%rename("%s") TFE_Py_ForwardAccumulatorPopState;
|
||||||
%rename("%s") TFE_Py_PackForwardGradients;
|
%rename("%s") TFE_Py_PackForwardGradients;
|
||||||
%rename("%s") TFE_NewContextOptions;
|
%rename("%s") TFE_NewContextOptions;
|
||||||
%rename("%s") TFE_ContextOptionsSetConfig;
|
%rename("%s") TFE_ContextOptionsSetConfig;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user