Refactor training helper functions to separate library.
Change: 154613254
This commit is contained in:
parent
dae9329b0a
commit
2d264f38fd
@ -4,6 +4,7 @@ tensorflow/core/kernels/variable_ops.cc
|
||||
tensorflow/core/kernels/unpack_op.cc
|
||||
tensorflow/core/kernels/transpose_op.cc
|
||||
tensorflow/core/kernels/transpose_functor_cpu.cc
|
||||
tensorflow/core/kernels/training_op_helpers.cc
|
||||
tensorflow/core/kernels/training_ops.cc
|
||||
tensorflow/core/kernels/topk_op.cc
|
||||
tensorflow/core/kernels/tile_ops.cc
|
||||
|
@ -382,6 +382,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "training_op_helpers",
|
||||
srcs = ["training_op_helpers.cc"],
|
||||
hdrs = ["training_op_helpers.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":variable_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bounds_check",
|
||||
hdrs = ["bounds_check.h"],
|
||||
@ -3498,6 +3511,7 @@ tf_kernel_library(
|
||||
prefix = "training_ops",
|
||||
deps = [
|
||||
":bounds_check",
|
||||
":training_op_helpers",
|
||||
":variable_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -4039,6 +4053,7 @@ filegroup(
|
||||
"tensor_array.h",
|
||||
"tile_ops_cpu_impl.h",
|
||||
"tile_ops_impl.h",
|
||||
"training_op_helpers.h",
|
||||
"training_ops.h",
|
||||
"transpose_functor.h",
|
||||
"transpose_op.h",
|
||||
@ -4176,6 +4191,7 @@ filegroup(
|
||||
"tile_ops_cpu_impl_6.cc",
|
||||
"tile_ops_cpu_impl_7.cc",
|
||||
"topk_op.cc",
|
||||
"training_op_helpers.cc",
|
||||
"training_ops.cc",
|
||||
"transpose_functor_cpu.cc",
|
||||
"transpose_op.cc",
|
||||
|
96
tensorflow/core/kernels/training_op_helpers.cc
Normal file
96
tensorflow/core/kernels/training_op_helpers.cc
Normal file
@ -0,0 +1,96 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
|
||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||
Var* var;
|
||||
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
||||
return var->mu();
|
||||
} else {
|
||||
ctx->CtxFailureWithWarning(
|
||||
errors::Internal("Invalid variable reference."));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return ctx->input_ref_mutex(input);
|
||||
}
|
||||
|
||||
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
|
||||
// in address order to mitigate deadlock. Returns a vector of acquired mutexes.
|
||||
// Safe to pass duplicates - will only lock each distinct mutex once. If
|
||||
// do_lock is false, returns immediately. Note that this silently doesn't lock
|
||||
// mutexes for invalid variable references; in all usages this is followed by
|
||||
// GetInputTensor which will signal a failure.
|
||||
std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
|
||||
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
|
||||
std::vector<mutex_lock> locks;
|
||||
if (!do_lock) {
|
||||
return locks;
|
||||
}
|
||||
std::vector<mutex*> mutexes;
|
||||
std::vector<int> acquire_order;
|
||||
for (auto input : input_ids) {
|
||||
mutex* mutex = GetTrainingVariableMutex(ctx, input);
|
||||
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
|
||||
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
|
||||
acquire_order.push_back(mutexes.size());
|
||||
mutexes.push_back(mutex);
|
||||
}
|
||||
}
|
||||
std::sort(acquire_order.begin(), acquire_order.end(),
|
||||
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
||||
|
||||
for (auto input : acquire_order) {
|
||||
mutex* mu = GetTrainingVariableMutex(ctx, input);
|
||||
if (mu != nullptr) {
|
||||
locks.emplace_back(*mu);
|
||||
}
|
||||
}
|
||||
return locks;
|
||||
}
|
||||
|
||||
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
|
||||
bool lock_held, Tensor* out) {
|
||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||
Var* var;
|
||||
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
||||
if (lock_held) {
|
||||
*out = *var->tensor();
|
||||
} else {
|
||||
mutex_lock ml(*var->mu());
|
||||
*out = *var->tensor();
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Internal("Invalid variable reference.");
|
||||
}
|
||||
}
|
||||
*out = ctx->mutable_input(input, lock_held);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
||||
int output) {
|
||||
if (ctx->input_dtype(input) != DT_RESOURCE) {
|
||||
ctx->forward_ref_input_to_ref_output(input, output);
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
36
tensorflow/core/kernels/training_op_helpers.h
Normal file
36
tensorflow/core/kernels/training_op_helpers.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
|
||||
#define TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input);
|
||||
|
||||
std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
|
||||
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
|
||||
|
||||
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
|
||||
bool lock_held, Tensor* out);
|
||||
|
||||
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
||||
int output);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -294,80 +295,6 @@ struct ApplyCenteredRMSProp<CPUDevice, T> {
|
||||
|
||||
} // namespace functor
|
||||
|
||||
mutex* GetMutex(OpKernelContext* ctx, int input) {
|
||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||
Var* var;
|
||||
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
||||
return var->mu();
|
||||
} else {
|
||||
ctx->CtxFailureWithWarning(
|
||||
errors::Internal("Invalid variable reference."));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return ctx->input_ref_mutex(input);
|
||||
}
|
||||
|
||||
// MaybeLockMutexesInOrder is a helper function to acquire mutexes in address
|
||||
// order to mitigate deadlock. Returns a vector of acquired mutexes. Safe to
|
||||
// pass duplicates - will only lock each distinct mutex once. If do_lock is
|
||||
// false, returns immediately. Note that this silently doesn't lock mutexes for
|
||||
// invalid variable references; in all usages this is followed by GetInputTensor
|
||||
// which will signal a failure.
|
||||
std::vector<mutex_lock> MaybeLockMutexesInOrder(
|
||||
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
|
||||
std::vector<mutex_lock> locks;
|
||||
if (!do_lock) {
|
||||
return locks;
|
||||
}
|
||||
std::vector<mutex*> mutexes;
|
||||
std::vector<int> acquire_order;
|
||||
for (auto input : input_ids) {
|
||||
mutex* mutex = GetMutex(ctx, input);
|
||||
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
|
||||
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
|
||||
acquire_order.push_back(input);
|
||||
mutexes.push_back(mutex);
|
||||
}
|
||||
}
|
||||
std::sort(acquire_order.begin(), acquire_order.end(),
|
||||
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
||||
|
||||
for (auto input : acquire_order) {
|
||||
mutex* mu = GetMutex(ctx, input);
|
||||
if (mu != nullptr) {
|
||||
locks.emplace_back(*mu);
|
||||
}
|
||||
}
|
||||
return locks;
|
||||
}
|
||||
|
||||
Status GetInputTensor(OpKernelContext* ctx, int input, bool lock_held,
|
||||
Tensor* out) {
|
||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||
Var* var;
|
||||
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
||||
if (lock_held) {
|
||||
*out = *var->tensor();
|
||||
} else {
|
||||
mutex_lock ml(*var->mu());
|
||||
*out = *var->tensor();
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Internal("Invalid variable reference.");
|
||||
}
|
||||
}
|
||||
*out = ctx->mutable_input(input, lock_held);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
||||
int output) {
|
||||
if (ctx->input_dtype(input) != DT_RESOURCE) {
|
||||
ctx->forward_ref_input_to_ref_output(input, output);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ApplyGradientDescentOp : public OpKernel {
|
||||
@ -377,9 +304,11 @@ class ApplyGradientDescentOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -461,7 +390,7 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
if (use_exclusive_lock_) {
|
||||
mutex_lock l1(*GetMutex(ctx, 0));
|
||||
mutex_lock l1(*GetTrainingVariableMutex(ctx, 0));
|
||||
// Don't try to acquire a lock on the second ref as they share the same
|
||||
// mutex.
|
||||
//
|
||||
@ -482,12 +411,14 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
|
||||
void DoValidate(OpKernelContext* ctx) {
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
Tensor accum_update;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||
&accum_update));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -534,12 +465,14 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
void DoCompute(OpKernelContext* ctx) {
|
||||
const Device& device = ctx->template eigen_device<Device>();
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
Tensor accum_update;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||
&accum_update));
|
||||
|
||||
const Tensor& lr = ctx->input(3);
|
||||
const Tensor& rho = ctx->input(4);
|
||||
@ -606,7 +539,7 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
mutex* mu_var = GetMutex(ctx, 0);
|
||||
mutex* mu_var = GetTrainingVariableMutex(ctx, 0);
|
||||
// mu_accum is actually the same mutex as mu_var since currently we use a
|
||||
// global mutex.
|
||||
//
|
||||
@ -615,13 +548,14 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
||||
mu_var->lock();
|
||||
}
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum_grad;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensor(ctx, 1, use_exclusive_lock_, &accum_grad));
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_,
|
||||
&accum_grad));
|
||||
Tensor accum_update;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||
&accum_update));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -756,9 +690,11 @@ class ApplyProximalGradientDescentOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -823,9 +759,11 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
|
||||
errors::InvalidArgument("var must be at least 1 dimensional"));
|
||||
|
||||
@ -965,11 +903,14 @@ class ApplyAdagradOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1055,11 +996,14 @@ class ApplyProximalAdagradOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1159,11 +1103,14 @@ class SparseApplyAdagradOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1290,11 +1237,14 @@ class SparseApplyProximalAdagradOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1459,15 +1409,17 @@ class ApplyAdagradDAOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor gradient_accum;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor gradient_accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_,
|
||||
&gradient_accum));
|
||||
Tensor gradient_squared_accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
|
||||
&gradient_squared_accum));
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||
&gradient_squared_accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1559,15 +1511,17 @@ class SparseApplyAdagradDAOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor gradient_accum;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor gradient_accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_,
|
||||
&gradient_accum));
|
||||
Tensor gradient_squared_accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
|
||||
&gradient_squared_accum));
|
||||
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||
&gradient_squared_accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1753,14 +1707,18 @@ class ApplyFtrlOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
Tensor linear;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &linear));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1864,13 +1822,17 @@ class SparseApplyFtrlOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
Tensor linear;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &linear));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2070,12 +2032,15 @@ class ApplyMomentumOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2170,12 +2135,15 @@ class SparseApplyMomentumOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
auto locks =
|
||||
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2283,14 +2251,18 @@ class ApplyAdamOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor m;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &m));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &m));
|
||||
Tensor v;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &v));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &v));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2423,14 +2395,18 @@ class ApplyRMSPropOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -2501,17 +2477,21 @@ class ApplyCenteredRMSPropOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks =
|
||||
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2, 3});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor mg;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &mg));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 3, use_exclusive_lock_, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -2658,14 +2638,18 @@ class SparseApplyRMSPropOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -2783,17 +2767,21 @@ class SparseApplyCenteredRMSPropOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks =
|
||||
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
|
||||
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||
{0, 1, 2, 3});
|
||||
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor mg;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &mg));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensorFromVariable(ctx, 3, use_exclusive_lock_, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
|
@ -148,7 +148,6 @@ struct ApplyCenteredRMSProp {
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad);
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user