Use RefCountPtr in LookupResource to avoid leaks

LookupResource returns a raw pointer which the caller needs to Unref.
The prevalent pattern is this is followed by a ScopedUnref. This can be
problematic, since if a caller forgets to add a ScopedUnref call, we
have a memory leak. We resolve this by using RefCountPtr instead of a
raw pointer in LookupResource. Most use cases have been migrated in this
change.

Note some variables were renamed to handle line length restrictions.

PiperOrigin-RevId: 250423227
This commit is contained in:
Gaurav Jain 2019-05-28 21:53:42 -07:00 committed by TensorFlower Gardener
parent 4941b4a73f
commit 2e758833c3
43 changed files with 372 additions and 421 deletions

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
@ -31,12 +32,11 @@ std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
std::map<int, OptionalTensor> variables;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
if (ctx->input(i).dtype() == DT_RESOURCE) {
Var* variable = nullptr;
core::RefCountPtr<Var> variable;
ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& optional = variables[i];
optional.name = handle.name();
if (LookupResource(ctx, handle, &variable).ok()) {
core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu());
optional.present = true;
optional.value = *variable->tensor();

View File

@ -40,7 +40,7 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
"Variable and value dtypes don't match; respectively, ",
DataTypeString(dtype_), " and ",
DataTypeString(context->input(1).dtype())));
Var* variable = nullptr;
core::RefCountPtr<Var> variable;
const Tensor& value = context->input(1);
// Note: every resource-variable-manipulating op assumes copy-on-write
// semantics, and creates a copy of the variable's Tensor if its refcount is
@ -58,7 +58,6 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
(*ptr)->is_initialized = true;
return Status::OK();
}));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/util/stream_executor_util.h"
namespace tensorflow {
@ -86,7 +87,7 @@ static Status GetVariableInfosFromCtxInputs(
variable_indices, std::back_inserter(resource_handles),
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables;
std::vector<core::RefCountPtr<Var>> variables;
TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables));
result->clear();

View File

@ -84,9 +84,8 @@ class PopulateTRTEngineCache : public OpKernel {
void Compute(OpKernelContext* ctx) override {
ResourceHandle handle = HandleFromInput(ctx, 0);
TRTEngineCacheResource* resource = nullptr;
core::RefCountPtr<TRTEngineCacheResource> resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource));
core::ScopedUnref unref_me(resource);
auto allocator = resource->allocator_.get();
OP_REQUIRES(ctx, allocator != nullptr,

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
@ -139,19 +140,19 @@ class BigtableTableOp : public OpKernel {
ResourceMgr* mgr = ctx->resource_manager();
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
BigtableClientResource* client_resource;
core::RefCountPtr<BigtableClientResource> client_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
core::ScopedUnref unref_client(client_resource);
BigtableTableResource* resource;
OP_REQUIRES_OK(
ctx, mgr->LookupOrCreate<BigtableTableResource>(
cinfo_.container(), cinfo_.name(), &resource,
[this, client_resource](BigtableTableResource** ret) {
*ret = new BigtableTableResource(client_resource, table_);
return Status::OK();
}));
OP_REQUIRES_OK(ctx,
mgr->LookupOrCreate<BigtableTableResource>(
cinfo_.container(), cinfo_.name(), &resource,
[this, &client_resource](BigtableTableResource** ret) {
*ret = new BigtableTableResource(
client_resource.get(), table_);
return Status::OK();
}));
initialized_ = true;
}
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
@ -236,10 +237,9 @@ class ToBigtableOp : public AsyncOpKernel {
errors::InvalidArgument("timestamp must be >= -1"),
done);
BigtableTableResource* resource;
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
core::ScopedUnref resource_cleanup(resource);
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());

View File

@ -26,9 +26,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
BigtableTableResource* table;
core::RefCountPtr<BigtableTableResource> table;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
core::ScopedUnref scoped_unref(table);
std::vector<string> column_families;
std::vector<string> columns;
@ -50,7 +49,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
}
*output =
new Dataset(ctx, input, table, std::move(column_families),
new Dataset(ctx, input, table.get(), std::move(column_families),
std::move(columns), output_types, std::move(output_shapes));
}

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
namespace data {
@ -28,12 +29,10 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
string prefix;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
BigtableTableResource* resource;
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource);
*output = new Dataset(ctx, resource, std::move(prefix));
*output = new Dataset(ctx, resource.get(), std::move(prefix));
}
private:

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
namespace data {
@ -31,13 +32,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
string end_key;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
BigtableTableResource* resource;
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource);
*output =
new Dataset(ctx, resource, std::move(start_key), std::move(end_key));
*output = new Dataset(ctx, resource.get(), std::move(start_key),
std::move(end_key));
}
private:

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
namespace data {
@ -35,10 +36,9 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
string end_key;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
BigtableTableResource* resource;
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource);
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
errors::InvalidArgument(
@ -49,7 +49,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
"If prefix is specified, end_key must be empty."));
}
*output = new Dataset(ctx, resource, std::move(prefix),
*output = new Dataset(ctx, resource.get(), std::move(prefix),
std::move(start_key), std::move(end_key));
}

View File

@ -25,11 +25,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
BigtableTableResource* resource;
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource);
*output = new Dataset(ctx, resource);
*output = new Dataset(ctx, resource.get());
}
private:

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
namespace data {
@ -64,10 +65,9 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
errors::InvalidArgument(
"Probability outside the range of (0, 1]. Got: ", probability));
BigtableTableResource* resource;
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref scoped_unref(resource);
const uint64 num_outputs = columns.size() + 1;
std::vector<PartialTensorShape> output_shapes;
@ -79,7 +79,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
output_types.push_back(DT_STRING);
}
*output = new Dataset(ctx, resource, std::move(prefix),
*output = new Dataset(ctx, resource.get(), std::move(prefix),
std::move(start_key), std::move(end_key),
std::move(column_families), std::move(columns),
probability, output_types, std::move(output_shapes));

View File

@ -21,6 +21,7 @@
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
@ -44,7 +45,7 @@ class CreateTreeEnsembleVariableOp : public OpKernel {
const Tensor* tree_ensemble_config_t;
OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
&tree_ensemble_config_t));
auto* result = new boosted_trees::models::DecisionTreeEnsembleResource();
auto* result = new DecisionTreeEnsembleResource();
if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<string>()(),
stamp_token)) {
result->Unref();
@ -69,11 +70,10 @@ class TreeEnsembleStampTokenOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex());
core::ScopedUnref unref_me(ensemble_resource);
Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
&output_stamp_token_t));
@ -88,11 +88,10 @@ class TreeEnsembleSerializeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex());
core::ScopedUnref unref_me(ensemble_resource);
Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
&output_stamp_token_t));
@ -112,11 +111,10 @@ class TreeEnsembleDeserializeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
mutex_lock l(*ensemble_resource->get_mutex());
core::ScopedUnref unref_me(ensemble_resource);
// Get the stamp token.
const Tensor* stamp_token_t;
@ -146,12 +144,11 @@ class TreeEnsembleUsedHandlersOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex());
core::ScopedUnref unref_me(ensemble_resource);
// Get the stamp token.
const Tensor* stamp_token_t;
@ -194,9 +191,8 @@ class TreeEnsembleUsedHandlersOp : public OpKernel {
REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource);
REGISTER_KERNEL_BUILDER(
Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU),
IsResourceInitialized<boosted_trees::models::DecisionTreeEnsembleResource>);
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU),
IsResourceInitialized<DecisionTreeEnsembleResource>);
REGISTER_KERNEL_BUILDER(Name("CreateTreeEnsembleVariable").Device(DEVICE_CPU),
CreateTreeEnsembleVariableOp);

View File

