Refactor training helper functions to separate library.

Change: 154613254
This commit is contained in:
A. Unique TensorFlower 2017-04-28 20:56:13 -08:00 committed by TensorFlower Gardener
parent dae9329b0a
commit 2d264f38fd
6 changed files with 294 additions and 158 deletions

View File

@ -4,6 +4,7 @@ tensorflow/core/kernels/variable_ops.cc
tensorflow/core/kernels/unpack_op.cc
tensorflow/core/kernels/transpose_op.cc
tensorflow/core/kernels/transpose_functor_cpu.cc
tensorflow/core/kernels/training_op_helpers.cc
tensorflow/core/kernels/training_ops.cc
tensorflow/core/kernels/topk_op.cc
tensorflow/core/kernels/tile_ops.cc

View File

@ -382,6 +382,19 @@ cc_library(
],
)
cc_library(
name = "training_op_helpers",
srcs = ["training_op_helpers.cc"],
hdrs = ["training_op_helpers.h"],
visibility = [":friends"],
deps = [
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
)
cc_library(
name = "bounds_check",
hdrs = ["bounds_check.h"],
@ -3498,6 +3511,7 @@ tf_kernel_library(
prefix = "training_ops",
deps = [
":bounds_check",
":training_op_helpers",
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -4039,6 +4053,7 @@ filegroup(
"tensor_array.h",
"tile_ops_cpu_impl.h",
"tile_ops_impl.h",
"training_op_helpers.h",
"training_ops.h",
"transpose_functor.h",
"transpose_op.h",
@ -4176,6 +4191,7 @@ filegroup(
"tile_ops_cpu_impl_6.cc",
"tile_ops_cpu_impl_7.cc",
"topk_op.cc",
"training_op_helpers.cc",
"training_ops.cc",
"transpose_functor_cpu.cc",
"transpose_op.cc",

View File

@ -0,0 +1,96 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
if (ctx->input_dtype(input) == DT_RESOURCE) {
Var* var;
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
return var->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
return nullptr;
}
}
return ctx->input_ref_mutex(input);
}
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
// in address order to mitigate deadlock. Returns a vector of acquired mutexes.
// Safe to pass duplicates - will only lock each distinct mutex once. If
// do_lock is false, returns immediately. Note that this silently doesn't lock
// mutexes for invalid variable references; in all usages this is followed by
// GetInputTensor which will signal a failure.
std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
std::vector<mutex_lock> locks;
if (!do_lock) {
return locks;
}
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
mutex* mutex = GetTrainingVariableMutex(ctx, input);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
mutexes.push_back(mutex);
}
}
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
for (auto input : acquire_order) {
mutex* mu = GetTrainingVariableMutex(ctx, input);
if (mu != nullptr) {
locks.emplace_back(*mu);
}
}
return locks;
}
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
bool lock_held, Tensor* out) {
if (ctx->input_dtype(input) == DT_RESOURCE) {
Var* var;
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
if (lock_held) {
*out = *var->tensor();
} else {
mutex_lock ml(*var->mu());
*out = *var->tensor();
}
return Status::OK();
} else {
return errors::Internal("Invalid variable reference.");
}
}
*out = ctx->mutable_input(input, lock_held);
return Status::OK();
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
int output) {
if (ctx->input_dtype(input) != DT_RESOURCE) {
ctx->forward_ref_input_to_ref_output(input, output);
}
}
} // end namespace tensorflow

View File

