Use RefCountPtr in LookupResource to avoid leaks
LookupResource returns a raw pointer which the caller needs to Unref. The prevalent pattern is this is followed by a ScopedUnref. This can be problematic, since if a caller forgets to add a ScopedUnref call, we have a memory leak. We resolve this by using RefCountPtr instead of a raw pointer in LookupResource. Most use cases have been migrated in this change. Note some variables were renamed to handle line length restrictions. PiperOrigin-RevId: 250423227
This commit is contained in:
parent
4941b4a73f
commit
2e758833c3
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -31,12 +32,11 @@ std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
|
||||
std::map<int, OptionalTensor> variables;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
if (ctx->input(i).dtype() == DT_RESOURCE) {
|
||||
Var* variable = nullptr;
|
||||
core::RefCountPtr<Var> variable;
|
||||
ResourceHandle handle = HandleFromInput(ctx, i);
|
||||
OptionalTensor& optional = variables[i];
|
||||
optional.name = handle.name();
|
||||
if (LookupResource(ctx, handle, &variable).ok()) {
|
||||
core::ScopedUnref scoped_unref(variable);
|
||||
tf_shared_lock lock(*variable->mu());
|
||||
optional.present = true;
|
||||
optional.value = *variable->tensor();
|
||||
|
@ -40,7 +40,7 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
|
||||
"Variable and value dtypes don't match; respectively, ",
|
||||
DataTypeString(dtype_), " and ",
|
||||
DataTypeString(context->input(1).dtype())));
|
||||
Var* variable = nullptr;
|
||||
core::RefCountPtr<Var> variable;
|
||||
const Tensor& value = context->input(1);
|
||||
// Note: every resource-variable-manipulating op assumes copy-on-write
|
||||
// semantics, and creates a copy of the variable's Tensor if its refcount is
|
||||
@ -58,7 +58,6 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
|
||||
(*ptr)->is_initialized = true;
|
||||
return Status::OK();
|
||||
}));
|
||||
core::ScopedUnref s(variable);
|
||||
mutex_lock ml(*variable->mu());
|
||||
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
|
||||
errors::InvalidArgument(
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -86,7 +87,7 @@ static Status GetVariableInfosFromCtxInputs(
|
||||
variable_indices, std::back_inserter(resource_handles),
|
||||
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
|
||||
|
||||
std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables;
|
||||
std::vector<core::RefCountPtr<Var>> variables;
|
||||
TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables));
|
||||
|
||||
result->clear();
|
||||
|
@ -84,9 +84,8 @@ class PopulateTRTEngineCache : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
ResourceHandle handle = HandleFromInput(ctx, 0);
|
||||
TRTEngineCacheResource* resource = nullptr;
|
||||
core::RefCountPtr<TRTEngineCacheResource> resource;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource));
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
auto allocator = resource->allocator_.get();
|
||||
OP_REQUIRES(ctx, allocator != nullptr,
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -139,19 +140,19 @@ class BigtableTableOp : public OpKernel {
|
||||
ResourceMgr* mgr = ctx->resource_manager();
|
||||
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
|
||||
|
||||
BigtableClientResource* client_resource;
|
||||
core::RefCountPtr<BigtableClientResource> client_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
|
||||
core::ScopedUnref unref_client(client_resource);
|
||||
|
||||
BigtableTableResource* resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, mgr->LookupOrCreate<BigtableTableResource>(
|
||||
cinfo_.container(), cinfo_.name(), &resource,
|
||||
[this, client_resource](BigtableTableResource** ret) {
|
||||
*ret = new BigtableTableResource(client_resource, table_);
|
||||
return Status::OK();
|
||||
}));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
mgr->LookupOrCreate<BigtableTableResource>(
|
||||
cinfo_.container(), cinfo_.name(), &resource,
|
||||
[this, &client_resource](BigtableTableResource** ret) {
|
||||
*ret = new BigtableTableResource(
|
||||
client_resource.get(), table_);
|
||||
return Status::OK();
|
||||
}));
|
||||
initialized_ = true;
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
||||
@ -236,10 +237,9 @@ class ToBigtableOp : public AsyncOpKernel {
|
||||
errors::InvalidArgument("timestamp must be >= -1"),
|
||||
done);
|
||||
|
||||
BigtableTableResource* resource;
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
|
||||
core::ScopedUnref resource_cleanup(resource);
|
||||
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
|
@ -26,9 +26,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
BigtableTableResource* table;
|
||||
core::RefCountPtr<BigtableTableResource> table;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
|
||||
core::ScopedUnref scoped_unref(table);
|
||||
|
||||
std::vector<string> column_families;
|
||||
std::vector<string> columns;
|
||||
@ -50,7 +49,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, input, table, std::move(column_families),
|
||||
new Dataset(ctx, input, table.get(), std::move(column_families),
|
||||
std::move(columns), output_types, std::move(output_shapes));
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -28,12 +29,10 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
|
||||
string prefix;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
|
||||
|
||||
BigtableTableResource* resource;
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
core::ScopedUnref scoped_unref(resource);
|
||||
|
||||
*output = new Dataset(ctx, resource, std::move(prefix));
|
||||
*output = new Dataset(ctx, resource.get(), std::move(prefix));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -31,13 +32,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
|
||||
string end_key;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
|
||||
|
||||
BigtableTableResource* resource;
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
core::ScopedUnref scoped_unref(resource);
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, resource, std::move(start_key), std::move(end_key));
|
||||
*output = new Dataset(ctx, resource.get(), std::move(start_key),
|
||||
std::move(end_key));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -35,10 +36,9 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
|
||||
string end_key;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
|
||||
|
||||
BigtableTableResource* resource;
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
core::ScopedUnref scoped_unref(resource);
|
||||
|
||||
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
|
||||
errors::InvalidArgument(
|
||||
@ -49,7 +49,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
|
||||
"If prefix is specified, end_key must be empty."));
|
||||
}
|
||||
|
||||
*output = new Dataset(ctx, resource, std::move(prefix),
|
||||
*output = new Dataset(ctx, resource.get(), std::move(prefix),
|
||||
std::move(start_key), std::move(end_key));
|
||||
}
|
||||
|
||||
|
@ -25,11 +25,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
|
||||
using DatasetOpKernel::DatasetOpKernel;
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||
BigtableTableResource* resource;
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
core::ScopedUnref scoped_unref(resource);
|
||||
*output = new Dataset(ctx, resource);
|
||||
*output = new Dataset(ctx, resource.get());
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -64,10 +65,9 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
|
||||
errors::InvalidArgument(
|
||||
"Probability outside the range of (0, 1]. Got: ", probability));
|
||||
|
||||
BigtableTableResource* resource;
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
core::ScopedUnref scoped_unref(resource);
|
||||
|
||||
const uint64 num_outputs = columns.size() + 1;
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
@ -79,7 +79,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
|
||||
output_types.push_back(DT_STRING);
|
||||
}
|
||||
|
||||
*output = new Dataset(ctx, resource, std::move(prefix),
|
||||
*output = new Dataset(ctx, resource.get(), std::move(prefix),
|
||||
std::move(start_key), std::move(end_key),
|
||||
std::move(column_families), std::move(columns),
|
||||
probability, output_types, std::move(output_shapes));
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -44,7 +45,7 @@ class CreateTreeEnsembleVariableOp : public OpKernel {
|
||||
const Tensor* tree_ensemble_config_t;
|
||||
OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
|
||||
&tree_ensemble_config_t));
|
||||
auto* result = new boosted_trees::models::DecisionTreeEnsembleResource();
|
||||
auto* result = new DecisionTreeEnsembleResource();
|
||||
if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<string>()(),
|
||||
stamp_token)) {
|
||||
result->Unref();
|
||||
@ -69,11 +70,10 @@ class TreeEnsembleStampTokenOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
Tensor* output_stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||
&output_stamp_token_t));
|
||||
@ -88,11 +88,10 @@ class TreeEnsembleSerializeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
Tensor* output_stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||
&output_stamp_token_t));
|
||||
@ -112,11 +111,10 @@ class TreeEnsembleDeserializeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
mutex_lock l(*ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
@ -146,12 +144,11 @@ class TreeEnsembleUsedHandlersOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
@ -194,9 +191,8 @@ class TreeEnsembleUsedHandlersOp : public OpKernel {
|
||||
|
||||
REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU),
|
||||
IsResourceInitialized<boosted_trees::models::DecisionTreeEnsembleResource>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU),
|
||||
IsResourceInitialized<DecisionTreeEnsembleResource>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateTreeEnsembleVariable").Device(DEVICE_CPU),
|
||||
CreateTreeEnsembleVariableOp);
|
||||
|
@ -163,12 +163,11 @@ class GradientTreesPredictionOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
DecisionTreeEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
// Gets the resource. Grabs the mutex but releases it.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
// Release the reference to the resource once we're done using it.
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
if (use_locking_) {
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
DoCompute(context, ensemble_resource,
|
||||
@ -184,9 +183,10 @@ class GradientTreesPredictionOp : public OpKernel {
|
||||
// leaf index in prediction. Though this class invokes only with this param
|
||||
// value as false, the subclass GradientTreesPredictionVerboseOp will invoke
|
||||
// with the true value.
|
||||
virtual void DoCompute(OpKernelContext* context,
|
||||
DecisionTreeEnsembleResource* ensemble_resource,
|
||||
const bool return_output_leaf_index) {
|
||||
virtual void DoCompute(
|
||||
OpKernelContext* context,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
const bool return_output_leaf_index) {
|
||||
// Read dense float features list;
|
||||
OpInputList dense_float_features_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
|
||||
@ -352,9 +352,10 @@ class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp {
|
||||
: GradientTreesPredictionOp(context) {}
|
||||
|
||||
protected:
|
||||
void DoCompute(OpKernelContext* context,
|
||||
DecisionTreeEnsembleResource* ensemble_resource,
|
||||
bool return_output_leaf_index) override {
|
||||
void DoCompute(
|
||||
OpKernelContext* context,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
bool return_output_leaf_index) override {
|
||||
GradientTreesPredictionOp::DoCompute(context, ensemble_resource,
|
||||
/*return_output_leaf_index=*/true);
|
||||
}
|
||||
@ -372,12 +373,10 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
DecisionTreeEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
// Gets the resource. Grabs the mutex but releases it.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
// Release the reference to the resource once we're done using it.
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
if (use_locking_) {
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
DoCompute(context, ensemble_resource);
|
||||
@ -387,17 +386,18 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
void DoCompute(OpKernelContext* context,
|
||||
DecisionTreeEnsembleResource* ensemble_resource) {
|
||||
void DoCompute(
|
||||
OpKernelContext* context,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
|
||||
// The last non-finalized tree in the ensemble is by convention the
|
||||
// one to partition on. If no such tree exists, a nodeless tree is
|
||||
// created.
|
||||
boosted_trees::trees::DecisionTreeConfig empty_tree_config;
|
||||
const boosted_trees::trees::DecisionTreeConfig& tree_config =
|
||||
(ensemble_resource->num_trees() <= 0 ||
|
||||
ensemble_resource->LastTreeMetadata()->is_finalized())
|
||||
(resource->num_trees() <= 0 ||
|
||||
resource->LastTreeMetadata()->is_finalized())
|
||||
? empty_tree_config
|
||||
: *ensemble_resource->LastTree();
|
||||
: *resource->LastTree();
|
||||
|
||||
// Read dense float features list;
|
||||
OpInputList dense_float_features_list;
|
||||
|
@ -28,6 +28,7 @@
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -299,13 +300,12 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel {
|
||||
const ResourceHandle& handle =
|
||||
resource_handle_list[resource_handle_idx]
|
||||
.flat<ResourceHandle>()(0);
|
||||
QuantileStreamResource* streams_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context,
|
||||
LookupResource(context, handle, &streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
core::ScopedUnref unref_me(streams_resource);
|
||||
|
||||
// If the stamp is invalid we drop the update.
|
||||
if (!streams_resource->is_stamp_valid(stamp_token)) {
|
||||
@ -467,13 +467,12 @@ class QuantileAccumulatorSerializeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
QuantileStreamResource* streams_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
core::ScopedUnref unref_me(streams_resource);
|
||||
|
||||
int64 stamp_token = streams_resource->stamp();
|
||||
Tensor* stream_state_t;
|
||||
@ -526,13 +525,12 @@ class QuantileAccumulatorDeserializeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
QuantileStreamResource* streams_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
core::ScopedUnref unref_me(streams_resource);
|
||||
|
||||
int64 old_stamp_token = streams_resource->stamp();
|
||||
|
||||
@ -595,13 +593,12 @@ class QuantileAccumulatorFlushOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
QuantileStreamResource* streams_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
core::ScopedUnref unref_me(streams_resource);
|
||||
|
||||
const Tensor* next_stamp_token_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -641,13 +638,12 @@ class QuantileAccumulatorFlushSummaryOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
QuantileStreamResource* streams_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
core::ScopedUnref unref_me(streams_resource);
|
||||
|
||||
const Tensor* next_stamp_token_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -713,12 +709,11 @@ class QuantileAccumulatorGetBucketsOp : public OpKernel {
|
||||
const ResourceHandle& handle =
|
||||
resource_handle_list[resource_handle_idx]
|
||||
.flat<ResourceHandle>()(0);
|
||||
QuantileStreamResource* streams_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
OP_REQUIRES_OK(context,
|
||||
LookupResource(context, handle, &streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
core::ScopedUnref unref_me(streams_resource);
|
||||
|
||||
bool are_buckets_ready =
|
||||
streams_resource->is_stamp_valid(stamp_token) &&
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
@ -130,9 +131,8 @@ using StatsAccumulatorTensorResource =
|
||||
StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
|
||||
|
||||
void SerializeScalarAccumulatorToOutput(
|
||||
const StatsAccumulatorScalarResource& accumulator_resource,
|
||||
OpKernelContext* context) {
|
||||
int64 num_slots = accumulator_resource.values().size();
|
||||
const StatsAccumulatorScalarResource& resource, OpKernelContext* context) {
|
||||
int64 num_slots = resource.values().size();
|
||||
Tensor* partition_ids_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
|
||||
TensorShape({num_slots}),
|
||||
@ -159,7 +159,7 @@ void SerializeScalarAccumulatorToOutput(
|
||||
auto hessians = hessians_t->vec<float>();
|
||||
|
||||
int i = 0;
|
||||
for (const auto& iter : accumulator_resource.values()) {
|
||||
for (const auto& iter : resource.values()) {
|
||||
partition_ids(i) = iter.first.partition_id;
|
||||
feature_ids(i, 0) = iter.first.feature_id;
|
||||
feature_ids(i, 1) = iter.first.dimension;
|
||||
@ -171,9 +171,8 @@ void SerializeScalarAccumulatorToOutput(
|
||||
}
|
||||
|
||||
void SerializeTensorAccumulatorToOutput(
|
||||
const StatsAccumulatorTensorResource& accumulator_resource,
|
||||
OpKernelContext* context) {
|
||||
int64 num_slots = accumulator_resource.values().size();
|
||||
const StatsAccumulatorTensorResource& resource, OpKernelContext* context) {
|
||||
int64 num_slots = resource.values().size();
|
||||
Tensor* partition_ids_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
|
||||
TensorShape({num_slots}),
|
||||
@ -186,7 +185,7 @@ void SerializeTensorAccumulatorToOutput(
|
||||
&feature_ids_t));
|
||||
auto feature_ids = feature_ids_t->matrix<int64>();
|
||||
|
||||
TensorShape gradient_shape = accumulator_resource.gradient_shape();
|
||||
TensorShape gradient_shape = resource.gradient_shape();
|
||||
int64 num_gradient_elements = gradient_shape.num_elements();
|
||||
gradient_shape.InsertDim(0, num_slots);
|
||||
Tensor* gradients_t = nullptr;
|
||||
@ -195,7 +194,7 @@ void SerializeTensorAccumulatorToOutput(
|
||||
&gradients_t));
|
||||
auto gradients = gradients_t->flat_outer_dims<float>();
|
||||
|
||||
TensorShape hessian_shape = accumulator_resource.hessian_shape();
|
||||
TensorShape hessian_shape = resource.hessian_shape();
|
||||
int64 num_hessian_elements = hessian_shape.num_elements();
|
||||
hessian_shape.InsertDim(0, num_slots);
|
||||
Tensor* hessians_t = nullptr;
|
||||
@ -204,7 +203,7 @@ void SerializeTensorAccumulatorToOutput(
|
||||
auto hessians = hessians_t->flat_outer_dims<float>();
|
||||
|
||||
int i = 0;
|
||||
for (const auto& iter : accumulator_resource.values()) {
|
||||
for (const auto& iter : resource.values()) {
|
||||
partition_ids(i) = iter.first.partition_id;
|
||||
feature_ids(i, 0) = iter.first.feature_id;
|
||||
feature_ids(i, 1) = iter.first.dimension;
|
||||
@ -220,11 +219,10 @@ void SerializeTensorAccumulatorToOutput(
|
||||
}
|
||||
|
||||
void AddToScalarAccumulator(
|
||||
StatsAccumulatorScalarResource* accumulator_resource,
|
||||
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
|
||||
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
|
||||
const Tensor& gradients_t, const Tensor& hessians_t) {
|
||||
accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
|
||||
1);
|
||||
resource->set_num_updates(resource->num_updates() + 1);
|
||||
const TensorShape& partition_ids_shape = partition_ids_t.shape();
|
||||
const auto& partition_ids = partition_ids_t.vec<int32>();
|
||||
const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
|
||||
@ -232,7 +230,7 @@ void AddToScalarAccumulator(
|
||||
const auto& hessians = hessians_t.vec<float>();
|
||||
|
||||
int64 num_updates = partition_ids_shape.dim_size(0);
|
||||
auto stats_map = accumulator_resource->mutable_values();
|
||||
auto stats_map = resource->mutable_values();
|
||||
for (int64 i = 0; i < num_updates; ++i) {
|
||||
const auto key =
|
||||
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
|
||||
@ -248,7 +246,7 @@ void AddToScalarAccumulator(
|
||||
}
|
||||
|
||||
void AddToScalarAccumulator(
|
||||
StatsAccumulatorScalarResource* accumulator_resource,
|
||||
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
|
||||
OpKernelContext* context) {
|
||||
const Tensor* partition_ids_t;
|
||||
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
|
||||
@ -258,17 +256,16 @@ void AddToScalarAccumulator(
|
||||
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
|
||||
const Tensor* hessians_t;
|
||||
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
||||
AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
|
||||
AddToScalarAccumulator(resource, *partition_ids_t, *feature_ids_t,
|
||||
*gradients_t, *hessians_t);
|
||||
}
|
||||
|
||||
void AddToTensorAccumulator(
|
||||
StatsAccumulatorTensorResource* accumulator_resource,
|
||||
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
|
||||
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
|
||||
const Tensor& gradients_t, const Tensor& hessians_t,
|
||||
OpKernelContext* context) {
|
||||
accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
|
||||
1);
|
||||
resource->set_num_updates(resource->num_updates() + 1);
|
||||
|
||||
const TensorShape& partition_ids_shape = partition_ids_t.shape();
|
||||
const auto& partition_ids = partition_ids_t.vec<int32>();
|
||||
@ -283,19 +280,19 @@ void AddToTensorAccumulator(
|
||||
|
||||
// TODO(soroush): Move gradient and hessian shape check to ShapeFn.
|
||||
OP_REQUIRES(
|
||||
context, gradients_shape == accumulator_resource->gradient_shape(),
|
||||
context, gradients_shape == resource->gradient_shape(),
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Gradients dimensions must match: ", gradients_shape.DebugString(),
|
||||
", ", accumulator_resource->gradient_shape().DebugString())));
|
||||
", ", resource->gradient_shape().DebugString())));
|
||||
|
||||
OP_REQUIRES(
|
||||
context, hessians_shape == accumulator_resource->hessian_shape(),
|
||||
context, hessians_shape == resource->hessian_shape(),
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
|
||||
accumulator_resource->hessian_shape().DebugString())));
|
||||
resource->hessian_shape().DebugString())));
|
||||
|
||||
int64 num_updates = partition_ids_shape.dim_size(0);
|
||||
auto stats_map = accumulator_resource->mutable_values();
|
||||
auto stats_map = resource->mutable_values();
|
||||
for (int64 i = 0; i < num_updates; ++i) {
|
||||
const auto key =
|
||||
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
|
||||
@ -325,7 +322,7 @@ void AddToTensorAccumulator(
|
||||
}
|
||||
|
||||
void AddToTensorAccumulator(
|
||||
StatsAccumulatorTensorResource* accumulator_resource,
|
||||
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
|
||||
OpKernelContext* context) {
|
||||
const Tensor* partition_ids_t;
|
||||
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
|
||||
@ -335,7 +332,7 @@ void AddToTensorAccumulator(
|
||||
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
|
||||
const Tensor* hessians_t;
|
||||
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
||||
AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
|
||||
AddToTensorAccumulator(resource, *partition_ids_t, *feature_ids_t,
|
||||
*gradients_t, *hessians_t, context);
|
||||
}
|
||||
|
||||
@ -452,20 +449,18 @@ class StatsAccumulatorScalarAddOp : public OpKernel {
|
||||
resource_handle_list[resource_handle_idx]
|
||||
.flat<ResourceHandle>()(0);
|
||||
|
||||
StatsAccumulatorScalarResource* accumulator_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, handle,
|
||||
&accumulator_resource));
|
||||
mutex_lock l(*accumulator_resource->mutex());
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
// If the stamp is invalid we drop the update.
|
||||
if (!accumulator_resource->is_stamp_valid(stamp_token)) {
|
||||
if (!resource->is_stamp_valid(stamp_token)) {
|
||||
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
|
||||
<< "Passed stamp token: " << stamp_token << " "
|
||||
<< "Current token: " << accumulator_resource->stamp();
|
||||
<< "Current token: " << resource->stamp();
|
||||
return;
|
||||
}
|
||||
AddToScalarAccumulator(accumulator_resource,
|
||||
AddToScalarAccumulator(resource,
|
||||
partition_ids_list[resource_handle_idx],
|
||||
feature_ids_list[resource_handle_idx],
|
||||
gradients_list[resource_handle_idx],
|
||||
@ -517,20 +512,18 @@ class StatsAccumulatorTensorAddOp : public OpKernel {
|
||||
resource_handle_list[resource_handle_idx]
|
||||
.flat<ResourceHandle>()(0);
|
||||
|
||||
StatsAccumulatorTensorResource* accumulator_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, handle,
|
||||
&accumulator_resource));
|
||||
mutex_lock l(*accumulator_resource->mutex());
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
// If the stamp is invalid we drop the update.
|
||||
if (!accumulator_resource->is_stamp_valid(stamp_token)) {
|
||||
if (!resource->is_stamp_valid(stamp_token)) {
|
||||
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
|
||||
<< "Passed stamp token: " << stamp_token << " "
|
||||
<< "Current token: " << accumulator_resource->stamp();
|
||||
<< "Current token: " << resource->stamp();
|
||||
return;
|
||||
}
|
||||
AddToTensorAccumulator(accumulator_resource,
|
||||
AddToTensorAccumulator(resource,
|
||||
partition_ids_list[resource_handle_idx],
|
||||
feature_ids_list[resource_handle_idx],
|
||||
gradients_list[resource_handle_idx],
|
||||
@ -549,11 +542,10 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
StatsAccumulatorScalarResource* accumulator_resource;
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&accumulator_resource));
|
||||
mutex_lock l(*accumulator_resource->mutex());
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
@ -562,7 +554,7 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
|
||||
// If the stamp is invalid we restart the PS. It shouldn't happen since
|
||||
// only Chief should call this function and chief is guaranteed to be in
|
||||
// a consistent state.
|
||||
CHECK(accumulator_resource->is_stamp_valid(stamp_token));
|
||||
CHECK(resource->is_stamp_valid(stamp_token));
|
||||
|
||||
const Tensor* next_stamp_token_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -570,15 +562,15 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
|
||||
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
|
||||
CHECK(stamp_token != next_stamp_token);
|
||||
|
||||
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
|
||||
SerializeScalarAccumulatorToOutput(*resource, context);
|
||||
Tensor* num_updates_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_updates", TensorShape({}),
|
||||
&num_updates_t));
|
||||
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
|
||||
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||
|
||||
accumulator_resource->Clear();
|
||||
accumulator_resource->set_stamp(next_stamp_token);
|
||||
resource->Clear();
|
||||
resource->set_stamp(next_stamp_token);
|
||||
}
|
||||
};
|
||||
|
||||
@ -591,11 +583,10 @@ class StatsAccumulatorTensorFlushOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
StatsAccumulatorTensorResource* accumulator_resource;
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&accumulator_resource));
|
||||
mutex_lock l(*accumulator_resource->mutex());
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
@ -609,16 +600,16 @@ class StatsAccumulatorTensorFlushOp : public OpKernel {
|
||||
// If the stamp is invalid we restart the PS. It shouldn't happen since
|
||||
// only Chief should call this function and chief is guaranteed to be in
|
||||
// a consistent state.
|
||||
CHECK(accumulator_resource->is_stamp_valid(stamp_token));
|
||||
CHECK(resource->is_stamp_valid(stamp_token));
|
||||
CHECK(stamp_token != next_stamp_token);
|
||||
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
|
||||
SerializeTensorAccumulatorToOutput(*resource, context);
|
||||
Tensor* num_updates_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_updates", TensorShape({}),
|
||||
&num_updates_t));
|
||||
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
|
||||
accumulator_resource->Clear();
|
||||
accumulator_resource->set_stamp(next_stamp_token);
|
||||
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||
resource->Clear();
|
||||
resource->set_stamp(next_stamp_token);
|
||||
}
|
||||
};
|
||||
|
||||
@ -631,22 +622,21 @@ class StatsAccumulatorScalarDeserializeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
StatsAccumulatorScalarResource* accumulator_resource;
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&accumulator_resource));
|
||||
mutex_lock l(*accumulator_resource->mutex());
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
// Check the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
accumulator_resource->Clear();
|
||||
accumulator_resource->set_stamp(stamp_token);
|
||||
AddToScalarAccumulator(accumulator_resource, context);
|
||||
resource->Clear();
|
||||
resource->set_stamp(stamp_token);
|
||||
AddToScalarAccumulator(resource, context);
|
||||
const Tensor* num_updates_t;
|
||||
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
|
||||
accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
|
||||
resource->set_num_updates(num_updates_t->scalar<int64>()());
|
||||
}
|
||||
};
|
||||
|
||||
@ -660,22 +650,21 @@ class StatsAccumulatorTensorDeserializeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
StatsAccumulatorTensorResource* accumulator_resource;
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&accumulator_resource));
|
||||
mutex_lock l(*accumulator_resource->mutex());
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
// Check the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
accumulator_resource->Clear();
|
||||
accumulator_resource->set_stamp(stamp_token);
|
||||
AddToTensorAccumulator(accumulator_resource, context);
|
||||
resource->Clear();
|
||||
resource->set_stamp(stamp_token);
|
||||
AddToTensorAccumulator(resource, context);
|
||||
const Tensor* num_updates_t;
|
||||
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
|
||||
accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
|
||||
resource->set_num_updates(num_updates_t->scalar<int64>()());
|
||||
}
|
||||
};
|
||||
|
||||
@ -689,23 +678,22 @@ class StatsAccumulatorScalarSerializeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
StatsAccumulatorScalarResource* accumulator_resource;
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&accumulator_resource));
|
||||
mutex_lock l(*accumulator_resource->mutex());
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
SerializeScalarAccumulatorToOutput(*resource, context);
|
||||
Tensor* stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("stamp_token", TensorShape({}),
|
||||
&stamp_token_t));
|
||||
stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
|
||||
stamp_token_t->scalar<int64>()() = resource->stamp();
|
||||
|
||||
Tensor* num_updates_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_updates", TensorShape({}),
|
||||
&num_updates_t));
|
||||
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
|
||||
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||
}
|
||||
};
|
||||
|
||||
@ -719,23 +707,22 @@ class StatsAccumulatorTensorSerializeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
StatsAccumulatorTensorResource* accumulator_resource;
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&accumulator_resource));
|
||||
mutex_lock l(*accumulator_resource->mutex());
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
SerializeTensorAccumulatorToOutput(*resource, context);
|
||||
Tensor* stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("stamp_token", TensorShape({}),
|
||||
&stamp_token_t));
|
||||
stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
|
||||
stamp_token_t->scalar<int64>()() = resource->stamp();
|
||||
|
||||
Tensor* num_updates_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_updates", TensorShape({}),
|
||||
&num_updates_t));
|
||||
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
|
||||
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||
}
|
||||
};
|
||||
|
||||
@ -751,12 +738,11 @@ class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorShape gradient_shape = TensorShape({});
|
||||
TensorShape hessian_shape = TensorShape({});
|
||||
StatsAccumulatorScalarResource* accumulator_resource =
|
||||
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource(
|
||||
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape));
|
||||
// Check the stamp token.
|
||||
AddToScalarAccumulator(accumulator_resource, context);
|
||||
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
|
||||
AddToScalarAccumulator(resource, context);
|
||||
SerializeScalarAccumulatorToOutput(*resource, context);
|
||||
}
|
||||
};
|
||||
|
||||
@ -780,12 +766,11 @@ class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
|
||||
TensorShape hessians_shape = hessians_t->shape();
|
||||
hessians_shape.RemoveDim(0);
|
||||
|
||||
StatsAccumulatorTensorResource* accumulator_resource =
|
||||
new StatsAccumulatorTensorResource(gradients_shape, hessians_shape);
|
||||
core::ScopedUnref unref_me(accumulator_resource);
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource(
|
||||
new StatsAccumulatorTensorResource(gradients_shape, hessians_shape));
|
||||
// Check the stamp token.
|
||||
AddToTensorAccumulator(accumulator_resource, context);
|
||||
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
|
||||
AddToTensorAccumulator(resource, context);
|
||||
SerializeTensorAccumulatorToOutput(*resource, context);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
|
||||
@ -31,6 +32,8 @@ namespace {
|
||||
|
||||
using boosted_trees::learner::LearnerConfig;
|
||||
using boosted_trees::learner::LearningRateConfig;
|
||||
using boosted_trees::models::DecisionTreeEnsembleResource;
|
||||
using boosted_trees::trees::DecisionTreeConfig;
|
||||
using boosted_trees::trees::Leaf;
|
||||
using boosted_trees::trees::TreeNode;
|
||||
using boosted_trees::trees::TreeNodeMetadata;
|
||||
@ -193,10 +196,9 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Get decision tree ensemble.
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
mutex_lock l(*ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
@ -255,8 +257,8 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
|
||||
|
||||
private:
|
||||
// Helper method to retrieve the bias from the tree ensemble.
|
||||
boosted_trees::trees::Leaf* RetrieveBias(
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource,
|
||||
Leaf* RetrieveBias(
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
int64 logits_dimension) {
|
||||
const int32 num_trees = ensemble_resource->num_trees();
|
||||
if (num_trees <= 0) {
|
||||
@ -319,10 +321,9 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Get decision tree ensemble.
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
mutex_lock l(*ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
@ -400,10 +401,9 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
// Update and retrieve the growable tree.
|
||||
// If the tree is fully built and dropout was applied, it also adjusts the
|
||||
// weights of dropped and the last tree.
|
||||
boosted_trees::trees::DecisionTreeConfig* const tree_config =
|
||||
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
|
||||
dropout_seed, max_tree_depth,
|
||||
weak_learner_type);
|
||||
DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(
|
||||
ensemble_resource, learning_rate, dropout_seed, max_tree_depth,
|
||||
weak_learner_type);
|
||||
// Split tree nodes.
|
||||
switch (weak_learner_type) {
|
||||
case LearnerConfig::NORMAL_DECISION_TREE: {
|
||||
@ -559,8 +559,7 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
}
|
||||
|
||||
void UpdateTreeWeightsIfDropout(
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* const
|
||||
ensemble_resource,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
const uint64 dropout_seed) {
|
||||
// It is possible that the tree was built with dropout. If it is the case,
|
||||
// we need to adjust the tree weight, or bail out.
|
||||
@ -609,9 +608,8 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
|
||||
// Helper method to update the growable tree which is by definition the last
|
||||
// tree in the ensemble.
|
||||
boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* const
|
||||
ensemble_resource,
|
||||
DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
const float learning_rate, const uint64 dropout_seed,
|
||||
const int32 max_tree_depth, const int32 weak_learner_type) {
|
||||
const auto num_trees = ensemble_resource->num_trees();
|
||||
@ -719,8 +717,8 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
// leaf children given the split candidate.
|
||||
void SplitTreeNode(
|
||||
const int32 node_id, SplitCandidate* split,
|
||||
boosted_trees::trees::DecisionTreeConfig* tree_config,
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
|
||||
DecisionTreeConfig* tree_config,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
|
||||
// No-op if we have no real node.
|
||||
CHECK(node_id < tree_config->nodes_size())
|
||||
<< "Invalid node " << node_id << " to split.";
|
||||
@ -761,14 +759,13 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
(*tree_config->mutable_nodes(node_id)) =
|
||||
*split->split_info.mutable_split_node();
|
||||
if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
|
||||
ensemble_resource->MaybeAddUsedHandler(split->handler_id);
|
||||
resource->MaybeAddUsedHandler(split->handler_id);
|
||||
}
|
||||
}
|
||||
|
||||
void SplitTreeLayer(
|
||||
SplitCandidate* split,
|
||||
boosted_trees::trees::DecisionTreeConfig* tree_config,
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
|
||||
SplitCandidate* split, DecisionTreeConfig* tree_config,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
|
||||
int depth = 0;
|
||||
while (depth < tree_config->nodes_size() &&
|
||||
tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
|
||||
@ -903,10 +900,9 @@ class TreeEnsembleStatsOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Get decision tree ensemble.
|
||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
|
@ -95,7 +95,7 @@ class ZeroVarInitializer : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Var* variable = nullptr;
|
||||
core::RefCountPtr<Var> variable;
|
||||
OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, 0), &variable,
|
||||
[this, ctx](Var** var_ptr) {
|
||||
@ -117,7 +117,6 @@ class ZeroVarInitializer : public OpKernel {
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
core::ScopedUnref scoped(variable);
|
||||
mutex_lock ml(*variable->mu());
|
||||
|
||||
OP_REQUIRES(ctx, !variable->is_initialized,
|
||||
|
@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
|
||||
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
|
||||
#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
|
||||
@ -24,6 +25,7 @@
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -76,11 +78,10 @@ class TreeSerializeOp : public OpKernel {
|
||||
explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
Tensor* output_config_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
||||
@ -100,12 +101,11 @@ class TreeDeserializeOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||
auto handle = HandleFromInput(context, 0);
|
||||
OP_REQUIRES_OK(context,
|
||||
LookupResource(context, handle, &decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const Tensor* tree_config_t;
|
||||
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
||||
@ -131,11 +131,10 @@ class TreeSizeOp : public OpKernel {
|
||||
explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
Tensor* output_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape(), &output_t));
|
||||
@ -144,7 +143,7 @@ class TreeSizeOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
void TraverseTree(const DecisionTreeResource* tree_resource,
|
||||
void TraverseTree(DecisionTreeResource* tree_resource,
|
||||
const std::unique_ptr<TensorDataSet>& data, int32 start,
|
||||
int32 end,
|
||||
const std::function<void(int32, int32)>& set_leaf_id,
|
||||
@ -182,11 +181,10 @@ class TreePredictionsV4Op : public OpKernel {
|
||||
data_set->set_input_tensors(input_data, sparse_input_indices,
|
||||
sparse_input_values, sparse_input_shape);
|
||||
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
&resource));
|
||||
mutex_lock l(*resource->get_mutex());
|
||||
|
||||
const int num_data = data_set->NumItems();
|
||||
const int32 num_outputs = param_proto_.num_outputs();
|
||||
@ -205,15 +203,16 @@ class TreePredictionsV4Op : public OpKernel {
|
||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
int num_threads = worker_threads->num_threads;
|
||||
const int64 costPerTraverse = 500;
|
||||
auto traverse = [this, &out, &data_set, decision_tree_resource, num_data,
|
||||
&tree_paths](int64 start, int64 end) {
|
||||
auto traverse = [this, &out, &data_set, &resource, num_data, &tree_paths](
|
||||
int64 start, int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_data);
|
||||
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
|
||||
|
||||
TraverseTree(resource.get(), data_set, static_cast<int32>(start),
|
||||
static_cast<int32>(end),
|
||||
std::bind(&TreePredictionsV4Op::set_output_value, this,
|
||||
std::placeholders::_1, std::placeholders::_2,
|
||||
decision_tree_resource, &out),
|
||||
resource.get(), &out),
|
||||
param_proto_.inference_tree_paths() ? &tree_paths : nullptr);
|
||||
};
|
||||
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
||||
@ -283,11 +282,10 @@ class TraverseTreeV4Op : public OpKernel {
|
||||
data_set->set_input_tensors(input_data, sparse_input_indices,
|
||||
sparse_input_values, sparse_input_shape);
|
||||
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
&resource));
|
||||
mutex_lock l(*resource->get_mutex());
|
||||
|
||||
const int num_data = data_set->NumItems();
|
||||
|
||||
@ -304,11 +302,11 @@ class TraverseTreeV4Op : public OpKernel {
|
||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
int num_threads = worker_threads->num_threads;
|
||||
const int64 costPerTraverse = 500;
|
||||
auto traverse = [&set_leaf_ids, &data_set, decision_tree_resource,
|
||||
num_data](int64 start, int64 end) {
|
||||
auto traverse = [&set_leaf_ids, &data_set, &resource, num_data](int64 start,
|
||||
int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_data);
|
||||
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
|
||||
TraverseTree(resource.get(), data_set, static_cast<int32>(start),
|
||||
static_cast<int32>(end), set_leaf_ids, nullptr);
|
||||
};
|
||||
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
||||
@ -336,11 +334,10 @@ class UpdateModelV4Op : public OpKernel {
|
||||
const Tensor& input_labels = context->input(2);
|
||||
const Tensor& input_weights = context->input(3);
|
||||
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const int num_data = input_labels.shape().dim_size(0);
|
||||
const int32 label_dim =
|
||||
@ -356,9 +353,10 @@ class UpdateModelV4Op : public OpKernel {
|
||||
UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
|
||||
}
|
||||
|
||||
void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target,
|
||||
int32 start, int32 end,
|
||||
DecisionTreeResource* decision_tree_resource) {
|
||||
void UpdateModel(
|
||||
const Tensor& leaf_ids, const TensorInputTarget& target, int32 start,
|
||||
int32 end,
|
||||
const core::RefCountPtr<DecisionTreeResource>& decision_tree_resource) {
|
||||
const auto leaves = leaf_ids.unaligned_flat<int32>();
|
||||
for (int i = start; i < end; ++i) {
|
||||
model_op_->UpdateModel(
|
||||
@ -384,11 +382,10 @@ class FeatureUsageCountsOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const auto& tree = decision_tree_resource->decision_tree();
|
||||
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
@ -87,11 +88,10 @@ class FertileStatsSerializeOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
FertileStatsResource* fertile_stats_resource;
|
||||
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&fertile_stats_resource));
|
||||
mutex_lock l(*fertile_stats_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(fertile_stats_resource);
|
||||
Tensor* output_config_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
||||
@ -116,11 +116,10 @@ class FertileStatsDeserializeOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
FertileStatsResource* fertile_stats_resource;
|
||||
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&fertile_stats_resource));
|
||||
mutex_lock l(*fertile_stats_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(fertile_stats_resource);
|
||||
|
||||
const Tensor* stats_config_t;
|
||||
OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
|
||||
@ -145,7 +144,7 @@ class FertileStatsDeserializeOp : public OpKernel {
|
||||
// acquired, put it in a waiting queue to come back to later and try the next
|
||||
// one. Once all leaf_ids have been visited, cycle through the waiting ids
|
||||
// until they're gone.
|
||||
void UpdateStats(FertileStatsResource* fertile_stats_resource,
|
||||
void UpdateStats(const core::RefCountPtr<FertileStatsResource>& resource,
|
||||
const std::unique_ptr<TensorDataSet>& data,
|
||||
const TensorInputTarget& target, int num_targets,
|
||||
const Tensor& leaf_ids_tensor,
|
||||
@ -183,8 +182,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
|
||||
}
|
||||
|
||||
bool is_finished;
|
||||
fertile_stats_resource->AddExampleToStatsAndInitialize(
|
||||
data, &target, {example_id}, leaf_id, &is_finished);
|
||||
resource->AddExampleToStatsAndInitialize(data, &target, {example_id},
|
||||
leaf_id, &is_finished);
|
||||
leaf_lock->unlock();
|
||||
if (is_finished) {
|
||||
set_lock->lock();
|
||||
@ -196,8 +195,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
|
||||
|
||||
// Update leaves from start through end in the leaf_examples iterator.
|
||||
void UpdateStatsCollated(
|
||||
FertileStatsResource* fertile_stats_resource,
|
||||
DecisionTreeResource* tree_resource,
|
||||
const core::RefCountPtr<FertileStatsResource>& fertile_stats_resource,
|
||||
const core::RefCountPtr<DecisionTreeResource>& tree_resource,
|
||||
const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
|
||||
int num_targets,
|
||||
const std::unordered_map<int32, std::vector<int>>& leaf_examples,
|
||||
@ -251,18 +250,15 @@ class ProcessInputOp : public OpKernel {
|
||||
data_set->set_input_tensors(input_data, sparse_input_indices,
|
||||
sparse_input_values, sparse_input_shape);
|
||||
|
||||
FertileStatsResource* fertile_stats_resource;
|
||||
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
|
||||
&fertile_stats_resource));
|
||||
DecisionTreeResource* tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&tree_resource));
|
||||
mutex_lock l1(*fertile_stats_resource->get_mutex());
|
||||
mutex_lock l2(*tree_resource->get_mutex());
|
||||
|
||||
core::ScopedUnref unref_stats(fertile_stats_resource);
|
||||
core::ScopedUnref unref_tree(tree_resource);
|
||||
|
||||
const int32 num_data = data_set->NumItems();
|
||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
int num_threads = worker_threads->num_threads;
|
||||
@ -308,7 +304,7 @@ class ProcessInputOp : public OpKernel {
|
||||
// if it really matters that much.
|
||||
const int64 costPerUpdate = 1000;
|
||||
auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set,
|
||||
fertile_stats_resource, &locks, &set_lock, &ready_to_split,
|
||||
&fertile_stats_resource, &locks, &set_lock, &ready_to_split,
|
||||
num_data](int64 start, int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_data);
|
||||
@ -317,8 +313,8 @@ class ProcessInputOp : public OpKernel {
|
||||
static_cast<int32>(end), &ready_to_split);
|
||||
};
|
||||
|
||||
auto update_collated = [&target, &num_targets, fertile_stats_resource,
|
||||
tree_resource, &leaf_examples, &set_lock,
|
||||
auto update_collated = [&target, &num_targets, &fertile_stats_resource,
|
||||
&tree_resource, &leaf_examples, &set_lock,
|
||||
&ready_to_split, &data_set,
|
||||
num_leaves](int64 start, int64 end) {
|
||||
CHECK(start <= end);
|
||||
@ -362,18 +358,15 @@ class GrowTreeOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
FertileStatsResource* fertile_stats_resource;
|
||||
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
|
||||
&fertile_stats_resource));
|
||||
DecisionTreeResource* tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&tree_resource));
|
||||
mutex_lock l1(*fertile_stats_resource->get_mutex());
|
||||
mutex_lock l2(*tree_resource->get_mutex());
|
||||
|
||||
core::ScopedUnref unref_stats(fertile_stats_resource);
|
||||
core::ScopedUnref unref_tree(tree_resource);
|
||||
|
||||
const Tensor& finished_nodes = context->input(2);
|
||||
|
||||
const auto finished = finished_nodes.unaligned_flat<int32>();
|
||||
@ -463,19 +456,16 @@ class FinalizeTreeOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
DecisionTreeResource* tree_resource;
|
||||
core::RefCountPtr<DecisionTreeResource> tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&tree_resource));
|
||||
FertileStatsResource* fertile_stats_resource;
|
||||
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
|
||||
&fertile_stats_resource));
|
||||
|
||||
mutex_lock l1(*fertile_stats_resource->get_mutex());
|
||||
mutex_lock l2(*tree_resource->get_mutex());
|
||||
|
||||
core::ScopedUnref unref_me(tree_resource);
|
||||
core::ScopedUnref unref_stats(fertile_stats_resource);
|
||||
|
||||
// TODO(thomaswc): Add threads
|
||||
int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size();
|
||||
for (int i = 0; i < num_nodes; i++) {
|
||||
|
@ -270,12 +270,19 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
|
||||
template <typename T, bool use_dynamic_cast = false>
|
||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
|
||||
|
||||
// Looks up a resource pointed by a given resource handle.
|
||||
//
|
||||
// Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
|
||||
// requiring the caller to explicitly call `Unref()`.
|
||||
template <typename T>
|
||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
core::RefCountPtr<T>* value);
|
||||
|
||||
// Looks up multiple resources pointed by a sequence of resource handles. If
|
||||
// p[i] is uninitialized then values[i] is unmodified.
|
||||
template <typename T>
|
||||
Status LookupResources(
|
||||
OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
|
||||
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
|
||||
Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
|
||||
std::vector<core::RefCountPtr<T>>* values);
|
||||
|
||||
// Looks up or creates a resource.
|
||||
//
|
||||
@ -283,10 +290,19 @@ Status LookupResources(
|
||||
// must call its `Unref()` method when it has finished using it. If the
|
||||
// `creator` is invoked, its reference on the created resource is transferred
|
||||
// to `ctx->resource_mgr()`.
|
||||
//
|
||||
// Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
|
||||
// requiring the caller to explicitly call `Unref()`.
|
||||
template <typename T>
|
||||
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
T** value, std::function<Status(T**)> creator);
|
||||
|
||||
// Looks up or creates a resource.
|
||||
template <typename T>
|
||||
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
core::RefCountPtr<T>* value,
|
||||
std::function<Status(T**)> creator);
|
||||
|
||||
// Destroys a resource pointed by a given resource handle.
|
||||
template <typename T>
|
||||
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
|
||||
@ -587,9 +603,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status LookupResources(
|
||||
OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p,
|
||||
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) {
|
||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
core::RefCountPtr<T>* value) {
|
||||
T* raw_ptr = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
|
||||
value->reset(raw_ptr);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status LookupResources(OpKernelContext* ctx,
|
||||
absl::Span<ResourceHandle const* const> p,
|
||||
std::vector<core::RefCountPtr<T>>* values) {
|
||||
std::vector<std::pair<const string*, const string*>> containers_and_names(
|
||||
p.size());
|
||||
for (size_t i = 0; i < p.size(); ++i) {
|
||||
@ -607,6 +633,17 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
creator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
core::RefCountPtr<T>* value,
|
||||
std::function<Status(T**)> creator) {
|
||||
T* raw_ptr = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
|
||||
value->reset(raw_ptr);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
|
||||
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -266,15 +267,14 @@ TEST(ResourceHandleTest, CRUD) {
|
||||
TF_EXPECT_OK(CreateResource(&ctx, p, r));
|
||||
}
|
||||
{
|
||||
StubResource* r = nullptr;
|
||||
core::RefCountPtr<StubResource> r;
|
||||
TF_ASSERT_OK(LookupResource(&ctx, p, &r));
|
||||
ASSERT_TRUE(r != nullptr);
|
||||
EXPECT_EQ(r->value_, 42);
|
||||
r->Unref();
|
||||
}
|
||||
{
|
||||
TF_EXPECT_OK(DeleteResource<StubResource>(&ctx, p));
|
||||
StubResource* unused = nullptr;
|
||||
core::RefCountPtr<StubResource> unused;
|
||||
EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok());
|
||||
}
|
||||
}
|
||||
@ -338,13 +338,12 @@ TEST(ResourceHandleTest, DeleteUsingResourceHandle) {
|
||||
StubResource* r = new StubResource;
|
||||
TF_EXPECT_OK(CreateResource(&ctx, p, r));
|
||||
|
||||
StubResource* lookup_r = nullptr;
|
||||
core::RefCountPtr<StubResource> lookup_r;
|
||||
TF_EXPECT_OK(LookupResource<StubResource>(&ctx, p, &lookup_r));
|
||||
EXPECT_EQ(lookup_r, r);
|
||||
EXPECT_EQ(lookup_r.get(), r);
|
||||
|
||||
TF_EXPECT_OK(DeleteResource(&ctx, p));
|
||||
EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true);
|
||||
r->Unref();
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -166,6 +166,7 @@ tf_kernel_library(
|
||||
":variable_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -5022,6 +5023,7 @@ STATE_DEPS = [
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
] + if_sycl(["//tensorflow/core:sycl_runtime"])
|
||||
|
||||
tf_kernel_library(
|
||||
@ -7589,6 +7591,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/db:sqlite",
|
||||
"//tensorflow/core/summary:schema",
|
||||
|
@ -83,6 +83,7 @@ tf_kernel_library(
|
||||
":resources",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||
],
|
||||
)
|
||||
@ -106,6 +107,7 @@ tf_kernel_library(
|
||||
":tree_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
@ -117,6 +119,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
|
||||
],
|
||||
)
|
||||
|
@ -51,12 +51,10 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
BoostedTreesEnsembleResource* resource;
|
||||
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
|
||||
// Get the resource.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&resource));
|
||||
// Release the reference to the resource once we're done using it.
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
// Get the inputs.
|
||||
OpInputList bucketized_features_list;
|
||||
@ -198,12 +196,10 @@ class BoostedTreesPredictOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
BoostedTreesEnsembleResource* resource;
|
||||
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
|
||||
// Get the resource.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&resource));
|
||||
// Release the reference to the resource once we're done using it.
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
// Get the inputs.
|
||||
OpInputList bucketized_features_list;
|
||||
@ -302,12 +298,10 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
BoostedTreesEnsembleResource* resource;
|
||||
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
|
||||
// Get the resource.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&resource));
|
||||
// Release the reference to the resource once we're done using it.
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
// Get the inputs.
|
||||
OpInputList bucketized_features_list;
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
|
||||
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -224,12 +225,11 @@ class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
|
||||
ResourceHandle handle;
|
||||
OP_REQUIRES_OK(context,
|
||||
HandleFromInput(context, kResourceHandleName, &handle));
|
||||
QuantileStreamResource* stream_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> stream_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*stream_resource->mutex());
|
||||
core::ScopedUnref unref_me(stream_resource);
|
||||
|
||||
OpInputList summaries_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -281,13 +281,12 @@ class BoostedTreesQuantileStreamResourceDeserializeOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
QuantileStreamResource* streams_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
core::ScopedUnref unref_me(streams_resource);
|
||||
|
||||
OpInputList bucket_boundaries_list;
|
||||
OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
|
||||
@ -336,12 +335,11 @@ class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
|
||||
ResourceHandle handle;
|
||||
OP_REQUIRES_OK(context,
|
||||
HandleFromInput(context, kResourceHandleName, &handle));
|
||||
QuantileStreamResource* stream_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> stream_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*stream_resource->mutex());
|
||||
core::ScopedUnref unref_me(stream_resource);
|
||||
|
||||
const Tensor* num_buckets_t;
|
||||
OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
|
||||
@ -391,12 +389,11 @@ class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
|
||||
ResourceHandle handle;
|
||||
OP_REQUIRES_OK(context,
|
||||
HandleFromInput(context, kResourceHandleName, &handle));
|
||||
QuantileStreamResource* stream_resource;
|
||||
core::RefCountPtr<QuantileStreamResource> stream_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*stream_resource->mutex());
|
||||
core::ScopedUnref unref_me(stream_resource);
|
||||
|
||||
const int64 num_streams = stream_resource->num_streams();
|
||||
CHECK_EQ(num_features_, num_streams);
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -78,11 +79,10 @@ class BoostedTreesGetEnsembleStatesOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Looks up the resource.
|
||||
BoostedTreesEnsembleResource* tree_ensemble_resource;
|
||||
core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&tree_ensemble_resource));
|
||||
tf_shared_lock l(*tree_ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(tree_ensemble_resource);
|
||||
|
||||
// Sets the outputs.
|
||||
const int num_trees = tree_ensemble_resource->num_trees();
|
||||
@ -141,11 +141,10 @@ class BoostedTreesSerializeEnsembleOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
BoostedTreesEnsembleResource* tree_ensemble_resource;
|
||||
core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&tree_ensemble_resource));
|
||||
tf_shared_lock l(*tree_ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(tree_ensemble_resource);
|
||||
Tensor* output_stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||
&output_stamp_token_t));
|
||||
@ -169,11 +168,10 @@ class BoostedTreesDeserializeEnsembleOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
BoostedTreesEnsembleResource* tree_ensemble_resource;
|
||||
core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&tree_ensemble_resource));
|
||||
mutex_lock l(*tree_ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(tree_ensemble_resource);
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
||||
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -55,10 +56,9 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Get decision tree ensemble.
|
||||
BoostedTreesEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
mutex_lock l(*ensemble_resource->get_mutex());
|
||||
// Increase the ensemble stamp.
|
||||
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
|
||||
@ -176,19 +176,19 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
||||
|
||||
private:
|
||||
int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
|
||||
BoostedTreesEnsembleResource* const ensemble_resource) {
|
||||
int32 num_trees = ensemble_resource->num_trees();
|
||||
const core::RefCountPtr<BoostedTreesEnsembleResource>& resource) {
|
||||
int32 num_trees = resource->num_trees();
|
||||
int32 current_tree = num_trees - 1;
|
||||
|
||||
// Increment global attempt stats.
|
||||
ensemble_resource->UpdateGrowingMetadata();
|
||||
resource->UpdateGrowingMetadata();
|
||||
|
||||
// Note we don't set tree weight to be equal to learning rate, since we
|
||||
// apply learning rate to leaf weights instead, when doing layer-by-layer
|
||||
// boosting.
|
||||
if (num_trees <= 0) {
|
||||
// Create a new tree with a no-op leaf.
|
||||
current_tree = ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
|
||||
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
|
||||
}
|
||||
return current_tree;
|
||||
}
|
||||
@ -250,10 +250,9 @@ class BoostedTreesCenterBiasOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Get decision tree ensemble.
|
||||
BoostedTreesEnsembleResource* ensemble_resource;
|
||||
core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
core::ScopedUnref unref_me(ensemble_resource);
|
||||
mutex_lock l(*ensemble_resource->get_mutex());
|
||||
// Increase the ensemble stamp.
|
||||
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -65,11 +66,9 @@ class ResourceCountUpToOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
Var* variable = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
LookupResource<Var>(context, HandleFromInput(context, 0), &variable));
|
||||
core::ScopedUnref s(variable);
|
||||
core::RefCountPtr<Var> variable;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&variable));
|
||||
mutex_lock l(*variable->mu());
|
||||
Tensor before_increment = *variable->tensor();
|
||||
OP_REQUIRES(
|
||||
|
@ -342,6 +342,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:summary_interface",
|
||||
],
|
||||
@ -380,6 +381,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/kernels/data/stats_utils.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -83,17 +84,16 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
StatsAggregatorResource* stats_aggregator_resource;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
|
||||
&stats_aggregator_resource));
|
||||
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
|
||||
core::RefCountPtr<StatsAggregatorResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
|
||||
string tag;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
|
||||
string prefix;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix));
|
||||
|
||||
*output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource,
|
||||
tag, prefix);
|
||||
*output =
|
||||
new Dataset(ctx, input, ctx->input(1), resource.get(), tag, prefix);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -101,12 +101,12 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
const Tensor& resource_handle,
|
||||
StatsAggregatorResource* stats_aggregator_resource,
|
||||
const string& tag, const string& prefix)
|
||||
StatsAggregatorResource* resource, const string& tag,
|
||||
const string& prefix)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
resource_handle_(resource_handle),
|
||||
stats_aggregator_resource_(stats_aggregator_resource),
|
||||
stats_aggregator_resource_(resource),
|
||||
tag_(tag),
|
||||
prefix_(prefix) {
|
||||
input_->Ref();
|
||||
@ -169,13 +169,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
StatsAggregatorResource* stats_aggregator_resource =
|
||||
StatsAggregatorResource* resource =
|
||||
dataset()->stats_aggregator_resource_;
|
||||
IteratorContext::Params params(ctx);
|
||||
params.stats_aggregator = std::shared_ptr<StatsAggregator>(
|
||||
new StatsAggregatorWithTagAndPrefix(
|
||||
stats_aggregator_resource->stats_aggregator(), dataset()->tag_,
|
||||
dataset()->prefix_));
|
||||
new StatsAggregatorWithTagAndPrefix(resource->stats_aggregator(),
|
||||
dataset()->tag_,
|
||||
dataset()->prefix_));
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/kernels/summary_interface.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
@ -258,10 +258,9 @@ class StatsAggregatorSummaryOp : public OpKernel {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||
|
||||
StatsAggregatorResource* resource;
|
||||
core::RefCountPtr<StatsAggregatorResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
core::ScopedUnref unref_iterator(resource);
|
||||
|
||||
Tensor* summary_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
|
||||
@ -281,21 +280,19 @@ class StatsAggregatorSetSummaryWriterOp : public OpKernel {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||
|
||||
StatsAggregatorResource* resource;
|
||||
core::RefCountPtr<StatsAggregatorResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
core::ScopedUnref unref_iterator(resource);
|
||||
|
||||
const Tensor& summary_resource_handle_t = ctx->input(1);
|
||||
OP_REQUIRES(ctx,
|
||||
TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
|
||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||
SummaryWriterInterface* sumamry_resource;
|
||||
core::RefCountPtr<SummaryWriterInterface> sumamry_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource));
|
||||
core::ScopedUnref unref_sumamry_resource(sumamry_resource);
|
||||
TF_CHECK_OK(
|
||||
resource->stats_aggregator()->SetSummaryWriter(sumamry_resource));
|
||||
resource->stats_aggregator()->SetSummaryWriter(sumamry_resource.get()));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
@ -127,12 +128,10 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
ThreadPoolResource* threadpool_resource;
|
||||
core::RefCountPtr<ThreadPoolResource> threadpool_resource;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
|
||||
&threadpool_resource));
|
||||
core::ScopedUnref unref_iterator(threadpool_resource);
|
||||
|
||||
*output = new Dataset(ctx, input, ctx->input(1), threadpool_resource);
|
||||
*output = new Dataset(ctx, input, ctx->input(1), threadpool_resource.get());
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
@ -595,10 +596,9 @@ int64 AnonymousIteratorHandleOp::current_id_(0);
|
||||
void MakeIteratorOp::Compute(OpKernelContext* ctx) {
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
IteratorResource* iterator_resource;
|
||||
core::RefCountPtr<IteratorResource> iterator_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
|
||||
core::ScopedUnref unref(iterator_resource);
|
||||
OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset));
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -492,10 +493,9 @@ class MultiDeviceIteratorInitOp : public OpKernel {
|
||||
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
MultiDeviceIterator* resource;
|
||||
core::RefCountPtr<MultiDeviceIterator> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
|
||||
core::ScopedUnref unref(resource);
|
||||
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
IteratorContext::Params params(ctx);
|
||||
@ -535,7 +535,7 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
|
||||
ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
|
||||
int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
|
||||
|
||||
MultiDeviceIterator* iterator;
|
||||
core::RefCountPtr<MultiDeviceIterator> iterator;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
|
||||
|
||||
@ -557,7 +557,6 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
|
||||
std::placeholders::_1, std::move(done));
|
||||
|
||||
iterator->GetNextFromShard(ctx, shard_num, incarnation_id, callback);
|
||||
iterator->Unref();
|
||||
}
|
||||
};
|
||||
|
||||
@ -577,10 +576,9 @@ class MultiDeviceIteratorToStringHandleOp : public OpKernel {
|
||||
|
||||
// Validate that the handle corresponds to a real resource, and
|
||||
// that it is an MultiDeviceIterator.
|
||||
MultiDeviceIterator* resource;
|
||||
core::RefCountPtr<MultiDeviceIterator> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
resource->Unref();
|
||||
|
||||
Tensor* string_handle_t;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
@ -629,9 +627,8 @@ class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
|
||||
|
||||
// Validate that the handle corresponds to a real resource, and
|
||||
// that it is an MultiDeviceIterator.
|
||||
MultiDeviceIterator* resource;
|
||||
core::RefCountPtr<MultiDeviceIterator> resource;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
|
||||
core::ScopedUnref unref_iterator(resource);
|
||||
if (!output_types_.empty()) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
VerifyTypesMatch(output_types_, resource->output_types()));
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||
@ -371,10 +372,9 @@ class RandomBinomialOp : public OpKernel {
|
||||
"Input probs should have length 1 or shape[0], got shape: ",
|
||||
probs_tensor.shape().DebugString()));
|
||||
}
|
||||
Var* var = nullptr;
|
||||
core::RefCountPtr<Var> var;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
|
||||
|
||||
ScopedUnlockUnrefVar var_guard(var);
|
||||
Tensor* var_tensor = var->tensor();
|
||||
OP_REQUIRES(
|
||||
ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE,
|
||||
@ -404,7 +404,6 @@ class RandomBinomialOp : public OpKernel {
|
||||
auto philox = GetPhiloxRandomFromMem(var_data);
|
||||
UpdateMemWithPhiloxRandom(
|
||||
philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data);
|
||||
var_guard.Release();
|
||||
|
||||
auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
|
||||
binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
|
||||
|
@ -129,7 +129,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
|
||||
} // namespace
|
||||
|
||||
void ReadVariableOp::Compute(OpKernelContext* ctx) {
|
||||
Var* variable = nullptr;
|
||||
core::RefCountPtr<Var> variable;
|
||||
const ResourceHandle& handle = HandleFromInput(ctx, 0);
|
||||
const auto status = LookupResource(ctx, handle, &variable);
|
||||
OP_REQUIRES(ctx, status.ok(),
|
||||
@ -139,7 +139,6 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
|
||||
". This could mean that the variable was uninitialized. ",
|
||||
status.ToString()));
|
||||
|
||||
core::ScopedUnref s(variable);
|
||||
{
|
||||
tf_shared_lock ml(*variable->mu());
|
||||
// We're acquiring a reference to the underlying buffer while
|
||||
@ -175,8 +174,7 @@ ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
}
|
||||
|
||||
void ReadVariablesOp::Compute(OpKernelContext* ctx) {
|
||||
std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
|
||||
dtypes_.size());
|
||||
std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
|
||||
std::vector<const ResourceHandle*> handles(dtypes_.size());
|
||||
for (size_t i = 0; i < dtypes_.size(); ++i) {
|
||||
handles[i] = &HandleFromInput(ctx, i);
|
||||
@ -265,10 +263,9 @@ class VariableShapeOp : public OpKernel {
|
||||
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Var* variable = nullptr;
|
||||
core::RefCountPtr<Var> variable;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
|
||||
core::ScopedUnref s(variable);
|
||||
variable->mu()->lock_shared();
|
||||
TensorShape shape = variable->tensor()->shape();
|
||||
variable->mu()->unlock_shared();
|
||||
@ -343,7 +340,7 @@ class AssignVariableOp : public OpKernel {
|
||||
"Variable and value dtypes don't match; respectively, ",
|
||||
DataTypeString(dtype_), " and ",
|
||||
DataTypeString(context->input(1).dtype())));
|
||||
Var* variable = nullptr;
|
||||
core::RefCountPtr<Var> variable;
|
||||
const Tensor& value = context->input(1);
|
||||
// Note: every resource-variable-manipulating op assumes copy-on-write
|
||||
// semantics, and creates a copy of the variable's Tensor if its refcount is
|
||||
@ -361,7 +358,6 @@ class AssignVariableOp : public OpKernel {
|
||||
(*ptr)->is_initialized = true;
|
||||
return Status::OK();
|
||||
}));
|
||||
core::ScopedUnref s(variable);
|
||||
mutex_lock ml(*variable->mu());
|
||||
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
|
||||
errors::InvalidArgument(
|
||||
@ -404,7 +400,7 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& value = context->input(1);
|
||||
Var* variable = nullptr;
|
||||
core::RefCountPtr<Var> variable;
|
||||
OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
|
||||
context, HandleFromInput(context, 0), &variable,
|
||||
[](Var** ptr) {
|
||||
@ -412,7 +408,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
|
||||
*ptr = new Var(DT_VARIANT);
|
||||
return Status::OK();
|
||||
}));
|
||||
core::ScopedUnref s(variable);
|
||||
|
||||
// For purposes of forwarding DT_VARIANT, we want the least
|
||||
// restrictive attr; we already know the input is on host.
|
||||
@ -500,10 +495,9 @@ class AssignUpdateVariableOp : public OpKernel {
|
||||
explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
Var* variable = nullptr;
|
||||
core::RefCountPtr<Var> variable;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&variable));
|
||||
core::ScopedUnref s(variable);
|
||||
|
||||
const Tensor& value = context->input(1);
|
||||
// TODO(apassos): We could possibly avoid the copy done by
|
||||
@ -568,13 +562,12 @@ class VarIsInitializedOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape({}), &output));
|
||||
auto output_tensor = output->tensor<bool, 0>();
|
||||
Var* variable = nullptr;
|
||||
core::RefCountPtr<Var> variable;
|
||||
Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
|
||||
if (!s.ok()) {
|
||||
output_tensor() = false;
|
||||
return;
|
||||
}
|
||||
core::ScopedUnref su(variable);
|
||||
mutex_lock ml(*variable->mu());
|
||||
output_tensor() = variable->is_initialized;
|
||||
}
|
||||
@ -623,10 +616,9 @@ class ResourceGatherOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
Var* v = nullptr;
|
||||
core::RefCountPtr<Var> v;
|
||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||
core::ScopedUnref su(v);
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
|
||||
// NOTE: We hold the lock for the whole gather operation instead
|
||||
// of increasing the reference count of v->tensor() to avoid a
|
||||
// situation where a write to the same variable will see a
|
||||
@ -765,10 +757,9 @@ class ResourceGatherNdOp : public OpKernel {
|
||||
explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
Var* v = nullptr;
|
||||
core::RefCountPtr<Var> v;
|
||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||
core::ScopedUnref su(v);
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
|
||||
// NOTE: We hold the lock for the whole gather operation instead
|
||||
// of increasing the reference count of v->tensor() to avoid a
|
||||
// situation where a write to the same variable will see a
|
||||
@ -821,10 +812,9 @@ class ResourceScatterUpdateOp : public OpKernel {
|
||||
explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
Var* v = nullptr;
|
||||
core::RefCountPtr<Var> v;
|
||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||
core::ScopedUnref unref_v(v);
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
|
||||
tf_shared_lock ml(*v->mu());
|
||||
Tensor* params = v->tensor();
|
||||
const Tensor& indices = c->input(1);
|
||||
|
@ -243,10 +243,9 @@ class ScatterNdUpdateOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
if (dtype_ == DT_RESOURCE) {
|
||||
Var* v;
|
||||
core::RefCountPtr<Var> v;
|
||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||
core::ScopedUnref scoped_unref(v);
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
|
||||
mutex_lock m(*v->mu());
|
||||
DoCompute(c);
|
||||
} else if (use_exclusive_lock_) {
|
||||
@ -271,9 +270,8 @@ class ScatterNdUpdateOp : public OpKernel {
|
||||
TensorShape params_shape;
|
||||
|
||||
if (dtype_ == DT_RESOURCE) {
|
||||
Var* v;
|
||||
core::RefCountPtr<Var> v;
|
||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||
core::ScopedUnref scoped_unref(v);
|
||||
Tensor* t = v->tensor();
|
||||
params = *t;
|
||||
params_shape = params.shape();
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
// See docs in ../ops/array_ops.cc.
|
||||
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -323,12 +324,11 @@ class StridedSliceAssignOp : public OpKernel {
|
||||
}
|
||||
} else {
|
||||
if (context->input_dtype(0) == DT_RESOURCE) {
|
||||
Var* v;
|
||||
core::RefCountPtr<Var> v;
|
||||
OP_REQUIRES_OK(
|
||||
context, LookupResource(context, HandleFromInput(context, 0), &v));
|
||||
core::ScopedUnref scoped_unref(v);
|
||||
OP_REQUIRES_OK(context,
|
||||
EnsureSparseVariableAccess<Device, T>(context, v));
|
||||
EnsureSparseVariableAccess<Device, T>(context, v.get()));
|
||||
mutex_lock ml(*v->mu());
|
||||
old_lhs = v->tensor();
|
||||
OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/db/sqlite.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/summary/schema.h"
|
||||
@ -45,7 +46,7 @@ class CreateSummaryFileWriterOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
|
||||
const string filename_suffix = tmp->scalar<string>()();
|
||||
|
||||
SummaryWriterInterface* s = nullptr;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
|
||||
ctx, HandleFromInput(ctx, 0), &s,
|
||||
[max_queue, flush_millis, logdir, filename_suffix,
|
||||
@ -54,7 +55,6 @@ class CreateSummaryFileWriterOp : public OpKernel {
|
||||
max_queue, flush_millis, logdir,
|
||||
filename_suffix, ctx->env(), s);
|
||||
}));
|
||||
core::ScopedUnref unref(s);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
|
||||
@ -75,7 +75,7 @@ class CreateSummaryDbWriterOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
|
||||
const string user_name = tmp->scalar<string>()();
|
||||
|
||||
SummaryWriterInterface* s = nullptr;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
LookupOrCreateResource<SummaryWriterInterface>(
|
||||
@ -91,7 +91,6 @@ class CreateSummaryDbWriterOp : public OpKernel {
|
||||
db, experiment_name, run_name, user_name, ctx->env(), s));
|
||||
return Status::OK();
|
||||
}));
|
||||
core::ScopedUnref unref(s);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
|
||||
@ -102,9 +101,8 @@ class FlushSummaryWriterOp : public OpKernel {
|
||||
explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
OP_REQUIRES_OK(ctx, s->Flush());
|
||||
}
|
||||
};
|
||||
@ -128,9 +126,8 @@ class WriteSummaryOp : public OpKernel {
|
||||
explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
const Tensor* tmp;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||
const int64 step = tmp->scalar<int64>()();
|
||||
@ -153,9 +150,8 @@ class WriteRawProtoSummaryOp : public OpKernel {
|
||||
explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
const Tensor* tmp;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
|
||||
@ -190,9 +186,8 @@ class ImportEventOp : public OpKernel {
|
||||
explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
const Tensor* t;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("event", &t));
|
||||
std::unique_ptr<Event> event{new Event};
|
||||
@ -211,9 +206,8 @@ class WriteScalarSummaryOp : public OpKernel {
|
||||
explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
const Tensor* tmp;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||
const int64 step = tmp->scalar<int64>()();
|
||||
@ -234,9 +228,8 @@ class WriteHistogramSummaryOp : public OpKernel {
|
||||
explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
const Tensor* tmp;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||
const int64 step = tmp->scalar<int64>()();
|
||||
@ -263,9 +256,8 @@ class WriteImageSummaryOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
const Tensor* tmp;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||
const int64 step = tmp->scalar<int64>()();
|
||||
@ -299,9 +291,8 @@ class WriteAudioSummaryOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
const Tensor* tmp;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||
const int64 step = tmp->scalar<int64>()();
|
||||
@ -328,9 +319,8 @@ class WriteGraphSummaryOp : public OpKernel {
|
||||
explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
core::RefCountPtr<SummaryWriterInterface> s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
const Tensor* t;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("step", &t));
|
||||
const int64 step = t->scalar<int64>()();
|
||||
|
@ -26,6 +26,7 @@ tf_kernel_library(
|
||||
":resources",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||
],
|
||||
)
|
||||
@ -37,6 +38,7 @@ tf_kernel_library(
|
||||
":resources",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||
],
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
@ -29,12 +30,10 @@ class TensorForestTreePredictOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorForestTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const Tensor* dense_features_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("dense_features", &dense_features_t));
|
||||
@ -60,7 +59,7 @@ class TensorForestTreePredictOp : public OpKernel {
|
||||
// We will need to run it on a number of trees of diff depth
|
||||
// and see the num of cpu cycles
|
||||
const int64 cost_per_traverse = 500;
|
||||
auto traverse = [this, &out, &dense_features, decision_tree_resource,
|
||||
auto traverse = [this, &out, &dense_features, &decision_tree_resource,
|
||||
batch_size](int64 start, int64 end) {
|
||||
DCHECK_LE(start, end) << "Start exceeding End";
|
||||
DCHECK_LE(end, batch_size) << "End exceeding batch size";
|
||||
@ -74,9 +73,10 @@ class TensorForestTreePredictOp : public OpKernel {
|
||||
traverse);
|
||||
};
|
||||
|
||||
void set_output_value(const int32 example_id, const int32 leaf_id,
|
||||
const TensorForestTreeResource* decision_tree_resource,
|
||||
TTypes<float>::Matrix* out) const {
|
||||
void set_output_value(
|
||||
const int32 example_id, const int32 leaf_id,
|
||||
const core::RefCountPtr<TensorForestTreeResource>& decision_tree_resource,
|
||||
TTypes<float>::Matrix* out) const {
|
||||
for (int j = 0; j < logits_dimension_; ++j) {
|
||||
const float logit = decision_tree_resource->get_prediction(leaf_id, j);
|
||||
(*out)(example_id, j) = logit;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||||
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -55,11 +56,10 @@ class TensorForestTreeSerializeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorForestTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
Tensor* output_config_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
||||
@ -74,13 +74,11 @@ class TensorForestTreeDeserializeOp : public OpKernel {
|
||||
explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorForestTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const Tensor* tree_config_t;
|
||||
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
||||
|
||||
@ -102,11 +100,10 @@ class TensorForestTreeSizeOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorForestTreeResource* decision_tree_resource;
|
||||
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
Tensor* output_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape(), &output_t));
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -167,10 +168,8 @@ VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
|
||||
std::sort(acquire_order.begin(), acquire_order.end(),
|
||||
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
||||
|
||||
std::unique_ptr<std::vector<mutex_lock>> locks =
|
||||
absl::make_unique<std::vector<mutex_lock>>();
|
||||
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks =
|
||||
absl::make_unique<std::vector<tf_shared_lock>>();
|
||||
auto locks = absl::make_unique<std::vector<mutex_lock>>();
|
||||
auto shared_locks = absl::make_unique<std::vector<tf_shared_lock>>();
|
||||
locks->reserve(acquire_order.size());
|
||||
|
||||
for (auto input : acquire_order) {
|
||||
@ -241,11 +240,10 @@ template <typename Device, typename T>
|
||||
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
|
||||
bool lock_held, bool sparse, Tensor* out) {
|
||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||
Var* var;
|
||||
core::RefCountPtr<Var> var;
|
||||
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
|
||||
core::ScopedUnref unref_var(var);
|
||||
if (sparse) {
|
||||
TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var));
|
||||
TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
|
||||
*out = *var->tensor();
|
||||
return Status::OK();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user