Use RefCountPtr in LookupResource to avoid leaks

LookupResource returns a raw pointer which the caller needs to Unref.
The prevalent pattern is this is followed by a ScopedUnref. This can be
problematic, since if a caller forgets to add a ScopedUnref call, we
have a memory leak. We resolve this by using RefCountPtr instead of a
raw pointer in LookupResource. Most use cases have been migrated in this
change.

Note some variables were renamed to handle line length restrictions.

PiperOrigin-RevId: 250423227
This commit is contained in:
Gaurav Jain 2019-05-28 21:53:42 -07:00 committed by TensorFlower Gardener
parent 4941b4a73f
commit 2e758833c3
43 changed files with 372 additions and 421 deletions

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
@ -31,12 +32,11 @@ std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
std::map<int, OptionalTensor> variables; std::map<int, OptionalTensor> variables;
for (int64 i = 0; i < ctx->num_inputs(); ++i) { for (int64 i = 0; i < ctx->num_inputs(); ++i) {
if (ctx->input(i).dtype() == DT_RESOURCE) { if (ctx->input(i).dtype() == DT_RESOURCE) {
Var* variable = nullptr; core::RefCountPtr<Var> variable;
ResourceHandle handle = HandleFromInput(ctx, i); ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& optional = variables[i]; OptionalTensor& optional = variables[i];
optional.name = handle.name(); optional.name = handle.name();
if (LookupResource(ctx, handle, &variable).ok()) { if (LookupResource(ctx, handle, &variable).ok()) {
core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu()); tf_shared_lock lock(*variable->mu());
optional.present = true; optional.present = true;
optional.value = *variable->tensor(); optional.value = *variable->tensor();

View File

@ -40,7 +40,7 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
"Variable and value dtypes don't match; respectively, ", "Variable and value dtypes don't match; respectively, ",
DataTypeString(dtype_), " and ", DataTypeString(dtype_), " and ",
DataTypeString(context->input(1).dtype()))); DataTypeString(context->input(1).dtype())));
Var* variable = nullptr; core::RefCountPtr<Var> variable;
const Tensor& value = context->input(1); const Tensor& value = context->input(1);
// Note: every resource-variable-manipulating op assumes copy-on-write // Note: every resource-variable-manipulating op assumes copy-on-write
// semantics, and creates a copy of the variable's Tensor if its refcount is // semantics, and creates a copy of the variable's Tensor if its refcount is
@ -58,7 +58,6 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
(*ptr)->is_initialized = true; (*ptr)->is_initialized = true;
return Status::OK(); return Status::OK();
})); }));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu()); mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument( errors::InvalidArgument(

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/util/stream_executor_util.h" #include "tensorflow/core/util/stream_executor_util.h"
namespace tensorflow { namespace tensorflow {
@ -86,7 +87,7 @@ static Status GetVariableInfosFromCtxInputs(
variable_indices, std::back_inserter(resource_handles), variable_indices, std::back_inserter(resource_handles),
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); }); [&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables; std::vector<core::RefCountPtr<Var>> variables;
TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables)); TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables));
result->clear(); result->clear();

View File

@ -84,9 +84,8 @@ class PopulateTRTEngineCache : public OpKernel {
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
ResourceHandle handle = HandleFromInput(ctx, 0); ResourceHandle handle = HandleFromInput(ctx, 0);
TRTEngineCacheResource* resource = nullptr; core::RefCountPtr<TRTEngineCacheResource> resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource)); OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource));
core::ScopedUnref unref_me(resource);
auto allocator = resource->allocator_.get(); auto allocator = resource->allocator_.get();
OP_REQUIRES(ctx, allocator != nullptr, OP_REQUIRES(ctx, allocator != nullptr,

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow { namespace tensorflow {
@ -139,19 +140,19 @@ class BigtableTableOp : public OpKernel {
ResourceMgr* mgr = ctx->resource_manager(); ResourceMgr* mgr = ctx->resource_manager();
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
BigtableClientResource* client_resource; core::RefCountPtr<BigtableClientResource> client_resource;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
core::ScopedUnref unref_client(client_resource);
BigtableTableResource* resource; BigtableTableResource* resource;
OP_REQUIRES_OK( OP_REQUIRES_OK(ctx,
ctx, mgr->LookupOrCreate<BigtableTableResource>( mgr->LookupOrCreate<BigtableTableResource>(
cinfo_.container(), cinfo_.name(), &resource, cinfo_.container(), cinfo_.name(), &resource,
[this, client_resource](BigtableTableResource** ret) { [this, &client_resource](BigtableTableResource** ret) {
*ret = new BigtableTableResource(client_resource, table_); *ret = new BigtableTableResource(
return Status::OK(); client_resource.get(), table_);
})); return Status::OK();
}));
initialized_ = true; initialized_ = true;
} }
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
@ -236,10 +237,9 @@ class ToBigtableOp : public AsyncOpKernel {
errors::InvalidArgument("timestamp must be >= -1"), errors::InvalidArgument("timestamp must be >= -1"),
done); done);
BigtableTableResource* resource; core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK_ASYNC( OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done); ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
core::ScopedUnref resource_cleanup(resource);
std::vector<Tensor> components; std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size()); components.reserve(dataset->output_dtypes().size());

View File

@ -26,9 +26,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input, void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override { DatasetBase** output) override {
BigtableTableResource* table; core::RefCountPtr<BigtableTableResource> table;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
core::ScopedUnref scoped_unref(table);
std::vector<string> column_families; std::vector<string> column_families;
std::vector<string> columns; std::vector<string> columns;
@ -50,7 +49,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
} }
*output = *output =
new Dataset(ctx, input, table, std::move(column_families), new Dataset(ctx, input, table.get(), std::move(column_families),
std::move(columns), output_types, std::move(output_shapes)); std::move(columns), output_types, std::move(output_shapes));
} }

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
@ -28,12 +29,10 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
string prefix; string prefix;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix)); OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
BigtableTableResource* resource; core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource); *output = new Dataset(ctx, resource.get(), std::move(prefix));
*output = new Dataset(ctx, resource, std::move(prefix));
} }
private: private:

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
@ -31,13 +32,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
string end_key; string end_key;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key)); OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
BigtableTableResource* resource; core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource); *output = new Dataset(ctx, resource.get(), std::move(start_key),
std::move(end_key));
*output =
new Dataset(ctx, resource, std::move(start_key), std::move(end_key));
} }
private: private:

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h" #include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
@ -35,10 +36,9 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
string end_key; string end_key;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key)); OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
BigtableTableResource* resource; core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource);
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
errors::InvalidArgument( errors::InvalidArgument(
@ -49,7 +49,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
"If prefix is specified, end_key must be empty.")); "If prefix is specified, end_key must be empty."));
} }
*output = new Dataset(ctx, resource, std::move(prefix), *output = new Dataset(ctx, resource.get(), std::move(prefix),
std::move(start_key), std::move(end_key)); std::move(start_key), std::move(end_key));
} }

View File

@ -25,11 +25,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
using DatasetOpKernel::DatasetOpKernel; using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
BigtableTableResource* resource; core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource); *output = new Dataset(ctx, resource.get());
*output = new Dataset(ctx, resource);
} }
private: private:

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
@ -64,10 +65,9 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
errors::InvalidArgument( errors::InvalidArgument(
"Probability outside the range of (0, 1]. Got: ", probability)); "Probability outside the range of (0, 1]. Got: ", probability));
BigtableTableResource* resource; core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource);
const uint64 num_outputs = columns.size() + 1; const uint64 num_outputs = columns.size() + 1;
std::vector<PartialTensorShape> output_shapes; std::vector<PartialTensorShape> output_shapes;
@ -79,7 +79,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
output_types.push_back(DT_STRING); output_types.push_back(DT_STRING);
} }
*output = new Dataset(ctx, resource, std::move(prefix), *output = new Dataset(ctx, resource.get(), std::move(prefix),
std::move(start_key), std::move(end_key), std::move(start_key), std::move(end_key),
std::move(column_families), std::move(columns), std::move(column_families), std::move(columns),
probability, output_types, std::move(output_shapes)); probability, output_types, std::move(output_shapes));

View File

@ -21,6 +21,7 @@
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow { namespace tensorflow {
@ -44,7 +45,7 @@ class CreateTreeEnsembleVariableOp : public OpKernel {
const Tensor* tree_ensemble_config_t; const Tensor* tree_ensemble_config_t;
OP_REQUIRES_OK(context, context->input("tree_ensemble_config", OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
&tree_ensemble_config_t)); &tree_ensemble_config_t));
auto* result = new boosted_trees::models::DecisionTreeEnsembleResource(); auto* result = new DecisionTreeEnsembleResource();
if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<string>()(), if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<string>()(),
stamp_token)) { stamp_token)) {
result->Unref(); result->Unref();
@ -69,11 +70,10 @@ class TreeEnsembleStampTokenOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex()); tf_shared_lock l(*ensemble_resource->get_mutex());
core::ScopedUnref unref_me(ensemble_resource);
Tensor* output_stamp_token_t = nullptr; Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
&output_stamp_token_t)); &output_stamp_token_t));
@ -88,11 +88,10 @@ class TreeEnsembleSerializeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex()); tf_shared_lock l(*ensemble_resource->get_mutex());
core::ScopedUnref unref_me(ensemble_resource);
Tensor* output_stamp_token_t = nullptr; Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
&output_stamp_token_t)); &output_stamp_token_t));
@ -112,11 +111,10 @@ class TreeEnsembleDeserializeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
mutex_lock l(*ensemble_resource->get_mutex()); mutex_lock l(*ensemble_resource->get_mutex());
core::ScopedUnref unref_me(ensemble_resource);
// Get the stamp token. // Get the stamp token.
const Tensor* stamp_token_t; const Tensor* stamp_token_t;
@ -146,12 +144,11 @@ class TreeEnsembleUsedHandlersOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex()); tf_shared_lock l(*ensemble_resource->get_mutex());
core::ScopedUnref unref_me(ensemble_resource);
// Get the stamp token. // Get the stamp token.
const Tensor* stamp_token_t; const Tensor* stamp_token_t;
@ -194,9 +191,8 @@ class TreeEnsembleUsedHandlersOp : public OpKernel {
REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource); REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU),
Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU), IsResourceInitialized<DecisionTreeEnsembleResource>);
IsResourceInitialized<boosted_trees::models::DecisionTreeEnsembleResource>);
REGISTER_KERNEL_BUILDER(Name("CreateTreeEnsembleVariable").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("CreateTreeEnsembleVariable").Device(DEVICE_CPU),
CreateTreeEnsembleVariableOp); CreateTreeEnsembleVariableOp);

View File