@ -0,0 +1,36 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
#define TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input);
std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
bool lock_held, Tensor* out);
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
int output);
} // end namespace tensorflow
#endif // TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
@ -294,80 +295,6 @@ struct ApplyCenteredRMSProp<CPUDevice, T> {
} // namespace functor
mutex* GetMutex(OpKernelContext* ctx, int input) {
if (ctx->input_dtype(input) == DT_RESOURCE) {
Var* var;
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
return var->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
return nullptr;
}
}
return ctx->input_ref_mutex(input);
}
// MaybeLockMutexesInOrder is a helper function to acquire mutexes in address
// order to mitigate deadlock. Returns a vector of acquired mutexes. Safe to
// pass duplicates - will only lock each distinct mutex once. If do_lock is
// false, returns immediately. Note that this silently doesn't lock mutexes for
// invalid variable references; in all usages this is followed by GetInputTensor
// which will signal a failure.
std::vector<mutex_lock> MaybeLockMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
std::vector<mutex_lock> locks;
if (!do_lock) {
return locks;
}
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
mutex* mutex = GetMutex(ctx, input);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(input);
mutexes.push_back(mutex);
}
}
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
for (auto input : acquire_order) {
mutex* mu = GetMutex(ctx, input);
if (mu != nullptr) {
locks.emplace_back(*mu);
}
}
return locks;
}
Status GetInputTensor(OpKernelContext* ctx, int input, bool lock_held,
Tensor* out) {
if (ctx->input_dtype(input) == DT_RESOURCE) {
Var* var;
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
if (lock_held) {
*out = *var->tensor();
} else {
mutex_lock ml(*var->mu());
*out = *var->tensor();
}
return Status::OK();
} else {
return errors::Internal("Invalid variable reference.");
}
}
*out = ctx->mutable_input(input, lock_held);
return Status::OK();
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
int output) {
if (ctx->input_dtype(input) != DT_RESOURCE) {
ctx->forward_ref_input_to_ref_output(input, output);
}
}
template <typename Device, typename T>
class ApplyGradientDescentOp : public OpKernel {
@ -377,9 +304,11 @@ class ApplyGradientDescentOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES(
ctx, var.IsInitialized(),
@ -461,7 +390,7 @@ class ApplyAdadeltaOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
if (use_exclusive_lock_) {
mutex_lock l1(*GetMutex(ctx, 0));
mutex_lock l1(*GetTrainingVariableMutex(ctx, 0));
// Don't try to acquire a lock on the second ref as they share the same
// mutex.
//
@ -482,12 +411,14 @@ class ApplyAdadeltaOp : public OpKernel {
void DoValidate(OpKernelContext* ctx) {
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
Tensor accum_update;
OP_REQUIRES_OK(ctx,
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
&accum_update));
OP_REQUIRES(
ctx, var.IsInitialized(),
@ -534,12 +465,14 @@ class ApplyAdadeltaOp : public OpKernel {
void DoCompute(OpKernelContext* ctx) {
const Device& device = ctx->template eigen_device<Device>();
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
Tensor accum_update;
OP_REQUIRES_OK(ctx,
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
&accum_update));
const Tensor& lr = ctx->input(3);
const Tensor& rho = ctx->input(4);
@ -606,7 +539,7 @@ class SparseApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
mutex* mu_var = GetMutex(ctx, 0);
mutex* mu_var = GetTrainingVariableMutex(ctx, 0);
// mu_accum is actually the same mutex as mu_var since currently we use a
// global mutex.
//
@ -615,13 +548,14 @@ class SparseApplyAdadeltaOp : public OpKernel {
mu_var->lock();
}
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum_grad;
OP_REQUIRES_OK(ctx,
GetInputTensor(ctx, 1, use_exclusive_lock_, &accum_grad));
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_,
&accum_grad));
Tensor accum_update;
OP_REQUIRES_OK(ctx,
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
&accum_update));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -756,9 +690,11 @@ class ApplyProximalGradientDescentOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES(
ctx, var.IsInitialized(),
@ -823,9 +759,11 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
errors::InvalidArgument("var must be at least 1 dimensional"));
@ -965,11 +903,14 @@ class ApplyAdagradOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -1055,11 +996,14 @@ class ApplyProximalAdagradOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -1159,11 +1103,14 @@ class SparseApplyAdagradOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -1290,11 +1237,14 @@ class SparseApplyProximalAdagradOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -1459,15 +1409,17 @@ class ApplyAdagradDAOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
Tensor gradient_accum;
OP_REQUIRES_OK(
ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor gradient_accum;
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_,
&gradient_accum));
Tensor gradient_squared_accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
&gradient_squared_accum));
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
&gradient_squared_accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -1559,15 +1511,17 @@ class SparseApplyAdagradDAOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
Tensor gradient_accum;
OP_REQUIRES_OK(
ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor gradient_accum;
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_,
&gradient_accum));
Tensor gradient_squared_accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
&gradient_squared_accum));
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
&gradient_squared_accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -1753,14 +1707,18 @@ class ApplyFtrlOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
{0, 1, 2});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
Tensor linear;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &linear));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -1864,13 +1822,17 @@ class SparseApplyFtrlOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
{0, 1, 2});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
Tensor linear;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &linear));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -2070,12 +2032,15 @@ class ApplyMomentumOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -2170,12 +2135,15 @@ class SparseApplyMomentumOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
auto locks =
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor accum;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -2283,14 +2251,18 @@ class ApplyAdamOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
{0, 1, 2});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor m;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &m));
OP_REQUIRES_OK(ctx,
GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &m));
Tensor v;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &v));
OP_REQUIRES_OK(ctx,
GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &v));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@ -2423,14 +2395,18 @@ class ApplyRMSPropOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
{0, 1, 2});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor ms;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &ms));
Tensor mom;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &mom));
OP_REQUIRES(
ctx, var.IsInitialized(),
@ -2501,17 +2477,21 @@ class ApplyCenteredRMSPropOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
auto locks =
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
{0, 1, 2, 3});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor mg;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &mg));
Tensor ms;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &ms));
Tensor mom;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 3, use_exclusive_lock_, &mom));
OP_REQUIRES(
ctx, var.IsInitialized(),
@ -2658,14 +2638,18 @@ class SparseApplyRMSPropOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
{0, 1, 2});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor ms;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &ms));
Tensor mom;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &mom));
OP_REQUIRES(
ctx, var.IsInitialized(),
@ -2783,17 +2767,21 @@ class SparseApplyCenteredRMSPropOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks =
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
{0, 1, 2, 3});
Tensor var;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
Tensor mg;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &mg));
Tensor ms;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &ms));
Tensor mom;
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
OP_REQUIRES_OK(
ctx, GetInputTensorFromVariable(ctx, 3, use_exclusive_lock_, &mom));
OP_REQUIRES(
ctx, var.IsInitialized(),

View File

@ -148,7 +148,6 @@ struct ApplyCenteredRMSProp {
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad);
};
} // end namespace functor
} // end namespace tensorflow