@ -163,12 +163,11 @@ class GradientTreesPredictionOp : public OpKernel {
}
void Compute(OpKernelContext* const context) override {
DecisionTreeEnsembleResource* ensemble_resource;
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
// Gets the resource. Grabs the mutex but releases it.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(ensemble_resource);
if (use_locking_) {
tf_shared_lock l(*ensemble_resource->get_mutex());
DoCompute(context, ensemble_resource,
@ -184,9 +183,10 @@ class GradientTreesPredictionOp : public OpKernel {
// leaf index in prediction. Though this class invokes only with this param
// value as false, the subclass GradientTreesPredictionVerboseOp will invoke
// with the true value.
virtual void DoCompute(OpKernelContext* context,
DecisionTreeEnsembleResource* ensemble_resource,
const bool return_output_leaf_index) {
virtual void DoCompute(
OpKernelContext* context,
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
const bool return_output_leaf_index) {
// Read dense float features list;
OpInputList dense_float_features_list;
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
@ -352,9 +352,10 @@ class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp {
: GradientTreesPredictionOp(context) {}
protected:
void DoCompute(OpKernelContext* context,
DecisionTreeEnsembleResource* ensemble_resource,
bool return_output_leaf_index) override {
void DoCompute(
OpKernelContext* context,
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
bool return_output_leaf_index) override {
GradientTreesPredictionOp::DoCompute(context, ensemble_resource,
/*return_output_leaf_index=*/true);
}
@ -372,12 +373,10 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
}
void Compute(OpKernelContext* const context) override {
DecisionTreeEnsembleResource* ensemble_resource;
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
// Gets the resource. Grabs the mutex but releases it.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(ensemble_resource);
if (use_locking_) {
tf_shared_lock l(*ensemble_resource->get_mutex());
DoCompute(context, ensemble_resource);
@ -387,17 +386,18 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
}
private:
void DoCompute(OpKernelContext* context,
DecisionTreeEnsembleResource* ensemble_resource) {
void DoCompute(
OpKernelContext* context,
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
// The last non-finalized tree in the ensemble is by convention the
// one to partition on. If no such tree exists, a nodeless tree is
// created.
boosted_trees::trees::DecisionTreeConfig empty_tree_config;
const boosted_trees::trees::DecisionTreeConfig& tree_config =
(ensemble_resource->num_trees() <= 0 ||
ensemble_resource->LastTreeMetadata()->is_finalized())
(resource->num_trees() <= 0 ||
resource->LastTreeMetadata()->is_finalized())
? empty_tree_config
: *ensemble_resource->LastTree();
: *resource->LastTree();
// Read dense float features list;
OpInputList dense_float_features_list;

View File

@ -28,6 +28,7 @@
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
@ -299,13 +300,12 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel {
const ResourceHandle& handle =
resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0);
QuantileStreamResource* streams_resource;
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context,
LookupResource(context, handle, &streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
// If the stamp is invalid we drop the update.
if (!streams_resource->is_stamp_valid(stamp_token)) {
@ -467,13 +467,12 @@ class QuantileAccumulatorSerializeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource;
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
int64 stamp_token = streams_resource->stamp();
Tensor* stream_state_t;
@ -526,13 +525,12 @@ class QuantileAccumulatorDeserializeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource;
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
int64 old_stamp_token = streams_resource->stamp();
@ -595,13 +593,12 @@ class QuantileAccumulatorFlushOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource;
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context,
@ -641,13 +638,12 @@ class QuantileAccumulatorFlushSummaryOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource;
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context,
@ -713,12 +709,11 @@ class QuantileAccumulatorGetBucketsOp : public OpKernel {
const ResourceHandle& handle =
resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0);
QuantileStreamResource* streams_resource;
core::RefCountPtr<QuantileStreamResource> streams_resource;
OP_REQUIRES_OK(context,
LookupResource(context, handle, &streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
bool are_buckets_ready =
streams_resource->is_stamp_valid(stamp_token) &&

View File

@ -27,6 +27,7 @@
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
@ -130,9 +131,8 @@ using StatsAccumulatorTensorResource =
StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
void SerializeScalarAccumulatorToOutput(
const StatsAccumulatorScalarResource& accumulator_resource,
OpKernelContext* context) {
int64 num_slots = accumulator_resource.values().size();
const StatsAccumulatorScalarResource& resource, OpKernelContext* context) {
int64 num_slots = resource.values().size();
Tensor* partition_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
TensorShape({num_slots}),
@ -159,7 +159,7 @@ void SerializeScalarAccumulatorToOutput(
auto hessians = hessians_t->vec<float>();
int i = 0;
for (const auto& iter : accumulator_resource.values()) {
for (const auto& iter : resource.values()) {
partition_ids(i) = iter.first.partition_id;
feature_ids(i, 0) = iter.first.feature_id;
feature_ids(i, 1) = iter.first.dimension;
@ -171,9 +171,8 @@ void SerializeScalarAccumulatorToOutput(
}
void SerializeTensorAccumulatorToOutput(
const StatsAccumulatorTensorResource& accumulator_resource,
OpKernelContext* context) {
int64 num_slots = accumulator_resource.values().size();
const StatsAccumulatorTensorResource& resource, OpKernelContext* context) {
int64 num_slots = resource.values().size();
Tensor* partition_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
TensorShape({num_slots}),
@ -186,7 +185,7 @@ void SerializeTensorAccumulatorToOutput(
&feature_ids_t));
auto feature_ids = feature_ids_t->matrix<int64>();
TensorShape gradient_shape = accumulator_resource.gradient_shape();
TensorShape gradient_shape = resource.gradient_shape();
int64 num_gradient_elements = gradient_shape.num_elements();
gradient_shape.InsertDim(0, num_slots);
Tensor* gradients_t = nullptr;
@ -195,7 +194,7 @@ void SerializeTensorAccumulatorToOutput(
&gradients_t));
auto gradients = gradients_t->flat_outer_dims<float>();
TensorShape hessian_shape = accumulator_resource.hessian_shape();
TensorShape hessian_shape = resource.hessian_shape();
int64 num_hessian_elements = hessian_shape.num_elements();
hessian_shape.InsertDim(0, num_slots);
Tensor* hessians_t = nullptr;
@ -204,7 +203,7 @@ void SerializeTensorAccumulatorToOutput(
auto hessians = hessians_t->flat_outer_dims<float>();
int i = 0;
for (const auto& iter : accumulator_resource.values()) {
for (const auto& iter : resource.values()) {
partition_ids(i) = iter.first.partition_id;
feature_ids(i, 0) = iter.first.feature_id;
feature_ids(i, 1) = iter.first.dimension;
@ -220,11 +219,10 @@ void SerializeTensorAccumulatorToOutput(
}
void AddToScalarAccumulator(
StatsAccumulatorScalarResource* accumulator_resource,
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
const Tensor& gradients_t, const Tensor& hessians_t) {
accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
1);
resource->set_num_updates(resource->num_updates() + 1);
const TensorShape& partition_ids_shape = partition_ids_t.shape();
const auto& partition_ids = partition_ids_t.vec<int32>();
const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
@ -232,7 +230,7 @@ void AddToScalarAccumulator(
const auto& hessians = hessians_t.vec<float>();
int64 num_updates = partition_ids_shape.dim_size(0);
auto stats_map = accumulator_resource->mutable_values();
auto stats_map = resource->mutable_values();
for (int64 i = 0; i < num_updates; ++i) {
const auto key =
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
@ -248,7 +246,7 @@ void AddToScalarAccumulator(
}
void AddToScalarAccumulator(
StatsAccumulatorScalarResource* accumulator_resource,
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
OpKernelContext* context) {
const Tensor* partition_ids_t;
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
@ -258,17 +256,16 @@ void AddToScalarAccumulator(
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
AddToScalarAccumulator(resource, *partition_ids_t, *feature_ids_t,
*gradients_t, *hessians_t);
}
void AddToTensorAccumulator(
StatsAccumulatorTensorResource* accumulator_resource,
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
const Tensor& gradients_t, const Tensor& hessians_t,
OpKernelContext* context) {
accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
1);
resource->set_num_updates(resource->num_updates() + 1);
const TensorShape& partition_ids_shape = partition_ids_t.shape();
const auto& partition_ids = partition_ids_t.vec<int32>();
@ -283,19 +280,19 @@ void AddToTensorAccumulator(
// TODO(soroush): Move gradient and hessian shape check to ShapeFn.
OP_REQUIRES(
context, gradients_shape == accumulator_resource->gradient_shape(),
context, gradients_shape == resource->gradient_shape(),
errors::InvalidArgument(strings::StrCat(
"Gradients dimensions must match: ", gradients_shape.DebugString(),
", ", accumulator_resource->gradient_shape().DebugString())));
", ", resource->gradient_shape().DebugString())));
OP_REQUIRES(
context, hessians_shape == accumulator_resource->hessian_shape(),
context, hessians_shape == resource->hessian_shape(),
errors::InvalidArgument(strings::StrCat(
"Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
accumulator_resource->hessian_shape().DebugString())));
resource->hessian_shape().DebugString())));
int64 num_updates = partition_ids_shape.dim_size(0);
auto stats_map = accumulator_resource->mutable_values();
auto stats_map = resource->mutable_values();
for (int64 i = 0; i < num_updates; ++i) {
const auto key =
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
@ -325,7 +322,7 @@ void AddToTensorAccumulator(
}
void AddToTensorAccumulator(
StatsAccumulatorTensorResource* accumulator_resource,
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
OpKernelContext* context) {
const Tensor* partition_ids_t;
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
@ -335,7 +332,7 @@ void AddToTensorAccumulator(
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
AddToTensorAccumulator(resource, *partition_ids_t, *feature_ids_t,
*gradients_t, *hessians_t, context);
}
@ -452,20 +449,18 @@ class StatsAccumulatorScalarAddOp : public OpKernel {
resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0);
StatsAccumulatorScalarResource* accumulator_resource;
OP_REQUIRES_OK(context, LookupResource(context, handle,
&accumulator_resource));
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
mutex_lock l(*resource->mutex());
// If the stamp is invalid we drop the update.
if (!accumulator_resource->is_stamp_valid(stamp_token)) {
if (!resource->is_stamp_valid(stamp_token)) {
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
<< "Passed stamp token: " << stamp_token << " "
<< "Current token: " << accumulator_resource->stamp();
<< "Current token: " << resource->stamp();
return;
}
AddToScalarAccumulator(accumulator_resource,
AddToScalarAccumulator(resource,
partition_ids_list[resource_handle_idx],
feature_ids_list[resource_handle_idx],
gradients_list[resource_handle_idx],
@ -517,20 +512,18 @@ class StatsAccumulatorTensorAddOp : public OpKernel {
resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0);
StatsAccumulatorTensorResource* accumulator_resource;
OP_REQUIRES_OK(context, LookupResource(context, handle,
&accumulator_resource));
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
mutex_lock l(*resource->mutex());
// If the stamp is invalid we drop the update.
if (!accumulator_resource->is_stamp_valid(stamp_token)) {
if (!resource->is_stamp_valid(stamp_token)) {
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
<< "Passed stamp token: " << stamp_token << " "
<< "Current token: " << accumulator_resource->stamp();
<< "Current token: " << resource->stamp();
return;
}
AddToTensorAccumulator(accumulator_resource,
AddToTensorAccumulator(resource,
partition_ids_list[resource_handle_idx],
feature_ids_list[resource_handle_idx],
gradients_list[resource_handle_idx],
@ -549,11 +542,10 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
StatsAccumulatorScalarResource* accumulator_resource;
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource));
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
&resource));
mutex_lock l(*resource->mutex());
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
@ -562,7 +554,7 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
// If the stamp is invalid we restart the PS. It shouldn't happen since
// only Chief should call this function and chief is guaranteed to be in
// a consistent state.
CHECK(accumulator_resource->is_stamp_valid(stamp_token));
CHECK(resource->is_stamp_valid(stamp_token));
const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context,
@ -570,15 +562,15 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
CHECK(stamp_token != next_stamp_token);
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
SerializeScalarAccumulatorToOutput(*resource, context);
Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}),
&num_updates_t));
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
num_updates_t->scalar<int64>()() = resource->num_updates();
accumulator_resource->Clear();
accumulator_resource->set_stamp(next_stamp_token);
resource->Clear();
resource->set_stamp(next_stamp_token);
}
};
@ -591,11 +583,10 @@ class StatsAccumulatorTensorFlushOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
StatsAccumulatorTensorResource* accumulator_resource;
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource));
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
&resource));
mutex_lock l(*resource->mutex());
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
@ -609,16 +600,16 @@ class StatsAccumulatorTensorFlushOp : public OpKernel {
// If the stamp is invalid we restart the PS. It shouldn't happen since
// only Chief should call this function and chief is guaranteed to be in
// a consistent state.
CHECK(accumulator_resource->is_stamp_valid(stamp_token));
CHECK(resource->is_stamp_valid(stamp_token));
CHECK(stamp_token != next_stamp_token);
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
SerializeTensorAccumulatorToOutput(*resource, context);
Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}),
&num_updates_t));
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
accumulator_resource->Clear();
accumulator_resource->set_stamp(next_stamp_token);
num_updates_t->scalar<int64>()() = resource->num_updates();
resource->Clear();
resource->set_stamp(next_stamp_token);
}
};
@ -631,22 +622,21 @@ class StatsAccumulatorScalarDeserializeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
StatsAccumulatorScalarResource* accumulator_resource;
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource));
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
&resource));
mutex_lock l(*resource->mutex());
// Check the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
accumulator_resource->Clear();
accumulator_resource->set_stamp(stamp_token);
AddToScalarAccumulator(accumulator_resource, context);
resource->Clear();
resource->set_stamp(stamp_token);
AddToScalarAccumulator(resource, context);
const Tensor* num_updates_t;
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
resource->set_num_updates(num_updates_t->scalar<int64>()());
}
};
@ -660,22 +650,21 @@ class StatsAccumulatorTensorDeserializeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
StatsAccumulatorTensorResource* accumulator_resource;
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource));
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
&resource));
mutex_lock l(*resource->mutex());
// Check the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
accumulator_resource->Clear();
accumulator_resource->set_stamp(stamp_token);
AddToTensorAccumulator(accumulator_resource, context);
resource->Clear();
resource->set_stamp(stamp_token);
AddToTensorAccumulator(resource, context);
const Tensor* num_updates_t;
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
resource->set_num_updates(num_updates_t->scalar<int64>()());
}
};
@ -689,23 +678,22 @@ class StatsAccumulatorScalarSerializeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
StatsAccumulatorScalarResource* accumulator_resource;
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource));
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
&resource));
mutex_lock l(*resource->mutex());
SerializeScalarAccumulatorToOutput(*resource, context);
Tensor* stamp_token_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("stamp_token", TensorShape({}),
&stamp_token_t));
stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
stamp_token_t->scalar<int64>()() = resource->stamp();
Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}),
&num_updates_t));
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
num_updates_t->scalar<int64>()() = resource->num_updates();
}
};
@ -719,23 +707,22 @@ class StatsAccumulatorTensorSerializeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
StatsAccumulatorTensorResource* accumulator_resource;
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&accumulator_resource));
mutex_lock l(*accumulator_resource->mutex());
core::ScopedUnref unref_me(accumulator_resource);
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
&resource));
mutex_lock l(*resource->mutex());
SerializeTensorAccumulatorToOutput(*resource, context);
Tensor* stamp_token_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("stamp_token", TensorShape({}),
&stamp_token_t));
stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
stamp_token_t->scalar<int64>()() = resource->stamp();
Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}),
&num_updates_t));
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
num_updates_t->scalar<int64>()() = resource->num_updates();
}
};
@ -751,12 +738,11 @@ class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
void Compute(OpKernelContext* context) override {
TensorShape gradient_shape = TensorShape({});
TensorShape hessian_shape = TensorShape({});
StatsAccumulatorScalarResource* accumulator_resource =
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
core::ScopedUnref unref_me(accumulator_resource);
core::RefCountPtr<StatsAccumulatorScalarResource> resource(
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape));
// Check the stamp token.
AddToScalarAccumulator(accumulator_resource, context);
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
AddToScalarAccumulator(resource, context);
SerializeScalarAccumulatorToOutput(*resource, context);
}
};
@ -780,12 +766,11 @@ class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
TensorShape hessians_shape = hessians_t->shape();
hessians_shape.RemoveDim(0);
StatsAccumulatorTensorResource* accumulator_resource =
new StatsAccumulatorTensorResource(gradients_shape, hessians_shape);
core::ScopedUnref unref_me(accumulator_resource);
core::RefCountPtr<StatsAccumulatorTensorResource> resource(
new StatsAccumulatorTensorResource(gradients_shape, hessians_shape));
// Check the stamp token.
AddToTensorAccumulator(accumulator_resource, context);
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
AddToTensorAccumulator(resource, context);
SerializeTensorAccumulatorToOutput(*resource, context);
}
};

