Changing the copy-on-write semantics of resource variables.
A variable now has a bit which can be turned on which, when turned on, makes that variable act as copy-on-read instead of copy-on-write. This allows sparse writes to happen concurrently while only holding a shared lock, mimicking the use_locking behavior of ref variables. PiperOrigin-RevId: 224855851
This commit is contained in:
parent
4e7564ef05
commit
95358d2da3
@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext(
|
||||
}
|
||||
}
|
||||
|
||||
void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
|
||||
Device* device,
|
||||
Tensor* output_tensor,
|
||||
StatusCallback done) const {
|
||||
done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
|
||||
}
|
||||
|
||||
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
Device* device,
|
||||
Tensor* device_tensor,
|
||||
|
@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext {
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
absl::string_view tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) override;
|
||||
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
|
||||
Tensor* output_tensor,
|
||||
StatusCallback done) const override;
|
||||
|
||||
xla::LocalClient* client() const { return client_; }
|
||||
se::Stream* stream() const { return stream_.get(); }
|
||||
|
@ -37,6 +37,14 @@ void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
|
||||
}
|
||||
|
||||
void GPUDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
|
||||
Device* device,
|
||||
Tensor* output_tensor,
|
||||
StatusCallback done) const {
|
||||
GPUUtil::CopyGPUTensorToSameGPU(device, this, input_tensor, output_tensor,
|
||||
done);
|
||||
}
|
||||
|
||||
Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream,
|
||||
std::function<void()> func) {
|
||||
const DeviceBase::GpuDeviceInfo* gpu_info =
|
||||
|
@ -57,6 +57,10 @@ class GPUDeviceContext : public DeviceContext {
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) override;
|
||||
|
||||
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
|
||||
Tensor* output_tensor,
|
||||
StatusCallback done) const override;
|
||||
|
||||
void MaintainLifetimeOnStream(const Tensor* t,
|
||||
se::Stream* stream) const override {}
|
||||
|
||||
|
@ -82,6 +82,13 @@ class DeviceContext : public core::RefCounted {
|
||||
done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
|
||||
}
|
||||
|
||||
// Copies a tensor in this device.
|
||||
virtual void CopyTensorInSameDevice(const Tensor* input_tensor,
|
||||
Device* device, Tensor* output_tensor,
|
||||
StatusCallback done) const {
|
||||
done(errors::Unimplemented("Copy in same device not implemented."));
|
||||
}
|
||||
|
||||
// "device_tensor" is a tensor on a non-CPU device. Copies
|
||||
// device_tensor into "cpu_tensor". "cpu_tensor" must be allocated
|
||||
// to be of the same size as "device_tensor".
|
||||
|
@ -278,6 +278,12 @@ class DummyDeviceContext : public DeviceContext {
|
||||
~DummyDeviceContext() override {}
|
||||
int stream_id() const { return stream_id_; }
|
||||
|
||||
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
|
||||
Tensor* output_tensor,
|
||||
StatusCallback done) const override {
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
private:
|
||||
const int stream_id_;
|
||||
};
|
||||
|
@ -20,14 +20,46 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Resource stored by variables in the resource manager
|
||||
// (new, resource-style version).
|
||||
// Resource stored by variables in the resource manager (new, resource-style
|
||||
// version).
|
||||
//
|
||||
// These variables have a mixed access mode: they can operate on copy-on-write
|
||||
// mode (the default) or copy-on-read mode (used only for sparse access).
|
||||
//
|
||||
// When copy-on-write mode is enabled reading the value of the variable involves
|
||||
// grabbing its mutex in shared mode and aliasing the internal tensor as the
|
||||
// output of the read operation, increasing its reference count. Writing,
|
||||
// conversely, works by, under an exclusive lock, detecting whether there are
|
||||
// outstanding aliases of the tensor, using the reference count, copying the
|
||||
// tensor if they exist, and writing to either the original or a copy with no
|
||||
// outstanding aliases. Sparse operations are not supported in copy-on-write
|
||||
// mode.
|
||||
//
|
||||
// When a variable is accessed sparsely it switches to copy-on-read mode. To
|
||||
// switch we need to grab an exclusive lock and might (if there are aliases)
|
||||
// need to copy the entire tensor. Once copy-on-read mode is enabled, no tensor
|
||||
// is allowed to alias the variable's internal tensor. This means dense reads
|
||||
// must return a copy of the variable, done while holding a shared lock. Dense
|
||||
// writes do not need to check whether aliases exist, and can always write
|
||||
// directly to the buffer without making a copy, while holding an exclusive
|
||||
// lock. Sparse reads and sparse writes, on the other hand, can be done under a
|
||||
// shared or exclusive mutex (the damage from writes under a shared mutex is
|
||||
// limited since no other buffer is allowed to alias the variable's
|
||||
// buffer). Using an exclusive mutex disallows concurrent writes and concurrent
|
||||
// sparse reads, providing some extra safety at the expense of performance,
|
||||
// while shared mutex allow for "hogwild" behavior. Doing sparse writes under a
|
||||
// shared mutex prevents them from overlapping with dense writes, which is
|
||||
// necessary as dense writes can change the shape the of the tensor.
|
||||
//
|
||||
// Transitioning a variable from copy-on-read mode to copy-on-write mode is
|
||||
// currently not supported. To upgrade a variable from copy-on-write to
|
||||
// copy-on-read use `EnsureSparseVariableAccess()`, and then grab the variable's
|
||||
// mutex as desired. To access the variable in dense mode grab the mutex either
|
||||
// directly or via `MaybeLockVariableInputMutexesInOrder` on all variables being
|
||||
// modified and then call `PrepareToUpdateVariable` on them in any order.
|
||||
class Var : public ResourceBase {
|
||||
public:
|
||||
explicit Var(DataType dtype) : tensor_(dtype) {}
|
||||
// Not copyable or movable.
|
||||
Var(const Var&) = delete;
|
||||
Var& operator=(const Var&) = delete;
|
||||
|
||||
// When locking multiple variables, the locks must be acquired in order of
|
||||
// increasing mu() address.
|
||||
@ -48,11 +80,19 @@ class Var : public ResourceBase {
|
||||
bool is_initialized = false; // GUARDED_BY(mu_) but annotalysis doesn't like
|
||||
// it.
|
||||
|
||||
// Also fake-guarded by mu_. Should be set to True whenever any sparse
|
||||
// operation uses the variable. Once this is true no tensor is allowed to
|
||||
// alias the memory of the variable, and we always copy the variable on
|
||||
// reads. This allows sparse operations to happen with only a shared lock if
|
||||
// so desired.
|
||||
std::atomic<bool> copy_on_read_mode{false};
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
Tensor tensor_;
|
||||
|
||||
~Var() override {}
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Var);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -45,6 +45,7 @@ class TensorBuffer;
|
||||
class TensorCApi;
|
||||
class TensorDescription;
|
||||
class TensorProto;
|
||||
class Var;
|
||||
|
||||
namespace batch_util {
|
||||
Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
|
||||
@ -581,11 +582,16 @@ class Tensor {
|
||||
friend class XlaTensor; // For access to RefCountIsOne().
|
||||
friend class XlaTensorBuffer; // For access to the private constructor taking
|
||||
// the buffer
|
||||
friend class Var;
|
||||
template <typename Device, typename T>
|
||||
friend class AssignVariableOp; // For access to RefCountIsOne().
|
||||
template <typename Device, typename T>
|
||||
friend Status PrepareToUpdateVariable(
|
||||
OpKernelContext* ctx, Tensor* tensor); // For access to RefCountIsOne().
|
||||
OpKernelContext* ctx, Tensor* tensor,
|
||||
bool copy_on_read_mode); // For access to RefCountIsOne().
|
||||
template <typename Device, typename T>
|
||||
friend Status EnsureSparseVariableAccess(
|
||||
OpKernelContext* ctx, Var* var); // For access to RefCountIsOne().
|
||||
friend Status batch_util::CopyElementToSlice(
|
||||
Tensor element, Tensor* parent,
|
||||
int64 index); // For access to RefCountIsOne().
|
||||
|
@ -2196,6 +2196,7 @@ tf_kernel_library(
|
||||
":state",
|
||||
":training_op_helpers",
|
||||
":variable_ops",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
@ -55,6 +55,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
@ -84,6 +85,47 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
namespace {
|
||||
Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
|
||||
Tensor* output;
|
||||
Notification n;
|
||||
Status status;
|
||||
AllocatorAttributes attr;
|
||||
if (t->dtype() == DT_VARIANT) {
|
||||
attr.set_on_host(true);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(output_idx, t->shape(), &output, attr));
|
||||
if (t->dtype() == DT_VARIANT) {
|
||||
output->flat<Variant>() = t->flat<Variant>();
|
||||
} else if (ctx->op_device_context() != nullptr) {
|
||||
// TODO(apassos): remove the down_cast by just returning Device* from
|
||||
// OpKernelContext
|
||||
Device* device = static_cast<Device*>(ctx->device());
|
||||
ctx->op_device_context()->CopyTensorInSameDevice(
|
||||
t, device, output, [&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return status;
|
||||
} else {
|
||||
switch (t->dtype()) {
|
||||
#define HANDLER(type) \
|
||||
case DataTypeToEnum<type>::value: \
|
||||
output->flat<type>() = t->flat<type>(); \
|
||||
break;
|
||||
TF_CALL_ALL_TYPES(HANDLER);
|
||||
#undef HANDLER
|
||||
default:
|
||||
return errors::Internal("Unsupported dtype", t->dtype());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ReadVariableOp::Compute(OpKernelContext* ctx) {
|
||||
Var* variable = nullptr;
|
||||
const ResourceHandle& handle = HandleFromInput(ctx, 0);
|
||||
@ -100,12 +142,16 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
|
||||
// holding a shared lock to guarantee ordering of reads and
|
||||
// writes.
|
||||
tf_shared_lock ml(*variable->mu());
|
||||
const Tensor& t = *variable->tensor();
|
||||
OP_REQUIRES(ctx, dtype_ == t.dtype(),
|
||||
const Tensor* t = variable->tensor();
|
||||
OP_REQUIRES(ctx, dtype_ == t->dtype(),
|
||||
errors::InvalidArgument(
|
||||
"Trying to read variable with wrong dtype. Expected ",
|
||||
DataTypeString(dtype_), " got ", DataTypeString(t.dtype())));
|
||||
ctx->set_output(0, t);
|
||||
DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
|
||||
if (variable->copy_on_read_mode.load()) {
|
||||
OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
|
||||
} else {
|
||||
ctx->set_output(0, *t);
|
||||
}
|
||||
}
|
||||
|
||||
ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
@ -146,14 +192,18 @@ void ReadVariablesOp::Compute(OpKernelContext* ctx) {
|
||||
// holding a shared lock to guarantee ordering of reads and
|
||||
// writes.
|
||||
tf_shared_lock ml(*variables[i]->mu());
|
||||
const Tensor& t = *variables[i]->tensor();
|
||||
OP_REQUIRES(ctx, dtypes_[i] == t.dtype(),
|
||||
OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
|
||||
errors::InvalidArgument(
|
||||
"Trying to read variable ", handles[i]->name(),
|
||||
" from Container: ", handles[i]->container(),
|
||||
" with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
|
||||
" got ", DataTypeString(t.dtype())));
|
||||
ctx->set_output(i, t);
|
||||
" got ", DataTypeString(variables[i]->tensor()->dtype())));
|
||||
if (variables[i]->copy_on_read_mode.load()) {
|
||||
OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
|
||||
} else {
|
||||
const Tensor& t = *variables[i]->tensor();
|
||||
ctx->set_output(i, t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -308,8 +358,23 @@ class AssignVariableOp : public OpKernel {
|
||||
"Trying to assign variable with wrong dtype. Expected ",
|
||||
DataTypeString(variable->tensor()->dtype()), " got ",
|
||||
DataTypeString(dtype_)));
|
||||
if (variable->copy_on_read_mode.load()) {
|
||||
PersistentTensor unused;
|
||||
Tensor* tmp;
|
||||
AllocatorAttributes attr;
|
||||
attr.set_gpu_compatible(true);
|
||||
attr.set_nic_compatible(true);
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(value.dtype(), value.shape(),
|
||||
&unused, &tmp, attr));
|
||||
functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
|
||||
copy_functor(context->eigen_device<Device>(), tmp->flat<T>(),
|
||||
value.flat<T>());
|
||||
*variable->tensor() = *tmp;
|
||||
} else {
|
||||
*variable->tensor() = value;
|
||||
}
|
||||
variable->is_initialized = true;
|
||||
*variable->tensor() = value;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -442,8 +507,9 @@ class AssignUpdateVariableOp : public OpKernel {
|
||||
" using a Tensor with shape ",
|
||||
value.shape().DebugString(),
|
||||
", shapes must be equal."));
|
||||
OP_REQUIRES_OK(context,
|
||||
PrepareToUpdateVariable<Device, T>(context, var_tensor));
|
||||
OP_REQUIRES_OK(
|
||||
context, PrepareToUpdateVariable<Device, T>(
|
||||
context, var_tensor, variable->copy_on_read_mode.load()));
|
||||
functor::DenseUpdate<Device, T, Op> update_functor;
|
||||
update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
|
||||
value.flat<T>());
|
||||
@ -524,6 +590,7 @@ class ResourceGatherOp : public OpKernel {
|
||||
Var* v = nullptr;
|
||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||
core::ScopedUnref su(v);
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
||||
// NOTE: We hold the lock for the whole gather operation instead
|
||||
// of increasing the reference count of v->tensor() to avoid a
|
||||
// situation where a write to the same variable will see a
|
||||
@ -639,9 +706,9 @@ class ResourceScatterUpdateOp : public OpKernel {
|
||||
Var* v = nullptr;
|
||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||
core::ScopedUnref unref_v(v);
|
||||
mutex_lock ml(*v->mu());
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
||||
tf_shared_lock ml(*v->mu());
|
||||
Tensor* params = v->tensor();
|
||||
OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, params));
|
||||
const Tensor& indices = c->input(1);
|
||||
const Tensor& updates = c->input(2);
|
||||
|
||||
|
@ -231,6 +231,7 @@ class ScatterNdUpdateOp : public OpKernel {
|
||||
Var* v;
|
||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||
core::ScopedUnref scoped_unref(v);
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
||||
mutex_lock m(*v->mu());
|
||||
DoCompute(c);
|
||||
} else if (use_exclusive_lock_) {
|
||||
@ -258,7 +259,6 @@ class ScatterNdUpdateOp : public OpKernel {
|
||||
Var* v;
|
||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||
Tensor* t = v->tensor();
|
||||
OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
|
||||
params = *t;
|
||||
params_shape = params.shape();
|
||||
} else if (IsRefType(c->input_dtype(0))) {
|
||||
|
@ -307,9 +307,9 @@ class StridedSliceAssignOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
LookupResource(context, HandleFromInput(context, 0), &v));
|
||||
core::ScopedUnref scoped_unref(v);
|
||||
mutex_lock ml(*v->mu());
|
||||
OP_REQUIRES_OK(context,
|
||||
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
|
||||
EnsureSparseVariableAccess<Device, T>(context, v));
|
||||
mutex_lock ml(*v->mu());
|
||||
old_lhs = v->tensor();
|
||||
OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
|
||||
errors::InvalidArgument(
|
||||
|
@ -19,70 +19,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
|
||||
Var** maybe_resource) {
|
||||
*maybe_resource = nullptr;
|
||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||
if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
|
||||
return (*maybe_resource)->mu();
|
||||
} else {
|
||||
ctx->CtxFailureWithWarning(
|
||||
errors::Internal("Invalid variable reference."));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return ctx->input_ref_mutex(input);
|
||||
}
|
||||
|
||||
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
|
||||
// in address order to mitigate deadlock. Returns a structure that, when
|
||||
// deleted, will release the acquired mutexes. Safe to pass duplicates - will
|
||||
// only lock each distinct mutex once. If do_lock is false, returns
|
||||
// immediately. Note that this silently doesn't lock mutexes for invalid
|
||||
// variable references; in all usages this is followed by GetInputTensor which
|
||||
// will signal a failure.
|
||||
VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
|
||||
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
|
||||
bool any_resource = false;
|
||||
for (auto i : input_ids) {
|
||||
if (ctx->input_dtype(i) == DT_RESOURCE) {
|
||||
any_resource = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!do_lock && !any_resource) {
|
||||
return VariableInputLockHolder({}, {});
|
||||
}
|
||||
std::vector<Var*> vars;
|
||||
std::vector<mutex*> mutexes;
|
||||
std::vector<int> acquire_order;
|
||||
for (auto input : input_ids) {
|
||||
Var* var;
|
||||
mutex* mutex = GetTrainingVariableMutex(ctx, input, &var);
|
||||
if (var) vars.push_back(var);
|
||||
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
|
||||
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
|
||||
acquire_order.push_back(mutexes.size());
|
||||
mutexes.push_back(mutex);
|
||||
}
|
||||
}
|
||||
std::sort(acquire_order.begin(), acquire_order.end(),
|
||||
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
||||
|
||||
std::unique_ptr<std::vector<mutex_lock>> locks =
|
||||
MakeUnique<std::vector<mutex_lock>>();
|
||||
locks->reserve(acquire_order.size());
|
||||
|
||||
for (auto input : acquire_order) {
|
||||
Var* var;
|
||||
mutex* mu = GetTrainingVariableMutex(ctx, input, &var);
|
||||
core::ScopedUnref scoped_unref(var);
|
||||
if (mu != nullptr) {
|
||||
locks->emplace_back(*mu);
|
||||
}
|
||||
}
|
||||
return VariableInputLockHolder(std::move(vars), std::move(locks));
|
||||
}
|
||||
|
||||
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
||||
int output) {
|
||||
|
@ -17,30 +17,72 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
|
||||
//
|
||||
// If `input` corresponds to a `DT_RESOURCE`-type variable input,
|
||||
// `*maybe_resource` will be updated to contain the underlying resource, and the
|
||||
// caller will be responsible for calling `Unref()` on that resource.
|
||||
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
|
||||
Var** maybe_resource);
|
||||
// Must be called before performing a sparse operation on a variable. Ensures
|
||||
// that no concurrent dense operations can happen while holding the variable's
|
||||
// lock.
|
||||
template <typename Device, typename T>
|
||||
Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
|
||||
if (var->copy_on_read_mode.load()) {
|
||||
return Status::OK();
|
||||
}
|
||||
mutex_lock ml(*var->mu());
|
||||
// Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
|
||||
// also happen if there are no concurrent reads of the variable and
|
||||
// copy-on-read mode is false.
|
||||
if (var->tensor()->RefCountIsOne()) {
|
||||
var->copy_on_read_mode.store(true);
|
||||
return Status::OK();
|
||||
}
|
||||
PersistentTensor unused;
|
||||
Tensor* tmp;
|
||||
if (std::is_same<T, Variant>::value) {
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_persistent(
|
||||
var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
|
||||
|
||||
const auto elements_in = var->tensor()->flat<Variant>();
|
||||
auto elements_out = tmp->flat<Variant>();
|
||||
for (int64 i = 0; i < elements_in.size(); ++i) {
|
||||
elements_out(i) = elements_in(i);
|
||||
}
|
||||
} else {
|
||||
AllocatorAttributes attr;
|
||||
attr.set_gpu_compatible(true);
|
||||
attr.set_nic_compatible(true);
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_persistent(
|
||||
var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
|
||||
functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
|
||||
copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
|
||||
const_cast<const Tensor*>(var->tensor())->flat<T>());
|
||||
}
|
||||
*var->tensor() = *tmp;
|
||||
var->copy_on_read_mode.store(true);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Utility structure that releases a sequence of borrowed mutexes when it is
|
||||
// deleted.
|
||||
struct VariableInputLockHolder {
|
||||
public:
|
||||
VariableInputLockHolder(std::vector<Var*> vars,
|
||||
std::unique_ptr<std::vector<mutex_lock>> locks)
|
||||
: vars_(std::move(vars)), locks_(std::move(locks)) {}
|
||||
VariableInputLockHolder(
|
||||
std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
|
||||
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
|
||||
: vars_(std::move(vars)),
|
||||
locks_(std::move(locks)),
|
||||
shared_locks_(std::move(shared_locks)) {}
|
||||
|
||||
VariableInputLockHolder(VariableInputLockHolder&& other)
|
||||
: vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {}
|
||||
: vars_(std::move(other.vars_)),
|
||||
locks_(std::move(other.locks_)),
|
||||
shared_locks_(std::move(other.shared_locks_)) {}
|
||||
|
||||
~VariableInputLockHolder() {
|
||||
// Release the locks before unreffing the Vars, because each lock
|
||||
@ -56,10 +98,95 @@ struct VariableInputLockHolder {
|
||||
// NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
|
||||
// because a `std::vector<mutex_lock>` is not movable on all platforms.
|
||||
std::unique_ptr<std::vector<mutex_lock>> locks_;
|
||||
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
|
||||
};
|
||||
|
||||
// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
|
||||
//
|
||||
// If `input` corresponds to a `DT_RESOURCE`-type variable input,
|
||||
// `*maybe_resource` will be updated to contain the underlying resource, and the
|
||||
// caller will be responsible for calling `Unref()` on that resource.
|
||||
template <typename Device, typename T>
|
||||
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
|
||||
Var** maybe_resource) {
|
||||
*maybe_resource = nullptr;
|
||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||
if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
|
||||
if (sparse) {
|
||||
EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource);
|
||||
}
|
||||
return (*maybe_resource)->mu();
|
||||
} else {
|
||||
ctx->CtxFailureWithWarning(
|
||||
errors::Internal("Invalid variable reference."));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return ctx->input_ref_mutex(input);
|
||||
}
|
||||
|
||||
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
|
||||
// in address order to mitigate deadlock. Returns a structure that, when
|
||||
// deleted, will release the acquired mutexes. Safe to pass duplicates - will
|
||||
// only lock each distinct mutex once. If sparse is true will ensure the
|
||||
// variable gets switched to copy-on-read mode before trying to acquire the
|
||||
// locks. If do_lock is false, returns immediately for reference variables. For
|
||||
// resource variables in copy-on-read-mode it will grab a shared lock if do_lock
|
||||
// is false, exclusive lock otherwise. Note that this silently doesn't lock
|
||||
// mutexes for invalid variable references; in all usages this is followed by
|
||||
// GetInputTensor which will signal a failure.
|
||||
template <typename Device, typename T>
|
||||
VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
|
||||
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
|
||||
OpKernelContext* ctx, bool do_lock, bool sparse,
|
||||
const std::vector<int>& input_ids) {
|
||||
bool any_resource = false;
|
||||
for (auto i : input_ids) {
|
||||
if (ctx->input_dtype(i) == DT_RESOURCE) {
|
||||
any_resource = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!do_lock && !any_resource) {
|
||||
return VariableInputLockHolder({}, {}, {});
|
||||
}
|
||||
std::vector<Var*> vars;
|
||||
std::vector<mutex*> mutexes;
|
||||
std::vector<int> acquire_order;
|
||||
for (auto input : input_ids) {
|
||||
Var* var;
|
||||
mutex* mutex =
|
||||
GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
|
||||
if (var) vars.push_back(var);
|
||||
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
|
||||
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
|
||||
acquire_order.push_back(mutexes.size());
|
||||
mutexes.push_back(mutex);
|
||||
}
|
||||
}
|
||||
std::sort(acquire_order.begin(), acquire_order.end(),
|
||||
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
||||
|
||||
std::unique_ptr<std::vector<mutex_lock>> locks =
|
||||
absl::make_unique<std::vector<mutex_lock>>();
|
||||
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks =
|
||||
absl::make_unique<std::vector<tf_shared_lock>>();
|
||||
locks->reserve(acquire_order.size());
|
||||
|
||||
for (auto input : acquire_order) {
|
||||
Var* var;
|
||||
mutex* mu = GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
|
||||
core::ScopedUnref scoped_unref(var);
|
||||
if (mu != nullptr) {
|
||||
if (do_lock) {
|
||||
locks->emplace_back(*mu);
|
||||
} else {
|
||||
shared_locks->emplace_back(*mu);
|
||||
}
|
||||
}
|
||||
}
|
||||
return VariableInputLockHolder(std::move(vars), std::move(locks),
|
||||
std::move(shared_locks));
|
||||
}
|
||||
|
||||
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
||||
int output);
|
||||
@ -68,8 +195,9 @@ void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
||||
// reference count of 1 before you update it.
|
||||
// REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
|
||||
template <typename Device, typename T>
|
||||
Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor) {
|
||||
if (!tensor->RefCountIsOne()) {
|
||||
Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
|
||||
bool copy_on_read_mode) {
|
||||
if (copy_on_read_mode || !tensor->RefCountIsOne()) {
|
||||
// Tensor's buffer is in use by some read, so we need to copy before
|
||||
// updating.
|
||||
PersistentTensor unused;
|
||||
@ -100,12 +228,14 @@ Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// This gives you `*out`, a tensor you can update, corresponding to a
|
||||
// variable passed as input index `input`. This handles the
|
||||
// differences between reference and resource variables. For resource
|
||||
// variables, we ensure `*out` has a reference count of 1 (using
|
||||
// PrepareToUpdateVariable() to copy if necessary) unless
|
||||
// sparse && !lock_held, in which case it never copies.
|
||||
// This gives you `*out`, a tensor you can update, corresponding to a variable
|
||||
// passed as input index `input`. This handles the differences between
|
||||
// reference and resource variables. For reference variables we can just grab
|
||||
// the tensor, grabbing the lock if lock_held is False.
|
||||
//
|
||||
// For resource variables we, if sparse is true, ensure it's in copy-on-read
|
||||
// mode, and then, regardless of the value of sparse, ensure its refcount is 1
|
||||
// (by potentially copying its contents). In this case lock_held is ignored.
|
||||
template <typename Device, typename T>
|
||||
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
|
||||
bool lock_held, bool sparse, Tensor* out) {
|
||||
@ -113,7 +243,13 @@ Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
|
||||
Var* var;
|
||||
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
|
||||
core::ScopedUnref unref_var(var);
|
||||
TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(ctx, var->tensor()));
|
||||
if (sparse) {
|
||||
TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var));
|
||||
*out = *var->tensor();
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
|
||||
ctx, var->tensor(), var->copy_on_read_mode.load()));
|
||||
*out = *var->tensor();
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -465,11 +465,12 @@ class ApplyGradientDescentOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -506,11 +507,12 @@ class ApplyGradientDescentOp<SYCLDevice, T> : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<SYCLDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -600,7 +602,8 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Var* resource;
|
||||
mutex* mu = GetTrainingVariableMutex(ctx, 0, &resource);
|
||||
const bool sparse = false;
|
||||
mutex* mu = GetTrainingVariableMutex<Device, T>(ctx, 0, sparse, &resource);
|
||||
core::ScopedUnref scoped_unref(resource);
|
||||
if (use_exclusive_lock_ && mu != nullptr) {
|
||||
mutex_lock l1(*mu);
|
||||
@ -624,14 +627,16 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
|
||||
void DoValidate(OpKernelContext* ctx) {
|
||||
Tensor var;
|
||||
const bool sparse = false;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
Tensor accum_update;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &accum_update));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable<Device, T>(ctx, 2, use_exclusive_lock_,
|
||||
sparse, &accum_update));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -678,14 +683,16 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
void DoCompute(OpKernelContext* ctx) {
|
||||
const Device& device = ctx->template eigen_device<Device>();
|
||||
Tensor var;
|
||||
const bool sparse = false;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
Tensor accum_update;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &accum_update));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable<Device, T>(ctx, 2, use_exclusive_lock_,
|
||||
sparse, &accum_update));
|
||||
|
||||
const Tensor& lr = ctx->input(3);
|
||||
const Tensor& rho = ctx->input(4);
|
||||
@ -751,7 +758,8 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Var* var;
|
||||
mutex* mu = GetTrainingVariableMutex(ctx, 0, &var);
|
||||
const bool sparse = true;
|
||||
mutex* mu = GetTrainingVariableMutex<CPUDevice, T>(ctx, 0, sparse, &var);
|
||||
core::ScopedUnref scoped_unref(var);
|
||||
// mu_accum is actually the same mutex as mu_var since currently we use a
|
||||
// global mutex.
|
||||
@ -767,14 +775,16 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
||||
|
||||
void DoCompute(OpKernelContext* ctx) {
|
||||
Tensor var;
|
||||
const bool sparse = true;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum_grad;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, true, &accum_grad));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum_grad));
|
||||
Tensor accum_update;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 2, use_exclusive_lock_, true, &accum_update));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 2, use_exclusive_lock_, sparse, &accum_update));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -907,11 +917,12 @@ class ApplyProximalGradientDescentOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -976,11 +987,12 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
|
||||
errors::InvalidArgument("var must be at least 1 dimensional"));
|
||||
|
||||
@ -1121,14 +1133,15 @@ class ApplyAdagradOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1214,14 +1227,15 @@ class ApplyProximalAdagradOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1316,14 +1330,15 @@ class SparseApplyAdagradOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, true, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1456,14 +1471,15 @@ class SparseApplyProximalAdagradOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, true, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1628,19 +1644,20 @@ class ApplyAdagradDAOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor gradient_accum;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable<Device, T>(ctx, 1, use_exclusive_lock_,
|
||||
false, &gradient_accum));
|
||||
sparse, &gradient_accum));
|
||||
Tensor gradient_squared_accum;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &gradient_squared_accum));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &gradient_squared_accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1729,19 +1746,20 @@ class SparseApplyAdagradDAOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor gradient_accum;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, true, &gradient_accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &gradient_accum));
|
||||
Tensor gradient_squared_accum;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 2, use_exclusive_lock_, true, &gradient_squared_accum));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &gradient_squared_accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1927,18 +1945,19 @@ class ApplyFtrlOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
Tensor linear;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &linear));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &linear));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2079,17 +2098,18 @@ class SparseApplyFtrlOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, true, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
Tensor linear;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, true, &linear));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &linear));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2353,15 +2373,16 @@ class ApplyMomentumOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2454,15 +2475,16 @@ class SparseApplyMomentumOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, true, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2572,15 +2594,16 @@ class ApplyKerasMomentumOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2671,15 +2694,16 @@ class SparseApplyKerasMomentumOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, true, &accum));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2783,18 +2807,19 @@ class ApplyAdamOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor m;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &m));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &m));
|
||||
Tensor v;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &v));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &v));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2873,18 +2898,19 @@ class ApplyAdamOp<SYCLDevice, T> : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<SYCLDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor m;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &m));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &m));
|
||||
Tensor v;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &v));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &v));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -3043,21 +3069,22 @@ class ApplyAdamWithAmsgradOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor m;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &m));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &m));
|
||||
Tensor v;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &v));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &v));
|
||||
Tensor vhat;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 3, use_exclusive_lock_, false, &vhat));
|
||||
ctx, 3, use_exclusive_lock_, sparse, &vhat));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -3184,18 +3211,19 @@ class ApplyAdaMaxOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor m;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &m));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &m));
|
||||
Tensor v;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &v));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &v));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -3312,18 +3340,19 @@ class ApplyRMSPropOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &ms));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &mom));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -3394,21 +3423,22 @@ class ApplyCenteredRMSPropOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2, 3});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2, 3});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor mg;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &mg));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &mg));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 2, use_exclusive_lock_, false, &ms));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 3, use_exclusive_lock_, false, &mom));
|
||||
ctx, 3, use_exclusive_lock_, sparse, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -3553,18 +3583,19 @@ class SparseApplyRMSPropOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, true, &ms));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 2, use_exclusive_lock_, true, &mom));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -3682,21 +3713,22 @@ class SparseApplyCenteredRMSPropOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2, 3});
|
||||
const bool sparse = true;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1, 2, 3});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 0, use_exclusive_lock_, true, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor mg;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 1, use_exclusive_lock_, true, &mg));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &mg));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 2, use_exclusive_lock_, true, &ms));
|
||||
ctx, 2, use_exclusive_lock_, sparse, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
|
||||
ctx, 3, use_exclusive_lock_, true, &mom));
|
||||
ctx, 3, use_exclusive_lock_, sparse, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -3852,15 +3884,16 @@ class ApplyAddSignOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor m;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &m));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &m));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -3958,15 +3991,16 @@ class ApplyPowerSignOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
const bool sparse = false;
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
|
||||
ctx, use_exclusive_lock_, sparse, {0, 1});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 0, use_exclusive_lock_, false, &var));
|
||||
ctx, 0, use_exclusive_lock_, sparse, &var));
|
||||
Tensor m;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
|
||||
ctx, 1, use_exclusive_lock_, false, &m));
|
||||
ctx, 1, use_exclusive_lock_, sparse, &m));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
@ -953,6 +954,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
state_ops.scatter_sub(v, [1], [3])
|
||||
self.assertAllEqual([1.0, -1.0], v.numpy())
|
||||
|
||||
def testScatterUpdateVariant(self):
|
||||
with context.eager_mode():
|
||||
v = resource_variable_ops.ResourceVariable([
|
||||
list_ops.empty_tensor_list(
|
||||
element_dtype=dtypes.float32, element_shape=[])
|
||||
])
|
||||
v.scatter_update(
|
||||
ops.IndexedSlices(
|
||||
list_ops.tensor_list_from_tensor([1., 2.], element_shape=[]), 0))
|
||||
self.assertAllEqual(
|
||||
list_ops.tensor_list_get_item(v[0], 0, element_dtype=dtypes.float32),
|
||||
1.)
|
||||
|
||||
def testScatterNdAddStateOps(self):
|
||||
with context.eager_mode():
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
|
Loading…
x
Reference in New Issue
Block a user