From 2e758833c3642cde938a87a4dff93b27606fe8bf Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Tue, 28 May 2019 21:53:42 -0700 Subject: [PATCH] Use RefCountPtr in LookupResource to avoid leaks LookupResource returns a raw pointer which the caller needs to Unref. The prevalent pattern is this is followed by a ScopedUnref. This can be problematic, since if a caller forgets to add a ScopedUnref call, we have a memory leak. We resolve this by using RefCountPtr instead of a raw pointer in LookupResource. Most use cases have been migrated in this change. Note some variables were renamed to handle line length restrictions. PiperOrigin-RevId: 250423227 --- .../compiler/jit/xla_compile_on_demand_op.cc | 4 +- tensorflow/compiler/jit/xla_device_ops.cc | 3 +- tensorflow/compiler/jit/xla_launch_util.cc | 3 +- .../kernels/trt_engine_resource_ops.cc | 3 +- .../bigtable/kernels/bigtable_kernels.cc | 22 +-- .../kernels/bigtable_lookup_dataset_op.cc | 5 +- .../kernels/bigtable_prefix_key_dataset_op.cc | 7 +- .../kernels/bigtable_range_key_dataset_op.cc | 9 +- .../bigtable_sample_key_pairs_dataset_op.cc | 6 +- .../bigtable_sample_keys_dataset_op.cc | 5 +- .../kernels/bigtable_scan_dataset_op.cc | 6 +- .../boosted_trees/kernels/model_ops.cc | 20 +- .../boosted_trees/kernels/prediction_ops.cc | 32 +-- .../boosted_trees/kernels/quantile_ops.cc | 19 +- .../kernels/stats_accumulator_ops.cc | 185 ++++++++---------- .../boosted_trees/kernels/training_ops.cc | 42 ++-- .../framework/kernels/zero_initializer_op.cc | 3 +- .../tensor_forest/kernels/model_ops.cc | 55 +++--- .../tensor_forest/kernels/stats_ops.cc | 44 ++--- tensorflow/core/framework/resource_mgr.h | 49 ++++- .../core/framework/resource_mgr_test.cc | 11 +- tensorflow/core/kernels/BUILD | 3 + tensorflow/core/kernels/boosted_trees/BUILD | 3 + .../kernels/boosted_trees/prediction_ops.cc | 12 +- .../kernels/boosted_trees/quantile_ops.cc | 13 +- .../kernels/boosted_trees/resource_ops.cc | 10 +- .../kernels/boosted_trees/training_ops.cc | 15 +- tensorflow/core/kernels/count_up_to_op.cc | 9 +- .../core/kernels/data/experimental/BUILD | 2 + .../set_stats_aggregator_dataset_op.cc | 26 +-- .../data/experimental/stats_aggregator_ops.cc | 15 +- .../experimental/threadpool_dataset_op.cc | 7 +- tensorflow/core/kernels/data/iterator_ops.cc | 4 +- .../kernels/data/multi_device_iterator_ops.cc | 13 +- tensorflow/core/kernels/random_binomial_op.cc | 5 +- .../core/kernels/resource_variable_ops.cc | 36 ++-- tensorflow/core/kernels/scatter_nd_op.cc | 8 +- tensorflow/core/kernels/strided_slice_op.cc | 6 +- tensorflow/core/kernels/summary_kernels.cc | 34 ++-- tensorflow/core/kernels/tensor_forest/BUILD | 2 + .../kernels/tensor_forest/prediction_ops.cc | 14 +- .../kernels/tensor_forest/resource_ops.cc | 11 +- tensorflow/core/kernels/training_op_helpers.h | 12 +- 43 files changed, 372 insertions(+), 421 deletions(-) diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index c7e8d61d280..23349f61965 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { @@ -31,12 +32,11 @@ std::map GetVariables(OpKernelContext* ctx) { std::map variables; for (int64 i = 0; i < ctx->num_inputs(); ++i) { if (ctx->input(i).dtype() == DT_RESOURCE) { - Var* variable = nullptr; + core::RefCountPtr variable; ResourceHandle handle = HandleFromInput(ctx, i); OptionalTensor& optional = variables[i]; optional.name = handle.name(); if (LookupResource(ctx, handle, &variable).ok()) { - core::ScopedUnref scoped_unref(variable); tf_shared_lock lock(*variable->mu()); optional.present = true; optional.value = *variable->tensor(); diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index f56c26ba010..8126059262b 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -40,7 +40,7 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) { "Variable and value dtypes don't match; respectively, ", DataTypeString(dtype_), " and ", DataTypeString(context->input(1).dtype()))); - Var* variable = nullptr; + core::RefCountPtr variable; const Tensor& value = context->input(1); // Note: every resource-variable-manipulating op assumes copy-on-write // semantics, and creates a copy of the variable's Tensor if its refcount is @@ -58,7 +58,6 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) { (*ptr)->is_initialized = true; return Status::OK(); })); - core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, errors::InvalidArgument( diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index d66c80fea90..e9c4eb6e8ee 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { @@ -86,7 +87,7 @@ static Status GetVariableInfosFromCtxInputs( variable_indices, std::back_inserter(resource_handles), [&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); }); - std::vector> variables; + std::vector> variables; TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables)); result->clear(); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc index a41a8a2c1c4..bb5c643b7b0 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -84,9 +84,8 @@ class PopulateTRTEngineCache : public OpKernel { void Compute(OpKernelContext* ctx) override { ResourceHandle handle = HandleFromInput(ctx, 0); - TRTEngineCacheResource* resource = nullptr; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource)); - core::ScopedUnref unref_me(resource); auto allocator = resource->allocator_.get(); OP_REQUIRES(ctx, allocator != nullptr, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index 002d68111cd..51b27ea4212 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { @@ -139,19 +140,19 @@ class BigtableTableOp : public OpKernel { ResourceMgr* mgr = ctx->resource_manager(); OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); - BigtableClientResource* client_resource; + core::RefCountPtr client_resource; OP_REQUIRES_OK( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); - core::ScopedUnref unref_client(client_resource); BigtableTableResource* resource; - OP_REQUIRES_OK( - ctx, mgr->LookupOrCreate( - cinfo_.container(), cinfo_.name(), &resource, - [this, client_resource](BigtableTableResource** ret) { - *ret = new BigtableTableResource(client_resource, table_); - return Status::OK(); - })); + OP_REQUIRES_OK(ctx, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, &client_resource](BigtableTableResource** ret) { + *ret = new BigtableTableResource( + client_resource.get(), table_); + return Status::OK(); + })); initialized_ = true; } OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( @@ -236,10 +237,9 @@ class ToBigtableOp : public AsyncOpKernel { errors::InvalidArgument("timestamp must be >= -1"), done); - BigtableTableResource* resource; + core::RefCountPtr resource; OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done); - core::ScopedUnref resource_cleanup(resource); std::vector components; components.reserve(dataset->output_dtypes().size()); diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index 98ec991a934..8039ef8cd77 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -26,9 +26,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - BigtableTableResource* table; + core::RefCountPtr table; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); - core::ScopedUnref scoped_unref(table); std::vector column_families; std::vector columns; @@ -50,7 +49,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { } *output = - new Dataset(ctx, input, table, std::move(column_families), + new Dataset(ctx, input, table.get(), std::move(column_families), std::move(columns), output_types, std::move(output_shapes)); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc index 92a36586672..e9d4a1e05ea 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { namespace data { @@ -28,12 +29,10 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { string prefix; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); - BigtableTableResource* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); - core::ScopedUnref scoped_unref(resource); - - *output = new Dataset(ctx, resource, std::move(prefix)); + *output = new Dataset(ctx, resource.get(), std::move(prefix)); } private: diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc index bd8805a3827..be3c7cc5f38 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { namespace data { @@ -31,13 +32,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { string end_key; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); - BigtableTableResource* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); - core::ScopedUnref scoped_unref(resource); - - *output = - new Dataset(ctx, resource, std::move(start_key), std::move(end_key)); + *output = new Dataset(ctx, resource.get(), std::move(start_key), + std::move(end_key)); } private: diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index 88284c5a4e9..880f5e40f25 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { namespace data { @@ -35,10 +36,9 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { string end_key; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); - BigtableTableResource* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); - core::ScopedUnref scoped_unref(resource); OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), errors::InvalidArgument( @@ -49,7 +49,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { "If prefix is specified, end_key must be empty.")); } - *output = new Dataset(ctx, resource, std::move(prefix), + *output = new Dataset(ctx, resource.get(), std::move(prefix), std::move(start_key), std::move(end_key)); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc index 119da35973a..53be3b5a2bb 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -25,11 +25,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { using DatasetOpKernel::DatasetOpKernel; void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - BigtableTableResource* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); - core::ScopedUnref scoped_unref(resource); - *output = new Dataset(ctx, resource); + *output = new Dataset(ctx, resource.get()); } private: diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc index 688289a4e24..e68c83ed547 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { namespace data { @@ -64,10 +65,9 @@ class BigtableScanDatasetOp : public DatasetOpKernel { errors::InvalidArgument( "Probability outside the range of (0, 1]. Got: ", probability)); - BigtableTableResource* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); - core::ScopedUnref scoped_unref(resource); const uint64 num_outputs = columns.size() + 1; std::vector output_shapes; @@ -79,7 +79,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel { output_types.push_back(DT_STRING); } - *output = new Dataset(ctx, resource, std::move(prefix), + *output = new Dataset(ctx, resource.get(), std::move(prefix), std::move(start_key), std::move(end_key), std::move(column_families), std::move(columns), probability, output_types, std::move(output_shapes)); diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 3bf33186ec1..9655e49d91b 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -21,6 +21,7 @@ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/thread_annotations.h" namespace tensorflow { @@ -44,7 +45,7 @@ class CreateTreeEnsembleVariableOp : public OpKernel { const Tensor* tree_ensemble_config_t; OP_REQUIRES_OK(context, context->input("tree_ensemble_config", &tree_ensemble_config_t)); - auto* result = new boosted_trees::models::DecisionTreeEnsembleResource(); + auto* result = new DecisionTreeEnsembleResource(); if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), stamp_token)) { result->Unref(); @@ -69,11 +70,10 @@ class TreeEnsembleStampTokenOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); tf_shared_lock l(*ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(ensemble_resource); Tensor* output_stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_stamp_token_t)); @@ -88,11 +88,10 @@ class TreeEnsembleSerializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); tf_shared_lock l(*ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(ensemble_resource); Tensor* output_stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_stamp_token_t)); @@ -112,11 +111,10 @@ class TreeEnsembleDeserializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); mutex_lock l(*ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(ensemble_resource); // Get the stamp token. const Tensor* stamp_token_t; @@ -146,12 +144,11 @@ class TreeEnsembleUsedHandlersOp : public OpKernel { } void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); tf_shared_lock l(*ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(ensemble_resource); // Get the stamp token. const Tensor* stamp_token_t; @@ -194,9 +191,8 @@ class TreeEnsembleUsedHandlersOp : public OpKernel { REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource); -REGISTER_KERNEL_BUILDER( - Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU), - IsResourceInitialized); +REGISTER_KERNEL_BUILDER(Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU), + IsResourceInitialized); REGISTER_KERNEL_BUILDER(Name("CreateTreeEnsembleVariable").Device(DEVICE_CPU), CreateTreeEnsembleVariableOp); diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index 9493c1a1394..b740e0629ad 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -163,12 +163,11 @@ class GradientTreesPredictionOp : public OpKernel { } void Compute(OpKernelContext* const context) override { - DecisionTreeEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; // Gets the resource. Grabs the mutex but releases it. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); // Release the reference to the resource once we're done using it. - core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { tf_shared_lock l(*ensemble_resource->get_mutex()); DoCompute(context, ensemble_resource, @@ -184,9 +183,10 @@ class GradientTreesPredictionOp : public OpKernel { // leaf index in prediction. Though this class invokes only with this param // value as false, the subclass GradientTreesPredictionVerboseOp will invoke // with the true value. - virtual void DoCompute(OpKernelContext* context, - DecisionTreeEnsembleResource* ensemble_resource, - const bool return_output_leaf_index) { + virtual void DoCompute( + OpKernelContext* context, + const core::RefCountPtr& ensemble_resource, + const bool return_output_leaf_index) { // Read dense float features list; OpInputList dense_float_features_list; OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( @@ -352,9 +352,10 @@ class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp { : GradientTreesPredictionOp(context) {} protected: - void DoCompute(OpKernelContext* context, - DecisionTreeEnsembleResource* ensemble_resource, - bool return_output_leaf_index) override { + void DoCompute( + OpKernelContext* context, + const core::RefCountPtr& ensemble_resource, + bool return_output_leaf_index) override { GradientTreesPredictionOp::DoCompute(context, ensemble_resource, /*return_output_leaf_index=*/true); } @@ -372,12 +373,10 @@ class GradientTreesPartitionExamplesOp : public OpKernel { } void Compute(OpKernelContext* const context) override { - DecisionTreeEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; // Gets the resource. Grabs the mutex but releases it. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); - // Release the reference to the resource once we're done using it. - core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { tf_shared_lock l(*ensemble_resource->get_mutex()); DoCompute(context, ensemble_resource); @@ -387,17 +386,18 @@ class GradientTreesPartitionExamplesOp : public OpKernel { } private: - void DoCompute(OpKernelContext* context, - DecisionTreeEnsembleResource* ensemble_resource) { + void DoCompute( + OpKernelContext* context, + const core::RefCountPtr& resource) { // The last non-finalized tree in the ensemble is by convention the // one to partition on. If no such tree exists, a nodeless tree is // created. boosted_trees::trees::DecisionTreeConfig empty_tree_config; const boosted_trees::trees::DecisionTreeConfig& tree_config = - (ensemble_resource->num_trees() <= 0 || - ensemble_resource->LastTreeMetadata()->is_finalized()) + (resource->num_trees() <= 0 || + resource->LastTreeMetadata()->is_finalized()) ? empty_tree_config - : *ensemble_resource->LastTree(); + : *resource->LastTree(); // Read dense float features list; OpInputList dense_float_features_list; diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 606da663dc2..431dc68836b 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -28,6 +28,7 @@ #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" @@ -299,13 +300,12 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel { const ResourceHandle& handle = resource_handle_list[resource_handle_idx] .flat()(0); - QuantileStreamResource* streams_resource; + core::RefCountPtr streams_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, LookupResource(context, handle, &streams_resource)); // Remove the reference at the end of this scope. mutex_lock l(*streams_resource->mutex()); - core::ScopedUnref unref_me(streams_resource); // If the stamp is invalid we drop the update. if (!streams_resource->is_stamp_valid(stamp_token)) { @@ -467,13 +467,12 @@ class QuantileAccumulatorSerializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - QuantileStreamResource* streams_resource; + core::RefCountPtr streams_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &streams_resource)); // Remove the reference at the end of this scope. mutex_lock l(*streams_resource->mutex()); - core::ScopedUnref unref_me(streams_resource); int64 stamp_token = streams_resource->stamp(); Tensor* stream_state_t; @@ -526,13 +525,12 @@ class QuantileAccumulatorDeserializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - QuantileStreamResource* streams_resource; + core::RefCountPtr streams_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &streams_resource)); // Remove the reference at the end of this scope. mutex_lock l(*streams_resource->mutex()); - core::ScopedUnref unref_me(streams_resource); int64 old_stamp_token = streams_resource->stamp(); @@ -595,13 +593,12 @@ class QuantileAccumulatorFlushOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - QuantileStreamResource* streams_resource; + core::RefCountPtr streams_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &streams_resource)); // Remove the reference at the end of this scope. mutex_lock l(*streams_resource->mutex()); - core::ScopedUnref unref_me(streams_resource); const Tensor* next_stamp_token_t; OP_REQUIRES_OK(context, @@ -641,13 +638,12 @@ class QuantileAccumulatorFlushSummaryOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - QuantileStreamResource* streams_resource; + core::RefCountPtr streams_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &streams_resource)); // Remove the reference at the end of this scope. mutex_lock l(*streams_resource->mutex()); - core::ScopedUnref unref_me(streams_resource); const Tensor* next_stamp_token_t; OP_REQUIRES_OK(context, @@ -713,12 +709,11 @@ class QuantileAccumulatorGetBucketsOp : public OpKernel { const ResourceHandle& handle = resource_handle_list[resource_handle_idx] .flat()(0); - QuantileStreamResource* streams_resource; + core::RefCountPtr streams_resource; OP_REQUIRES_OK(context, LookupResource(context, handle, &streams_resource)); // Remove the reference at the end of this scope. mutex_lock l(*streams_resource->mutex()); - core::ScopedUnref unref_me(streams_resource); bool are_buckets_ready = streams_resource->is_stamp_valid(stamp_token) && diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index 6faf6963011..3a3836ef937 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -27,6 +27,7 @@ #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" @@ -130,9 +131,8 @@ using StatsAccumulatorTensorResource = StatsAccumulatorResource, std::vector>; void SerializeScalarAccumulatorToOutput( - const StatsAccumulatorScalarResource& accumulator_resource, - OpKernelContext* context) { - int64 num_slots = accumulator_resource.values().size(); + const StatsAccumulatorScalarResource& resource, OpKernelContext* context) { + int64 num_slots = resource.values().size(); Tensor* partition_ids_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", TensorShape({num_slots}), @@ -159,7 +159,7 @@ void SerializeScalarAccumulatorToOutput( auto hessians = hessians_t->vec(); int i = 0; - for (const auto& iter : accumulator_resource.values()) { + for (const auto& iter : resource.values()) { partition_ids(i) = iter.first.partition_id; feature_ids(i, 0) = iter.first.feature_id; feature_ids(i, 1) = iter.first.dimension; @@ -171,9 +171,8 @@ void SerializeScalarAccumulatorToOutput( } void SerializeTensorAccumulatorToOutput( - const StatsAccumulatorTensorResource& accumulator_resource, - OpKernelContext* context) { - int64 num_slots = accumulator_resource.values().size(); + const StatsAccumulatorTensorResource& resource, OpKernelContext* context) { + int64 num_slots = resource.values().size(); Tensor* partition_ids_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", TensorShape({num_slots}), @@ -186,7 +185,7 @@ void SerializeTensorAccumulatorToOutput( &feature_ids_t)); auto feature_ids = feature_ids_t->matrix(); - TensorShape gradient_shape = accumulator_resource.gradient_shape(); + TensorShape gradient_shape = resource.gradient_shape(); int64 num_gradient_elements = gradient_shape.num_elements(); gradient_shape.InsertDim(0, num_slots); Tensor* gradients_t = nullptr; @@ -195,7 +194,7 @@ void SerializeTensorAccumulatorToOutput( &gradients_t)); auto gradients = gradients_t->flat_outer_dims(); - TensorShape hessian_shape = accumulator_resource.hessian_shape(); + TensorShape hessian_shape = resource.hessian_shape(); int64 num_hessian_elements = hessian_shape.num_elements(); hessian_shape.InsertDim(0, num_slots); Tensor* hessians_t = nullptr; @@ -204,7 +203,7 @@ void SerializeTensorAccumulatorToOutput( auto hessians = hessians_t->flat_outer_dims(); int i = 0; - for (const auto& iter : accumulator_resource.values()) { + for (const auto& iter : resource.values()) { partition_ids(i) = iter.first.partition_id; feature_ids(i, 0) = iter.first.feature_id; feature_ids(i, 1) = iter.first.dimension; @@ -220,11 +219,10 @@ void SerializeTensorAccumulatorToOutput( } void AddToScalarAccumulator( - StatsAccumulatorScalarResource* accumulator_resource, + const core::RefCountPtr& resource, const Tensor& partition_ids_t, const Tensor& feature_ids_t, const Tensor& gradients_t, const Tensor& hessians_t) { - accumulator_resource->set_num_updates(accumulator_resource->num_updates() + - 1); + resource->set_num_updates(resource->num_updates() + 1); const TensorShape& partition_ids_shape = partition_ids_t.shape(); const auto& partition_ids = partition_ids_t.vec(); const auto& feature_ids_and_dimensions = feature_ids_t.matrix(); @@ -232,7 +230,7 @@ void AddToScalarAccumulator( const auto& hessians = hessians_t.vec(); int64 num_updates = partition_ids_shape.dim_size(0); - auto stats_map = accumulator_resource->mutable_values(); + auto stats_map = resource->mutable_values(); for (int64 i = 0; i < num_updates; ++i) { const auto key = PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), @@ -248,7 +246,7 @@ void AddToScalarAccumulator( } void AddToScalarAccumulator( - StatsAccumulatorScalarResource* accumulator_resource, + const core::RefCountPtr& resource, OpKernelContext* context) { const Tensor* partition_ids_t; OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t)); @@ -258,17 +256,16 @@ void AddToScalarAccumulator( OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t, + AddToScalarAccumulator(resource, *partition_ids_t, *feature_ids_t, *gradients_t, *hessians_t); } void AddToTensorAccumulator( - StatsAccumulatorTensorResource* accumulator_resource, + const core::RefCountPtr& resource, const Tensor& partition_ids_t, const Tensor& feature_ids_t, const Tensor& gradients_t, const Tensor& hessians_t, OpKernelContext* context) { - accumulator_resource->set_num_updates(accumulator_resource->num_updates() + - 1); + resource->set_num_updates(resource->num_updates() + 1); const TensorShape& partition_ids_shape = partition_ids_t.shape(); const auto& partition_ids = partition_ids_t.vec(); @@ -283,19 +280,19 @@ void AddToTensorAccumulator( // TODO(soroush): Move gradient and hessian shape check to ShapeFn. OP_REQUIRES( - context, gradients_shape == accumulator_resource->gradient_shape(), + context, gradients_shape == resource->gradient_shape(), errors::InvalidArgument(strings::StrCat( "Gradients dimensions must match: ", gradients_shape.DebugString(), - ", ", accumulator_resource->gradient_shape().DebugString()))); + ", ", resource->gradient_shape().DebugString()))); OP_REQUIRES( - context, hessians_shape == accumulator_resource->hessian_shape(), + context, hessians_shape == resource->hessian_shape(), errors::InvalidArgument(strings::StrCat( "Hessian dimensions must match: ", hessians_shape.DebugString(), ", ", - accumulator_resource->hessian_shape().DebugString()))); + resource->hessian_shape().DebugString()))); int64 num_updates = partition_ids_shape.dim_size(0); - auto stats_map = accumulator_resource->mutable_values(); + auto stats_map = resource->mutable_values(); for (int64 i = 0; i < num_updates; ++i) { const auto key = PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), @@ -325,7 +322,7 @@ void AddToTensorAccumulator( } void AddToTensorAccumulator( - StatsAccumulatorTensorResource* accumulator_resource, + const core::RefCountPtr& resource, OpKernelContext* context) { const Tensor* partition_ids_t; OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t)); @@ -335,7 +332,7 @@ void AddToTensorAccumulator( OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t, + AddToTensorAccumulator(resource, *partition_ids_t, *feature_ids_t, *gradients_t, *hessians_t, context); } @@ -452,20 +449,18 @@ class StatsAccumulatorScalarAddOp : public OpKernel { resource_handle_list[resource_handle_idx] .flat()(0); - StatsAccumulatorScalarResource* accumulator_resource; - OP_REQUIRES_OK(context, LookupResource(context, handle, - &accumulator_resource)); - mutex_lock l(*accumulator_resource->mutex()); - core::ScopedUnref unref_me(accumulator_resource); + core::RefCountPtr resource; + OP_REQUIRES_OK(context, LookupResource(context, handle, &resource)); + mutex_lock l(*resource->mutex()); // If the stamp is invalid we drop the update. - if (!accumulator_resource->is_stamp_valid(stamp_token)) { + if (!resource->is_stamp_valid(stamp_token)) { VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. " << "Passed stamp token: " << stamp_token << " " - << "Current token: " << accumulator_resource->stamp(); + << "Current token: " << resource->stamp(); return; } - AddToScalarAccumulator(accumulator_resource, + AddToScalarAccumulator(resource, partition_ids_list[resource_handle_idx], feature_ids_list[resource_handle_idx], gradients_list[resource_handle_idx], @@ -517,20 +512,18 @@ class StatsAccumulatorTensorAddOp : public OpKernel { resource_handle_list[resource_handle_idx] .flat()(0); - StatsAccumulatorTensorResource* accumulator_resource; - OP_REQUIRES_OK(context, LookupResource(context, handle, - &accumulator_resource)); - mutex_lock l(*accumulator_resource->mutex()); - core::ScopedUnref unref_me(accumulator_resource); + core::RefCountPtr resource; + OP_REQUIRES_OK(context, LookupResource(context, handle, &resource)); + mutex_lock l(*resource->mutex()); // If the stamp is invalid we drop the update. - if (!accumulator_resource->is_stamp_valid(stamp_token)) { + if (!resource->is_stamp_valid(stamp_token)) { VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. " << "Passed stamp token: " << stamp_token << " " - << "Current token: " << accumulator_resource->stamp(); + << "Current token: " << resource->stamp(); return; } - AddToTensorAccumulator(accumulator_resource, + AddToTensorAccumulator(resource, partition_ids_list[resource_handle_idx], feature_ids_list[resource_handle_idx], gradients_list[resource_handle_idx], @@ -549,11 +542,10 @@ class StatsAccumulatorScalarFlushOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - StatsAccumulatorScalarResource* accumulator_resource; + core::RefCountPtr resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &accumulator_resource)); - mutex_lock l(*accumulator_resource->mutex()); - core::ScopedUnref unref_me(accumulator_resource); + &resource)); + mutex_lock l(*resource->mutex()); const Tensor* stamp_token_t; OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); @@ -562,7 +554,7 @@ class StatsAccumulatorScalarFlushOp : public OpKernel { // If the stamp is invalid we restart the PS. It shouldn't happen since // only Chief should call this function and chief is guaranteed to be in // a consistent state. - CHECK(accumulator_resource->is_stamp_valid(stamp_token)); + CHECK(resource->is_stamp_valid(stamp_token)); const Tensor* next_stamp_token_t; OP_REQUIRES_OK(context, @@ -570,15 +562,15 @@ class StatsAccumulatorScalarFlushOp : public OpKernel { int64 next_stamp_token = next_stamp_token_t->scalar()(); CHECK(stamp_token != next_stamp_token); - SerializeScalarAccumulatorToOutput(*accumulator_resource, context); + SerializeScalarAccumulatorToOutput(*resource, context); Tensor* num_updates_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output("num_updates", TensorShape({}), &num_updates_t)); - num_updates_t->scalar()() = accumulator_resource->num_updates(); + num_updates_t->scalar()() = resource->num_updates(); - accumulator_resource->Clear(); - accumulator_resource->set_stamp(next_stamp_token); + resource->Clear(); + resource->set_stamp(next_stamp_token); } }; @@ -591,11 +583,10 @@ class StatsAccumulatorTensorFlushOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - StatsAccumulatorTensorResource* accumulator_resource; + core::RefCountPtr resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &accumulator_resource)); - mutex_lock l(*accumulator_resource->mutex()); - core::ScopedUnref unref_me(accumulator_resource); + &resource)); + mutex_lock l(*resource->mutex()); const Tensor* stamp_token_t; OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); @@ -609,16 +600,16 @@ class StatsAccumulatorTensorFlushOp : public OpKernel { // If the stamp is invalid we restart the PS. It shouldn't happen since // only Chief should call this function and chief is guaranteed to be in // a consistent state. - CHECK(accumulator_resource->is_stamp_valid(stamp_token)); + CHECK(resource->is_stamp_valid(stamp_token)); CHECK(stamp_token != next_stamp_token); - SerializeTensorAccumulatorToOutput(*accumulator_resource, context); + SerializeTensorAccumulatorToOutput(*resource, context); Tensor* num_updates_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output("num_updates", TensorShape({}), &num_updates_t)); - num_updates_t->scalar()() = accumulator_resource->num_updates(); - accumulator_resource->Clear(); - accumulator_resource->set_stamp(next_stamp_token); + num_updates_t->scalar()() = resource->num_updates(); + resource->Clear(); + resource->set_stamp(next_stamp_token); } }; @@ -631,22 +622,21 @@ class StatsAccumulatorScalarDeserializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - StatsAccumulatorScalarResource* accumulator_resource; + core::RefCountPtr resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &accumulator_resource)); - mutex_lock l(*accumulator_resource->mutex()); - core::ScopedUnref unref_me(accumulator_resource); + &resource)); + mutex_lock l(*resource->mutex()); // Check the stamp token. const Tensor* stamp_token_t; OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); int64 stamp_token = stamp_token_t->scalar()(); - accumulator_resource->Clear(); - accumulator_resource->set_stamp(stamp_token); - AddToScalarAccumulator(accumulator_resource, context); + resource->Clear(); + resource->set_stamp(stamp_token); + AddToScalarAccumulator(resource, context); const Tensor* num_updates_t; OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t)); - accumulator_resource->set_num_updates(num_updates_t->scalar()()); + resource->set_num_updates(num_updates_t->scalar()()); } }; @@ -660,22 +650,21 @@ class StatsAccumulatorTensorDeserializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - StatsAccumulatorTensorResource* accumulator_resource; + core::RefCountPtr resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &accumulator_resource)); - mutex_lock l(*accumulator_resource->mutex()); - core::ScopedUnref unref_me(accumulator_resource); + &resource)); + mutex_lock l(*resource->mutex()); // Check the stamp token. const Tensor* stamp_token_t; OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); int64 stamp_token = stamp_token_t->scalar()(); - accumulator_resource->Clear(); - accumulator_resource->set_stamp(stamp_token); - AddToTensorAccumulator(accumulator_resource, context); + resource->Clear(); + resource->set_stamp(stamp_token); + AddToTensorAccumulator(resource, context); const Tensor* num_updates_t; OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t)); - accumulator_resource->set_num_updates(num_updates_t->scalar()()); + resource->set_num_updates(num_updates_t->scalar()()); } }; @@ -689,23 +678,22 @@ class StatsAccumulatorScalarSerializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - StatsAccumulatorScalarResource* accumulator_resource; + core::RefCountPtr resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &accumulator_resource)); - mutex_lock l(*accumulator_resource->mutex()); - core::ScopedUnref unref_me(accumulator_resource); - SerializeScalarAccumulatorToOutput(*accumulator_resource, context); + &resource)); + mutex_lock l(*resource->mutex()); + SerializeScalarAccumulatorToOutput(*resource, context); Tensor* stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output("stamp_token", TensorShape({}), &stamp_token_t)); - stamp_token_t->scalar()() = accumulator_resource->stamp(); + stamp_token_t->scalar()() = resource->stamp(); Tensor* num_updates_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output("num_updates", TensorShape({}), &num_updates_t)); - num_updates_t->scalar()() = accumulator_resource->num_updates(); + num_updates_t->scalar()() = resource->num_updates(); } }; @@ -719,23 +707,22 @@ class StatsAccumulatorTensorSerializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - StatsAccumulatorTensorResource* accumulator_resource; + core::RefCountPtr resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &accumulator_resource)); - mutex_lock l(*accumulator_resource->mutex()); - core::ScopedUnref unref_me(accumulator_resource); - SerializeTensorAccumulatorToOutput(*accumulator_resource, context); + &resource)); + mutex_lock l(*resource->mutex()); + SerializeTensorAccumulatorToOutput(*resource, context); Tensor* stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output("stamp_token", TensorShape({}), &stamp_token_t)); - stamp_token_t->scalar()() = accumulator_resource->stamp(); + stamp_token_t->scalar()() = resource->stamp(); Tensor* num_updates_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output("num_updates", TensorShape({}), &num_updates_t)); - num_updates_t->scalar()() = accumulator_resource->num_updates(); + num_updates_t->scalar()() = resource->num_updates(); } }; @@ -751,12 +738,11 @@ class StatsAccumulatorScalarMakeSummaryOp : public OpKernel { void Compute(OpKernelContext* context) override { TensorShape gradient_shape = TensorShape({}); TensorShape hessian_shape = TensorShape({}); - StatsAccumulatorScalarResource* accumulator_resource = - new StatsAccumulatorScalarResource(gradient_shape, hessian_shape); - core::ScopedUnref unref_me(accumulator_resource); + core::RefCountPtr resource( + new StatsAccumulatorScalarResource(gradient_shape, hessian_shape)); // Check the stamp token. - AddToScalarAccumulator(accumulator_resource, context); - SerializeScalarAccumulatorToOutput(*accumulator_resource, context); + AddToScalarAccumulator(resource, context); + SerializeScalarAccumulatorToOutput(*resource, context); } }; @@ -780,12 +766,11 @@ class StatsAccumulatorTensorMakeSummaryOp : public OpKernel { TensorShape hessians_shape = hessians_t->shape(); hessians_shape.RemoveDim(0); - StatsAccumulatorTensorResource* accumulator_resource = - new StatsAccumulatorTensorResource(gradients_shape, hessians_shape); - core::ScopedUnref unref_me(accumulator_resource); + core::RefCountPtr resource( + new StatsAccumulatorTensorResource(gradients_shape, hessians_shape)); // Check the stamp token. - AddToTensorAccumulator(accumulator_resource, context); - SerializeTensorAccumulatorToOutput(*accumulator_resource, context); + AddToTensorAccumulator(resource, context); + SerializeTensorAccumulatorToOutput(*resource, context); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index a30cfa663f4..91c017839b5 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -21,6 +21,7 @@ #include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; @@ -31,6 +32,8 @@ namespace { using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearningRateConfig; +using boosted_trees::models::DecisionTreeEnsembleResource; +using boosted_trees::trees::DecisionTreeConfig; using boosted_trees::trees::Leaf; using boosted_trees::trees::TreeNode; using boosted_trees::trees::TreeNodeMetadata; @@ -193,10 +196,9 @@ class CenterTreeEnsembleBiasOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); - core::ScopedUnref unref_me(ensemble_resource); mutex_lock l(*ensemble_resource->get_mutex()); // Get the stamp token. @@ -255,8 +257,8 @@ class CenterTreeEnsembleBiasOp : public OpKernel { private: // Helper method to retrieve the bias from the tree ensemble. - boosted_trees::trees::Leaf* RetrieveBias( - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource, + Leaf* RetrieveBias( + const core::RefCountPtr& ensemble_resource, int64 logits_dimension) { const int32 num_trees = ensemble_resource->num_trees(); if (num_trees <= 0) { @@ -319,10 +321,9 @@ class GrowTreeEnsembleOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); - core::ScopedUnref unref_me(ensemble_resource); mutex_lock l(*ensemble_resource->get_mutex()); // Get the stamp token. @@ -400,10 +401,9 @@ class GrowTreeEnsembleOp : public OpKernel { // Update and retrieve the growable tree. // If the tree is fully built and dropout was applied, it also adjusts the // weights of dropped and the last tree. - boosted_trees::trees::DecisionTreeConfig* const tree_config = - UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, - dropout_seed, max_tree_depth, - weak_learner_type); + DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree( + ensemble_resource, learning_rate, dropout_seed, max_tree_depth, + weak_learner_type); // Split tree nodes. switch (weak_learner_type) { case LearnerConfig::NORMAL_DECISION_TREE: { @@ -559,8 +559,7 @@ class GrowTreeEnsembleOp : public OpKernel { } void UpdateTreeWeightsIfDropout( - boosted_trees::models::DecisionTreeEnsembleResource* const - ensemble_resource, + const core::RefCountPtr& ensemble_resource, const uint64 dropout_seed) { // It is possible that the tree was built with dropout. If it is the case, // we need to adjust the tree weight, or bail out. @@ -609,9 +608,8 @@ class GrowTreeEnsembleOp : public OpKernel { // Helper method to update the growable tree which is by definition the last // tree in the ensemble. - boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree( - boosted_trees::models::DecisionTreeEnsembleResource* const - ensemble_resource, + DecisionTreeConfig* UpdateAndRetrieveGrowableTree( + const core::RefCountPtr& ensemble_resource, const float learning_rate, const uint64 dropout_seed, const int32 max_tree_depth, const int32 weak_learner_type) { const auto num_trees = ensemble_resource->num_trees(); @@ -719,8 +717,8 @@ class GrowTreeEnsembleOp : public OpKernel { // leaf children given the split candidate. void SplitTreeNode( const int32 node_id, SplitCandidate* split, - boosted_trees::trees::DecisionTreeConfig* tree_config, - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) { + DecisionTreeConfig* tree_config, + const core::RefCountPtr& resource) { // No-op if we have no real node. CHECK(node_id < tree_config->nodes_size()) << "Invalid node " << node_id << " to split."; @@ -761,14 +759,13 @@ class GrowTreeEnsembleOp : public OpKernel { (*tree_config->mutable_nodes(node_id)) = *split->split_info.mutable_split_node(); if (learner_config_.constraints().max_number_of_unique_feature_columns()) { - ensemble_resource->MaybeAddUsedHandler(split->handler_id); + resource->MaybeAddUsedHandler(split->handler_id); } } void SplitTreeLayer( - SplitCandidate* split, - boosted_trees::trees::DecisionTreeConfig* tree_config, - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) { + SplitCandidate* split, DecisionTreeConfig* tree_config, + const core::RefCountPtr& resource) { int depth = 0; while (depth < tree_config->nodes_size() && tree_config->nodes(depth).node_case() != TreeNode::kLeaf) { @@ -903,10 +900,9 @@ class TreeEnsembleStatsOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); - core::ScopedUnref unref_me(ensemble_resource); tf_shared_lock l(*ensemble_resource->get_mutex()); // Get the stamp token. diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc index 6ab3f460b36..3dd68d5ddf7 100644 --- a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc +++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc @@ -95,7 +95,7 @@ class ZeroVarInitializer : public OpKernel { } void Compute(OpKernelContext* ctx) override { - Var* variable = nullptr; + core::RefCountPtr variable; OP_REQUIRES_OK(ctx, LookupOrCreateResource( ctx, HandleFromInput(ctx, 0), &variable, [this, ctx](Var** var_ptr) { @@ -117,7 +117,6 @@ class ZeroVarInitializer : public OpKernel { return Status::OK(); })); - core::ScopedUnref scoped(variable); mutex_lock ml(*variable->mu()); OP_REQUIRES(ctx, !variable->is_initialized, diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index 76b1d2b4da2..94650fe108b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -13,6 +13,7 @@ // limitations under the License. // ============================================================================= #include + #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" @@ -24,6 +25,7 @@ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" @@ -76,11 +78,10 @@ class TreeSerializeOp : public OpKernel { explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - DecisionTreeResource* decision_tree_resource; + core::RefCountPtr decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_resource)); mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape(), &output_config_t)); @@ -100,12 +101,11 @@ class TreeDeserializeOp : public OpKernel { } void Compute(OpKernelContext* context) override { - DecisionTreeResource* decision_tree_resource; + core::RefCountPtr decision_tree_resource; auto handle = HandleFromInput(context, 0); OP_REQUIRES_OK(context, LookupResource(context, handle, &decision_tree_resource)); mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); const Tensor* tree_config_t; OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t)); @@ -131,11 +131,10 @@ class TreeSizeOp : public OpKernel { explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - DecisionTreeResource* decision_tree_resource; + core::RefCountPtr decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_resource)); mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); Tensor* output_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_t)); @@ -144,7 +143,7 @@ class TreeSizeOp : public OpKernel { } }; -void TraverseTree(const DecisionTreeResource* tree_resource, +void TraverseTree(DecisionTreeResource* tree_resource, const std::unique_ptr& data, int32 start, int32 end, const std::function& set_leaf_id, @@ -182,11 +181,10 @@ class TreePredictionsV4Op : public OpKernel { data_set->set_input_tensors(input_data, sparse_input_indices, sparse_input_values, sparse_input_shape); - DecisionTreeResource* decision_tree_resource; + core::RefCountPtr resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_resource)); - mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); + &resource)); + mutex_lock l(*resource->get_mutex()); const int num_data = data_set->NumItems(); const int32 num_outputs = param_proto_.num_outputs(); @@ -205,15 +203,16 @@ class TreePredictionsV4Op : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [this, &out, &data_set, decision_tree_resource, num_data, - &tree_paths](int64 start, int64 end) { + auto traverse = [this, &out, &data_set, &resource, num_data, &tree_paths]( + int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - TraverseTree(decision_tree_resource, data_set, static_cast(start), + + TraverseTree(resource.get(), data_set, static_cast(start), static_cast(end), std::bind(&TreePredictionsV4Op::set_output_value, this, std::placeholders::_1, std::placeholders::_2, - decision_tree_resource, &out), + resource.get(), &out), param_proto_.inference_tree_paths() ? &tree_paths : nullptr); }; Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, @@ -283,11 +282,10 @@ class TraverseTreeV4Op : public OpKernel { data_set->set_input_tensors(input_data, sparse_input_indices, sparse_input_values, sparse_input_shape); - DecisionTreeResource* decision_tree_resource; + core::RefCountPtr resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_resource)); - mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); + &resource)); + mutex_lock l(*resource->get_mutex()); const int num_data = data_set->NumItems(); @@ -304,11 +302,11 @@ class TraverseTreeV4Op : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [&set_leaf_ids, &data_set, decision_tree_resource, - num_data](int64 start, int64 end) { + auto traverse = [&set_leaf_ids, &data_set, &resource, num_data](int64 start, + int64 end) { CHECK(start <= end); CHECK(end <= num_data); - TraverseTree(decision_tree_resource, data_set, static_cast(start), + TraverseTree(resource.get(), data_set, static_cast(start), static_cast(end), set_leaf_ids, nullptr); }; Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, @@ -336,11 +334,10 @@ class UpdateModelV4Op : public OpKernel { const Tensor& input_labels = context->input(2); const Tensor& input_weights = context->input(3); - DecisionTreeResource* decision_tree_resource; + core::RefCountPtr decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_resource)); mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); const int num_data = input_labels.shape().dim_size(0); const int32 label_dim = @@ -356,9 +353,10 @@ class UpdateModelV4Op : public OpKernel { UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource); } - void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target, - int32 start, int32 end, - DecisionTreeResource* decision_tree_resource) { + void UpdateModel( + const Tensor& leaf_ids, const TensorInputTarget& target, int32 start, + int32 end, + const core::RefCountPtr& decision_tree_resource) { const auto leaves = leaf_ids.unaligned_flat(); for (int i = start; i < end; ++i) { model_op_->UpdateModel( @@ -384,11 +382,10 @@ class FeatureUsageCountsOp : public OpKernel { } void Compute(OpKernelContext* context) override { - DecisionTreeResource* decision_tree_resource; + core::RefCountPtr decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_resource)); mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); const auto& tree = decision_tree_resource->decision_tree(); diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index 0243f106814..ede6e1abc9f 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -26,6 +26,7 @@ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" @@ -87,11 +88,10 @@ class FertileStatsSerializeOp : public OpKernel { } void Compute(OpKernelContext* context) override { - FertileStatsResource* fertile_stats_resource; + core::RefCountPtr fertile_stats_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &fertile_stats_resource)); mutex_lock l(*fertile_stats_resource->get_mutex()); - core::ScopedUnref unref_me(fertile_stats_resource); Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape(), &output_config_t)); @@ -116,11 +116,10 @@ class FertileStatsDeserializeOp : public OpKernel { } void Compute(OpKernelContext* context) override { - FertileStatsResource* fertile_stats_resource; + core::RefCountPtr fertile_stats_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &fertile_stats_resource)); mutex_lock l(*fertile_stats_resource->get_mutex()); - core::ScopedUnref unref_me(fertile_stats_resource); const Tensor* stats_config_t; OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t)); @@ -145,7 +144,7 @@ class FertileStatsDeserializeOp : public OpKernel { // acquired, put it in a waiting queue to come back to later and try the next // one. Once all leaf_ids have been visited, cycle through the waiting ids // until they're gone. -void UpdateStats(FertileStatsResource* fertile_stats_resource, +void UpdateStats(const core::RefCountPtr& resource, const std::unique_ptr& data, const TensorInputTarget& target, int num_targets, const Tensor& leaf_ids_tensor, @@ -183,8 +182,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource, } bool is_finished; - fertile_stats_resource->AddExampleToStatsAndInitialize( - data, &target, {example_id}, leaf_id, &is_finished); + resource->AddExampleToStatsAndInitialize(data, &target, {example_id}, + leaf_id, &is_finished); leaf_lock->unlock(); if (is_finished) { set_lock->lock(); @@ -196,8 +195,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource, // Update leaves from start through end in the leaf_examples iterator. void UpdateStatsCollated( - FertileStatsResource* fertile_stats_resource, - DecisionTreeResource* tree_resource, + const core::RefCountPtr& fertile_stats_resource, + const core::RefCountPtr& tree_resource, const std::unique_ptr& data, const TensorInputTarget& target, int num_targets, const std::unordered_map>& leaf_examples, @@ -251,18 +250,15 @@ class ProcessInputOp : public OpKernel { data_set->set_input_tensors(input_data, sparse_input_indices, sparse_input_values, sparse_input_shape); - FertileStatsResource* fertile_stats_resource; + core::RefCountPtr fertile_stats_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), &fertile_stats_resource)); - DecisionTreeResource* tree_resource; + core::RefCountPtr tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &tree_resource)); mutex_lock l1(*fertile_stats_resource->get_mutex()); mutex_lock l2(*tree_resource->get_mutex()); - core::ScopedUnref unref_stats(fertile_stats_resource); - core::ScopedUnref unref_tree(tree_resource); - const int32 num_data = data_set->NumItems(); auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; @@ -308,7 +304,7 @@ class ProcessInputOp : public OpKernel { // if it really matters that much. const int64 costPerUpdate = 1000; auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set, - fertile_stats_resource, &locks, &set_lock, &ready_to_split, + &fertile_stats_resource, &locks, &set_lock, &ready_to_split, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); @@ -317,8 +313,8 @@ class ProcessInputOp : public OpKernel { static_cast(end), &ready_to_split); }; - auto update_collated = [&target, &num_targets, fertile_stats_resource, - tree_resource, &leaf_examples, &set_lock, + auto update_collated = [&target, &num_targets, &fertile_stats_resource, + &tree_resource, &leaf_examples, &set_lock, &ready_to_split, &data_set, num_leaves](int64 start, int64 end) { CHECK(start <= end); @@ -362,18 +358,15 @@ class GrowTreeOp : public OpKernel { } void Compute(OpKernelContext* context) override { - FertileStatsResource* fertile_stats_resource; + core::RefCountPtr fertile_stats_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), &fertile_stats_resource)); - DecisionTreeResource* tree_resource; + core::RefCountPtr tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &tree_resource)); mutex_lock l1(*fertile_stats_resource->get_mutex()); mutex_lock l2(*tree_resource->get_mutex()); - core::ScopedUnref unref_stats(fertile_stats_resource); - core::ScopedUnref unref_tree(tree_resource); - const Tensor& finished_nodes = context->input(2); const auto finished = finished_nodes.unaligned_flat(); @@ -463,19 +456,16 @@ class FinalizeTreeOp : public OpKernel { } void Compute(OpKernelContext* context) override { - DecisionTreeResource* tree_resource; + core::RefCountPtr tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &tree_resource)); - FertileStatsResource* fertile_stats_resource; + core::RefCountPtr fertile_stats_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), &fertile_stats_resource)); mutex_lock l1(*fertile_stats_resource->get_mutex()); mutex_lock l2(*tree_resource->get_mutex()); - core::ScopedUnref unref_me(tree_resource); - core::ScopedUnref unref_stats(fertile_stats_resource); - // TODO(thomaswc): Add threads int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size(); for (int i = 0; i < num_nodes; i++) { diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index da547d5829f..0f1dfc4efd0 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -270,12 +270,19 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value); template Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); +// Looks up a resource pointed by a given resource handle. +// +// Prefer usage of LookupResource taking `core::RefCountPtr` to avoid +// requiring the caller to explicitly call `Unref()`. +template +Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + core::RefCountPtr* value); + // Looks up multiple resources pointed by a sequence of resource handles. If // p[i] is uninitialized then values[i] is unmodified. template -Status LookupResources( - OpKernelContext* ctx, absl::Span p, - std::vector>* values); +Status LookupResources(OpKernelContext* ctx, absl::Span p, + std::vector>* values); // Looks up or creates a resource. // @@ -283,10 +290,19 @@ Status LookupResources( // must call its `Unref()` method when it has finished using it. If the // `creator` is invoked, its reference on the created resource is transferred // to `ctx->resource_mgr()`. +// +// Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid +// requiring the caller to explicitly call `Unref()`. template Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, T** value, std::function creator); +// Looks up or creates a resource. +template +Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, + core::RefCountPtr* value, + std::function creator); + // Destroys a resource pointed by a given resource handle. template Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); @@ -587,9 +603,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, } template -Status LookupResources( - OpKernelContext* ctx, absl::Span p, - std::vector>* values) { +Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + core::RefCountPtr* value) { + T* raw_ptr = nullptr; + TF_RETURN_IF_ERROR(LookupResource(ctx, p, &raw_ptr)); + value->reset(raw_ptr); + + return Status::OK(); +} + +template +Status LookupResources(OpKernelContext* ctx, + absl::Span p, + std::vector>* values) { std::vector> containers_and_names( p.size()); for (size_t i = 0; i < p.size(); ++i) { @@ -607,6 +633,17 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, creator); } +template +Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, + core::RefCountPtr* value, + std::function creator) { + T* raw_ptr = nullptr; + TF_RETURN_IF_ERROR(LookupOrCreateResource(ctx, p, &raw_ptr, creator)); + value->reset(raw_ptr); + + return Status::OK(); +} + template Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index 2f5880e9396..e093b144c85 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -266,15 +267,14 @@ TEST(ResourceHandleTest, CRUD) { TF_EXPECT_OK(CreateResource(&ctx, p, r)); } { - StubResource* r = nullptr; + core::RefCountPtr r; TF_ASSERT_OK(LookupResource(&ctx, p, &r)); ASSERT_TRUE(r != nullptr); EXPECT_EQ(r->value_, 42); - r->Unref(); } { TF_EXPECT_OK(DeleteResource(&ctx, p)); - StubResource* unused = nullptr; + core::RefCountPtr unused; EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok()); } } @@ -338,13 +338,12 @@ TEST(ResourceHandleTest, DeleteUsingResourceHandle) { StubResource* r = new StubResource; TF_EXPECT_OK(CreateResource(&ctx, p, r)); - StubResource* lookup_r = nullptr; + core::RefCountPtr lookup_r; TF_EXPECT_OK(LookupResource(&ctx, p, &lookup_r)); - EXPECT_EQ(lookup_r, r); + EXPECT_EQ(lookup_r.get(), r); TF_EXPECT_OK(DeleteResource(&ctx, p)); EXPECT_NE(LookupResource(&ctx, p, &lookup_r).ok(), true); - r->Unref(); } } // end namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index e58ab4a676b..0963effcc3e 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -166,6 +166,7 @@ tf_kernel_library( ":variable_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//third_party/eigen3", ], ) @@ -5022,6 +5023,7 @@ STATE_DEPS = [ "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ] + if_sycl(["//tensorflow/core:sycl_runtime"]) tf_kernel_library( @@ -7589,6 +7591,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/db:sqlite", "//tensorflow/core/summary:schema", diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD index 5f6ccac1b19..ac94371de14 100644 --- a/tensorflow/core/kernels/boosted_trees/BUILD +++ b/tensorflow/core/kernels/boosted_trees/BUILD @@ -83,6 +83,7 @@ tf_kernel_library( ":resources", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", ], ) @@ -106,6 +107,7 @@ tf_kernel_library( ":tree_helper", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", "//third_party/eigen3", ], @@ -117,6 +119,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles", ], ) diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc index 72194601000..718cf8e4139 100644 --- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc @@ -51,12 +51,10 @@ class BoostedTreesTrainingPredictOp : public OpKernel { } void Compute(OpKernelContext* const context) override { - BoostedTreesEnsembleResource* resource; + core::RefCountPtr resource; // Get the resource. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &resource)); - // Release the reference to the resource once we're done using it. - core::ScopedUnref unref_me(resource); // Get the inputs. OpInputList bucketized_features_list; @@ -198,12 +196,10 @@ class BoostedTreesPredictOp : public OpKernel { } void Compute(OpKernelContext* const context) override { - BoostedTreesEnsembleResource* resource; + core::RefCountPtr resource; // Get the resource. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &resource)); - // Release the reference to the resource once we're done using it. - core::ScopedUnref unref_me(resource); // Get the inputs. OpInputList bucketized_features_list; @@ -302,12 +298,10 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel { } void Compute(OpKernelContext* const context) override { - BoostedTreesEnsembleResource* resource; + core::RefCountPtr resource; // Get the resource. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &resource)); - // Release the reference to the resource once we're done using it. - core::ScopedUnref unref_me(resource); // Get the inputs. OpInputList bucketized_features_list; diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc index 81f04732d33..da6ad38b425 100644 --- a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc @@ -27,6 +27,7 @@ #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h" #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -224,12 +225,11 @@ class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel { ResourceHandle handle; OP_REQUIRES_OK(context, HandleFromInput(context, kResourceHandleName, &handle)); - QuantileStreamResource* stream_resource; + core::RefCountPtr stream_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); // Remove the reference at the end of this scope. mutex_lock l(*stream_resource->mutex()); - core::ScopedUnref unref_me(stream_resource); OpInputList summaries_list; OP_REQUIRES_OK(context, @@ -281,13 +281,12 @@ class BoostedTreesQuantileStreamResourceDeserializeOp : public OpKernel { } void Compute(OpKernelContext* context) override { - QuantileStreamResource* streams_resource; + core::RefCountPtr streams_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &streams_resource)); // Remove the reference at the end of this scope. mutex_lock l(*streams_resource->mutex()); - core::ScopedUnref unref_me(streams_resource); OpInputList bucket_boundaries_list; OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName, @@ -336,12 +335,11 @@ class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel { ResourceHandle handle; OP_REQUIRES_OK(context, HandleFromInput(context, kResourceHandleName, &handle)); - QuantileStreamResource* stream_resource; + core::RefCountPtr stream_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); // Remove the reference at the end of this scope. mutex_lock l(*stream_resource->mutex()); - core::ScopedUnref unref_me(stream_resource); const Tensor* num_buckets_t; OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t)); @@ -391,12 +389,11 @@ class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp ResourceHandle handle; OP_REQUIRES_OK(context, HandleFromInput(context, kResourceHandleName, &handle)); - QuantileStreamResource* stream_resource; + core::RefCountPtr stream_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); // Remove the reference at the end of this scope. mutex_lock l(*stream_resource->mutex()); - core::ScopedUnref unref_me(stream_resource); const int64 num_streams = stream_resource->num_streams(); CHECK_EQ(num_features_, num_streams); diff --git a/tensorflow/core/kernels/boosted_trees/resource_ops.cc b/tensorflow/core/kernels/boosted_trees/resource_ops.cc index 563f7b8b08c..5a9c3549041 100644 --- a/tensorflow/core/kernels/boosted_trees/resource_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/resource_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/boosted_trees/resources.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { @@ -78,11 +79,10 @@ class BoostedTreesGetEnsembleStatesOp : public OpKernel { void Compute(OpKernelContext* context) override { // Looks up the resource. - BoostedTreesEnsembleResource* tree_ensemble_resource; + core::RefCountPtr tree_ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &tree_ensemble_resource)); tf_shared_lock l(*tree_ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(tree_ensemble_resource); // Sets the outputs. const int num_trees = tree_ensemble_resource->num_trees(); @@ -141,11 +141,10 @@ class BoostedTreesSerializeEnsembleOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - BoostedTreesEnsembleResource* tree_ensemble_resource; + core::RefCountPtr tree_ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &tree_ensemble_resource)); tf_shared_lock l(*tree_ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(tree_ensemble_resource); Tensor* output_stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_stamp_token_t)); @@ -169,11 +168,10 @@ class BoostedTreesDeserializeEnsembleOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - BoostedTreesEnsembleResource* tree_ensemble_resource; + core::RefCountPtr tree_ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &tree_ensemble_resource)); mutex_lock l(*tree_ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(tree_ensemble_resource); // Get the stamp token. const Tensor* stamp_token_t; diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc index 68cf99a6607..eabb8361127 100644 --- a/tensorflow/core/kernels/boosted_trees/training_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/boosted_trees/resources.h" #include "tensorflow/core/kernels/boosted_trees/tree_helper.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { @@ -55,10 +56,9 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - BoostedTreesEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); - core::ScopedUnref unref_me(ensemble_resource); mutex_lock l(*ensemble_resource->get_mutex()); // Increase the ensemble stamp. ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); @@ -176,19 +176,19 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { private: int32 UpdateGlobalAttemptsAndRetrieveGrowableTree( - BoostedTreesEnsembleResource* const ensemble_resource) { - int32 num_trees = ensemble_resource->num_trees(); + const core::RefCountPtr& resource) { + int32 num_trees = resource->num_trees(); int32 current_tree = num_trees - 1; // Increment global attempt stats. - ensemble_resource->UpdateGrowingMetadata(); + resource->UpdateGrowingMetadata(); // Note we don't set tree weight to be equal to learning rate, since we // apply learning rate to leaf weights instead, when doing layer-by-layer // boosting. if (num_trees <= 0) { // Create a new tree with a no-op leaf. - current_tree = ensemble_resource->AddNewTree(kLayerByLayerTreeWeight); + current_tree = resource->AddNewTree(kLayerByLayerTreeWeight); } return current_tree; } @@ -250,10 +250,9 @@ class BoostedTreesCenterBiasOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - BoostedTreesEnsembleResource* ensemble_resource; + core::RefCountPtr ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &ensemble_resource)); - core::ScopedUnref unref_me(ensemble_resource); mutex_lock l(*ensemble_resource->get_mutex()); // Increase the ensemble stamp. ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); diff --git a/tensorflow/core/kernels/count_up_to_op.cc b/tensorflow/core/kernels/count_up_to_op.cc index 9da0015fa2d..e59cda7b011 100644 --- a/tensorflow/core/kernels/count_up_to_op.cc +++ b/tensorflow/core/kernels/count_up_to_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -65,11 +66,9 @@ class ResourceCountUpToOp : public OpKernel { } void Compute(OpKernelContext* context) override { - Var* variable = nullptr; - OP_REQUIRES_OK( - context, - LookupResource(context, HandleFromInput(context, 0), &variable)); - core::ScopedUnref s(variable); + core::RefCountPtr variable; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &variable)); mutex_lock l(*variable->mu()); Tensor before_increment = *variable->tensor(); OP_REQUIRES( diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 2ba1143290c..07dbef191dc 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -342,6 +342,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:summary_interface", ], @@ -380,6 +381,7 @@ tf_kernel_library( "//tensorflow/core:experimental_dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/kernels/data:dataset_utils", "//third_party/eigen3", ], diff --git a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc index dcd4e68e65e..3129edf77cd 100644 --- a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/data/stats_utils.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { @@ -83,17 +84,16 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - StatsAggregatorResource* stats_aggregator_resource; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), - &stats_aggregator_resource)); - core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); + core::RefCountPtr resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 1), &resource)); string tag; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag)); string prefix; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix)); - *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource, - tag, prefix); + *output = + new Dataset(ctx, input, ctx->input(1), resource.get(), tag, prefix); } private: @@ -101,12 +101,12 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, const Tensor& resource_handle, - StatsAggregatorResource* stats_aggregator_resource, - const string& tag, const string& prefix) + StatsAggregatorResource* resource, const string& tag, + const string& prefix) : DatasetBase(DatasetContext(ctx)), input_(input), resource_handle_(resource_handle), - stats_aggregator_resource_(stats_aggregator_resource), + stats_aggregator_resource_(resource), tag_(tag), prefix_(prefix) { input_->Ref(); @@ -169,13 +169,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - StatsAggregatorResource* stats_aggregator_resource = + StatsAggregatorResource* resource = dataset()->stats_aggregator_resource_; IteratorContext::Params params(ctx); params.stats_aggregator = std::shared_ptr( - new StatsAggregatorWithTagAndPrefix( - stats_aggregator_resource->stats_aggregator(), dataset()->tag_, - dataset()->prefix_)); + new StatsAggregatorWithTagAndPrefix(resource->stats_aggregator(), + dataset()->tag_, + dataset()->prefix_)); IteratorContext iter_ctx(std::move(params)); return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence); } diff --git a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc index 0d6ec07364d..8145f7d45b3 100644 --- a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/stats_aggregator.h" - #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/kernels/summary_interface.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/histogram/histogram.h" #include "tensorflow/core/lib/monitoring/counter.h" @@ -258,10 +258,9 @@ class StatsAggregatorSummaryOp : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), errors::InvalidArgument("resource_handle must be a scalar")); - StatsAggregatorResource* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); - core::ScopedUnref unref_iterator(resource); Tensor* summary_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t)); @@ -281,21 +280,19 @@ class StatsAggregatorSetSummaryWriterOp : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), errors::InvalidArgument("resource_handle must be a scalar")); - StatsAggregatorResource* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); - core::ScopedUnref unref_iterator(resource); const Tensor& summary_resource_handle_t = ctx->input(1); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()), errors::InvalidArgument("resource_handle must be a scalar")); - SummaryWriterInterface* sumamry_resource; + core::RefCountPtr sumamry_resource; OP_REQUIRES_OK( ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource)); - core::ScopedUnref unref_sumamry_resource(sumamry_resource); TF_CHECK_OK( - resource->stats_aggregator()->SetSummaryWriter(sumamry_resource)); + resource->stats_aggregator()->SetSummaryWriter(sumamry_resource.get())); } }; diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index 9d1649bf021..5c6af068ef7 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/util/work_sharder.h" @@ -127,12 +128,10 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - ThreadPoolResource* threadpool_resource; + core::RefCountPtr threadpool_resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &threadpool_resource)); - core::ScopedUnref unref_iterator(threadpool_resource); - - *output = new Dataset(ctx, input, ctx->input(1), threadpool_resource); + *output = new Dataset(ctx, input, ctx->input(1), threadpool_resource.get()); } private: diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index f5907831676..25c8b69b37f 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/unbounded_thread_pool.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" @@ -595,10 +596,9 @@ int64 AnonymousIteratorHandleOp::current_id_(0); void MakeIteratorOp::Compute(OpKernelContext* ctx) { DatasetBase* dataset; OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); - IteratorResource* iterator_resource; + core::RefCountPtr iterator_resource; OP_REQUIRES_OK( ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); - core::ScopedUnref unref(iterator_resource); OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset)); } diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index 34d9ece8b06..1c0ab96ea61 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/unbounded_thread_pool.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/util/device_name_utils.h" @@ -492,10 +493,9 @@ class MultiDeviceIteratorInitOp : public OpKernel { DatasetBase* dataset; OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); - MultiDeviceIterator* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &resource)); - core::ScopedUnref unref(resource); std::unique_ptr iterator; IteratorContext::Params params(ctx); @@ -535,7 +535,7 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done); int64 incarnation_id = tensor_incarnation_id->scalar()(); - MultiDeviceIterator* iterator; + core::RefCountPtr iterator; OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); @@ -557,7 +557,6 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { std::placeholders::_1, std::move(done)); iterator->GetNextFromShard(ctx, shard_num, incarnation_id, callback); - iterator->Unref(); } }; @@ -577,10 +576,9 @@ class MultiDeviceIteratorToStringHandleOp : public OpKernel { // Validate that the handle corresponds to a real resource, and // that it is an MultiDeviceIterator. - MultiDeviceIterator* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); - resource->Unref(); Tensor* string_handle_t; OP_REQUIRES_OK(ctx, @@ -629,9 +627,8 @@ class MultiDeviceIteratorFromStringHandleOp : public OpKernel { // Validate that the handle corresponds to a real resource, and // that it is an MultiDeviceIterator. - MultiDeviceIterator* resource; + core::RefCountPtr resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource)); - core::ScopedUnref unref_iterator(resource); if (!output_types_.empty()) { OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_types_, resource->output_types())); diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc index 6ed36605530..f0d8310e26f 100644 --- a/tensorflow/core/kernels/random_binomial_op.cc +++ b/tensorflow/core/kernels/random_binomial_op.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" #include "tensorflow/core/kernels/training_op_helpers.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/guarded_philox_random.h" @@ -371,10 +372,9 @@ class RandomBinomialOp : public OpKernel { "Input probs should have length 1 or shape[0], got shape: ", probs_tensor.shape().DebugString())); } - Var* var = nullptr; + core::RefCountPtr var; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var)); - ScopedUnlockUnrefVar var_guard(var); Tensor* var_tensor = var->tensor(); OP_REQUIRES( ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE, @@ -404,7 +404,6 @@ class RandomBinomialOp : public OpKernel { auto philox = GetPhiloxRandomFromMem(var_data); UpdateMemWithPhiloxRandom( philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data); - var_guard.Release(); auto binomial_functor = functor::RandomBinomialFunctor(); binomial_functor(ctx, ctx->eigen_device(), num_batches, diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index dda4e02b984..16f8ed12baf 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -129,7 +129,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) { } // namespace void ReadVariableOp::Compute(OpKernelContext* ctx) { - Var* variable = nullptr; + core::RefCountPtr variable; const ResourceHandle& handle = HandleFromInput(ctx, 0); const auto status = LookupResource(ctx, handle, &variable); OP_REQUIRES(ctx, status.ok(), @@ -139,7 +139,6 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) { ". This could mean that the variable was uninitialized. ", status.ToString())); - core::ScopedUnref s(variable); { tf_shared_lock ml(*variable->mu()); // We're acquiring a reference to the underlying buffer while @@ -175,8 +174,7 @@ ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) { } void ReadVariablesOp::Compute(OpKernelContext* ctx) { - std::vector> variables( - dtypes_.size()); + std::vector> variables(dtypes_.size()); std::vector handles(dtypes_.size()); for (size_t i = 0; i < dtypes_.size(); ++i) { handles[i] = &HandleFromInput(ctx, i); @@ -265,10 +263,9 @@ class VariableShapeOp : public OpKernel { explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* ctx) override { - Var* variable = nullptr; + core::RefCountPtr variable; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &variable)); - core::ScopedUnref s(variable); variable->mu()->lock_shared(); TensorShape shape = variable->tensor()->shape(); variable->mu()->unlock_shared(); @@ -343,7 +340,7 @@ class AssignVariableOp : public OpKernel { "Variable and value dtypes don't match; respectively, ", DataTypeString(dtype_), " and ", DataTypeString(context->input(1).dtype()))); - Var* variable = nullptr; + core::RefCountPtr variable; const Tensor& value = context->input(1); // Note: every resource-variable-manipulating op assumes copy-on-write // semantics, and creates a copy of the variable's Tensor if its refcount is @@ -361,7 +358,6 @@ class AssignVariableOp : public OpKernel { (*ptr)->is_initialized = true; return Status::OK(); })); - core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, errors::InvalidArgument( @@ -404,7 +400,7 @@ class AssignVariableOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& value = context->input(1); - Var* variable = nullptr; + core::RefCountPtr variable; OP_REQUIRES_OK(context, LookupOrCreateResource( context, HandleFromInput(context, 0), &variable, [](Var** ptr) { @@ -412,7 +408,6 @@ class AssignVariableOp : public OpKernel { *ptr = new Var(DT_VARIANT); return Status::OK(); })); - core::ScopedUnref s(variable); // For purposes of forwarding DT_VARIANT, we want the least // restrictive attr; we already know the input is on host. @@ -500,10 +495,9 @@ class AssignUpdateVariableOp : public OpKernel { explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* context) override { - Var* variable = nullptr; + core::RefCountPtr variable; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &variable)); - core::ScopedUnref s(variable); const Tensor& value = context->input(1); // TODO(apassos): We could possibly avoid the copy done by @@ -568,13 +562,12 @@ class VarIsInitializedOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); auto output_tensor = output->tensor(); - Var* variable = nullptr; + core::RefCountPtr variable; Status s = LookupResource(context, HandleFromInput(context, 0), &variable); if (!s.ok()) { output_tensor() = false; return; } - core::ScopedUnref su(variable); mutex_lock ml(*variable->mu()); output_tensor() = variable->is_initialized; } @@ -623,10 +616,9 @@ class ResourceGatherOp : public OpKernel { } void Compute(OpKernelContext* c) override { - Var* v = nullptr; + core::RefCountPtr v; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); - core::ScopedUnref su(v); - OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v)); + OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v.get())); // NOTE: We hold the lock for the whole gather operation instead // of increasing the reference count of v->tensor() to avoid a // situation where a write to the same variable will see a @@ -765,10 +757,9 @@ class ResourceGatherNdOp : public OpKernel { explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* c) override { - Var* v = nullptr; + core::RefCountPtr v; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); - core::ScopedUnref su(v); - OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v)); + OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v.get())); // NOTE: We hold the lock for the whole gather operation instead // of increasing the reference count of v->tensor() to avoid a // situation where a write to the same variable will see a @@ -821,10 +812,9 @@ class ResourceScatterUpdateOp : public OpKernel { explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* c) override { - Var* v = nullptr; + core::RefCountPtr v; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); - core::ScopedUnref unref_v(v); - OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v)); + OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v.get())); tf_shared_lock ml(*v->mu()); Tensor* params = v->tensor(); const Tensor& indices = c->input(1); diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index ea03abd40f5..d307385e3a7 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -243,10 +243,9 @@ class ScatterNdUpdateOp : public OpKernel { void Compute(OpKernelContext* c) override { if (dtype_ == DT_RESOURCE) { - Var* v; + core::RefCountPtr v; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); - core::ScopedUnref scoped_unref(v); - OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v)); + OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v.get())); mutex_lock m(*v->mu()); DoCompute(c); } else if (use_exclusive_lock_) { @@ -271,9 +270,8 @@ class ScatterNdUpdateOp : public OpKernel { TensorShape params_shape; if (dtype_ == DT_RESOURCE) { - Var* v; + core::RefCountPtr v; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); - core::ScopedUnref scoped_unref(v); Tensor* t = v->tensor(); params = *t; params_shape = params.shape(); diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 03169965b14..eb202070042 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -15,6 +15,7 @@ limitations under the License. // See docs in ../ops/array_ops.cc. +#include "tensorflow/core/lib/core/refcount.h" #define EIGEN_USE_THREADS #if GOOGLE_CUDA @@ -323,12 +324,11 @@ class StridedSliceAssignOp : public OpKernel { } } else { if (context->input_dtype(0) == DT_RESOURCE) { - Var* v; + core::RefCountPtr v; OP_REQUIRES_OK( context, LookupResource(context, HandleFromInput(context, 0), &v)); - core::ScopedUnref scoped_unref(v); OP_REQUIRES_OK(context, - EnsureSparseVariableAccess(context, v)); + EnsureSparseVariableAccess(context, v.get())); mutex_lock ml(*v->mu()); old_lhs = v->tensor(); OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum::value, diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index 6ea9e19018e..e17e28efc63 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/summary/schema.h" @@ -45,7 +46,7 @@ class CreateSummaryFileWriterOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp)); const string filename_suffix = tmp->scalar()(); - SummaryWriterInterface* s = nullptr; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupOrCreateResource( ctx, HandleFromInput(ctx, 0), &s, [max_queue, flush_millis, logdir, filename_suffix, @@ -54,7 +55,6 @@ class CreateSummaryFileWriterOp : public OpKernel { max_queue, flush_millis, logdir, filename_suffix, ctx->env(), s); })); - core::ScopedUnref unref(s); } }; REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), @@ -75,7 +75,7 @@ class CreateSummaryDbWriterOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp)); const string user_name = tmp->scalar()(); - SummaryWriterInterface* s = nullptr; + core::RefCountPtr s; OP_REQUIRES_OK( ctx, LookupOrCreateResource( @@ -91,7 +91,6 @@ class CreateSummaryDbWriterOp : public OpKernel { db, experiment_name, run_name, user_name, ctx->env(), s)); return Status::OK(); })); - core::ScopedUnref unref(s); } }; REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU), @@ -102,9 +101,8 @@ class FlushSummaryWriterOp : public OpKernel { explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - SummaryWriterInterface* s; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); - core::ScopedUnref unref(s); OP_REQUIRES_OK(ctx, s->Flush()); } }; @@ -128,9 +126,8 @@ class WriteSummaryOp : public OpKernel { explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - SummaryWriterInterface* s; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); - core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); @@ -153,9 +150,8 @@ class WriteRawProtoSummaryOp : public OpKernel { explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - SummaryWriterInterface* s; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); - core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), @@ -190,9 +186,8 @@ class ImportEventOp : public OpKernel { explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - SummaryWriterInterface* s; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); - core::ScopedUnref unref(s); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("event", &t)); std::unique_ptr event{new Event}; @@ -211,9 +206,8 @@ class WriteScalarSummaryOp : public OpKernel { explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - SummaryWriterInterface* s; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); - core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); @@ -234,9 +228,8 @@ class WriteHistogramSummaryOp : public OpKernel { explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - SummaryWriterInterface* s; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); - core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); @@ -263,9 +256,8 @@ class WriteImageSummaryOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - SummaryWriterInterface* s; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); - core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); @@ -299,9 +291,8 @@ class WriteAudioSummaryOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - SummaryWriterInterface* s; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); - core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); @@ -328,9 +319,8 @@ class WriteGraphSummaryOp : public OpKernel { explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - SummaryWriterInterface* s; + core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); - core::ScopedUnref unref(s); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("step", &t)); const int64 step = t->scalar()(); diff --git a/tensorflow/core/kernels/tensor_forest/BUILD b/tensorflow/core/kernels/tensor_forest/BUILD index d4c23ff3c54..5fa63a860e9 100644 --- a/tensorflow/core/kernels/tensor_forest/BUILD +++ b/tensorflow/core/kernels/tensor_forest/BUILD @@ -26,6 +26,7 @@ tf_kernel_library( ":resources", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", ], ) @@ -37,6 +38,7 @@ tf_kernel_library( ":resources", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", ], ) diff --git a/tensorflow/core/kernels/tensor_forest/prediction_ops.cc b/tensorflow/core/kernels/tensor_forest/prediction_ops.cc index 8e75421fb95..08891db6f1a 100644 --- a/tensorflow/core/kernels/tensor_forest/prediction_ops.cc +++ b/tensorflow/core/kernels/tensor_forest/prediction_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/tensor_forest/resources.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/util/work_sharder.h" @@ -29,12 +30,10 @@ class TensorForestTreePredictOp : public OpKernel { } void Compute(OpKernelContext* context) override { - TensorForestTreeResource* decision_tree_resource; + core::RefCountPtr decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_resource)); mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); - const Tensor* dense_features_t = nullptr; OP_REQUIRES_OK(context, context->input("dense_features", &dense_features_t)); @@ -60,7 +59,7 @@ class TensorForestTreePredictOp : public OpKernel { // We will need to run it on a number of trees of diff depth // and see the num of cpu cycles const int64 cost_per_traverse = 500; - auto traverse = [this, &out, &dense_features, decision_tree_resource, + auto traverse = [this, &out, &dense_features, &decision_tree_resource, batch_size](int64 start, int64 end) { DCHECK_LE(start, end) << "Start exceeding End"; DCHECK_LE(end, batch_size) << "End exceeding batch size"; @@ -74,9 +73,10 @@ class TensorForestTreePredictOp : public OpKernel { traverse); }; - void set_output_value(const int32 example_id, const int32 leaf_id, - const TensorForestTreeResource* decision_tree_resource, - TTypes::Matrix* out) const { + void set_output_value( + const int32 example_id, const int32 leaf_id, + const core::RefCountPtr& decision_tree_resource, + TTypes::Matrix* out) const { for (int j = 0; j < logits_dimension_; ++j) { const float logit = decision_tree_resource->get_prediction(leaf_id, j); (*out)(example_id, j) = logit; diff --git a/tensorflow/core/kernels/tensor_forest/resource_ops.cc b/tensorflow/core/kernels/tensor_forest/resource_ops.cc index 0474d56098f..c225d83674f 100644 --- a/tensorflow/core/kernels/tensor_forest/resource_ops.cc +++ b/tensorflow/core/kernels/tensor_forest/resource_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h" #include "tensorflow/core/kernels/tensor_forest/resources.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { @@ -55,11 +56,10 @@ class TensorForestTreeSerializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - TensorForestTreeResource* decision_tree_resource; + core::RefCountPtr decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_resource)); mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape(), &output_config_t)); @@ -74,13 +74,11 @@ class TensorForestTreeDeserializeOp : public OpKernel { explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - TensorForestTreeResource* decision_tree_resource; + core::RefCountPtr decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_resource)); mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); - const Tensor* tree_config_t; OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t)); @@ -102,11 +100,10 @@ class TensorForestTreeSizeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - TensorForestTreeResource* decision_tree_resource; + core::RefCountPtr decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_resource)); mutex_lock l(*decision_tree_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_resource); Tensor* output_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_t)); diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h index 715dd8af7da..d17e19db880 100644 --- a/tensorflow/core/kernels/training_op_helpers.h +++ b/tensorflow/core/kernels/training_op_helpers.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { @@ -167,10 +168,8 @@ VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( std::sort(acquire_order.begin(), acquire_order.end(), [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); - std::unique_ptr> locks = - absl::make_unique>(); - std::unique_ptr> shared_locks = - absl::make_unique>(); + auto locks = absl::make_unique>(); + auto shared_locks = absl::make_unique>(); locks->reserve(acquire_order.size()); for (auto input : acquire_order) { @@ -241,11 +240,10 @@ template Status GetInputTensorFromVariable(OpKernelContext* ctx, int input, bool lock_held, bool sparse, Tensor* out) { if (ctx->input_dtype(input) == DT_RESOURCE) { - Var* var; + core::RefCountPtr var; TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var)); - core::ScopedUnref unref_var(var); if (sparse) { - TF_RETURN_IF_ERROR(EnsureSparseVariableAccess(ctx, var)); + TF_RETURN_IF_ERROR(EnsureSparseVariableAccess(ctx, var.get())); *out = *var->tensor(); return Status::OK(); }