View File

@ -21,6 +21,7 @@
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
@ -31,6 +32,8 @@ namespace {
using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearningRateConfig;
using boosted_trees::models::DecisionTreeEnsembleResource;
using boosted_trees::trees::DecisionTreeConfig;
using boosted_trees::trees::Leaf;
using boosted_trees::trees::TreeNode;
using boosted_trees::trees::TreeNodeMetadata;
@ -193,10 +196,9 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble.
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
mutex_lock l(*ensemble_resource->get_mutex());
// Get the stamp token.
@ -255,8 +257,8 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
private:
// Helper method to retrieve the bias from the tree ensemble.
boosted_trees::trees::Leaf* RetrieveBias(
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource,
Leaf* RetrieveBias(
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
int64 logits_dimension) {
const int32 num_trees = ensemble_resource->num_trees();
if (num_trees <= 0) {
@ -319,10 +321,9 @@ class GrowTreeEnsembleOp : public OpKernel {
void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble.
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
mutex_lock l(*ensemble_resource->get_mutex());
// Get the stamp token.
@ -400,10 +401,9 @@ class GrowTreeEnsembleOp : public OpKernel {
// Update and retrieve the growable tree.
// If the tree is fully built and dropout was applied, it also adjusts the
// weights of dropped and the last tree.
boosted_trees::trees::DecisionTreeConfig* const tree_config =
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
dropout_seed, max_tree_depth,
weak_learner_type);
DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(
ensemble_resource, learning_rate, dropout_seed, max_tree_depth,
weak_learner_type);
// Split tree nodes.
switch (weak_learner_type) {
case LearnerConfig::NORMAL_DECISION_TREE: {
@ -559,8 +559,7 @@ class GrowTreeEnsembleOp : public OpKernel {
}
void UpdateTreeWeightsIfDropout(
boosted_trees::models::DecisionTreeEnsembleResource* const
ensemble_resource,
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
const uint64 dropout_seed) {
// It is possible that the tree was built with dropout. If it is the case,
// we need to adjust the tree weight, or bail out.
@ -609,9 +608,8 @@ class GrowTreeEnsembleOp : public OpKernel {
// Helper method to update the growable tree which is by definition the last
// tree in the ensemble.
boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
boosted_trees::models::DecisionTreeEnsembleResource* const
ensemble_resource,
DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
const float learning_rate, const uint64 dropout_seed,
const int32 max_tree_depth, const int32 weak_learner_type) {
const auto num_trees = ensemble_resource->num_trees();
@ -719,8 +717,8 @@ class GrowTreeEnsembleOp : public OpKernel {
// leaf children given the split candidate.
void SplitTreeNode(
const int32 node_id, SplitCandidate* split,
boosted_trees::trees::DecisionTreeConfig* tree_config,
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
DecisionTreeConfig* tree_config,
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
// No-op if we have no real node.
CHECK(node_id < tree_config->nodes_size())
<< "Invalid node " << node_id << " to split.";
@ -761,14 +759,13 @@ class GrowTreeEnsembleOp : public OpKernel {
(*tree_config->mutable_nodes(node_id)) =
*split->split_info.mutable_split_node();
if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
ensemble_resource->MaybeAddUsedHandler(split->handler_id);
resource->MaybeAddUsedHandler(split->handler_id);
}
}
void SplitTreeLayer(
SplitCandidate* split,
boosted_trees::trees::DecisionTreeConfig* tree_config,
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
SplitCandidate* split, DecisionTreeConfig* tree_config,
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
int depth = 0;
while (depth < tree_config->nodes_size() &&
tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
@ -903,10 +900,9 @@ class TreeEnsembleStatsOp : public OpKernel {
void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble.
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
tf_shared_lock l(*ensemble_resource->get_mutex());
// Get the stamp token.

View File

@ -95,7 +95,7 @@ class ZeroVarInitializer : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
Var* variable = nullptr;
core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, 0), &variable,
[this, ctx](Var** var_ptr) {
@ -117,7 +117,6 @@ class ZeroVarInitializer : public OpKernel {
return Status::OK();
}));
core::ScopedUnref scoped(variable);
mutex_lock ml(*variable->mu());
OP_REQUIRES(ctx, !variable->is_initialized,

View File

@ -13,6 +13,7 @@
// limitations under the License.
// =============================================================================
#include <functional>
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
@ -24,6 +25,7 @@
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
@ -76,11 +78,10 @@ class TreeSerializeOp : public OpKernel {
explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
DecisionTreeResource* decision_tree_resource;
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_config_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape(), &output_config_t));
@ -100,12 +101,11 @@ class TreeDeserializeOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
DecisionTreeResource* decision_tree_resource;
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
auto handle = HandleFromInput(context, 0);
OP_REQUIRES_OK(context,
LookupResource(context, handle, &decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const Tensor* tree_config_t;
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
@ -131,11 +131,10 @@ class TreeSizeOp : public OpKernel {
explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
DecisionTreeResource* decision_tree_resource;
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape(), &output_t));
@ -144,7 +143,7 @@ class TreeSizeOp : public OpKernel {
}
};
void TraverseTree(const DecisionTreeResource* tree_resource,
void TraverseTree(DecisionTreeResource* tree_resource,
const std::unique_ptr<TensorDataSet>& data, int32 start,
int32 end,
const std::function<void(int32, int32)>& set_leaf_id,
@ -182,11 +181,10 @@ class TreePredictionsV4Op : public OpKernel {
data_set->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
DecisionTreeResource* decision_tree_resource;
core::RefCountPtr<DecisionTreeResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
&resource));
mutex_lock l(*resource->get_mutex());
const int num_data = data_set->NumItems();
const int32 num_outputs = param_proto_.num_outputs();
@ -205,15 +203,16 @@ class TreePredictionsV4Op : public OpKernel {
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
const int64 costPerTraverse = 500;
auto traverse = [this, &out, &data_set, decision_tree_resource, num_data,
&tree_paths](int64 start, int64 end) {
auto traverse = [this, &out, &data_set, &resource, num_data, &tree_paths](
int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
TraverseTree(resource.get(), data_set, static_cast<int32>(start),
static_cast<int32>(end),
std::bind(&TreePredictionsV4Op::set_output_value, this,
std::placeholders::_1, std::placeholders::_2,
decision_tree_resource, &out),
resource.get(), &out),
param_proto_.inference_tree_paths() ? &tree_paths : nullptr);
};
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
@ -283,11 +282,10 @@ class TraverseTreeV4Op : public OpKernel {
data_set->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
DecisionTreeResource* decision_tree_resource;
core::RefCountPtr<DecisionTreeResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
&resource));
mutex_lock l(*resource->get_mutex());
const int num_data = data_set->NumItems();
@ -304,11 +302,11 @@ class TraverseTreeV4Op : public OpKernel {
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
const int64 costPerTraverse = 500;
auto traverse = [&set_leaf_ids, &data_set, decision_tree_resource,
num_data](int64 start, int64 end) {
auto traverse = [&set_leaf_ids, &data_set, &resource, num_data](int64 start,
int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
TraverseTree(resource.get(), data_set, static_cast<int32>(start),
static_cast<int32>(end), set_leaf_ids, nullptr);
};
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
@ -336,11 +334,10 @@ class UpdateModelV4Op : public OpKernel {
const Tensor& input_labels = context->input(2);
const Tensor& input_weights = context->input(3);
DecisionTreeResource* decision_tree_resource;
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const int num_data = input_labels.shape().dim_size(0);
const int32 label_dim =
@ -356,9 +353,10 @@ class UpdateModelV4Op : public OpKernel {
UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
}
void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target,
int32 start, int32 end,
DecisionTreeResource* decision_tree_resource) {
void UpdateModel(
const Tensor& leaf_ids, const TensorInputTarget& target, int32 start,
int32 end,
const core::RefCountPtr<DecisionTreeResource>& decision_tree_resource) {
const auto leaves = leaf_ids.unaligned_flat<int32>();
for (int i = start; i < end; ++i) {
model_op_->UpdateModel(
@ -384,11 +382,10 @@ class FeatureUsageCountsOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
DecisionTreeResource* decision_tree_resource;
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const auto& tree = decision_tree_resource->decision_tree();

View File

@ -26,6 +26,7 @@
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
@ -87,11 +88,10 @@ class FertileStatsSerializeOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
FertileStatsResource* fertile_stats_resource;
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&fertile_stats_resource));
mutex_lock l(*fertile_stats_resource->get_mutex());
core::ScopedUnref unref_me(fertile_stats_resource);
Tensor* output_config_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape(), &output_config_t));
@ -116,11 +116,10 @@ class FertileStatsDeserializeOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
FertileStatsResource* fertile_stats_resource;
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&fertile_stats_resource));
mutex_lock l(*fertile_stats_resource->get_mutex());
core::ScopedUnref unref_me(fertile_stats_resource);
const Tensor* stats_config_t;
OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
@ -145,7 +144,7 @@ class FertileStatsDeserializeOp : public OpKernel {
// acquired, put it in a waiting queue to come back to later and try the next
// one. Once all leaf_ids have been visited, cycle through the waiting ids
// until they're gone.
void UpdateStats(FertileStatsResource* fertile_stats_resource,
void UpdateStats(const core::RefCountPtr<FertileStatsResource>& resource,
const std::unique_ptr<TensorDataSet>& data,
const TensorInputTarget& target, int num_targets,
const Tensor& leaf_ids_tensor,
@ -183,8 +182,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
}
bool is_finished;
fertile_stats_resource->AddExampleToStatsAndInitialize(
data, &target, {example_id}, leaf_id, &is_finished);
resource->AddExampleToStatsAndInitialize(data, &target, {example_id},
leaf_id, &is_finished);
leaf_lock->unlock();
if (is_finished) {
set_lock->lock();
@ -196,8 +195,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
// Update leaves from start through end in the leaf_examples iterator.
void UpdateStatsCollated(
FertileStatsResource* fertile_stats_resource,
DecisionTreeResource* tree_resource,
const core::RefCountPtr<FertileStatsResource>& fertile_stats_resource,
const core::RefCountPtr<DecisionTreeResource>& tree_resource,
const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
int num_targets,
const std::unordered_map<int32, std::vector<int>>& leaf_examples,
@ -251,18 +250,15 @@ class ProcessInputOp : public OpKernel {
data_set->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
FertileStatsResource* fertile_stats_resource;
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
&fertile_stats_resource));
DecisionTreeResource* tree_resource;
core::RefCountPtr<DecisionTreeResource> tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_resource));
mutex_lock l1(*fertile_stats_resource->get_mutex());
mutex_lock l2(*tree_resource->get_mutex());
core::ScopedUnref unref_stats(fertile_stats_resource);
core::ScopedUnref unref_tree(tree_resource);
const int32 num_data = data_set->NumItems();
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
@ -308,7 +304,7 @@ class ProcessInputOp : public OpKernel {
// if it really matters that much.
const int64 costPerUpdate = 1000;
auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set,
fertile_stats_resource, &locks, &set_lock, &ready_to_split,
&fertile_stats_resource, &locks, &set_lock, &ready_to_split,
num_data](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
@ -317,8 +313,8 @@ class ProcessInputOp : public OpKernel {
static_cast<int32>(end), &ready_to_split);
};
auto update_collated = [&target, &num_targets, fertile_stats_resource,
tree_resource, &leaf_examples, &set_lock,
auto update_collated = [&target, &num_targets, &fertile_stats_resource,
&tree_resource, &leaf_examples, &set_lock,
&ready_to_split, &data_set,
num_leaves](int64 start, int64 end) {
CHECK(start <= end);
@ -362,18 +358,15 @@ class GrowTreeOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
FertileStatsResource* fertile_stats_resource;
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
&fertile_stats_resource));
DecisionTreeResource* tree_resource;
core::RefCountPtr<DecisionTreeResource> tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_resource));
mutex_lock l1(*fertile_stats_resource->get_mutex());
mutex_lock l2(*tree_resource->get_mutex());
core::ScopedUnref unref_stats(fertile_stats_resource);
core::ScopedUnref unref_tree(tree_resource);
const Tensor& finished_nodes = context->input(2);
const auto finished = finished_nodes.unaligned_flat<int32>();
@ -463,19 +456,16 @@ class FinalizeTreeOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
DecisionTreeResource* tree_resource;
core::RefCountPtr<DecisionTreeResource> tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_resource));
FertileStatsResource* fertile_stats_resource;
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
&fertile_stats_resource));
mutex_lock l1(*fertile_stats_resource->get_mutex());
mutex_lock l2(*tree_resource->get_mutex());
core::ScopedUnref unref_me(tree_resource);
core::ScopedUnref unref_stats(fertile_stats_resource);
// TODO(thomaswc): Add threads
int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size();
for (int i = 0; i < num_nodes; i++) {

View File

@ -270,12 +270,19 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
template <typename T, bool use_dynamic_cast = false>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
// Looks up a resource pointed by a given resource handle.
//
// Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
// requiring the caller to explicitly call `Unref()`.
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
core::RefCountPtr<T>* value);
// Looks up multiple resources pointed by a sequence of resource handles. If
// p[i] is uninitialized then values[i] is unmodified.
template <typename T>
Status LookupResources(
OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
std::vector<core::RefCountPtr<T>>* values);
// Looks up or creates a resource.
//
@ -283,10 +290,19 @@ Status LookupResources(
// must call its `Unref()` method when it has finished using it. If the
// `creator` is invoked, its reference on the created resource is transferred
// to `ctx->resource_mgr()`.
//
// Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
// requiring the caller to explicitly call `Unref()`.
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator);
// Looks up or creates a resource.
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
core::RefCountPtr<T>* value,
std::function<Status(T**)> creator);
// Destroys a resource pointed by a given resource handle.
template <typename T>
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
@ -587,9 +603,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
}
template <typename T>
Status LookupResources(
OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p,
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) {
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
core::RefCountPtr<T>* value) {
T* raw_ptr = nullptr;
TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
value->reset(raw_ptr);
return Status::OK();
}
template <typename T>
Status LookupResources(OpKernelContext* ctx,
absl::Span<ResourceHandle const* const> p,
std::vector<core::RefCountPtr<T>>* values) {
std::vector<std::pair<const string*, const string*>> containers_and_names(
p.size());
for (size_t i = 0; i < p.size(); ++i) {
@ -607,6 +633,17 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
creator);
}
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
core::RefCountPtr<T>* value,
std::function<Status(T**)> creator) {
T* raw_ptr = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
value->reset(raw_ptr);
return Status::OK();
}
template <typename T>
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -266,15 +267,14 @@ TEST(ResourceHandleTest, CRUD) {
TF_EXPECT_OK(CreateResource(&ctx, p, r));
}
{
StubResource* r = nullptr;
core::RefCountPtr<StubResource> r;
TF_ASSERT_OK(LookupResource(&ctx, p, &r));
ASSERT_TRUE(r != nullptr);
EXPECT_EQ(r->value_, 42);
r->Unref();
}
{
TF_EXPECT_OK(DeleteResource<StubResource>(&ctx, p));
StubResource* unused = nullptr;
core::RefCountPtr<StubResource> unused;
EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok());
}
}
@ -338,13 +338,12 @@ TEST(ResourceHandleTest, DeleteUsingResourceHandle) {
StubResource* r = new StubResource;
TF_EXPECT_OK(CreateResource(&ctx, p, r));
StubResource* lookup_r = nullptr;
core::RefCountPtr<StubResource> lookup_r;
TF_EXPECT_OK(LookupResource<StubResource>(&ctx, p, &lookup_r));
EXPECT_EQ(lookup_r, r);
EXPECT_EQ(lookup_r.get(), r);
TF_EXPECT_OK(DeleteResource(&ctx, p));
EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true);
r->Unref();
}
} // end namespace tensorflow

