Refactor training helper functions to separate library.
Change: 154613254
This commit is contained in:
parent
dae9329b0a
commit
2d264f38fd
@ -4,6 +4,7 @@ tensorflow/core/kernels/variable_ops.cc
|
|||||||
tensorflow/core/kernels/unpack_op.cc
|
tensorflow/core/kernels/unpack_op.cc
|
||||||
tensorflow/core/kernels/transpose_op.cc
|
tensorflow/core/kernels/transpose_op.cc
|
||||||
tensorflow/core/kernels/transpose_functor_cpu.cc
|
tensorflow/core/kernels/transpose_functor_cpu.cc
|
||||||
|
tensorflow/core/kernels/training_op_helpers.cc
|
||||||
tensorflow/core/kernels/training_ops.cc
|
tensorflow/core/kernels/training_ops.cc
|
||||||
tensorflow/core/kernels/topk_op.cc
|
tensorflow/core/kernels/topk_op.cc
|
||||||
tensorflow/core/kernels/tile_ops.cc
|
tensorflow/core/kernels/tile_ops.cc
|
||||||
|
@ -382,6 +382,19 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "training_op_helpers",
|
||||||
|
srcs = ["training_op_helpers.cc"],
|
||||||
|
hdrs = ["training_op_helpers.h"],
|
||||||
|
visibility = [":friends"],
|
||||||
|
deps = [
|
||||||
|
":variable_ops",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "bounds_check",
|
name = "bounds_check",
|
||||||
hdrs = ["bounds_check.h"],
|
hdrs = ["bounds_check.h"],
|
||||||
@ -3498,6 +3511,7 @@ tf_kernel_library(
|
|||||||
prefix = "training_ops",
|
prefix = "training_ops",
|
||||||
deps = [
|
deps = [
|
||||||
":bounds_check",
|
":bounds_check",
|
||||||
|
":training_op_helpers",
|
||||||
":variable_ops",
|
":variable_ops",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -4039,6 +4053,7 @@ filegroup(
|
|||||||
"tensor_array.h",
|
"tensor_array.h",
|
||||||
"tile_ops_cpu_impl.h",
|
"tile_ops_cpu_impl.h",
|
||||||
"tile_ops_impl.h",
|
"tile_ops_impl.h",
|
||||||
|
"training_op_helpers.h",
|
||||||
"training_ops.h",
|
"training_ops.h",
|
||||||
"transpose_functor.h",
|
"transpose_functor.h",
|
||||||
"transpose_op.h",
|
"transpose_op.h",
|
||||||
@ -4176,6 +4191,7 @@ filegroup(
|
|||||||
"tile_ops_cpu_impl_6.cc",
|
"tile_ops_cpu_impl_6.cc",
|
||||||
"tile_ops_cpu_impl_7.cc",
|
"tile_ops_cpu_impl_7.cc",
|
||||||
"topk_op.cc",
|
"topk_op.cc",
|
||||||
|
"training_op_helpers.cc",
|
||||||
"training_ops.cc",
|
"training_ops.cc",
|
||||||
"transpose_functor_cpu.cc",
|
"transpose_functor_cpu.cc",
|
||||||
"transpose_op.cc",
|
"transpose_op.cc",
|
||||||
|
96
tensorflow/core/kernels/training_op_helpers.cc
Normal file
96
tensorflow/core/kernels/training_op_helpers.cc
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||||
|
#include "tensorflow/core/kernels/variable_ops.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
|
||||||
|
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||||
|
Var* var;
|
||||||
|
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
||||||
|
return var->mu();
|
||||||
|
} else {
|
||||||
|
ctx->CtxFailureWithWarning(
|
||||||
|
errors::Internal("Invalid variable reference."));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx->input_ref_mutex(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
|
||||||
|
// in address order to mitigate deadlock. Returns a vector of acquired mutexes.
|
||||||
|
// Safe to pass duplicates - will only lock each distinct mutex once. If
|
||||||
|
// do_lock is false, returns immediately. Note that this silently doesn't lock
|
||||||
|
// mutexes for invalid variable references; in all usages this is followed by
|
||||||
|
// GetInputTensor which will signal a failure.
|
||||||
|
std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
|
||||||
|
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
|
||||||
|
std::vector<mutex_lock> locks;
|
||||||
|
if (!do_lock) {
|
||||||
|
return locks;
|
||||||
|
}
|
||||||
|
std::vector<mutex*> mutexes;
|
||||||
|
std::vector<int> acquire_order;
|
||||||
|
for (auto input : input_ids) {
|
||||||
|
mutex* mutex = GetTrainingVariableMutex(ctx, input);
|
||||||
|
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
|
||||||
|
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
|
||||||
|
acquire_order.push_back(mutexes.size());
|
||||||
|
mutexes.push_back(mutex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::sort(acquire_order.begin(), acquire_order.end(),
|
||||||
|
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
||||||
|
|
||||||
|
for (auto input : acquire_order) {
|
||||||
|
mutex* mu = GetTrainingVariableMutex(ctx, input);
|
||||||
|
if (mu != nullptr) {
|
||||||
|
locks.emplace_back(*mu);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return locks;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
|
||||||
|
bool lock_held, Tensor* out) {
|
||||||
|
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||||
|
Var* var;
|
||||||
|
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
||||||
|
if (lock_held) {
|
||||||
|
*out = *var->tensor();
|
||||||
|
} else {
|
||||||
|
mutex_lock ml(*var->mu());
|
||||||
|
*out = *var->tensor();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
return errors::Internal("Invalid variable reference.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*out = ctx->mutable_input(input, lock_held);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
||||||
|
int output) {
|
||||||
|
if (ctx->input_dtype(input) != DT_RESOURCE) {
|
||||||
|
ctx->forward_ref_input_to_ref_output(input, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
36
tensorflow/core/kernels/training_op_helpers.h
Normal file
36
tensorflow/core/kernels/training_op_helpers.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
|
||||||
|
#define TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input);
|
||||||
|
|
||||||
|
std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
|
||||||
|
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
|
||||||
|
|
||||||
|
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
|
||||||
|
bool lock_held, Tensor* out);
|
||||||
|
|
||||||
|
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
||||||
|
int output);
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||||
#include "tensorflow/core/kernels/variable_ops.h"
|
#include "tensorflow/core/kernels/variable_ops.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -294,80 +295,6 @@ struct ApplyCenteredRMSProp<CPUDevice, T> {
|
|||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
mutex* GetMutex(OpKernelContext* ctx, int input) {
|
|
||||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
|
||||||
Var* var;
|
|
||||||
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
|
||||||
return var->mu();
|
|
||||||
} else {
|
|
||||||
ctx->CtxFailureWithWarning(
|
|
||||||
errors::Internal("Invalid variable reference."));
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ctx->input_ref_mutex(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaybeLockMutexesInOrder is a helper function to acquire mutexes in address
|
|
||||||
// order to mitigate deadlock. Returns a vector of acquired mutexes. Safe to
|
|
||||||
// pass duplicates - will only lock each distinct mutex once. If do_lock is
|
|
||||||
// false, returns immediately. Note that this silently doesn't lock mutexes for
|
|
||||||
// invalid variable references; in all usages this is followed by GetInputTensor
|
|
||||||
// which will signal a failure.
|
|
||||||
std::vector<mutex_lock> MaybeLockMutexesInOrder(
|
|
||||||
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
|
|
||||||
std::vector<mutex_lock> locks;
|
|
||||||
if (!do_lock) {
|
|
||||||
return locks;
|
|
||||||
}
|
|
||||||
std::vector<mutex*> mutexes;
|
|
||||||
std::vector<int> acquire_order;
|
|
||||||
for (auto input : input_ids) {
|
|
||||||
mutex* mutex = GetMutex(ctx, input);
|
|
||||||
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
|
|
||||||
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
|
|
||||||
acquire_order.push_back(input);
|
|
||||||
mutexes.push_back(mutex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::sort(acquire_order.begin(), acquire_order.end(),
|
|
||||||
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
|
||||||
|
|
||||||
for (auto input : acquire_order) {
|
|
||||||
mutex* mu = GetMutex(ctx, input);
|
|
||||||
if (mu != nullptr) {
|
|
||||||
locks.emplace_back(*mu);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return locks;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetInputTensor(OpKernelContext* ctx, int input, bool lock_held,
|
|
||||||
Tensor* out) {
|
|
||||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
|
||||||
Var* var;
|
|
||||||
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
|
||||||
if (lock_held) {
|
|
||||||
*out = *var->tensor();
|
|
||||||
} else {
|
|
||||||
mutex_lock ml(*var->mu());
|
|
||||||
*out = *var->tensor();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
} else {
|
|
||||||
return errors::Internal("Invalid variable reference.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*out = ctx->mutable_input(input, lock_held);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
|
||||||
int output) {
|
|
||||||
if (ctx->input_dtype(input) != DT_RESOURCE) {
|
|
||||||
ctx->forward_ref_input_to_ref_output(input, output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class ApplyGradientDescentOp : public OpKernel {
|
class ApplyGradientDescentOp : public OpKernel {
|
||||||
@ -377,9 +304,11 @@ class ApplyGradientDescentOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
@ -461,7 +390,7 @@ class ApplyAdadeltaOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
if (use_exclusive_lock_) {
|
if (use_exclusive_lock_) {
|
||||||
mutex_lock l1(*GetMutex(ctx, 0));
|
mutex_lock l1(*GetTrainingVariableMutex(ctx, 0));
|
||||||
// Don't try to acquire a lock on the second ref as they share the same
|
// Don't try to acquire a lock on the second ref as they share the same
|
||||||
// mutex.
|
// mutex.
|
||||||
//
|
//
|
||||||
@ -482,12 +411,14 @@ class ApplyAdadeltaOp : public OpKernel {
|
|||||||
|
|
||||||
void DoValidate(OpKernelContext* ctx) {
|
void DoValidate(OpKernelContext* ctx) {
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
Tensor accum_update;
|
Tensor accum_update;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||||
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
|
&accum_update));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
@ -534,12 +465,14 @@ class ApplyAdadeltaOp : public OpKernel {
|
|||||||
void DoCompute(OpKernelContext* ctx) {
|
void DoCompute(OpKernelContext* ctx) {
|
||||||
const Device& device = ctx->template eigen_device<Device>();
|
const Device& device = ctx->template eigen_device<Device>();
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
Tensor accum_update;
|
Tensor accum_update;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||||
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
|
&accum_update));
|
||||||
|
|
||||||
const Tensor& lr = ctx->input(3);
|
const Tensor& lr = ctx->input(3);
|
||||||
const Tensor& rho = ctx->input(4);
|
const Tensor& rho = ctx->input(4);
|
||||||
@ -606,7 +539,7 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
mutex* mu_var = GetMutex(ctx, 0);
|
mutex* mu_var = GetTrainingVariableMutex(ctx, 0);
|
||||||
// mu_accum is actually the same mutex as mu_var since currently we use a
|
// mu_accum is actually the same mutex as mu_var since currently we use a
|
||||||
// global mutex.
|
// global mutex.
|
||||||
//
|
//
|
||||||
@ -615,13 +548,14 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
|||||||
mu_var->lock();
|
mu_var->lock();
|
||||||
}
|
}
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum_grad;
|
Tensor accum_grad;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_,
|
||||||
GetInputTensor(ctx, 1, use_exclusive_lock_, &accum_grad));
|
&accum_grad));
|
||||||
Tensor accum_update;
|
Tensor accum_update;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||||
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
|
&accum_update));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -756,9 +690,11 @@ class ApplyProximalGradientDescentOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
@ -823,9 +759,11 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
|
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
|
||||||
errors::InvalidArgument("var must be at least 1 dimensional"));
|
errors::InvalidArgument("var must be at least 1 dimensional"));
|
||||||
|
|
||||||
@ -965,11 +903,14 @@ class ApplyAdagradOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -1055,11 +996,14 @@ class ApplyProximalAdagradOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -1159,11 +1103,14 @@ class SparseApplyAdagradOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -1290,11 +1237,14 @@ class SparseApplyProximalAdagradOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -1459,15 +1409,17 @@ class ApplyAdagradDAOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
|
||||||
Tensor gradient_accum;
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
|
Tensor gradient_accum;
|
||||||
|
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_,
|
||||||
|
&gradient_accum));
|
||||||
Tensor gradient_squared_accum;
|
Tensor gradient_squared_accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
|
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||||
&gradient_squared_accum));
|
&gradient_squared_accum));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -1559,15 +1511,17 @@ class SparseApplyAdagradDAOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
|
||||||
Tensor gradient_accum;
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
|
Tensor gradient_accum;
|
||||||
|
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_,
|
||||||
|
&gradient_accum));
|
||||||
Tensor gradient_squared_accum;
|
Tensor gradient_squared_accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
|
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_,
|
||||||
&gradient_squared_accum));
|
&gradient_squared_accum));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -1753,14 +1707,18 @@ class ApplyFtrlOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||||
|
{0, 1, 2});
|
||||||
|
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
Tensor linear;
|
Tensor linear;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &linear));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -1864,13 +1822,17 @@ class SparseApplyFtrlOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||||
|
{0, 1, 2});
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
Tensor linear;
|
Tensor linear;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &linear));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -2070,12 +2032,15 @@ class ApplyMomentumOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||||
|
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -2170,12 +2135,15 @@ class SparseApplyMomentumOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
auto locks =
|
||||||
|
MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||||
|
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor accum;
|
Tensor accum;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &accum));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -2283,14 +2251,18 @@ class ApplyAdamOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||||
|
{0, 1, 2});
|
||||||
|
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor m;
|
Tensor m;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &m));
|
OP_REQUIRES_OK(ctx,
|
||||||
|
GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &m));
|
||||||
Tensor v;
|
Tensor v;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &v));
|
OP_REQUIRES_OK(ctx,
|
||||||
|
GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &v));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
@ -2423,14 +2395,18 @@ class ApplyRMSPropOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||||
|
{0, 1, 2});
|
||||||
|
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor ms;
|
Tensor ms;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &ms));
|
||||||
Tensor mom;
|
Tensor mom;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &mom));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
@ -2501,17 +2477,21 @@ class ApplyCenteredRMSPropOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto locks =
|
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||||
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
|
{0, 1, 2, 3});
|
||||||
|
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor mg;
|
Tensor mg;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &mg));
|
||||||
Tensor ms;
|
Tensor ms;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &ms));
|
||||||
Tensor mom;
|
Tensor mom;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 3, use_exclusive_lock_, &mom));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
@ -2658,14 +2638,18 @@ class SparseApplyRMSPropOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||||
|
{0, 1, 2});
|
||||||
|
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor ms;
|
Tensor ms;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &ms));
|
||||||
Tensor mom;
|
Tensor mom;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &mom));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
@ -2783,17 +2767,21 @@ class SparseApplyCenteredRMSPropOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||||
auto locks =
|
auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
|
||||||
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
|
{0, 1, 2, 3});
|
||||||
|
|
||||||
Tensor var;
|
Tensor var;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var));
|
||||||
Tensor mg;
|
Tensor mg;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &mg));
|
||||||
Tensor ms;
|
Tensor ms;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &ms));
|
||||||
Tensor mom;
|
Tensor mom;
|
||||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInputTensorFromVariable(ctx, 3, use_exclusive_lock_, &mom));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var.IsInitialized(),
|
ctx, var.IsInitialized(),
|
||||||
|
@ -148,7 +148,6 @@ struct ApplyCenteredRMSProp {
|
|||||||
typename TTypes<T>::ConstScalar epsilon,
|
typename TTypes<T>::ConstScalar epsilon,
|
||||||
typename TTypes<T>::ConstFlat grad);
|
typename TTypes<T>::ConstFlat grad);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace functor
|
} // end namespace functor
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user