@ -163,12 +163,11 @@ class GradientTreesPredictionOp : public OpKernel {
} }
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
DecisionTreeEnsembleResource* ensemble_resource; core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
// Gets the resource. Grabs the mutex but releases it. // Gets the resource. Grabs the mutex but releases it.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
// Release the reference to the resource once we're done using it. // Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(ensemble_resource);
if (use_locking_) { if (use_locking_) {
tf_shared_lock l(*ensemble_resource->get_mutex()); tf_shared_lock l(*ensemble_resource->get_mutex());
DoCompute(context, ensemble_resource, DoCompute(context, ensemble_resource,
@ -184,9 +183,10 @@ class GradientTreesPredictionOp : public OpKernel {
// leaf index in prediction. Though this class invokes only with this param // leaf index in prediction. Though this class invokes only with this param
// value as false, the subclass GradientTreesPredictionVerboseOp will invoke // value as false, the subclass GradientTreesPredictionVerboseOp will invoke
// with the true value. // with the true value.
virtual void DoCompute(OpKernelContext* context, virtual void DoCompute(
DecisionTreeEnsembleResource* ensemble_resource, OpKernelContext* context,
const bool return_output_leaf_index) { const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
const bool return_output_leaf_index) {
// Read dense float features list; // Read dense float features list;
OpInputList dense_float_features_list; OpInputList dense_float_features_list;
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
@ -352,9 +352,10 @@ class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp {
: GradientTreesPredictionOp(context) {} : GradientTreesPredictionOp(context) {}
protected: protected:
void DoCompute(OpKernelContext* context, void DoCompute(
DecisionTreeEnsembleResource* ensemble_resource, OpKernelContext* context,
bool return_output_leaf_index) override { const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
bool return_output_leaf_index) override {
GradientTreesPredictionOp::DoCompute(context, ensemble_resource, GradientTreesPredictionOp::DoCompute(context, ensemble_resource,
/*return_output_leaf_index=*/true); /*return_output_leaf_index=*/true);
} }
@ -372,12 +373,10 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
} }
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
DecisionTreeEnsembleResource* ensemble_resource; core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
// Gets the resource. Grabs the mutex but releases it. // Gets the resource. Grabs the mutex but releases it.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(ensemble_resource);
if (use_locking_) { if (use_locking_) {
tf_shared_lock l(*ensemble_resource->get_mutex()); tf_shared_lock l(*ensemble_resource->get_mutex());
DoCompute(context, ensemble_resource); DoCompute(context, ensemble_resource);
@ -387,17 +386,18 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
} }
private: private:
void DoCompute(OpKernelContext* context, void DoCompute(
DecisionTreeEnsembleResource* ensemble_resource) { OpKernelContext* context,
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
// The last non-finalized tree in the ensemble is by convention the // The last non-finalized tree in the ensemble is by convention the
// one to partition on. If no such tree exists, a nodeless tree is // one to partition on. If no such tree exists, a nodeless tree is
// created. // created.
boosted_trees::trees::DecisionTreeConfig empty_tree_config; boosted_trees::trees::DecisionTreeConfig empty_tree_config;
const boosted_trees::trees::DecisionTreeConfig& tree_config = const boosted_trees::trees::DecisionTreeConfig& tree_config =
(ensemble_resource->num_trees() <= 0 || (resource->num_trees() <= 0 ||
ensemble_resource->LastTreeMetadata()->is_finalized()) resource->LastTreeMetadata()->is_finalized())
? empty_tree_config ? empty_tree_config
: *ensemble_resource->LastTree(); : *resource->LastTree();
// Read dense float features list; // Read dense float features list;
OpInputList dense_float_features_list; OpInputList dense_float_features_list;

View File

@ -28,6 +28,7 @@
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -299,13 +300,12 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel {
const ResourceHandle& handle = const ResourceHandle& handle =
resource_handle_list[resource_handle_idx] resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0); .flat<ResourceHandle>()(0);
QuantileStreamResource* streams_resource; core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle. // Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
LookupResource(context, handle, &streams_resource)); LookupResource(context, handle, &streams_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex()); mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
// If the stamp is invalid we drop the update. // If the stamp is invalid we drop the update.
if (!streams_resource->is_stamp_valid(stamp_token)) { if (!streams_resource->is_stamp_valid(stamp_token)) {
@ -467,13 +467,12 @@ class QuantileAccumulatorSerializeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource; core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle. // Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource)); &streams_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex()); mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
int64 stamp_token = streams_resource->stamp(); int64 stamp_token = streams_resource->stamp();
Tensor* stream_state_t; Tensor* stream_state_t;
@ -526,13 +525,12 @@ class QuantileAccumulatorDeserializeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource; core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle. // Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource)); &streams_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex()); mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
int64 old_stamp_token = streams_resource->stamp(); int64 old_stamp_token = streams_resource->stamp();
@ -595,13 +593,12 @@ class QuantileAccumulatorFlushOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource; core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle. // Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource)); &streams_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex()); mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
const Tensor* next_stamp_token_t; const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
@ -641,13 +638,12 @@ class QuantileAccumulatorFlushSummaryOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource; core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle. // Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource)); &streams_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex()); mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
const Tensor* next_stamp_token_t; const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
@ -713,12 +709,11 @@ class QuantileAccumulatorGetBucketsOp : public OpKernel {
const ResourceHandle& handle = const ResourceHandle& handle =
resource_handle_list[resource_handle_idx] resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0); .flat<ResourceHandle>()(0);
QuantileStreamResource* streams_resource; core::RefCountPtr<QuantileStreamResource> streams_resource;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
LookupResource(context, handle, &streams_resource)); LookupResource(context, handle, &streams_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex()); mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
bool are_buckets_ready = bool are_buckets_ready =
streams_resource->is_stamp_valid(stamp_token) && streams_resource->is_stamp_valid(stamp_token) &&

View File

@ -27,6 +27,7 @@
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h" #include "tensorflow/core/util/work_sharder.h"
@ -130,9 +131,8 @@ using StatsAccumulatorTensorResource =
StatsAccumulatorResource<std::vector<float>, std::vector<float>>; StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
void SerializeScalarAccumulatorToOutput( void SerializeScalarAccumulatorToOutput(
const StatsAccumulatorScalarResource& accumulator_resource, const StatsAccumulatorScalarResource& resource, OpKernelContext* context) {
OpKernelContext* context) { int64 num_slots = resource.values().size();
int64 num_slots = accumulator_resource.values().size();
Tensor* partition_ids_t = nullptr; Tensor* partition_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
TensorShape({num_slots}), TensorShape({num_slots}),
@ -159,7 +159,7 @@ void SerializeScalarAccumulatorToOutput(
auto hessians = hessians_t->vec<float>(); auto hessians = hessians_t->vec<float>();
int i = 0; int i = 0;
for (const auto& iter : accumulator_resource.values()) { for (const auto& iter : resource.values()) {
partition_ids(i) = iter.first.partition_id; partition_ids(i) = iter.first.partition_id;
feature_ids(i, 0) = iter.first.feature_id; feature_ids(i, 0) = iter.first.feature_id;
feature_ids(i, 1) = iter.first.dimension; feature_ids(i, 1) = iter.first.dimension;
@ -171,9 +171,8 @@ void SerializeScalarAccumulatorToOutput(
} }
void SerializeTensorAccumulatorToOutput( void SerializeTensorAccumulatorToOutput(
const StatsAccumulatorTensorResource& accumulator_resource, const StatsAccumulatorTensorResource& resource, OpKernelContext* context) {
OpKernelContext* context) { int64 num_slots = resource.values().size();
int64 num_slots = accumulator_resource.values().size();
Tensor* partition_ids_t = nullptr; Tensor* partition_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
TensorShape({num_slots}), TensorShape({num_slots}),
@ -186,7 +185,7 @@ void SerializeTensorAccumulatorToOutput(
&feature_ids_t)); &feature_ids_t));
auto feature_ids = feature_ids_t->matrix<int64>(); auto feature_ids = feature_ids_t->matrix<int64>();
TensorShape gradient_shape = accumulator_resource.gradient_shape(); TensorShape gradient_shape = resource.gradient_shape();
int64 num_gradient_elements = gradient_shape.num_elements(); int64 num_gradient_elements = gradient_shape.num_elements();
gradient_shape.InsertDim(0, num_slots); gradient_shape.InsertDim(0, num_slots);
Tensor* gradients_t = nullptr; Tensor* gradients_t = nullptr;
@ -195,7 +194,7 @@ void SerializeTensorAccumulatorToOutput(
&gradients_t)); &gradients_t));
auto gradients = gradients_t->flat_outer_dims<float>(); auto gradients = gradients_t->flat_outer_dims<float>();
TensorShape hessian_shape = accumulator_resource.hessian_shape(); TensorShape hessian_shape = resource.hessian_shape();
int64 num_hessian_elements = hessian_shape.num_elements(); int64 num_hessian_elements = hessian_shape.num_elements();
hessian_shape.InsertDim(0, num_slots); hessian_shape.InsertDim(0, num_slots);
Tensor* hessians_t = nullptr; Tensor* hessians_t = nullptr;
@ -204,7 +203,7 @@ void SerializeTensorAccumulatorToOutput(
auto hessians = hessians_t->flat_outer_dims<float>(); auto hessians = hessians_t->flat_outer_dims<float>();
int i = 0; int i = 0;
for (const auto& iter : accumulator_resource.values()) { for (const auto& iter : resource.values()) {
partition_ids(i) = iter.first.partition_id; partition_ids(i) = iter.first.partition_id;
feature_ids(i, 0) = iter.first.feature_id; feature_ids(i, 0) = iter.first.feature_id;
feature_ids(i, 1) = iter.first.dimension; feature_ids(i, 1) = iter.first.dimension;
@ -220,11 +219,10 @@ void SerializeTensorAccumulatorToOutput(
} }
void AddToScalarAccumulator( void AddToScalarAccumulator(
StatsAccumulatorScalarResource* accumulator_resource, const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
const Tensor& partition_ids_t, const Tensor& feature_ids_t, const Tensor& partition_ids_t, const Tensor& feature_ids_t,
const Tensor& gradients_t, const Tensor& hessians_t) { const Tensor& gradients_t, const Tensor& hessians_t) {
accumulator_resource->set_num_updates(accumulator_resource->num_updates() + resource->set_num_updates(resource->num_updates() + 1);
1);
const TensorShape& partition_ids_shape = partition_ids_t.shape(); const TensorShape& partition_ids_shape = partition_ids_t.shape();
const auto& partition_ids = partition_ids_t.vec<int32>(); const auto& partition_ids = partition_ids_t.vec<int32>();
const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>(); const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
@ -232,7 +230,7 @@ void AddToScalarAccumulator(
const auto& hessians = hessians_t.vec<float>(); const auto& hessians = hessians_t.vec<float>();
int64 num_updates = partition_ids_shape.dim_size(0); int64 num_updates = partition_ids_shape.dim_size(0);
auto stats_map = accumulator_resource->mutable_values(); auto stats_map = resource->mutable_values();
for (int64 i = 0; i < num_updates; ++i) { for (int64 i = 0; i < num_updates; ++i) {
const auto key = const auto key =
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
@ -248,7 +246,7 @@ void AddToScalarAccumulator(
} }
void AddToScalarAccumulator( void AddToScalarAccumulator(
StatsAccumulatorScalarResource* accumulator_resource, const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
OpKernelContext* context) { OpKernelContext* context) {
const Tensor* partition_ids_t; const Tensor* partition_ids_t;
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t)); OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
@ -258,17 +256,16 @@ void AddToScalarAccumulator(
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
const Tensor* hessians_t; const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t, AddToScalarAccumulator(resource, *partition_ids_t, *feature_ids_t,
*gradients_t, *hessians_t); *gradients_t, *hessians_t);
} }
void AddToTensorAccumulator( void AddToTensorAccumulator(
StatsAccumulatorTensorResource* accumulator_resource, const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
const Tensor& partition_ids_t, const Tensor& feature_ids_t, const Tensor& partition_ids_t, const Tensor& feature_ids_t,
const Tensor& gradients_t, const Tensor& hessians_t, const Tensor& gradients_t, const Tensor& hessians_t,
OpKernelContext* context) { OpKernelContext* context) {
accumulator_resource->set_num_updates(accumulator_resource->num_updates() + resource->set_num_updates(resource->num_updates() + 1);
1);
const TensorShape& partition_ids_shape = partition_ids_t.shape(); const TensorShape& partition_ids_shape = partition_ids_t.shape();
const auto& partition_ids = partition_ids_t.vec<int32>(); const auto& partition_ids = partition_ids_t.vec<int32>();
@ -283,19 +280,19 @@ void AddToTensorAccumulator(
// TODO(soroush): Move gradient and hessian shape check to ShapeFn. // TODO(soroush): Move gradient and hessian shape check to ShapeFn.
OP_REQUIRES( OP_REQUIRES(
context, gradients_shape == accumulator_resource->gradient_shape(), context, gradients_shape == resource->gradient_shape(),
errors::InvalidArgument(strings::StrCat( errors::InvalidArgument(strings::StrCat(
"Gradients dimensions must match: ", gradients_shape.DebugString(), "Gradients dimensions must match: ", gradients_shape.DebugString(),
", ", accumulator_resource->gradient_shape().DebugString()))); ", ", resource->gradient_shape().DebugString())));
OP_REQUIRES( OP_REQUIRES(
context, hessians_shape == accumulator_resource->hessian_shape(), context, hessians_shape == resource->hessian_shape(),
errors::InvalidArgument(strings::StrCat( errors::InvalidArgument(strings::StrCat(
"Hessian dimensions must match: ", hessians_shape.DebugString(), ", ", "Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
accumulator_resource->hessian_shape().DebugString()))); resource->hessian_shape().DebugString())));
int64 num_updates = partition_ids_shape.dim_size(0); int64 num_updates = partition_ids_shape.dim_size(0);
auto stats_map = accumulator_resource->mutable_values(); auto stats_map = resource->mutable_values();
for (int64 i = 0; i < num_updates; ++i) { for (int64 i = 0; i < num_updates; ++i) {
const auto key = const auto key =
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
@ -325,7 +322,7 @@ void AddToTensorAccumulator(
} }
void AddToTensorAccumulator( void AddToTensorAccumulator(
StatsAccumulatorTensorResource* accumulator_resource, const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
OpKernelContext* context) { OpKernelContext* context) {
const Tensor* partition_ids_t; const Tensor* partition_ids_t;
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t)); OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
@ -335,7 +332,7 @@ void AddToTensorAccumulator(
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
const Tensor* hessians_t; const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t, AddToTensorAccumulator(resource, *partition_ids_t, *feature_ids_t,
*gradients_t, *hessians_t, context); *gradients_t, *hessians_t, context);
} }
@ -452,20 +449,18 @@ class StatsAccumulatorScalarAddOp : public OpKernel {
resource_handle_list[resource_handle_idx] resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0); .flat<ResourceHandle>()(0);
StatsAccumulatorScalarResource* accumulator_resource; core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, handle, OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
&accumulator_resource)); mutex_lock l(*resource->mutex());
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
// If the stamp is invalid we drop the update. // If the stamp is invalid we drop the update.
if (!accumulator_resource->is_stamp_valid(stamp_token)) { if (!resource->is_stamp_valid(stamp_token)) {
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. " VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
<< "Passed stamp token: " << stamp_token << " " << "Passed stamp token: " << stamp_token << " "
<< "Current token: " << accumulator_resource->stamp(); << "Current token: " << resource->stamp();
return; return;
} }
AddToScalarAccumulator(accumulator_resource, AddToScalarAccumulator(resource,
partition_ids_list[resource_handle_idx], partition_ids_list[resource_handle_idx],
feature_ids_list[resource_handle_idx], feature_ids_list[resource_handle_idx],
gradients_list[resource_handle_idx], gradients_list[resource_handle_idx],
@ -517,20 +512,18 @@ class StatsAccumulatorTensorAddOp : public OpKernel {
resource_handle_list[resource_handle_idx] resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0); .flat<ResourceHandle>()(0);
StatsAccumulatorTensorResource* accumulator_resource; core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, handle, OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
&accumulator_resource)); mutex_lock l(*resource->mutex());
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
// If the stamp is invalid we drop the update. // If the stamp is invalid we drop the update.
if (!accumulator_resource->is_stamp_valid(stamp_token)) { if (!resource->is_stamp_valid(stamp_token)) {
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. " VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
<< "Passed stamp token: " << stamp_token << " " << "Passed stamp token: " << stamp_token << " "
<< "Current token: " << accumulator_resource->stamp(); << "Current token: " << resource->stamp();
return; return;
} }
AddToTensorAccumulator(accumulator_resource, AddToTensorAccumulator(resource,
partition_ids_list[resource_handle_idx], partition_ids_list[resource_handle_idx],
feature_ids_list[resource_handle_idx], feature_ids_list[resource_handle_idx],
gradients_list[resource_handle_idx], gradients_list[resource_handle_idx],
@ -549,11 +542,10 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
StatsAccumulatorScalarResource* accumulator_resource; core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource)); &resource));
mutex_lock l(*accumulator_resource->mutex()); mutex_lock l(*resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
const Tensor* stamp_token_t; const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
@ -562,7 +554,7 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
// If the stamp is invalid we restart the PS. It shouldn't happen since // If the stamp is invalid we restart the PS. It shouldn't happen since
// only Chief should call this function and chief is guaranteed to be in // only Chief should call this function and chief is guaranteed to be in
// a consistent state. // a consistent state.
CHECK(accumulator_resource->is_stamp_valid(stamp_token)); CHECK(resource->is_stamp_valid(stamp_token));
const Tensor* next_stamp_token_t; const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
@ -570,15 +562,15 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()(); int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
CHECK(stamp_token != next_stamp_token); CHECK(stamp_token != next_stamp_token);
SerializeScalarAccumulatorToOutput(*accumulator_resource, context); SerializeScalarAccumulatorToOutput(*resource, context);
Tensor* num_updates_t = nullptr; Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}), context->allocate_output("num_updates", TensorShape({}),
&num_updates_t)); &num_updates_t));
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); num_updates_t->scalar<int64>()() = resource->num_updates();
accumulator_resource->Clear(); resource->Clear();
accumulator_resource->set_stamp(next_stamp_token); resource->set_stamp(next_stamp_token);
} }
}; };
@ -591,11 +583,10 @@ class StatsAccumulatorTensorFlushOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
StatsAccumulatorTensorResource* accumulator_resource; core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource)); &resource));
mutex_lock l(*accumulator_resource->mutex()); mutex_lock l(*resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
const Tensor* stamp_token_t; const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
@ -609,16 +600,16 @@ class StatsAccumulatorTensorFlushOp : public OpKernel {
// If the stamp is invalid we restart the PS. It shouldn't happen since // If the stamp is invalid we restart the PS. It shouldn't happen since
// only Chief should call this function and chief is guaranteed to be in // only Chief should call this function and chief is guaranteed to be in
// a consistent state. // a consistent state.
CHECK(accumulator_resource->is_stamp_valid(stamp_token)); CHECK(resource->is_stamp_valid(stamp_token));
CHECK(stamp_token != next_stamp_token); CHECK(stamp_token != next_stamp_token);
SerializeTensorAccumulatorToOutput(*accumulator_resource, context); SerializeTensorAccumulatorToOutput(*resource, context);
Tensor* num_updates_t = nullptr; Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}), context->allocate_output("num_updates", TensorShape({}),
&num_updates_t)); &num_updates_t));
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); num_updates_t->scalar<int64>()() = resource->num_updates();
accumulator_resource->Clear(); resource->Clear();
accumulator_resource->set_stamp(next_stamp_token); resource->set_stamp(next_stamp_token);
} }
}; };
@ -631,22 +622,21 @@ class StatsAccumulatorScalarDeserializeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
StatsAccumulatorScalarResource* accumulator_resource; core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource)); &resource));
mutex_lock l(*accumulator_resource->mutex()); mutex_lock l(*resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
// Check the stamp token. // Check the stamp token.
const Tensor* stamp_token_t; const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()(); int64 stamp_token = stamp_token_t->scalar<int64>()();
accumulator_resource->Clear(); resource->Clear();
accumulator_resource->set_stamp(stamp_token); resource->set_stamp(stamp_token);
AddToScalarAccumulator(accumulator_resource, context); AddToScalarAccumulator(resource, context);
const Tensor* num_updates_t; const Tensor* num_updates_t;
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t)); OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()()); resource->set_num_updates(num_updates_t->scalar<int64>()());
} }
}; };
@ -660,22 +650,21 @@ class StatsAccumulatorTensorDeserializeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
StatsAccumulatorTensorResource* accumulator_resource; core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource)); &resource));
mutex_lock l(*accumulator_resource->mutex()); mutex_lock l(*resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
// Check the stamp token. // Check the stamp token.
const Tensor* stamp_token_t; const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()(); int64 stamp_token = stamp_token_t->scalar<int64>()();
accumulator_resource->Clear(); resource->Clear();
accumulator_resource->set_stamp(stamp_token); resource->set_stamp(stamp_token);
AddToTensorAccumulator(accumulator_resource, context); AddToTensorAccumulator(resource, context);
const Tensor* num_updates_t; const Tensor* num_updates_t;
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t)); OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()()); resource->set_num_updates(num_updates_t->scalar<int64>()());
} }
}; };
@ -689,23 +678,22 @@ class StatsAccumulatorScalarSerializeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
StatsAccumulatorScalarResource* accumulator_resource; core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource)); &resource));
mutex_lock l(*accumulator_resource->mutex()); mutex_lock l(*resource->mutex());
core::ScopedUnref unref_me(accumulator_resource); SerializeScalarAccumulatorToOutput(*resource, context);
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
Tensor* stamp_token_t = nullptr; Tensor* stamp_token_t = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output("stamp_token", TensorShape({}), context->allocate_output("stamp_token", TensorShape({}),
&stamp_token_t)); &stamp_token_t));
stamp_token_t->scalar<int64>()() = accumulator_resource->stamp(); stamp_token_t->scalar<int64>()() = resource->stamp();
Tensor* num_updates_t = nullptr; Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}), context->allocate_output("num_updates", TensorShape({}),
&num_updates_t)); &num_updates_t));
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); num_updates_t->scalar<int64>()() = resource->num_updates();
} }
}; };
@ -719,23 +707,22 @@ class StatsAccumulatorTensorSerializeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
StatsAccumulatorTensorResource* accumulator_resource; core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource)); &resource));
mutex_lock l(*accumulator_resource->mutex()); mutex_lock l(*resource->mutex());
core::ScopedUnref unref_me(accumulator_resource); SerializeTensorAccumulatorToOutput(*resource, context);
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
Tensor* stamp_token_t = nullptr; Tensor* stamp_token_t = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output("stamp_token", TensorShape({}), context->allocate_output("stamp_token", TensorShape({}),
&stamp_token_t)); &stamp_token_t));
stamp_token_t->scalar<int64>()() = accumulator_resource->stamp(); stamp_token_t->scalar<int64>()() = resource->stamp();
Tensor* num_updates_t = nullptr; Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}), context->allocate_output("num_updates", TensorShape({}),
&num_updates_t)); &num_updates_t));
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); num_updates_t->scalar<int64>()() = resource->num_updates();
} }
}; };
@ -751,12 +738,11 @@ class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
TensorShape gradient_shape = TensorShape({}); TensorShape gradient_shape = TensorShape({});
TensorShape hessian_shape = TensorShape({}); TensorShape hessian_shape = TensorShape({});
StatsAccumulatorScalarResource* accumulator_resource = core::RefCountPtr<StatsAccumulatorScalarResource> resource(
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape); new StatsAccumulatorScalarResource(gradient_shape, hessian_shape));
core::ScopedUnref unref_me(accumulator_resource);
// Check the stamp token. // Check the stamp token.
AddToScalarAccumulator(accumulator_resource, context); AddToScalarAccumulator(resource, context);
SerializeScalarAccumulatorToOutput(*accumulator_resource, context); SerializeScalarAccumulatorToOutput(*resource, context);
} }
}; };
@ -780,12 +766,11 @@ class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
TensorShape hessians_shape = hessians_t->shape(); TensorShape hessians_shape = hessians_t->shape();
hessians_shape.RemoveDim(0); hessians_shape.RemoveDim(0);
StatsAccumulatorTensorResource* accumulator_resource = core::RefCountPtr<StatsAccumulatorTensorResource> resource(
new StatsAccumulatorTensorResource(gradients_shape, hessians_shape); new StatsAccumulatorTensorResource(gradients_shape, hessians_shape));
core::ScopedUnref unref_me(accumulator_resource);
// Check the stamp token. // Check the stamp token.
AddToTensorAccumulator(accumulator_resource, context); AddToTensorAccumulator(resource, context);
SerializeTensorAccumulatorToOutput(*accumulator_resource, context); SerializeTensorAccumulatorToOutput(*resource, context);
} }
}; };