View File

@ -166,6 +166,7 @@ tf_kernel_library(
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/eigen3",
],
)
@ -5022,6 +5023,7 @@ STATE_DEPS = [
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
] + if_sycl(["//tensorflow/core:sycl_runtime"])
tf_kernel_library(
@ -7589,6 +7591,7 @@ tf_kernel_library(
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/db:sqlite",
"//tensorflow/core/summary:schema",

View File

@ -83,6 +83,7 @@ tf_kernel_library(
":resources",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
],
)
@ -106,6 +107,7 @@ tf_kernel_library(
":tree_helper",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
"//third_party/eigen3",
],
@ -117,6 +119,7 @@ tf_kernel_library(
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
],
)

View File

@ -51,12 +51,10 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
}
void Compute(OpKernelContext* const context) override {
BoostedTreesEnsembleResource* resource;
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
// Get the resource.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(resource);
// Get the inputs.
OpInputList bucketized_features_list;
@ -198,12 +196,10 @@ class BoostedTreesPredictOp : public OpKernel {
}
void Compute(OpKernelContext* const context) override {
BoostedTreesEnsembleResource* resource;
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
// Get the resource.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(resource);
// Get the inputs.
OpInputList bucketized_features_list;
@ -302,12 +298,10 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
}
void Compute(OpKernelContext* const context) override {
BoostedTreesEnsembleResource* resource;
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
// Get the resource.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(resource);
// Get the inputs.
OpInputList bucketized_features_list;

View File

@ -27,6 +27,7 @@
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@ -224,12 +225,11 @@ class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
ResourceHandle handle;
OP_REQUIRES_OK(context,
HandleFromInput(context, kResourceHandleName, &handle));
QuantileStreamResource* stream_resource;
core::RefCountPtr<QuantileStreamResource> stream_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*stream_resource->mutex());
core::ScopedUnref unref_me(stream_resource);
OpInputList summaries_list;
OP_REQUIRES_OK(context,
@ -281,13 +281,12 @@ class BoostedTreesQuantileStreamResourceDeserializeOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource;
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
OpInputList bucket_boundaries_list;
OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
@ -336,12 +335,11 @@ class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
ResourceHandle handle;
OP_REQUIRES_OK(context,
HandleFromInput(context, kResourceHandleName, &handle));
QuantileStreamResource* stream_resource;
core::RefCountPtr<QuantileStreamResource> stream_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*stream_resource->mutex());
core::ScopedUnref unref_me(stream_resource);
const Tensor* num_buckets_t;
OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
@ -391,12 +389,11 @@ class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
ResourceHandle handle;
OP_REQUIRES_OK(context,
HandleFromInput(context, kResourceHandleName, &handle));
QuantileStreamResource* stream_resource;
core::RefCountPtr<QuantileStreamResource> stream_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*stream_resource->mutex());
core::ScopedUnref unref_me(stream_resource);
const int64 num_streams = stream_resource->num_streams();
CHECK_EQ(num_features_, num_streams);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/boosted_trees/resources.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
@ -78,11 +79,10 @@ class BoostedTreesGetEnsembleStatesOp : public OpKernel {
void Compute(OpKernelContext* context) override {
// Looks up the resource.
BoostedTreesEnsembleResource* tree_ensemble_resource;
core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_ensemble_resource));
tf_shared_lock l(*tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(tree_ensemble_resource);
// Sets the outputs.
const int num_trees = tree_ensemble_resource->num_trees();
@ -141,11 +141,10 @@ class BoostedTreesSerializeEnsembleOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
BoostedTreesEnsembleResource* tree_ensemble_resource;
core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_ensemble_resource));
tf_shared_lock l(*tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(tree_ensemble_resource);
Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
&output_stamp_token_t));
@ -169,11 +168,10 @@ class BoostedTreesDeserializeEnsembleOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
BoostedTreesEnsembleResource* tree_ensemble_resource;
core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&tree_ensemble_resource));
mutex_lock l(*tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(tree_ensemble_resource);
// Get the stamp token.
const Tensor* stamp_token_t;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/boosted_trees/resources.h"
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
@ -55,10 +56,9 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble.
BoostedTreesEnsembleResource* ensemble_resource;
core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
mutex_lock l(*ensemble_resource->get_mutex());
// Increase the ensemble stamp.
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
@ -176,19 +176,19 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
private:
int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
BoostedTreesEnsembleResource* const ensemble_resource) {
int32 num_trees = ensemble_resource->num_trees();
const core::RefCountPtr<BoostedTreesEnsembleResource>& resource) {
int32 num_trees = resource->num_trees();
int32 current_tree = num_trees - 1;
// Increment global attempt stats.
ensemble_resource->UpdateGrowingMetadata();
resource->UpdateGrowingMetadata();
// Note we don't set tree weight to be equal to learning rate, since we
// apply learning rate to leaf weights instead, when doing layer-by-layer
// boosting.
if (num_trees <= 0) {
// Create a new tree with a no-op leaf.
current_tree = ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
}
return current_tree;
}
@ -250,10 +250,9 @@ class BoostedTreesCenterBiasOp : public OpKernel {
void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble.
BoostedTreesEnsembleResource* ensemble_resource;
core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
core::ScopedUnref unref_me(ensemble_resource);
mutex_lock l(*ensemble_resource->get_mutex());
// Increase the ensemble stamp.
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@ -65,11 +66,9 @@ class ResourceCountUpToOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
Var* variable = nullptr;
OP_REQUIRES_OK(
context,
LookupResource<Var>(context, HandleFromInput(context, 0), &variable));
core::ScopedUnref s(variable);
core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&variable));
mutex_lock l(*variable->mu());
Tensor before_increment = *variable->tensor();
OP_REQUIRES(

View File

@ -342,6 +342,7 @@ tf_kernel_library(
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:summary_interface",
],
@ -380,6 +381,7 @@ tf_kernel_library(
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/data:dataset_utils",
"//third_party/eigen3",
],

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/data/stats_utils.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
@ -83,17 +84,16 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
StatsAggregatorResource* stats_aggregator_resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
core::RefCountPtr<StatsAggregatorResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
string tag;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
string prefix;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix));
*output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource,
tag, prefix);
*output =
new Dataset(ctx, input, ctx->input(1), resource.get(), tag, prefix);
}
private:
@ -101,12 +101,12 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
const Tensor& resource_handle,
StatsAggregatorResource* stats_aggregator_resource,
const string& tag, const string& prefix)
StatsAggregatorResource* resource, const string& tag,
const string& prefix)
: DatasetBase(DatasetContext(ctx)),
input_(input),
resource_handle_(resource_handle),
stats_aggregator_resource_(stats_aggregator_resource),
stats_aggregator_resource_(resource),
tag_(tag),
prefix_(prefix) {
input_->Ref();
@ -169,13 +169,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
StatsAggregatorResource* stats_aggregator_resource =
StatsAggregatorResource* resource =
dataset()->stats_aggregator_resource_;
IteratorContext::Params params(ctx);
params.stats_aggregator = std::shared_ptr<StatsAggregator>(
new StatsAggregatorWithTagAndPrefix(
stats_aggregator_resource->stats_aggregator(), dataset()->tag_,
dataset()->prefix_));
new StatsAggregatorWithTagAndPrefix(resource->stats_aggregator(),
dataset()->tag_,
dataset()->prefix_));
IteratorContext iter_ctx(std::move(params));
return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
}