View File

@ -21,6 +21,7 @@
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" #include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
@ -31,6 +32,8 @@ namespace {
using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearningRateConfig; using boosted_trees::learner::LearningRateConfig;
using boosted_trees::models::DecisionTreeEnsembleResource;
using boosted_trees::trees::DecisionTreeConfig;
using boosted_trees::trees::Leaf; using boosted_trees::trees::Leaf;
using boosted_trees::trees::TreeNode; using boosted_trees::trees::TreeNode;
using boosted_trees::trees::TreeNodeMetadata; using boosted_trees::trees::TreeNodeMetadata;
@ -193,10 +196,9 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble. // Get decision tree ensemble.
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
mutex_lock l(*ensemble_resource->get_mutex()); mutex_lock l(*ensemble_resource->get_mutex());
// Get the stamp token. // Get the stamp token.
@ -255,8 +257,8 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
private: private:
// Helper method to retrieve the bias from the tree ensemble. // Helper method to retrieve the bias from the tree ensemble.
boosted_trees::trees::Leaf* RetrieveBias( Leaf* RetrieveBias(
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource, const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
int64 logits_dimension) { int64 logits_dimension) {
const int32 num_trees = ensemble_resource->num_trees(); const int32 num_trees = ensemble_resource->num_trees();
if (num_trees <= 0) { if (num_trees <= 0) {
@ -319,10 +321,9 @@ class GrowTreeEnsembleOp : public OpKernel {
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble. // Get decision tree ensemble.
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
mutex_lock l(*ensemble_resource->get_mutex()); mutex_lock l(*ensemble_resource->get_mutex());
// Get the stamp token. // Get the stamp token.
@ -400,10 +401,9 @@ class GrowTreeEnsembleOp : public OpKernel {
// Update and retrieve the growable tree. // Update and retrieve the growable tree.
// If the tree is fully built and dropout was applied, it also adjusts the // If the tree is fully built and dropout was applied, it also adjusts the
// weights of dropped and the last tree. // weights of dropped and the last tree.
boosted_trees::trees::DecisionTreeConfig* const tree_config = DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, ensemble_resource, learning_rate, dropout_seed, max_tree_depth,
dropout_seed, max_tree_depth, weak_learner_type);
weak_learner_type);
// Split tree nodes. // Split tree nodes.
switch (weak_learner_type) { switch (weak_learner_type) {
case LearnerConfig::NORMAL_DECISION_TREE: { case LearnerConfig::NORMAL_DECISION_TREE: {
@ -559,8 +559,7 @@ class GrowTreeEnsembleOp : public OpKernel {
} }
void UpdateTreeWeightsIfDropout( void UpdateTreeWeightsIfDropout(
boosted_trees::models::DecisionTreeEnsembleResource* const const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
ensemble_resource,
const uint64 dropout_seed) { const uint64 dropout_seed) {
// It is possible that the tree was built with dropout. If it is the case, // It is possible that the tree was built with dropout. If it is the case,
// we need to adjust the tree weight, or bail out. // we need to adjust the tree weight, or bail out.
@ -609,9 +608,8 @@ class GrowTreeEnsembleOp : public OpKernel {
// Helper method to update the growable tree which is by definition the last // Helper method to update the growable tree which is by definition the last
// tree in the ensemble. // tree in the ensemble.
boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree( DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
boosted_trees::models::DecisionTreeEnsembleResource* const const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
ensemble_resource,
const float learning_rate, const uint64 dropout_seed, const float learning_rate, const uint64 dropout_seed,
const int32 max_tree_depth, const int32 weak_learner_type) { const int32 max_tree_depth, const int32 weak_learner_type) {
const auto num_trees = ensemble_resource->num_trees(); const auto num_trees = ensemble_resource->num_trees();
@ -719,8 +717,8 @@ class GrowTreeEnsembleOp : public OpKernel {
// leaf children given the split candidate. // leaf children given the split candidate.
void SplitTreeNode( void SplitTreeNode(
const int32 node_id, SplitCandidate* split, const int32 node_id, SplitCandidate* split,
boosted_trees::trees::DecisionTreeConfig* tree_config, DecisionTreeConfig* tree_config,
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) { const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
// No-op if we have no real node. // No-op if we have no real node.
CHECK(node_id < tree_config->nodes_size()) CHECK(node_id < tree_config->nodes_size())
<< "Invalid node " << node_id << " to split."; << "Invalid node " << node_id << " to split.";
@ -761,14 +759,13 @@ class GrowTreeEnsembleOp : public OpKernel {
(*tree_config->mutable_nodes(node_id)) = (*tree_config->mutable_nodes(node_id)) =
*split->split_info.mutable_split_node(); *split->split_info.mutable_split_node();
if (learner_config_.constraints().max_number_of_unique_feature_columns()) { if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
ensemble_resource->MaybeAddUsedHandler(split->handler_id); resource->MaybeAddUsedHandler(split->handler_id);
} }
} }
void SplitTreeLayer( void SplitTreeLayer(
SplitCandidate* split, SplitCandidate* split, DecisionTreeConfig* tree_config,
boosted_trees::trees::DecisionTreeConfig* tree_config, const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
int depth = 0; int depth = 0;
while (depth < tree_config->nodes_size() && while (depth < tree_config->nodes_size() &&
tree_config->nodes(depth).node_case() != TreeNode::kLeaf) { tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
@ -903,10 +900,9 @@ class TreeEnsembleStatsOp : public OpKernel {
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble. // Get decision tree ensemble.
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
tf_shared_lock l(*ensemble_resource->get_mutex()); tf_shared_lock l(*ensemble_resource->get_mutex());
// Get the stamp token. // Get the stamp token.

View File

@ -95,7 +95,7 @@ class ZeroVarInitializer : public OpKernel {
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
Var* variable = nullptr; core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>( OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, 0), &variable, ctx, HandleFromInput(ctx, 0), &variable,
[this, ctx](Var** var_ptr) { [this, ctx](Var** var_ptr) {
@ -117,7 +117,6 @@ class ZeroVarInitializer : public OpKernel {
return Status::OK(); return Status::OK();
})); }));
core::ScopedUnref scoped(variable);
mutex_lock ml(*variable->mu()); mutex_lock ml(*variable->mu());
OP_REQUIRES(ctx, !variable->is_initialized, OP_REQUIRES(ctx, !variable->is_initialized,

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#include <functional> #include <functional>
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
@ -24,6 +25,7 @@
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -76,11 +78,10 @@ class TreeSerializeOp : public OpKernel {
explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {} explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
DecisionTreeResource* decision_tree_resource; core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_config_t = nullptr; Tensor* output_config_t = nullptr;
OP_REQUIRES_OK( OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape(), &output_config_t)); context, context->allocate_output(0, TensorShape(), &output_config_t));
@ -100,12 +101,11 @@ class TreeDeserializeOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
DecisionTreeResource* decision_tree_resource; core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
auto handle = HandleFromInput(context, 0); auto handle = HandleFromInput(context, 0);
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
LookupResource(context, handle, &decision_tree_resource)); LookupResource(context, handle, &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const Tensor* tree_config_t; const Tensor* tree_config_t;
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t)); OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
@ -131,11 +131,10 @@ class TreeSizeOp : public OpKernel {
explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {} explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
DecisionTreeResource* decision_tree_resource; core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_t = nullptr; Tensor* output_t = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape(), &output_t)); context->allocate_output(0, TensorShape(), &output_t));
@ -144,7 +143,7 @@ class TreeSizeOp : public OpKernel {
} }
}; };
void TraverseTree(const DecisionTreeResource* tree_resource, void TraverseTree(DecisionTreeResource* tree_resource,
const std::unique_ptr<TensorDataSet>& data, int32 start, const std::unique_ptr<TensorDataSet>& data, int32 start,
int32 end, int32 end,
const std::function<void(int32, int32)>& set_leaf_id, const std::function<void(int32, int32)>& set_leaf_id,
@ -182,11 +181,10 @@ class TreePredictionsV4Op : public OpKernel {
data_set->set_input_tensors(input_data, sparse_input_indices, data_set->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape); sparse_input_values, sparse_input_shape);
DecisionTreeResource* decision_tree_resource; core::RefCountPtr<DecisionTreeResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const int num_data = data_set->NumItems(); const int num_data = data_set->NumItems();
const int32 num_outputs = param_proto_.num_outputs(); const int32 num_outputs = param_proto_.num_outputs();
@ -205,15 +203,16 @@ class TreePredictionsV4Op : public OpKernel {
auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads; int num_threads = worker_threads->num_threads;
const int64 costPerTraverse = 500; const int64 costPerTraverse = 500;
auto traverse = [this, &out, &data_set, decision_tree_resource, num_data, auto traverse = [this, &out, &data_set, &resource, num_data, &tree_paths](
&tree_paths](int64 start, int64 end) { int64 start, int64 end) {
CHECK(start <= end); CHECK(start <= end);
CHECK(end <= num_data); CHECK(end <= num_data);
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
TraverseTree(resource.get(), data_set, static_cast<int32>(start),
static_cast<int32>(end), static_cast<int32>(end),
std::bind(&TreePredictionsV4Op::set_output_value, this, std::bind(&TreePredictionsV4Op::set_output_value, this,
std::placeholders::_1, std::placeholders::_2, std::placeholders::_1, std::placeholders::_2,
decision_tree_resource, &out), resource.get(), &out),
param_proto_.inference_tree_paths() ? &tree_paths : nullptr); param_proto_.inference_tree_paths() ? &tree_paths : nullptr);
}; };
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
@ -283,11 +282,10 @@ class TraverseTreeV4Op : public OpKernel {
data_set->set_input_tensors(input_data, sparse_input_indices, data_set->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape); sparse_input_values, sparse_input_shape);
DecisionTreeResource* decision_tree_resource; core::RefCountPtr<DecisionTreeResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const int num_data = data_set->NumItems(); const int num_data = data_set->NumItems();
@ -304,11 +302,11 @@ class TraverseTreeV4Op : public OpKernel {
auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads; int num_threads = worker_threads->num_threads;
const int64 costPerTraverse = 500; const int64 costPerTraverse = 500;
auto traverse = [&set_leaf_ids, &data_set, decision_tree_resource, auto traverse = [&set_leaf_ids, &data_set, &resource, num_data](int64 start,
num_data](int64 start, int64 end) { int64 end) {
CHECK(start <= end); CHECK(start <= end);
CHECK(end <= num_data); CHECK(end <= num_data);
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start), TraverseTree(resource.get(), data_set, static_cast<int32>(start),
static_cast<int32>(end), set_leaf_ids, nullptr); static_cast<int32>(end), set_leaf_ids, nullptr);
}; };
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
@ -336,11 +334,10 @@ class UpdateModelV4Op : public OpKernel {
const Tensor& input_labels = context->input(2); const Tensor& input_labels = context->input(2);
const Tensor& input_weights = context->input(3); const Tensor& input_weights = context->input(3);
DecisionTreeResource* decision_tree_resource; core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const int num_data = input_labels.shape().dim_size(0); const int num_data = input_labels.shape().dim_size(0);
const int32 label_dim = const int32 label_dim =
@ -356,9 +353,10 @@ class UpdateModelV4Op : public OpKernel {
UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource); UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
} }
void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target, void UpdateModel(
int32 start, int32 end, const Tensor& leaf_ids, const TensorInputTarget& target, int32 start,
DecisionTreeResource* decision_tree_resource) { int32 end,
const core::RefCountPtr<DecisionTreeResource>& decision_tree_resource) {
const auto leaves = leaf_ids.unaligned_flat<int32>(); const auto leaves = leaf_ids.unaligned_flat<int32>();
for (int i = start; i < end; ++i) { for (int i = start; i < end; ++i) {
model_op_->UpdateModel( model_op_->UpdateModel(
@ -384,11 +382,10 @@ class FeatureUsageCountsOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
DecisionTreeResource* decision_tree_resource; core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const auto& tree = decision_tree_resource->decision_tree(); const auto& tree = decision_tree_resource->decision_tree();

View File

@ -26,6 +26,7 @@
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
@ -87,11 +88,10 @@ class FertileStatsSerializeOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
FertileStatsResource* fertile_stats_resource; core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&fertile_stats_resource)); &fertile_stats_resource));
mutex_lock l(*fertile_stats_resource->get_mutex()); mutex_lock l(*fertile_stats_resource->get_mutex());
core::ScopedUnref unref_me(fertile_stats_resource);
Tensor* output_config_t = nullptr; Tensor* output_config_t = nullptr;
OP_REQUIRES_OK( OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape(), &output_config_t)); context, context->allocate_output(0, TensorShape(), &output_config_t));
@ -116,11 +116,10 @@ class FertileStatsDeserializeOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
FertileStatsResource* fertile_stats_resource; core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&fertile_stats_resource)); &fertile_stats_resource));
mutex_lock l(*fertile_stats_resource->get_mutex()); mutex_lock l(*fertile_stats_resource->get_mutex());
core::ScopedUnref unref_me(fertile_stats_resource);
const Tensor* stats_config_t; const Tensor* stats_config_t;
OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t)); OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
@ -145,7 +144,7 @@ class FertileStatsDeserializeOp : public OpKernel {
// acquired, put it in a waiting queue to come back to later and try the next // acquired, put it in a waiting queue to come back to later and try the next
// one. Once all leaf_ids have been visited, cycle through the waiting ids // one. Once all leaf_ids have been visited, cycle through the waiting ids
// until they're gone. // until they're gone.
void UpdateStats(FertileStatsResource* fertile_stats_resource, void UpdateStats(const core::RefCountPtr<FertileStatsResource>& resource,
const std::unique_ptr<TensorDataSet>& data, const std::unique_ptr<TensorDataSet>& data,
const TensorInputTarget& target, int num_targets, const TensorInputTarget& target, int num_targets,
const Tensor& leaf_ids_tensor, const Tensor& leaf_ids_tensor,
@ -183,8 +182,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
} }
bool is_finished; bool is_finished;
fertile_stats_resource->AddExampleToStatsAndInitialize( resource->AddExampleToStatsAndInitialize(data, &target, {example_id},
data, &target, {example_id}, leaf_id, &is_finished); leaf_id, &is_finished);
leaf_lock->unlock(); leaf_lock->unlock();
if (is_finished) { if (is_finished) {
set_lock->lock(); set_lock->lock();
@ -196,8 +195,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
// Update leaves from start through end in the leaf_examples iterator. // Update leaves from start through end in the leaf_examples iterator.
void UpdateStatsCollated( void UpdateStatsCollated(
FertileStatsResource* fertile_stats_resource, const core::RefCountPtr<FertileStatsResource>& fertile_stats_resource,
DecisionTreeResource* tree_resource, const core::RefCountPtr<DecisionTreeResource>& tree_resource,
const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target, const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
int num_targets, int num_targets,
const std::unordered_map<int32, std::vector<int>>& leaf_examples, const std::unordered_map<int32, std::vector<int>>& leaf_examples,
@ -251,18 +250,15 @@ class ProcessInputOp : public OpKernel {
data_set->set_input_tensors(input_data, sparse_input_indices, data_set->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape); sparse_input_values, sparse_input_shape);
FertileStatsResource* fertile_stats_resource; core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
&fertile_stats_resource)); &fertile_stats_resource));
DecisionTreeResource* tree_resource; core::RefCountPtr<DecisionTreeResource> tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_resource)); &tree_resource));
mutex_lock l1(*fertile_stats_resource->get_mutex()); mutex_lock l1(*fertile_stats_resource->get_mutex());
mutex_lock l2(*tree_resource->get_mutex()); mutex_lock l2(*tree_resource->get_mutex());
core::ScopedUnref unref_stats(fertile_stats_resource);
core::ScopedUnref unref_tree(tree_resource);
const int32 num_data = data_set->NumItems(); const int32 num_data = data_set->NumItems();
auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads; int num_threads = worker_threads->num_threads;
@ -308,7 +304,7 @@ class ProcessInputOp : public OpKernel {
// if it really matters that much. // if it really matters that much.
const int64 costPerUpdate = 1000; const int64 costPerUpdate = 1000;
auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set, auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set,
fertile_stats_resource, &locks, &set_lock, &ready_to_split, &fertile_stats_resource, &locks, &set_lock, &ready_to_split,
num_data](int64 start, int64 end) { num_data](int64 start, int64 end) {
CHECK(start <= end); CHECK(start <= end);
CHECK(end <= num_data); CHECK(end <= num_data);
@ -317,8 +313,8 @@ class ProcessInputOp : public OpKernel {
static_cast<int32>(end), &ready_to_split); static_cast<int32>(end), &ready_to_split);
}; };
auto update_collated = [&target, &num_targets, fertile_stats_resource, auto update_collated = [&target, &num_targets, &fertile_stats_resource,
tree_resource, &leaf_examples, &set_lock, &tree_resource, &leaf_examples, &set_lock,
&ready_to_split, &data_set, &ready_to_split, &data_set,
num_leaves](int64 start, int64 end) { num_leaves](int64 start, int64 end) {
CHECK(start <= end); CHECK(start <= end);
@ -362,18 +358,15 @@ class GrowTreeOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
FertileStatsResource* fertile_stats_resource; core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
&fertile_stats_resource)); &fertile_stats_resource));
DecisionTreeResource* tree_resource; core::RefCountPtr<DecisionTreeResource> tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_resource)); &tree_resource));
mutex_lock l1(*fertile_stats_resource->get_mutex()); mutex_lock l1(*fertile_stats_resource->get_mutex());
mutex_lock l2(*tree_resource->get_mutex()); mutex_lock l2(*tree_resource->get_mutex());
core::ScopedUnref unref_stats(fertile_stats_resource);
core::ScopedUnref unref_tree(tree_resource);
const Tensor& finished_nodes = context->input(2); const Tensor& finished_nodes = context->input(2);
const auto finished = finished_nodes.unaligned_flat<int32>(); const auto finished = finished_nodes.unaligned_flat<int32>();
@ -463,19 +456,16 @@ class FinalizeTreeOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
DecisionTreeResource* tree_resource; core::RefCountPtr<DecisionTreeResource> tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_resource)); &tree_resource));
FertileStatsResource* fertile_stats_resource; core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
&fertile_stats_resource)); &fertile_stats_resource));
mutex_lock l1(*fertile_stats_resource->get_mutex()); mutex_lock l1(*fertile_stats_resource->get_mutex());
mutex_lock l2(*tree_resource->get_mutex()); mutex_lock l2(*tree_resource->get_mutex());
core::ScopedUnref unref_me(tree_resource);
core::ScopedUnref unref_stats(fertile_stats_resource);
// TODO(thomaswc): Add threads // TODO(thomaswc): Add threads
int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size(); int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size();
for (int i = 0; i < num_nodes; i++) { for (int i = 0; i < num_nodes; i++) {

View File

@ -270,12 +270,19 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
template <typename T, bool use_dynamic_cast = false> template <typename T, bool use_dynamic_cast = false>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
// Looks up a resource pointed by a given resource handle.
//
// Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
// requiring the caller to explicitly call `Unref()`.
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
core::RefCountPtr<T>* value);
// Looks up multiple resources pointed by a sequence of resource handles. If // Looks up multiple resources pointed by a sequence of resource handles. If
// p[i] is uninitialized then values[i] is unmodified. // p[i] is uninitialized then values[i] is unmodified.
template <typename T> template <typename T>
Status LookupResources( Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
OpKernelContext* ctx, absl::Span<ResourceHandle const> p, std::vector<core::RefCountPtr<T>>* values);
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
// Looks up or creates a resource. // Looks up or creates a resource.
// //
@ -283,10 +290,19 @@ Status LookupResources(
// must call its `Unref()` method when it has finished using it. If the // must call its `Unref()` method when it has finished using it. If the
// `creator` is invoked, its reference on the created resource is transferred // `creator` is invoked, its reference on the created resource is transferred
// to `ctx->resource_mgr()`. // to `ctx->resource_mgr()`.
//
// Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
// requiring the caller to explicitly call `Unref()`.
template <typename T> template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator); T** value, std::function<Status(T**)> creator);
// Looks up or creates a resource.
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
core::RefCountPtr<T>* value,
std::function<Status(T**)> creator);
// Destroys a resource pointed by a given resource handle. // Destroys a resource pointed by a given resource handle.
template <typename T> template <typename T>
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
@ -587,9 +603,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
} }
template <typename T> template <typename T>
Status LookupResources( Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p, core::RefCountPtr<T>* value) {
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) { T* raw_ptr = nullptr;
TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
value->reset(raw_ptr);
return Status::OK();
}
template <typename T>
Status LookupResources(OpKernelContext* ctx,
absl::Span<ResourceHandle const* const> p,
std::vector<core::RefCountPtr<T>>* values) {
std::vector<std::pair<const string*, const string*>> containers_and_names( std::vector<std::pair<const string*, const string*>> containers_and_names(
p.size()); p.size());
for (size_t i = 0; i < p.size(); ++i) { for (size_t i = 0; i < p.size(); ++i) {
@ -607,6 +633,17 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
creator); creator);
} }
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
core::RefCountPtr<T>* value,
std::function<Status(T**)> creator) {
T* raw_ptr = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
value->reset(raw_ptr);
return Status::OK();
}
template <typename T> template <typename T>
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
@ -266,15 +267,14 @@ TEST(ResourceHandleTest, CRUD) {
TF_EXPECT_OK(CreateResource(&ctx, p, r)); TF_EXPECT_OK(CreateResource(&ctx, p, r));
} }
{ {
StubResource* r = nullptr; core::RefCountPtr<StubResource> r;
TF_ASSERT_OK(LookupResource(&ctx, p, &r)); TF_ASSERT_OK(LookupResource(&ctx, p, &r));
ASSERT_TRUE(r != nullptr); ASSERT_TRUE(r != nullptr);
EXPECT_EQ(r->value_, 42); EXPECT_EQ(r->value_, 42);
r->Unref();
} }
{ {
TF_EXPECT_OK(DeleteResource<StubResource>(&ctx, p)); TF_EXPECT_OK(DeleteResource<StubResource>(&ctx, p));
StubResource* unused = nullptr; core::RefCountPtr<StubResource> unused;
EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok()); EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok());
} }
} }
@ -338,13 +338,12 @@ TEST(ResourceHandleTest, DeleteUsingResourceHandle) {
StubResource* r = new StubResource; StubResource* r = new StubResource;
TF_EXPECT_OK(CreateResource(&ctx, p, r)); TF_EXPECT_OK(CreateResource(&ctx, p, r));
StubResource* lookup_r = nullptr; core::RefCountPtr<StubResource> lookup_r;
TF_EXPECT_OK(LookupResource<StubResource>(&ctx, p, &lookup_r)); TF_EXPECT_OK(LookupResource<StubResource>(&ctx, p, &lookup_r));
EXPECT_EQ(lookup_r, r); EXPECT_EQ(lookup_r.get(), r);
TF_EXPECT_OK(DeleteResource(&ctx, p)); TF_EXPECT_OK(DeleteResource(&ctx, p));
EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true); EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true);
r->Unref();
} }
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -166,6 +166,7 @@ tf_kernel_library(
":variable_ops", ":variable_ops",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/eigen3", "//third_party/eigen3",
], ],
) )
@ -5022,6 +5023,7 @@ STATE_DEPS = [
"//third_party/eigen3", "//third_party/eigen3",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
] + if_sycl(["//tensorflow/core:sycl_runtime"]) ] + if_sycl(["//tensorflow/core:sycl_runtime"])
tf_kernel_library( tf_kernel_library(
@ -7589,6 +7591,7 @@ tf_kernel_library(
deps = [ deps = [
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/db:sqlite", "//tensorflow/core/lib/db:sqlite",
"//tensorflow/core/summary:schema", "//tensorflow/core/summary:schema",

View File

@ -83,6 +83,7 @@ tf_kernel_library(
":resources", ":resources",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
], ],
) )
@ -106,6 +107,7 @@ tf_kernel_library(
":tree_helper", ":tree_helper",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
"//third_party/eigen3", "//third_party/eigen3",
], ],
@ -117,6 +119,7 @@ tf_kernel_library(
deps = [ deps = [
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles", "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
], ],
) )

View File

@ -51,12 +51,10 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
} }
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
BoostedTreesEnsembleResource* resource; core::RefCountPtr<BoostedTreesEnsembleResource> resource;
// Get the resource. // Get the resource.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource)); &resource));
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(resource);
// Get the inputs. // Get the inputs.
OpInputList bucketized_features_list; OpInputList bucketized_features_list;
@ -198,12 +196,10 @@ class BoostedTreesPredictOp : public OpKernel {
} }
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
BoostedTreesEnsembleResource* resource; core::RefCountPtr<BoostedTreesEnsembleResource> resource;
// Get the resource. // Get the resource.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource)); &resource));
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(resource);
// Get the inputs. // Get the inputs.
OpInputList bucketized_features_list; OpInputList bucketized_features_list;
@ -302,12 +298,10 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
} }
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
BoostedTreesEnsembleResource* resource; core::RefCountPtr<BoostedTreesEnsembleResource> resource;
// Get the resource. // Get the resource.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource)); &resource));
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(resource);
// Get the inputs. // Get the inputs.
OpInputList bucketized_features_list; OpInputList bucketized_features_list;

View File

@ -27,6 +27,7 @@
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h" #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h" #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -224,12 +225,11 @@ class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
ResourceHandle handle; ResourceHandle handle;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
HandleFromInput(context, kResourceHandleName, &handle)); HandleFromInput(context, kResourceHandleName, &handle));
QuantileStreamResource* stream_resource; core::RefCountPtr<QuantileStreamResource> stream_resource;
// Create a reference to the underlying resource using the handle. // Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*stream_resource->mutex()); mutex_lock l(*stream_resource->mutex());
core::ScopedUnref unref_me(stream_resource);
OpInputList summaries_list; OpInputList summaries_list;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
@ -281,13 +281,12 @@ class BoostedTreesQuantileStreamResourceDeserializeOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource; core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle. // Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource)); &streams_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex()); mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
OpInputList bucket_boundaries_list; OpInputList bucket_boundaries_list;
OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName, OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
@ -336,12 +335,11 @@ class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
ResourceHandle handle; ResourceHandle handle;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
HandleFromInput(context, kResourceHandleName, &handle)); HandleFromInput(context, kResourceHandleName, &handle));
QuantileStreamResource* stream_resource; core::RefCountPtr<QuantileStreamResource> stream_resource;
// Create a reference to the underlying resource using the handle. // Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*stream_resource->mutex()); mutex_lock l(*stream_resource->mutex());
core::ScopedUnref unref_me(stream_resource);
const Tensor* num_buckets_t; const Tensor* num_buckets_t;
OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t)); OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
@ -391,12 +389,11 @@ class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
ResourceHandle handle; ResourceHandle handle;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
HandleFromInput(context, kResourceHandleName, &handle)); HandleFromInput(context, kResourceHandleName, &handle));
QuantileStreamResource* stream_resource; core::RefCountPtr<QuantileStreamResource> stream_resource;
// Create a reference to the underlying resource using the handle. // Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
// Remove the reference at the end of this scope. // Remove the reference at the end of this scope.
mutex_lock l(*stream_resource->mutex()); mutex_lock l(*stream_resource->mutex());
core::ScopedUnref unref_me(stream_resource);
const int64 num_streams = stream_resource->num_streams(); const int64 num_streams = stream_resource->num_streams();
CHECK_EQ(num_features_, num_streams); CHECK_EQ(num_features_, num_streams);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/boosted_trees/resources.h" #include "tensorflow/core/kernels/boosted_trees/resources.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
@ -78,11 +79,10 @@ class BoostedTreesGetEnsembleStatesOp : public OpKernel {
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
// Looks up the resource. // Looks up the resource.
BoostedTreesEnsembleResource* tree_ensemble_resource; core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_ensemble_resource)); &tree_ensemble_resource));
tf_shared_lock l(*tree_ensemble_resource->get_mutex()); tf_shared_lock l(*tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(tree_ensemble_resource);
// Sets the outputs. // Sets the outputs.
const int num_trees = tree_ensemble_resource->num_trees(); const int num_trees = tree_ensemble_resource->num_trees();
@ -141,11 +141,10 @@ class BoostedTreesSerializeEnsembleOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
BoostedTreesEnsembleResource* tree_ensemble_resource; core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_ensemble_resource)); &tree_ensemble_resource));
tf_shared_lock l(*tree_ensemble_resource->get_mutex()); tf_shared_lock l(*tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(tree_ensemble_resource);
Tensor* output_stamp_token_t = nullptr; Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
&output_stamp_token_t)); &output_stamp_token_t));
@ -169,11 +168,10 @@ class BoostedTreesDeserializeEnsembleOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
BoostedTreesEnsembleResource* tree_ensemble_resource; core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_ensemble_resource)); &tree_ensemble_resource));
mutex_lock l(*tree_ensemble_resource->get_mutex()); mutex_lock l(*tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(tree_ensemble_resource);
// Get the stamp token. // Get the stamp token.
const Tensor* stamp_token_t; const Tensor* stamp_token_t;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/boosted_trees/resources.h" #include "tensorflow/core/kernels/boosted_trees/resources.h"
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h" #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
@ -55,10 +56,9 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble. // Get decision tree ensemble.
BoostedTreesEnsembleResource* ensemble_resource; core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
mutex_lock l(*ensemble_resource->get_mutex()); mutex_lock l(*ensemble_resource->get_mutex());
// Increase the ensemble stamp. // Increase the ensemble stamp.
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
@ -176,19 +176,19 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
private: private:
int32 UpdateGlobalAttemptsAndRetrieveGrowableTree( int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
BoostedTreesEnsembleResource* const ensemble_resource) { const core::RefCountPtr<BoostedTreesEnsembleResource>& resource) {
int32 num_trees = ensemble_resource->num_trees(); int32 num_trees = resource->num_trees();
int32 current_tree = num_trees - 1; int32 current_tree = num_trees - 1;
// Increment global attempt stats. // Increment global attempt stats.
ensemble_resource->UpdateGrowingMetadata(); resource->UpdateGrowingMetadata();
// Note we don't set tree weight to be equal to learning rate, since we // Note we don't set tree weight to be equal to learning rate, since we
// apply learning rate to leaf weights instead, when doing layer-by-layer // apply learning rate to leaf weights instead, when doing layer-by-layer
// boosting. // boosting.
if (num_trees <= 0) { if (num_trees <= 0) {
// Create a new tree with a no-op leaf. // Create a new tree with a no-op leaf.
current_tree = ensemble_resource->AddNewTree(kLayerByLayerTreeWeight); current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
} }
return current_tree; return current_tree;
} }
@ -250,10 +250,9 @@ class BoostedTreesCenterBiasOp : public OpKernel {
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble. // Get decision tree ensemble.
BoostedTreesEnsembleResource* ensemble_resource; core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource)); &ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
mutex_lock l(*ensemble_resource->get_mutex()); mutex_lock l(*ensemble_resource->get_mutex());
// Increase the ensemble stamp. // Increase the ensemble stamp.
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -65,11 +66,9 @@ class ResourceCountUpToOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
Var* variable = nullptr; core::RefCountPtr<Var> variable;
OP_REQUIRES_OK( OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
context, &variable));
LookupResource<Var>(context, HandleFromInput(context, 0), &variable));
core::ScopedUnref s(variable);
mutex_lock l(*variable->mu()); mutex_lock l(*variable->mu());
Tensor before_increment = *variable->tensor(); Tensor before_increment = *variable->tensor();
OP_REQUIRES( OP_REQUIRES(

View File

@ -342,6 +342,7 @@ tf_kernel_library(
deps = [ deps = [
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:summary_interface", "//tensorflow/core/kernels:summary_interface",
], ],
@ -380,6 +381,7 @@ tf_kernel_library(
"//tensorflow/core:experimental_dataset_ops_op_lib", "//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/data:dataset_utils", "//tensorflow/core/kernels/data:dataset_utils",
"//third_party/eigen3", "//third_party/eigen3",
], ],

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/data/stats_utils.h" #include "tensorflow/core/kernels/data/stats_utils.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
namespace tensorflow { namespace tensorflow {
@ -83,17 +84,16 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input, void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override { DatasetBase** output) override {
StatsAggregatorResource* stats_aggregator_resource; core::RefCountPtr<StatsAggregatorResource> resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), OP_REQUIRES_OK(ctx,
&stats_aggregator_resource)); LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
string tag; string tag;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag)); OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
string prefix; string prefix;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix)); OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix));
*output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource, *output =
tag, prefix); new Dataset(ctx, input, ctx->input(1), resource.get(), tag, prefix);
} }
private: private:
@ -101,12 +101,12 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
public: public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
const Tensor& resource_handle, const Tensor& resource_handle,
StatsAggregatorResource* stats_aggregator_resource, StatsAggregatorResource* resource, const string& tag,
const string& tag, const string& prefix) const string& prefix)
: DatasetBase(DatasetContext(ctx)), : DatasetBase(DatasetContext(ctx)),
input_(input), input_(input),
resource_handle_(resource_handle), resource_handle_(resource_handle),
stats_aggregator_resource_(stats_aggregator_resource), stats_aggregator_resource_(resource),
tag_(tag), tag_(tag),
prefix_(prefix) { prefix_(prefix) {
input_->Ref(); input_->Ref();
@ -169,13 +169,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override { bool* end_of_sequence) override {
mutex_lock l(mu_); mutex_lock l(mu_);
StatsAggregatorResource* stats_aggregator_resource = StatsAggregatorResource* resource =
dataset()->stats_aggregator_resource_; dataset()->stats_aggregator_resource_;
IteratorContext::Params params(ctx); IteratorContext::Params params(ctx);
params.stats_aggregator = std::shared_ptr<StatsAggregator>( params.stats_aggregator = std::shared_ptr<StatsAggregator>(
new StatsAggregatorWithTagAndPrefix( new StatsAggregatorWithTagAndPrefix(resource->stats_aggregator(),
stats_aggregator_resource->stats_aggregator(), dataset()->tag_, dataset()->tag_,
dataset()->prefix_)); dataset()->prefix_));
IteratorContext iter_ctx(std::move(params)); IteratorContext iter_ctx(std::move(params));
return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence); return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
} }