View File

@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/stats_aggregator.h"
#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/kernels/summary_interface.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/lib/monitoring/counter.h"
@ -258,10 +258,9 @@ class StatsAggregatorSummaryOp : public OpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
StatsAggregatorResource* resource;
core::RefCountPtr<StatsAggregatorResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref unref_iterator(resource);
Tensor* summary_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
@ -281,21 +280,19 @@ class StatsAggregatorSetSummaryWriterOp : public OpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
StatsAggregatorResource* resource;
core::RefCountPtr<StatsAggregatorResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref unref_iterator(resource);
const Tensor& summary_resource_handle_t = ctx->input(1);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
SummaryWriterInterface* sumamry_resource;
core::RefCountPtr<SummaryWriterInterface> sumamry_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource));
core::ScopedUnref unref_sumamry_resource(sumamry_resource);
TF_CHECK_OK(
resource->stats_aggregator()->SetSummaryWriter(sumamry_resource));
resource->stats_aggregator()->SetSummaryWriter(sumamry_resource.get()));
}
};

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/util/work_sharder.h"
@ -127,12 +128,10 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
ThreadPoolResource* threadpool_resource;
core::RefCountPtr<ThreadPoolResource> threadpool_resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
&threadpool_resource));
core::ScopedUnref unref_iterator(threadpool_resource);
*output = new Dataset(ctx, input, ctx->input(1), threadpool_resource);
*output = new Dataset(ctx, input, ctx->input(1), threadpool_resource.get());
}
private:

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
@ -595,10 +596,9 @@ int64 AnonymousIteratorHandleOp::current_id_(0);
void MakeIteratorOp::Compute(OpKernelContext* ctx) {
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
IteratorResource* iterator_resource;
core::RefCountPtr<IteratorResource> iterator_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
core::ScopedUnref unref(iterator_resource);
OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset));
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -492,10 +493,9 @@ class MultiDeviceIteratorInitOp : public OpKernel {
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
MultiDeviceIterator* resource;
core::RefCountPtr<MultiDeviceIterator> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
core::ScopedUnref unref(resource);
std::unique_ptr<IteratorBase> iterator;
IteratorContext::Params params(ctx);
@ -535,7 +535,7 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
MultiDeviceIterator* iterator;
core::RefCountPtr<MultiDeviceIterator> iterator;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
@ -557,7 +557,6 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
std::placeholders::_1, std::move(done));
iterator->GetNextFromShard(ctx, shard_num, incarnation_id, callback);
iterator->Unref();
}
};
@ -577,10 +576,9 @@ class MultiDeviceIteratorToStringHandleOp : public OpKernel {
// Validate that the handle corresponds to a real resource, and
// that it is an MultiDeviceIterator.
MultiDeviceIterator* resource;
core::RefCountPtr<MultiDeviceIterator> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
resource->Unref();
Tensor* string_handle_t;
OP_REQUIRES_OK(ctx,
@ -629,9 +627,8 @@ class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
// Validate that the handle corresponds to a real resource, and
// that it is an MultiDeviceIterator.
MultiDeviceIterator* resource;
core::RefCountPtr<MultiDeviceIterator> resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
core::ScopedUnref unref_iterator(resource);
if (!output_types_.empty()) {
OP_REQUIRES_OK(ctx,
VerifyTypesMatch(output_types_, resource->output_types()));

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/guarded_philox_random.h"
@ -371,10 +372,9 @@ class RandomBinomialOp : public OpKernel {
"Input probs should have length 1 or shape[0], got shape: ",
probs_tensor.shape().DebugString()));
}
Var* var = nullptr;
core::RefCountPtr<Var> var;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
ScopedUnlockUnrefVar var_guard(var);
Tensor* var_tensor = var->tensor();
OP_REQUIRES(
ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE,
@ -404,7 +404,6 @@ class RandomBinomialOp : public OpKernel {
auto philox = GetPhiloxRandomFromMem(var_data);
UpdateMemWithPhiloxRandom(
philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data);
var_guard.Release();
auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,

View File

@ -129,7 +129,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
} // namespace
void ReadVariableOp::Compute(OpKernelContext* ctx) {
Var* variable = nullptr;
core::RefCountPtr<Var> variable;
const ResourceHandle& handle = HandleFromInput(ctx, 0);
const auto status = LookupResource(ctx, handle, &variable);
OP_REQUIRES(ctx, status.ok(),
@ -139,7 +139,6 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
". This could mean that the variable was uninitialized. ",
status.ToString()));
core::ScopedUnref s(variable);
{
tf_shared_lock ml(*variable->mu());
// We're acquiring a reference to the underlying buffer while
@ -175,8 +174,7 @@ ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
}
void ReadVariablesOp::Compute(OpKernelContext* ctx) {
std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
dtypes_.size());
std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
std::vector<const ResourceHandle*> handles(dtypes_.size());
for (size_t i = 0; i < dtypes_.size(); ++i) {
handles[i] = &HandleFromInput(ctx, i);
@ -265,10 +263,9 @@ class VariableShapeOp : public OpKernel {
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* ctx) override {
Var* variable = nullptr;
core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
core::ScopedUnref s(variable);
variable->mu()->lock_shared();
TensorShape shape = variable->tensor()->shape();
variable->mu()->unlock_shared();
@ -343,7 +340,7 @@ class AssignVariableOp : public OpKernel {
"Variable and value dtypes don't match; respectively, ",
DataTypeString(dtype_), " and ",
DataTypeString(context->input(1).dtype())));
Var* variable = nullptr;
core::RefCountPtr<Var> variable;
const Tensor& value = context->input(1);
// Note: every resource-variable-manipulating op assumes copy-on-write
// semantics, and creates a copy of the variable's Tensor if its refcount is
@ -361,7 +358,6 @@ class AssignVariableOp : public OpKernel {
(*ptr)->is_initialized = true;
return Status::OK();
}));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
@ -404,7 +400,7 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& value = context->input(1);
Var* variable = nullptr;
core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
context, HandleFromInput(context, 0), &variable,
[](Var** ptr) {
@ -412,7 +408,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
*ptr = new Var(DT_VARIANT);
return Status::OK();
}));
core::ScopedUnref s(variable);
// For purposes of forwarding DT_VARIANT, we want the least
// restrictive attr; we already know the input is on host.
@ -500,10 +495,9 @@ class AssignUpdateVariableOp : public OpKernel {
explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* context) override {
Var* variable = nullptr;
core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&variable));
core::ScopedUnref s(variable);
const Tensor& value = context->input(1);
// TODO(apassos): We could possibly avoid the copy done by
@ -568,13 +562,12 @@ class VarIsInitializedOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}), &output));
auto output_tensor = output->tensor<bool, 0>();
Var* variable = nullptr;
core::RefCountPtr<Var> variable;
Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
if (!s.ok()) {
output_tensor() = false;
return;
}
core::ScopedUnref su(variable);
mutex_lock ml(*variable->mu());
output_tensor() = variable->is_initialized;
}
@ -623,10 +616,9 @@ class ResourceGatherOp : public OpKernel {
}
void Compute(OpKernelContext* c) override {
Var* v = nullptr;
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref su(v);
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
// NOTE: We hold the lock for the whole gather operation instead
// of increasing the reference count of v->tensor() to avoid a
// situation where a write to the same variable will see a
@ -765,10 +757,9 @@ class ResourceGatherNdOp : public OpKernel {
explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* c) override {
Var* v = nullptr;
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref su(v);
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
// NOTE: We hold the lock for the whole gather operation instead
// of increasing the reference count of v->tensor() to avoid a
// situation where a write to the same variable will see a
@ -821,10 +812,9 @@ class ResourceScatterUpdateOp : public OpKernel {
explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* c) override {
Var* v = nullptr;
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref unref_v(v);
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
tf_shared_lock ml(*v->mu());
Tensor* params = v->tensor();
const Tensor& indices = c->input(1);

View File

@ -243,10 +243,9 @@ class ScatterNdUpdateOp : public OpKernel {
void Compute(OpKernelContext* c) override {
if (dtype_ == DT_RESOURCE) {
Var* v;
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref scoped_unref(v);
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
mutex_lock m(*v->mu());
DoCompute(c);
} else if (use_exclusive_lock_) {
@ -271,9 +270,8 @@ class ScatterNdUpdateOp : public OpKernel {
TensorShape params_shape;
if (dtype_ == DT_RESOURCE) {
Var* v;
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref scoped_unref(v);
Tensor* t = v->tensor();
params = *t;
params_shape = params.shape();

View File

@ -15,6 +15,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
#include "tensorflow/core/lib/core/refcount.h"
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
@ -323,12 +324,11 @@ class StridedSliceAssignOp : public OpKernel {
}
} else {
if (context->input_dtype(0) == DT_RESOURCE) {
Var* v;
core::RefCountPtr<Var> v;
OP_REQUIRES_OK(
context, LookupResource(context, HandleFromInput(context, 0), &v));
core::ScopedUnref scoped_unref(v);
OP_REQUIRES_OK(context,
EnsureSparseVariableAccess<Device, T>(context, v));
EnsureSparseVariableAccess<Device, T>(context, v.get()));
mutex_lock ml(*v->mu());
old_lhs = v->tensor();
OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/db/sqlite.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/summary/schema.h"
@ -45,7 +46,7 @@ class CreateSummaryFileWriterOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
const string filename_suffix = tmp->scalar<string>()();
SummaryWriterInterface* s = nullptr;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
ctx, HandleFromInput(ctx, 0), &s,
[max_queue, flush_millis, logdir, filename_suffix,
@ -54,7 +55,6 @@ class CreateSummaryFileWriterOp : public OpKernel {
max_queue, flush_millis, logdir,
filename_suffix, ctx->env(), s);
}));
core::ScopedUnref unref(s);
}
};
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
@ -75,7 +75,7 @@ class CreateSummaryDbWriterOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
const string user_name = tmp->scalar<string>()();
SummaryWriterInterface* s = nullptr;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(
ctx,
LookupOrCreateResource<SummaryWriterInterface>(
@ -91,7 +91,6 @@ class CreateSummaryDbWriterOp : public OpKernel {
db, experiment_name, run_name, user_name, ctx->env(), s));
return Status::OK();
}));
core::ScopedUnref unref(s);
}
};
REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
@ -102,9 +101,8 @@ class FlushSummaryWriterOp : public OpKernel {
explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
OP_REQUIRES_OK(ctx, s->Flush());
}
};
@ -128,9 +126,8 @@ class WriteSummaryOp : public OpKernel {
explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()();
@ -153,9 +150,8 @@ class WriteRawProtoSummaryOp : public OpKernel {
explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
@ -190,9 +186,8 @@ class ImportEventOp : public OpKernel {
explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* t;
OP_REQUIRES_OK(ctx, ctx->input("event", &t));
std::unique_ptr<Event> event{new Event};
@ -211,9 +206,8 @@ class WriteScalarSummaryOp : public OpKernel {
explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()();
@ -234,9 +228,8 @@ class WriteHistogramSummaryOp : public OpKernel {
explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()();
@ -263,9 +256,8 @@ class WriteImageSummaryOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()();
@ -299,9 +291,8 @@ class WriteAudioSummaryOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* tmp;
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
const int64 step = tmp->scalar<int64>()();
@ -328,9 +319,8 @@ class WriteGraphSummaryOp : public OpKernel {
explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
SummaryWriterInterface* s;
core::RefCountPtr<SummaryWriterInterface> s;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
core::ScopedUnref unref(s);
const Tensor* t;
OP_REQUIRES_OK(ctx, ctx->input("step", &t));
const int64 step = t->scalar<int64>()();

View File

@ -26,6 +26,7 @@ tf_kernel_library(
":resources",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
],
)
@ -37,6 +38,7 @@ tf_kernel_library(
":resources",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
],
)

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/tensor_forest/resources.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/work_sharder.h"
@ -29,12 +30,10 @@ class TensorForestTreePredictOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource;
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const Tensor* dense_features_t = nullptr;
OP_REQUIRES_OK(context,
context->input("dense_features", &dense_features_t));
@ -60,7 +59,7 @@ class TensorForestTreePredictOp : public OpKernel {
// We will need to run it on a number of trees of diff depth
// and see the num of cpu cycles
const int64 cost_per_traverse = 500;
auto traverse = [this, &out, &dense_features, decision_tree_resource,
auto traverse = [this, &out, &dense_features, &decision_tree_resource,
batch_size](int64 start, int64 end) {
DCHECK_LE(start, end) << "Start exceeding End";
DCHECK_LE(end, batch_size) << "End exceeding batch size";
@ -74,9 +73,10 @@ class TensorForestTreePredictOp : public OpKernel {
traverse);
};
void set_output_value(const int32 example_id, const int32 leaf_id,
const TensorForestTreeResource* decision_tree_resource,
TTypes<float>::Matrix* out) const {
void set_output_value(
const int32 example_id, const int32 leaf_id,
const core::RefCountPtr<TensorForestTreeResource>& decision_tree_resource,
TTypes<float>::Matrix* out) const {
for (int j = 0; j < logits_dimension_; ++j) {
const float logit = decision_tree_resource->get_prediction(leaf_id, j);
(*out)(example_id, j) = logit;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
#include "tensorflow/core/kernels/tensor_forest/resources.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
@ -55,11 +56,10 @@ class TensorForestTreeSerializeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource;
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_config_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape(), &output_config_t));
@ -74,13 +74,11 @@ class TensorForestTreeDeserializeOp : public OpKernel {
explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource;
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const Tensor* tree_config_t;
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
@ -102,11 +100,10 @@ class TensorForestTreeSizeOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource;
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape(), &output_t));

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
@ -167,10 +168,8 @@ VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
std::unique_ptr<std::vector<mutex_lock>> locks =
absl::make_unique<std::vector<mutex_lock>>();
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks =
absl::make_unique<std::vector<tf_shared_lock>>();
auto locks = absl::make_unique<std::vector<mutex_lock>>();
auto shared_locks = absl::make_unique<std::vector<tf_shared_lock>>();
locks->reserve(acquire_order.size());
for (auto input : acquire_order) {
@ -241,11 +240,10 @@ template <typename Device, typename T>
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
bool lock_held, bool sparse, Tensor* out) {
if (ctx->input_dtype(input) == DT_RESOURCE) {
Var* var;
core::RefCountPtr<Var> var;
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
core::ScopedUnref unref_var(var);
if (sparse) {
TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var));
TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
*out = *var->tensor();
return Status::OK();
}