View File

@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/stats_aggregator.h"
#include <memory> #include <memory>
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/kernels/summary_interface.h" #include "tensorflow/core/kernels/summary_interface.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/histogram/histogram.h" #include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/counter.h"
@ -258,10 +258,9 @@ class StatsAggregatorSummaryOp : public OpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar")); errors::InvalidArgument("resource_handle must be a scalar"));
StatsAggregatorResource* resource; core::RefCountPtr<StatsAggregatorResource> resource;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref unref_iterator(resource);
Tensor* summary_t; Tensor* summary_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
@ -281,21 +280,19 @@ class StatsAggregatorSetSummaryWriterOp : public OpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar")); errors::InvalidArgument("resource_handle must be a scalar"));
StatsAggregatorResource* resource; core::RefCountPtr<StatsAggregatorResource> resource;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref unref_iterator(resource);
const Tensor& summary_resource_handle_t = ctx->input(1); const Tensor& summary_resource_handle_t = ctx->input(1);
OP_REQUIRES(ctx, OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()), TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar")); errors::InvalidArgument("resource_handle must be a scalar"));
SummaryWriterInterface* sumamry_resource; core::RefCountPtr<SummaryWriterInterface> sumamry_resource;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource)); ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource));
core::ScopedUnref unref_sumamry_resource(sumamry_resource);
TF_CHECK_OK( TF_CHECK_OK(
resource->stats_aggregator()->SetSummaryWriter(sumamry_resource)); resource->stats_aggregator()->SetSummaryWriter(sumamry_resource.get()));
} }
}; };

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/util/work_sharder.h" #include "tensorflow/core/util/work_sharder.h"
@ -127,12 +128,10 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input, void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override { DatasetBase** output) override {
ThreadPoolResource* threadpool_resource; core::RefCountPtr<ThreadPoolResource> threadpool_resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
&threadpool_resource)); &threadpool_resource));
core::ScopedUnref unref_iterator(threadpool_resource); *output = new Dataset(ctx, input, ctx->input(1), threadpool_resource.get());
*output = new Dataset(ctx, input, ctx->input(1), threadpool_resource);
} }
private: private:

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
@ -595,10 +596,9 @@ int64 AnonymousIteratorHandleOp::current_id_(0);
void MakeIteratorOp::Compute(OpKernelContext* ctx) { void MakeIteratorOp::Compute(OpKernelContext* ctx) {
DatasetBase* dataset; DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
IteratorResource* iterator_resource; core::RefCountPtr<IteratorResource> iterator_resource;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
core::ScopedUnref unref(iterator_resource);
OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset)); OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset));
} }

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h" #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
@ -492,10 +493,9 @@ class MultiDeviceIteratorInitOp : public OpKernel {
DatasetBase* dataset; DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
MultiDeviceIterator* resource; core::RefCountPtr<MultiDeviceIterator> resource;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 1), &resource)); LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
core::ScopedUnref unref(resource);
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
IteratorContext::Params params(ctx); IteratorContext::Params params(ctx);
@ -535,7 +535,7 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done); ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
int64 incarnation_id = tensor_incarnation_id->scalar<int64>()(); int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
MultiDeviceIterator* iterator; core::RefCountPtr<MultiDeviceIterator> iterator;
OP_REQUIRES_OK_ASYNC( OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
@ -557,7 +557,6 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
std::placeholders::_1, std::move(done)); std::placeholders::_1, std::move(done));
iterator->GetNextFromShard(ctx, shard_num, incarnation_id, callback); iterator->GetNextFromShard(ctx, shard_num, incarnation_id, callback);
iterator->Unref();
} }
}; };
@ -577,10 +576,9 @@ class MultiDeviceIteratorToStringHandleOp : public OpKernel {
// Validate that the handle corresponds to a real resource, and // Validate that the handle corresponds to a real resource, and
// that it is an MultiDeviceIterator. // that it is an MultiDeviceIterator.
MultiDeviceIterator* resource; core::RefCountPtr<MultiDeviceIterator> resource;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
resource->Unref();
Tensor* string_handle_t; Tensor* string_handle_t;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
@ -629,9 +627,8 @@ class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
// Validate that the handle corresponds to a real resource, and // Validate that the handle corresponds to a real resource, and
// that it is an MultiDeviceIterator. // that it is an MultiDeviceIterator.
MultiDeviceIterator* resource; core::RefCountPtr<MultiDeviceIterator> resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource)); OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
core::ScopedUnref unref_iterator(resource);
if (!output_types_.empty()) { if (!output_types_.empty()) {
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
VerifyTypesMatch(output_types_, resource->output_types())); VerifyTypesMatch(output_types_, resource->output_types()));

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
#include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/guarded_philox_random.h" #include "tensorflow/core/util/guarded_philox_random.h"
@ -371,10 +372,9 @@ class RandomBinomialOp : public OpKernel {
"Input probs should have length 1 or shape[0], got shape: ", "Input probs should have length 1 or shape[0], got shape: ",
probs_tensor.shape().DebugString())); probs_tensor.shape().DebugString()));
} }
Var* var = nullptr; core::RefCountPtr<Var> var;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
ScopedUnlockUnrefVar var_guard(var);
Tensor* var_tensor = var->tensor(); Tensor* var_tensor = var->tensor();
OP_REQUIRES( OP_REQUIRES(
ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE, ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE,
@ -404,7 +404,6 @@ class RandomBinomialOp : public OpKernel {
auto philox = GetPhiloxRandomFromMem(var_data); auto philox = GetPhiloxRandomFromMem(var_data);
UpdateMemWithPhiloxRandom( UpdateMemWithPhiloxRandom(
philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data); philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data);
var_guard.Release();
auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>(); auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches, binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,

View File

@ -129,7 +129,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
} // namespace } // namespace
void ReadVariableOp::Compute(OpKernelContext* ctx) { void ReadVariableOp::Compute(OpKernelContext* ctx) {
Var* variable = nullptr; core::RefCountPtr<Var> variable;
const ResourceHandle& handle = HandleFromInput(ctx, 0); const ResourceHandle& handle = HandleFromInput(ctx, 0);
const auto status = LookupResource(ctx, handle, &variable); const auto status = LookupResource(ctx, handle, &variable);
OP_REQUIRES(ctx, status.ok(), OP_REQUIRES(ctx, status.ok(),
@ -139,7 +139,6 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
". This could mean that the variable was uninitialized. ", ". This could mean that the variable was uninitialized. ",
status.ToString())); status.ToString()));
core::ScopedUnref s(variable);
{ {
tf_shared_lock ml(*variable->mu()); tf_shared_lock ml(*variable->mu());
// We're acquiring a reference to the underlying buffer while // We're acquiring a reference to the underlying buffer while
@ -175,8 +174,7 @@ ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
} }
void ReadVariablesOp::Compute(OpKernelContext* ctx) { void ReadVariablesOp::Compute(OpKernelContext* ctx) {
std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables( std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
dtypes_.size());
std::vector<const ResourceHandle*> handles(dtypes_.size()); std::vector<const ResourceHandle*> handles(dtypes_.size());
for (size_t i = 0; i < dtypes_.size(); ++i) { for (size_t i = 0; i < dtypes_.size(); ++i) {
handles[i] = &HandleFromInput(ctx, i); handles[i] = &HandleFromInput(ctx, i);
@ -265,10 +263,9 @@ class VariableShapeOp : public OpKernel {
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {} explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
Var* variable = nullptr; core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &variable)); LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
core::ScopedUnref s(variable);
variable->mu()->lock_shared(); variable->mu()->lock_shared();
TensorShape shape = variable->tensor()->shape(); TensorShape shape = variable->tensor()->shape();
variable->mu()->unlock_shared(); variable->mu()->unlock_shared();
@ -343,7 +340,7 @@ class AssignVariableOp : public OpKernel {
"Variable and value dtypes don't match; respectively, ", "Variable and value dtypes don't match; respectively, ",
DataTypeString(dtype_), " and ", DataTypeString(dtype_), " and ",
DataTypeString(context->input(1).dtype()))); DataTypeString(context->input(1).dtype())));
Var* variable = nullptr; core::RefCountPtr<Var> variable;
const Tensor& value = context->input(1); const Tensor& value = context->input(1);
// Note: every resource-variable-manipulating op assumes copy-on-write // Note: every resource-variable-manipulating op assumes copy-on-write
// semantics, and creates a copy of the variable's Tensor if its refcount is // semantics, and creates a copy of the variable's Tensor if its refcount is
@ -361,7 +358,6 @@ class AssignVariableOp : public OpKernel {
(*ptr)->is_initialized = true; (*ptr)->is_initialized = true;
return Status::OK(); return Status::OK();
})); }));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu()); mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument( errors::InvalidArgument(
@ -404,7 +400,7 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
const Tensor& value = context->input(1); const Tensor& value = context->input(1);
Var* variable = nullptr; core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(context, LookupOrCreateResource<Var>( OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
context, HandleFromInput(context, 0), &variable, context, HandleFromInput(context, 0), &variable,
[](Var** ptr) { [](Var** ptr) {
@ -412,7 +408,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
*ptr = new Var(DT_VARIANT); *ptr = new Var(DT_VARIANT);
return Status::OK(); return Status::OK();
})); }));
core::ScopedUnref s(variable);
// For purposes of forwarding DT_VARIANT, we want the least // For purposes of forwarding DT_VARIANT, we want the least
// restrictive attr; we already know the input is on host. // restrictive attr; we already know the input is on host.
@ -500,10 +495,9 @@ class AssignUpdateVariableOp : public OpKernel {
explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {} explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
Var* variable = nullptr; core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&variable)); &variable));
core::ScopedUnref s(variable);
const Tensor& value = context->input(1); const Tensor& value = context->input(1);
// TODO(apassos): We could possibly avoid the copy done by // TODO(apassos): We could possibly avoid the copy done by
@ -568,13 +562,12 @@ class VarIsInitializedOp : public OpKernel {
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}), &output)); context->allocate_output(0, TensorShape({}), &output));
auto output_tensor = output->tensor<bool, 0>(); auto output_tensor = output->tensor<bool, 0>();
Var* variable = nullptr; core::RefCountPtr<Var> variable;
Status s = LookupResource(context, HandleFromInput(context, 0), &variable); Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
if (!s.ok()) { if (!s.ok()) {
output_tensor() = false; output_tensor() = false;
return; return;
} }
core::ScopedUnref su(variable);
mutex_lock ml(*variable->mu()); mutex_lock ml(*variable->mu());
output_tensor() = variable->is_initialized; output_tensor() = variable->is_initialized;
} }
@ -623,10 +616,9 @@ class ResourceGatherOp : public OpKernel {
} }
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* c) override {
Var* v = nullptr; core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref su(v); OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
// NOTE: We hold the lock for the whole gather operation instead // NOTE: We hold the lock for the whole gather operation instead
// of increasing the reference count of v->tensor() to avoid a // of increasing the reference count of v->tensor() to avoid a
// situation where a write to the same variable will see a // situation where a write to the same variable will see a
@ -765,10 +757,9 @@ class ResourceGatherNdOp : public OpKernel {
explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {} explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* c) override {
Var* v = nullptr; core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref su(v); OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
// NOTE: We hold the lock for the whole gather operation instead // NOTE: We hold the lock for the whole gather operation instead
// of increasing the reference count of v->tensor() to avoid a // of increasing the reference count of v->tensor() to avoid a
// situation where a write to the same variable will see a // situation where a write to the same variable will see a
@ -821,10 +812,9 @@ class ResourceScatterUpdateOp : public OpKernel {
explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {} explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* c) override {
Var* v = nullptr; core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref unref_v(v); OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
tf_shared_lock ml(*v->mu()); tf_shared_lock ml(*v->mu());
Tensor* params = v->tensor(); Tensor* params = v->tensor();
const Tensor& indices = c->input(1); const Tensor& indices = c->input(1);

View File

@ -243,10 +243,9 @@ class ScatterNdUpdateOp : public OpKernel {
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* c) override {
if (dtype_ == DT_RESOURCE) { if (dtype_ == DT_RESOURCE) {
Var* v; core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref scoped_unref(v); OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
mutex_lock m(*v->mu()); mutex_lock m(*v->mu());
DoCompute(c); DoCompute(c);
} else if (use_exclusive_lock_) { } else if (use_exclusive_lock_) {
@ -271,9 +270,8 @@ class ScatterNdUpdateOp : public OpKernel {
TensorShape params_shape; TensorShape params_shape;
if (dtype_ == DT_RESOURCE) { if (dtype_ == DT_RESOURCE) {
Var* v; core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref scoped_unref(v);
Tensor* t = v->tensor(); Tensor* t = v->tensor();
params = *t; params = *t;
params_shape = params.shape(); params_shape = params.shape();

View File

@ -15,6 +15,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc. // See docs in ../ops/array_ops.cc.
#include "tensorflow/core/lib/core/refcount.h"
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA
@ -323,12 +324,11 @@ class StridedSliceAssignOp : public OpKernel {
} }
} else { } else {
if (context->input_dtype(0) == DT_RESOURCE) { if (context->input_dtype(0) == DT_RESOURCE) {
Var* v; core::RefCountPtr<Var> v;
OP_REQUIRES_OK( OP_REQUIRES_OK(
context, LookupResource(context, HandleFromInput(context, 0), &v)); context, LookupResource(context, HandleFromInput(context, 0), &v));
core::ScopedUnref scoped_unref(v);
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
EnsureSparseVariableAccess<Device, T>(context, v)); EnsureSparseVariableAccess<Device, T>(context, v.get()));
mutex_lock ml(*v->mu()); mutex_lock ml(*v->mu());
old_lhs = v->tensor(); old_lhs = v->tensor();
OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value, OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/db/sqlite.h"
#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/summary/schema.h" #include "tensorflow/core/summary/schema.h"
@ -45,7 +46,7 @@ class CreateSummaryFileWriterOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp)); OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
const string filename_suffix = tmp->scalar<string>()(); const string filename_suffix = tmp->scalar<string>()();
SummaryWriterInterface* s = nullptr; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>( OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
ctx, HandleFromInput(ctx, 0), &s, ctx, HandleFromInput(ctx, 0), &s,
[max_queue, flush_millis, logdir, filename_suffix, [max_queue, flush_millis, logdir, filename_suffix,
@ -54,7 +55,6 @@ class CreateSummaryFileWriterOp : public OpKernel {
max_queue, flush_millis, logdir, max_queue, flush_millis, logdir,
filename_suffix, ctx->env(), s); filename_suffix, ctx->env(), s);
})); }));
core::ScopedUnref unref(s);
} }
}; };
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
@ -75,7 +75,7 @@ class CreateSummaryDbWriterOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp)); OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
const string user_name = tmp->scalar<string>()(); const string user_name = tmp->scalar<string>()();
SummaryWriterInterface* s = nullptr; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, ctx,
LookupOrCreateResource<SummaryWriterInterface>( LookupOrCreateResource<SummaryWriterInterface>(
@ -91,7 +91,6 @@ class CreateSummaryDbWriterOp : public OpKernel {
db, experiment_name, run_name, user_name, ctx->env(), s)); db, experiment_name, run_name, user_name, ctx->env(), s));
return Status::OK(); return Status::OK();
})); }));
core::ScopedUnref unref(s);
} }
}; };
REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
@ -102,9 +101,8 @@ class FlushSummaryWriterOp : public OpKernel {
explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
OP_REQUIRES_OK(ctx, s->Flush()); OP_REQUIRES_OK(ctx, s->Flush());
} }
}; };
@ -128,9 +126,8 @@ class WriteSummaryOp : public OpKernel {
explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp; const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()(); const int64 step = tmp->scalar<int64>()();
@ -153,9 +150,8 @@ class WriteRawProtoSummaryOp : public OpKernel {
explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp; const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
@ -190,9 +186,8 @@ class ImportEventOp : public OpKernel {
explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* t; const Tensor* t;
OP_REQUIRES_OK(ctx, ctx->input("event", &t)); OP_REQUIRES_OK(ctx, ctx->input("event", &t));
std::unique_ptr<Event> event{new Event}; std::unique_ptr<Event> event{new Event};
@ -211,9 +206,8 @@ class WriteScalarSummaryOp : public OpKernel {
explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp; const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()(); const int64 step = tmp->scalar<int64>()();
@ -234,9 +228,8 @@ class WriteHistogramSummaryOp : public OpKernel {
explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp; const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()(); const int64 step = tmp->scalar<int64>()();
@ -263,9 +256,8 @@ class WriteImageSummaryOp : public OpKernel {
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp; const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()(); const int64 step = tmp->scalar<int64>()();
@ -299,9 +291,8 @@ class WriteAudioSummaryOp : public OpKernel {
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp; const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()(); const int64 step = tmp->scalar<int64>()();
@ -328,9 +319,8 @@ class WriteGraphSummaryOp : public OpKernel {
explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s; core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* t; const Tensor* t;
OP_REQUIRES_OK(ctx, ctx->input("step", &t)); OP_REQUIRES_OK(ctx, ctx->input("step", &t));
const int64 step = t->scalar<int64>()(); const int64 step = t->scalar<int64>()();

View File

@ -26,6 +26,7 @@ tf_kernel_library(
":resources", ":resources",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
], ],
) )
@ -37,6 +38,7 @@ tf_kernel_library(
":resources", ":resources",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
], ],
) )

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/tensor_forest/resources.h" #include "tensorflow/core/kernels/tensor_forest/resources.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/work_sharder.h" #include "tensorflow/core/util/work_sharder.h"
@ -29,12 +30,10 @@ class TensorForestTreePredictOp : public OpKernel {
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource; core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const Tensor* dense_features_t = nullptr; const Tensor* dense_features_t = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->input("dense_features", &dense_features_t)); context->input("dense_features", &dense_features_t));
@ -60,7 +59,7 @@ class TensorForestTreePredictOp : public OpKernel {
// We will need to run it on a number of trees of diff depth // We will need to run it on a number of trees of diff depth
// and see the num of cpu cycles // and see the num of cpu cycles
const int64 cost_per_traverse = 500; const int64 cost_per_traverse = 500;
auto traverse = [this, &out, &dense_features, decision_tree_resource, auto traverse = [this, &out, &dense_features, &decision_tree_resource,
batch_size](int64 start, int64 end) { batch_size](int64 start, int64 end) {
DCHECK_LE(start, end) << "Start exceeding End"; DCHECK_LE(start, end) << "Start exceeding End";
DCHECK_LE(end, batch_size) << "End exceeding batch size"; DCHECK_LE(end, batch_size) << "End exceeding batch size";
@ -74,9 +73,10 @@ class TensorForestTreePredictOp : public OpKernel {
traverse); traverse);
}; };
void set_output_value(const int32 example_id, const int32 leaf_id, void set_output_value(
const TensorForestTreeResource* decision_tree_resource, const int32 example_id, const int32 leaf_id,
TTypes<float>::Matrix* out) const { const core::RefCountPtr<TensorForestTreeResource>& decision_tree_resource,
TTypes<float>::Matrix* out) const {
for (int j = 0; j < logits_dimension_; ++j) { for (int j = 0; j < logits_dimension_; ++j) {
const float logit = decision_tree_resource->get_prediction(leaf_id, j); const float logit = decision_tree_resource->get_prediction(leaf_id, j);
(*out)(example_id, j) = logit; (*out)(example_id, j) = logit;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h" #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
#include "tensorflow/core/kernels/tensor_forest/resources.h" #include "tensorflow/core/kernels/tensor_forest/resources.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
@ -55,11 +56,10 @@ class TensorForestTreeSerializeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource; core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_config_t = nullptr; Tensor* output_config_t = nullptr;
OP_REQUIRES_OK( OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape(), &output_config_t)); context, context->allocate_output(0, TensorShape(), &output_config_t));
@ -74,13 +74,11 @@ class TensorForestTreeDeserializeOp : public OpKernel {
explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context) explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context)
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource; core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const Tensor* tree_config_t; const Tensor* tree_config_t;
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t)); OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
@ -102,11 +100,10 @@ class TensorForestTreeSizeOp : public OpKernel {
: OpKernel(context) {} : OpKernel(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource; core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource)); &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex()); mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_t = nullptr; Tensor* output_t = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape(), &output_t)); context->allocate_output(0, TensorShape(), &output_t));

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow { namespace tensorflow {
@ -167,10 +168,8 @@ VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
std::sort(acquire_order.begin(), acquire_order.end(), std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
std::unique_ptr<std::vector<mutex_lock>> locks = auto locks = absl::make_unique<std::vector<mutex_lock>>();
absl::make_unique<std::vector<mutex_lock>>(); auto shared_locks = absl::make_unique<std::vector<tf_shared_lock>>();
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks =
absl::make_unique<std::vector<tf_shared_lock>>();
locks->reserve(acquire_order.size()); locks->reserve(acquire_order.size());
for (auto input : acquire_order) { for (auto input : acquire_order) {
@ -241,11 +240,10 @@ template <typename Device, typename T>
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input, Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
bool lock_held, bool sparse, Tensor* out) { bool lock_held, bool sparse, Tensor* out) {
if (ctx->input_dtype(input) == DT_RESOURCE) { if (ctx->input_dtype(input) == DT_RESOURCE) {
Var* var; core::RefCountPtr<Var> var;
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var)); TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
core::ScopedUnref unref_var(var);
if (sparse) { if (sparse) {
TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var)); TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
*out = *var->tensor(); *out = *var->tensor();
return Status::OK(); return Status::OK();
} }