Use RefCountPtr in LookupResource to avoid leaks
LookupResource returns a raw pointer which the caller needs to Unref. The prevalent pattern is this is followed by a ScopedUnref. This can be problematic, since if a caller forgets to add a ScopedUnref call, we have a memory leak. We resolve this by using RefCountPtr instead of a raw pointer in LookupResource. Most use cases have been migrated in this change. Note some variables were renamed to handle line length restrictions. PiperOrigin-RevId: 250423227
This commit is contained in:
parent
4941b4a73f
commit
2e758833c3
tensorflow
compiler
jit
tf2tensorrt/kernels
contrib
bigtable/kernels
bigtable_kernels.ccbigtable_lookup_dataset_op.ccbigtable_prefix_key_dataset_op.ccbigtable_range_key_dataset_op.ccbigtable_sample_key_pairs_dataset_op.ccbigtable_sample_keys_dataset_op.ccbigtable_scan_dataset_op.cc
boosted_trees/kernels
framework/kernels
tensor_forest/kernels
core
framework
kernels
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -31,12 +32,11 @@ std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
|
|||||||
std::map<int, OptionalTensor> variables;
|
std::map<int, OptionalTensor> variables;
|
||||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||||
if (ctx->input(i).dtype() == DT_RESOURCE) {
|
if (ctx->input(i).dtype() == DT_RESOURCE) {
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
ResourceHandle handle = HandleFromInput(ctx, i);
|
ResourceHandle handle = HandleFromInput(ctx, i);
|
||||||
OptionalTensor& optional = variables[i];
|
OptionalTensor& optional = variables[i];
|
||||||
optional.name = handle.name();
|
optional.name = handle.name();
|
||||||
if (LookupResource(ctx, handle, &variable).ok()) {
|
if (LookupResource(ctx, handle, &variable).ok()) {
|
||||||
core::ScopedUnref scoped_unref(variable);
|
|
||||||
tf_shared_lock lock(*variable->mu());
|
tf_shared_lock lock(*variable->mu());
|
||||||
optional.present = true;
|
optional.present = true;
|
||||||
optional.value = *variable->tensor();
|
optional.value = *variable->tensor();
|
||||||
|
@ -40,7 +40,7 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
|
|||||||
"Variable and value dtypes don't match; respectively, ",
|
"Variable and value dtypes don't match; respectively, ",
|
||||||
DataTypeString(dtype_), " and ",
|
DataTypeString(dtype_), " and ",
|
||||||
DataTypeString(context->input(1).dtype())));
|
DataTypeString(context->input(1).dtype())));
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
const Tensor& value = context->input(1);
|
const Tensor& value = context->input(1);
|
||||||
// Note: every resource-variable-manipulating op assumes copy-on-write
|
// Note: every resource-variable-manipulating op assumes copy-on-write
|
||||||
// semantics, and creates a copy of the variable's Tensor if its refcount is
|
// semantics, and creates a copy of the variable's Tensor if its refcount is
|
||||||
@ -58,7 +58,6 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
|
|||||||
(*ptr)->is_initialized = true;
|
(*ptr)->is_initialized = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
core::ScopedUnref s(variable);
|
|
||||||
mutex_lock ml(*variable->mu());
|
mutex_lock ml(*variable->mu());
|
||||||
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
|
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/util/stream_executor_util.h"
|
#include "tensorflow/core/util/stream_executor_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -86,7 +87,7 @@ static Status GetVariableInfosFromCtxInputs(
|
|||||||
variable_indices, std::back_inserter(resource_handles),
|
variable_indices, std::back_inserter(resource_handles),
|
||||||
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
|
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
|
||||||
|
|
||||||
std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables;
|
std::vector<core::RefCountPtr<Var>> variables;
|
||||||
TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables));
|
TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables));
|
||||||
|
|
||||||
result->clear();
|
result->clear();
|
||||||
|
@ -84,9 +84,8 @@ class PopulateTRTEngineCache : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
ResourceHandle handle = HandleFromInput(ctx, 0);
|
ResourceHandle handle = HandleFromInput(ctx, 0);
|
||||||
TRTEngineCacheResource* resource = nullptr;
|
core::RefCountPtr<TRTEngineCacheResource> resource;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource));
|
||||||
core::ScopedUnref unref_me(resource);
|
|
||||||
|
|
||||||
auto allocator = resource->allocator_.get();
|
auto allocator = resource->allocator_.get();
|
||||||
OP_REQUIRES(ctx, allocator != nullptr,
|
OP_REQUIRES(ctx, allocator != nullptr,
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -139,19 +140,19 @@ class BigtableTableOp : public OpKernel {
|
|||||||
ResourceMgr* mgr = ctx->resource_manager();
|
ResourceMgr* mgr = ctx->resource_manager();
|
||||||
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
|
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
|
||||||
|
|
||||||
BigtableClientResource* client_resource;
|
core::RefCountPtr<BigtableClientResource> client_resource;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
|
||||||
core::ScopedUnref unref_client(client_resource);
|
|
||||||
|
|
||||||
BigtableTableResource* resource;
|
BigtableTableResource* resource;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(ctx,
|
||||||
ctx, mgr->LookupOrCreate<BigtableTableResource>(
|
mgr->LookupOrCreate<BigtableTableResource>(
|
||||||
cinfo_.container(), cinfo_.name(), &resource,
|
cinfo_.container(), cinfo_.name(), &resource,
|
||||||
[this, client_resource](BigtableTableResource** ret) {
|
[this, &client_resource](BigtableTableResource** ret) {
|
||||||
*ret = new BigtableTableResource(client_resource, table_);
|
*ret = new BigtableTableResource(
|
||||||
return Status::OK();
|
client_resource.get(), table_);
|
||||||
}));
|
return Status::OK();
|
||||||
|
}));
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
}
|
}
|
||||||
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
||||||
@ -236,10 +237,9 @@ class ToBigtableOp : public AsyncOpKernel {
|
|||||||
errors::InvalidArgument("timestamp must be >= -1"),
|
errors::InvalidArgument("timestamp must be >= -1"),
|
||||||
done);
|
done);
|
||||||
|
|
||||||
BigtableTableResource* resource;
|
core::RefCountPtr<BigtableTableResource> resource;
|
||||||
OP_REQUIRES_OK_ASYNC(
|
OP_REQUIRES_OK_ASYNC(
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
|
||||||
core::ScopedUnref resource_cleanup(resource);
|
|
||||||
|
|
||||||
std::vector<Tensor> components;
|
std::vector<Tensor> components;
|
||||||
components.reserve(dataset->output_dtypes().size());
|
components.reserve(dataset->output_dtypes().size());
|
||||||
|
@ -26,9 +26,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
DatasetBase** output) override {
|
DatasetBase** output) override {
|
||||||
BigtableTableResource* table;
|
core::RefCountPtr<BigtableTableResource> table;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
|
||||||
core::ScopedUnref scoped_unref(table);
|
|
||||||
|
|
||||||
std::vector<string> column_families;
|
std::vector<string> column_families;
|
||||||
std::vector<string> columns;
|
std::vector<string> columns;
|
||||||
@ -50,7 +49,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
*output =
|
*output =
|
||||||
new Dataset(ctx, input, table, std::move(column_families),
|
new Dataset(ctx, input, table.get(), std::move(column_families),
|
||||||
std::move(columns), output_types, std::move(output_shapes));
|
std::move(columns), output_types, std::move(output_shapes));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -28,12 +29,10 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
|
|||||||
string prefix;
|
string prefix;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
|
||||||
|
|
||||||
BigtableTableResource* resource;
|
core::RefCountPtr<BigtableTableResource> resource;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||||
core::ScopedUnref scoped_unref(resource);
|
*output = new Dataset(ctx, resource.get(), std::move(prefix));
|
||||||
|
|
||||||
*output = new Dataset(ctx, resource, std::move(prefix));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -31,13 +32,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
|
|||||||
string end_key;
|
string end_key;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
|
||||||
|
|
||||||
BigtableTableResource* resource;
|
core::RefCountPtr<BigtableTableResource> resource;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||||
core::ScopedUnref scoped_unref(resource);
|
*output = new Dataset(ctx, resource.get(), std::move(start_key),
|
||||||
|
std::move(end_key));
|
||||||
*output =
|
|
||||||
new Dataset(ctx, resource, std::move(start_key), std::move(end_key));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
|
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -35,10 +36,9 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
|
|||||||
string end_key;
|
string end_key;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
|
||||||
|
|
||||||
BigtableTableResource* resource;
|
core::RefCountPtr<BigtableTableResource> resource;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||||
core::ScopedUnref scoped_unref(resource);
|
|
||||||
|
|
||||||
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
|
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -49,7 +49,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
|
|||||||
"If prefix is specified, end_key must be empty."));
|
"If prefix is specified, end_key must be empty."));
|
||||||
}
|
}
|
||||||
|
|
||||||
*output = new Dataset(ctx, resource, std::move(prefix),
|
*output = new Dataset(ctx, resource.get(), std::move(prefix),
|
||||||
std::move(start_key), std::move(end_key));
|
std::move(start_key), std::move(end_key));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,11 +25,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
|
|||||||
using DatasetOpKernel::DatasetOpKernel;
|
using DatasetOpKernel::DatasetOpKernel;
|
||||||
|
|
||||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||||
BigtableTableResource* resource;
|
core::RefCountPtr<BigtableTableResource> resource;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||||
core::ScopedUnref scoped_unref(resource);
|
*output = new Dataset(ctx, resource.get());
|
||||||
*output = new Dataset(ctx, resource);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -64,10 +65,9 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
|
|||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Probability outside the range of (0, 1]. Got: ", probability));
|
"Probability outside the range of (0, 1]. Got: ", probability));
|
||||||
|
|
||||||
BigtableTableResource* resource;
|
core::RefCountPtr<BigtableTableResource> resource;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||||
core::ScopedUnref scoped_unref(resource);
|
|
||||||
|
|
||||||
const uint64 num_outputs = columns.size() + 1;
|
const uint64 num_outputs = columns.size() + 1;
|
||||||
std::vector<PartialTensorShape> output_shapes;
|
std::vector<PartialTensorShape> output_shapes;
|
||||||
@ -79,7 +79,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
|
|||||||
output_types.push_back(DT_STRING);
|
output_types.push_back(DT_STRING);
|
||||||
}
|
}
|
||||||
|
|
||||||
*output = new Dataset(ctx, resource, std::move(prefix),
|
*output = new Dataset(ctx, resource.get(), std::move(prefix),
|
||||||
std::move(start_key), std::move(end_key),
|
std::move(start_key), std::move(end_key),
|
||||||
std::move(column_families), std::move(columns),
|
std::move(column_families), std::move(columns),
|
||||||
probability, output_types, std::move(output_shapes));
|
probability, output_types, std::move(output_shapes));
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -44,7 +45,7 @@ class CreateTreeEnsembleVariableOp : public OpKernel {
|
|||||||
const Tensor* tree_ensemble_config_t;
|
const Tensor* tree_ensemble_config_t;
|
||||||
OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
|
OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
|
||||||
&tree_ensemble_config_t));
|
&tree_ensemble_config_t));
|
||||||
auto* result = new boosted_trees::models::DecisionTreeEnsembleResource();
|
auto* result = new DecisionTreeEnsembleResource();
|
||||||
if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<string>()(),
|
if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<string>()(),
|
||||||
stamp_token)) {
|
stamp_token)) {
|
||||||
result->Unref();
|
result->Unref();
|
||||||
@ -69,11 +70,10 @@ class TreeEnsembleStampTokenOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
Tensor* output_stamp_token_t = nullptr;
|
Tensor* output_stamp_token_t = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||||
&output_stamp_token_t));
|
&output_stamp_token_t));
|
||||||
@ -88,11 +88,10 @@ class TreeEnsembleSerializeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
Tensor* output_stamp_token_t = nullptr;
|
Tensor* output_stamp_token_t = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||||
&output_stamp_token_t));
|
&output_stamp_token_t));
|
||||||
@ -112,11 +111,10 @@ class TreeEnsembleDeserializeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
mutex_lock l(*ensemble_resource->get_mutex());
|
mutex_lock l(*ensemble_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
|
|
||||||
// Get the stamp token.
|
// Get the stamp token.
|
||||||
const Tensor* stamp_token_t;
|
const Tensor* stamp_token_t;
|
||||||
@ -146,12 +144,11 @@ class TreeEnsembleUsedHandlersOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||||
|
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
|
|
||||||
// Get the stamp token.
|
// Get the stamp token.
|
||||||
const Tensor* stamp_token_t;
|
const Tensor* stamp_token_t;
|
||||||
@ -194,9 +191,8 @@ class TreeEnsembleUsedHandlersOp : public OpKernel {
|
|||||||
|
|
||||||
REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource);
|
REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource);
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU),
|
||||||
Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU),
|
IsResourceInitialized<DecisionTreeEnsembleResource>);
|
||||||
IsResourceInitialized<boosted_trees::models::DecisionTreeEnsembleResource>);
|
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("CreateTreeEnsembleVariable").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("CreateTreeEnsembleVariable").Device(DEVICE_CPU),
|
||||||
CreateTreeEnsembleVariableOp);
|
CreateTreeEnsembleVariableOp);
|
||||||
|
@ -163,12 +163,11 @@ class GradientTreesPredictionOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
DecisionTreeEnsembleResource* ensemble_resource;
|
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||||
// Gets the resource. Grabs the mutex but releases it.
|
// Gets the resource. Grabs the mutex but releases it.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
// Release the reference to the resource once we're done using it.
|
// Release the reference to the resource once we're done using it.
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
if (use_locking_) {
|
if (use_locking_) {
|
||||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||||
DoCompute(context, ensemble_resource,
|
DoCompute(context, ensemble_resource,
|
||||||
@ -184,9 +183,10 @@ class GradientTreesPredictionOp : public OpKernel {
|
|||||||
// leaf index in prediction. Though this class invokes only with this param
|
// leaf index in prediction. Though this class invokes only with this param
|
||||||
// value as false, the subclass GradientTreesPredictionVerboseOp will invoke
|
// value as false, the subclass GradientTreesPredictionVerboseOp will invoke
|
||||||
// with the true value.
|
// with the true value.
|
||||||
virtual void DoCompute(OpKernelContext* context,
|
virtual void DoCompute(
|
||||||
DecisionTreeEnsembleResource* ensemble_resource,
|
OpKernelContext* context,
|
||||||
const bool return_output_leaf_index) {
|
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||||
|
const bool return_output_leaf_index) {
|
||||||
// Read dense float features list;
|
// Read dense float features list;
|
||||||
OpInputList dense_float_features_list;
|
OpInputList dense_float_features_list;
|
||||||
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
|
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
|
||||||
@ -352,9 +352,10 @@ class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp {
|
|||||||
: GradientTreesPredictionOp(context) {}
|
: GradientTreesPredictionOp(context) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void DoCompute(OpKernelContext* context,
|
void DoCompute(
|
||||||
DecisionTreeEnsembleResource* ensemble_resource,
|
OpKernelContext* context,
|
||||||
bool return_output_leaf_index) override {
|
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||||
|
bool return_output_leaf_index) override {
|
||||||
GradientTreesPredictionOp::DoCompute(context, ensemble_resource,
|
GradientTreesPredictionOp::DoCompute(context, ensemble_resource,
|
||||||
/*return_output_leaf_index=*/true);
|
/*return_output_leaf_index=*/true);
|
||||||
}
|
}
|
||||||
@ -372,12 +373,10 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
DecisionTreeEnsembleResource* ensemble_resource;
|
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||||
// Gets the resource. Grabs the mutex but releases it.
|
// Gets the resource. Grabs the mutex but releases it.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
// Release the reference to the resource once we're done using it.
|
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
if (use_locking_) {
|
if (use_locking_) {
|
||||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||||
DoCompute(context, ensemble_resource);
|
DoCompute(context, ensemble_resource);
|
||||||
@ -387,17 +386,18 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void DoCompute(OpKernelContext* context,
|
void DoCompute(
|
||||||
DecisionTreeEnsembleResource* ensemble_resource) {
|
OpKernelContext* context,
|
||||||
|
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
|
||||||
// The last non-finalized tree in the ensemble is by convention the
|
// The last non-finalized tree in the ensemble is by convention the
|
||||||
// one to partition on. If no such tree exists, a nodeless tree is
|
// one to partition on. If no such tree exists, a nodeless tree is
|
||||||
// created.
|
// created.
|
||||||
boosted_trees::trees::DecisionTreeConfig empty_tree_config;
|
boosted_trees::trees::DecisionTreeConfig empty_tree_config;
|
||||||
const boosted_trees::trees::DecisionTreeConfig& tree_config =
|
const boosted_trees::trees::DecisionTreeConfig& tree_config =
|
||||||
(ensemble_resource->num_trees() <= 0 ||
|
(resource->num_trees() <= 0 ||
|
||||||
ensemble_resource->LastTreeMetadata()->is_finalized())
|
resource->LastTreeMetadata()->is_finalized())
|
||||||
? empty_tree_config
|
? empty_tree_config
|
||||||
: *ensemble_resource->LastTree();
|
: *resource->LastTree();
|
||||||
|
|
||||||
// Read dense float features list;
|
// Read dense float features list;
|
||||||
OpInputList dense_float_features_list;
|
OpInputList dense_float_features_list;
|
||||||
|
@ -28,6 +28,7 @@
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -299,13 +300,12 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel {
|
|||||||
const ResourceHandle& handle =
|
const ResourceHandle& handle =
|
||||||
resource_handle_list[resource_handle_idx]
|
resource_handle_list[resource_handle_idx]
|
||||||
.flat<ResourceHandle>()(0);
|
.flat<ResourceHandle>()(0);
|
||||||
QuantileStreamResource* streams_resource;
|
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||||
// Create a reference to the underlying resource using the handle.
|
// Create a reference to the underlying resource using the handle.
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
LookupResource(context, handle, &streams_resource));
|
LookupResource(context, handle, &streams_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*streams_resource->mutex());
|
mutex_lock l(*streams_resource->mutex());
|
||||||
core::ScopedUnref unref_me(streams_resource);
|
|
||||||
|
|
||||||
// If the stamp is invalid we drop the update.
|
// If the stamp is invalid we drop the update.
|
||||||
if (!streams_resource->is_stamp_valid(stamp_token)) {
|
if (!streams_resource->is_stamp_valid(stamp_token)) {
|
||||||
@ -467,13 +467,12 @@ class QuantileAccumulatorSerializeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
QuantileStreamResource* streams_resource;
|
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||||
// Create a reference to the underlying resource using the handle.
|
// Create a reference to the underlying resource using the handle.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&streams_resource));
|
&streams_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*streams_resource->mutex());
|
mutex_lock l(*streams_resource->mutex());
|
||||||
core::ScopedUnref unref_me(streams_resource);
|
|
||||||
|
|
||||||
int64 stamp_token = streams_resource->stamp();
|
int64 stamp_token = streams_resource->stamp();
|
||||||
Tensor* stream_state_t;
|
Tensor* stream_state_t;
|
||||||
@ -526,13 +525,12 @@ class QuantileAccumulatorDeserializeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
QuantileStreamResource* streams_resource;
|
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||||
// Create a reference to the underlying resource using the handle.
|
// Create a reference to the underlying resource using the handle.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&streams_resource));
|
&streams_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*streams_resource->mutex());
|
mutex_lock l(*streams_resource->mutex());
|
||||||
core::ScopedUnref unref_me(streams_resource);
|
|
||||||
|
|
||||||
int64 old_stamp_token = streams_resource->stamp();
|
int64 old_stamp_token = streams_resource->stamp();
|
||||||
|
|
||||||
@ -595,13 +593,12 @@ class QuantileAccumulatorFlushOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
QuantileStreamResource* streams_resource;
|
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||||
// Create a reference to the underlying resource using the handle.
|
// Create a reference to the underlying resource using the handle.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&streams_resource));
|
&streams_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*streams_resource->mutex());
|
mutex_lock l(*streams_resource->mutex());
|
||||||
core::ScopedUnref unref_me(streams_resource);
|
|
||||||
|
|
||||||
const Tensor* next_stamp_token_t;
|
const Tensor* next_stamp_token_t;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
@ -641,13 +638,12 @@ class QuantileAccumulatorFlushSummaryOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
QuantileStreamResource* streams_resource;
|
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||||
// Create a reference to the underlying resource using the handle.
|
// Create a reference to the underlying resource using the handle.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&streams_resource));
|
&streams_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*streams_resource->mutex());
|
mutex_lock l(*streams_resource->mutex());
|
||||||
core::ScopedUnref unref_me(streams_resource);
|
|
||||||
|
|
||||||
const Tensor* next_stamp_token_t;
|
const Tensor* next_stamp_token_t;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
@ -713,12 +709,11 @@ class QuantileAccumulatorGetBucketsOp : public OpKernel {
|
|||||||
const ResourceHandle& handle =
|
const ResourceHandle& handle =
|
||||||
resource_handle_list[resource_handle_idx]
|
resource_handle_list[resource_handle_idx]
|
||||||
.flat<ResourceHandle>()(0);
|
.flat<ResourceHandle>()(0);
|
||||||
QuantileStreamResource* streams_resource;
|
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
LookupResource(context, handle, &streams_resource));
|
LookupResource(context, handle, &streams_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*streams_resource->mutex());
|
mutex_lock l(*streams_resource->mutex());
|
||||||
core::ScopedUnref unref_me(streams_resource);
|
|
||||||
|
|
||||||
bool are_buckets_ready =
|
bool are_buckets_ready =
|
||||||
streams_resource->is_stamp_valid(stamp_token) &&
|
streams_resource->is_stamp_valid(stamp_token) &&
|
||||||
|
@ -27,6 +27,7 @@
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
@ -130,9 +131,8 @@ using StatsAccumulatorTensorResource =
|
|||||||
StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
|
StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
|
||||||
|
|
||||||
void SerializeScalarAccumulatorToOutput(
|
void SerializeScalarAccumulatorToOutput(
|
||||||
const StatsAccumulatorScalarResource& accumulator_resource,
|
const StatsAccumulatorScalarResource& resource, OpKernelContext* context) {
|
||||||
OpKernelContext* context) {
|
int64 num_slots = resource.values().size();
|
||||||
int64 num_slots = accumulator_resource.values().size();
|
|
||||||
Tensor* partition_ids_t = nullptr;
|
Tensor* partition_ids_t = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
|
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
|
||||||
TensorShape({num_slots}),
|
TensorShape({num_slots}),
|
||||||
@ -159,7 +159,7 @@ void SerializeScalarAccumulatorToOutput(
|
|||||||
auto hessians = hessians_t->vec<float>();
|
auto hessians = hessians_t->vec<float>();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto& iter : accumulator_resource.values()) {
|
for (const auto& iter : resource.values()) {
|
||||||
partition_ids(i) = iter.first.partition_id;
|
partition_ids(i) = iter.first.partition_id;
|
||||||
feature_ids(i, 0) = iter.first.feature_id;
|
feature_ids(i, 0) = iter.first.feature_id;
|
||||||
feature_ids(i, 1) = iter.first.dimension;
|
feature_ids(i, 1) = iter.first.dimension;
|
||||||
@ -171,9 +171,8 @@ void SerializeScalarAccumulatorToOutput(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SerializeTensorAccumulatorToOutput(
|
void SerializeTensorAccumulatorToOutput(
|
||||||
const StatsAccumulatorTensorResource& accumulator_resource,
|
const StatsAccumulatorTensorResource& resource, OpKernelContext* context) {
|
||||||
OpKernelContext* context) {
|
int64 num_slots = resource.values().size();
|
||||||
int64 num_slots = accumulator_resource.values().size();
|
|
||||||
Tensor* partition_ids_t = nullptr;
|
Tensor* partition_ids_t = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
|
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
|
||||||
TensorShape({num_slots}),
|
TensorShape({num_slots}),
|
||||||
@ -186,7 +185,7 @@ void SerializeTensorAccumulatorToOutput(
|
|||||||
&feature_ids_t));
|
&feature_ids_t));
|
||||||
auto feature_ids = feature_ids_t->matrix<int64>();
|
auto feature_ids = feature_ids_t->matrix<int64>();
|
||||||
|
|
||||||
TensorShape gradient_shape = accumulator_resource.gradient_shape();
|
TensorShape gradient_shape = resource.gradient_shape();
|
||||||
int64 num_gradient_elements = gradient_shape.num_elements();
|
int64 num_gradient_elements = gradient_shape.num_elements();
|
||||||
gradient_shape.InsertDim(0, num_slots);
|
gradient_shape.InsertDim(0, num_slots);
|
||||||
Tensor* gradients_t = nullptr;
|
Tensor* gradients_t = nullptr;
|
||||||
@ -195,7 +194,7 @@ void SerializeTensorAccumulatorToOutput(
|
|||||||
&gradients_t));
|
&gradients_t));
|
||||||
auto gradients = gradients_t->flat_outer_dims<float>();
|
auto gradients = gradients_t->flat_outer_dims<float>();
|
||||||
|
|
||||||
TensorShape hessian_shape = accumulator_resource.hessian_shape();
|
TensorShape hessian_shape = resource.hessian_shape();
|
||||||
int64 num_hessian_elements = hessian_shape.num_elements();
|
int64 num_hessian_elements = hessian_shape.num_elements();
|
||||||
hessian_shape.InsertDim(0, num_slots);
|
hessian_shape.InsertDim(0, num_slots);
|
||||||
Tensor* hessians_t = nullptr;
|
Tensor* hessians_t = nullptr;
|
||||||
@ -204,7 +203,7 @@ void SerializeTensorAccumulatorToOutput(
|
|||||||
auto hessians = hessians_t->flat_outer_dims<float>();
|
auto hessians = hessians_t->flat_outer_dims<float>();
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto& iter : accumulator_resource.values()) {
|
for (const auto& iter : resource.values()) {
|
||||||
partition_ids(i) = iter.first.partition_id;
|
partition_ids(i) = iter.first.partition_id;
|
||||||
feature_ids(i, 0) = iter.first.feature_id;
|
feature_ids(i, 0) = iter.first.feature_id;
|
||||||
feature_ids(i, 1) = iter.first.dimension;
|
feature_ids(i, 1) = iter.first.dimension;
|
||||||
@ -220,11 +219,10 @@ void SerializeTensorAccumulatorToOutput(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AddToScalarAccumulator(
|
void AddToScalarAccumulator(
|
||||||
StatsAccumulatorScalarResource* accumulator_resource,
|
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
|
||||||
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
|
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
|
||||||
const Tensor& gradients_t, const Tensor& hessians_t) {
|
const Tensor& gradients_t, const Tensor& hessians_t) {
|
||||||
accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
|
resource->set_num_updates(resource->num_updates() + 1);
|
||||||
1);
|
|
||||||
const TensorShape& partition_ids_shape = partition_ids_t.shape();
|
const TensorShape& partition_ids_shape = partition_ids_t.shape();
|
||||||
const auto& partition_ids = partition_ids_t.vec<int32>();
|
const auto& partition_ids = partition_ids_t.vec<int32>();
|
||||||
const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
|
const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
|
||||||
@ -232,7 +230,7 @@ void AddToScalarAccumulator(
|
|||||||
const auto& hessians = hessians_t.vec<float>();
|
const auto& hessians = hessians_t.vec<float>();
|
||||||
|
|
||||||
int64 num_updates = partition_ids_shape.dim_size(0);
|
int64 num_updates = partition_ids_shape.dim_size(0);
|
||||||
auto stats_map = accumulator_resource->mutable_values();
|
auto stats_map = resource->mutable_values();
|
||||||
for (int64 i = 0; i < num_updates; ++i) {
|
for (int64 i = 0; i < num_updates; ++i) {
|
||||||
const auto key =
|
const auto key =
|
||||||
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
|
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
|
||||||
@ -248,7 +246,7 @@ void AddToScalarAccumulator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AddToScalarAccumulator(
|
void AddToScalarAccumulator(
|
||||||
StatsAccumulatorScalarResource* accumulator_resource,
|
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
|
||||||
OpKernelContext* context) {
|
OpKernelContext* context) {
|
||||||
const Tensor* partition_ids_t;
|
const Tensor* partition_ids_t;
|
||||||
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
|
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
|
||||||
@ -258,17 +256,16 @@ void AddToScalarAccumulator(
|
|||||||
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
|
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
|
||||||
const Tensor* hessians_t;
|
const Tensor* hessians_t;
|
||||||
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
||||||
AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
|
AddToScalarAccumulator(resource, *partition_ids_t, *feature_ids_t,
|
||||||
*gradients_t, *hessians_t);
|
*gradients_t, *hessians_t);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddToTensorAccumulator(
|
void AddToTensorAccumulator(
|
||||||
StatsAccumulatorTensorResource* accumulator_resource,
|
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
|
||||||
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
|
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
|
||||||
const Tensor& gradients_t, const Tensor& hessians_t,
|
const Tensor& gradients_t, const Tensor& hessians_t,
|
||||||
OpKernelContext* context) {
|
OpKernelContext* context) {
|
||||||
accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
|
resource->set_num_updates(resource->num_updates() + 1);
|
||||||
1);
|
|
||||||
|
|
||||||
const TensorShape& partition_ids_shape = partition_ids_t.shape();
|
const TensorShape& partition_ids_shape = partition_ids_t.shape();
|
||||||
const auto& partition_ids = partition_ids_t.vec<int32>();
|
const auto& partition_ids = partition_ids_t.vec<int32>();
|
||||||
@ -283,19 +280,19 @@ void AddToTensorAccumulator(
|
|||||||
|
|
||||||
// TODO(soroush): Move gradient and hessian shape check to ShapeFn.
|
// TODO(soroush): Move gradient and hessian shape check to ShapeFn.
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, gradients_shape == accumulator_resource->gradient_shape(),
|
context, gradients_shape == resource->gradient_shape(),
|
||||||
errors::InvalidArgument(strings::StrCat(
|
errors::InvalidArgument(strings::StrCat(
|
||||||
"Gradients dimensions must match: ", gradients_shape.DebugString(),
|
"Gradients dimensions must match: ", gradients_shape.DebugString(),
|
||||||
", ", accumulator_resource->gradient_shape().DebugString())));
|
", ", resource->gradient_shape().DebugString())));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, hessians_shape == accumulator_resource->hessian_shape(),
|
context, hessians_shape == resource->hessian_shape(),
|
||||||
errors::InvalidArgument(strings::StrCat(
|
errors::InvalidArgument(strings::StrCat(
|
||||||
"Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
|
"Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
|
||||||
accumulator_resource->hessian_shape().DebugString())));
|
resource->hessian_shape().DebugString())));
|
||||||
|
|
||||||
int64 num_updates = partition_ids_shape.dim_size(0);
|
int64 num_updates = partition_ids_shape.dim_size(0);
|
||||||
auto stats_map = accumulator_resource->mutable_values();
|
auto stats_map = resource->mutable_values();
|
||||||
for (int64 i = 0; i < num_updates; ++i) {
|
for (int64 i = 0; i < num_updates; ++i) {
|
||||||
const auto key =
|
const auto key =
|
||||||
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
|
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
|
||||||
@ -325,7 +322,7 @@ void AddToTensorAccumulator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AddToTensorAccumulator(
|
void AddToTensorAccumulator(
|
||||||
StatsAccumulatorTensorResource* accumulator_resource,
|
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
|
||||||
OpKernelContext* context) {
|
OpKernelContext* context) {
|
||||||
const Tensor* partition_ids_t;
|
const Tensor* partition_ids_t;
|
||||||
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
|
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
|
||||||
@ -335,7 +332,7 @@ void AddToTensorAccumulator(
|
|||||||
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
|
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
|
||||||
const Tensor* hessians_t;
|
const Tensor* hessians_t;
|
||||||
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
||||||
AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
|
AddToTensorAccumulator(resource, *partition_ids_t, *feature_ids_t,
|
||||||
*gradients_t, *hessians_t, context);
|
*gradients_t, *hessians_t, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -452,20 +449,18 @@ class StatsAccumulatorScalarAddOp : public OpKernel {
|
|||||||
resource_handle_list[resource_handle_idx]
|
resource_handle_list[resource_handle_idx]
|
||||||
.flat<ResourceHandle>()(0);
|
.flat<ResourceHandle>()(0);
|
||||||
|
|
||||||
StatsAccumulatorScalarResource* accumulator_resource;
|
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, handle,
|
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
|
||||||
&accumulator_resource));
|
mutex_lock l(*resource->mutex());
|
||||||
mutex_lock l(*accumulator_resource->mutex());
|
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
|
||||||
|
|
||||||
// If the stamp is invalid we drop the update.
|
// If the stamp is invalid we drop the update.
|
||||||
if (!accumulator_resource->is_stamp_valid(stamp_token)) {
|
if (!resource->is_stamp_valid(stamp_token)) {
|
||||||
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
|
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
|
||||||
<< "Passed stamp token: " << stamp_token << " "
|
<< "Passed stamp token: " << stamp_token << " "
|
||||||
<< "Current token: " << accumulator_resource->stamp();
|
<< "Current token: " << resource->stamp();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
AddToScalarAccumulator(accumulator_resource,
|
AddToScalarAccumulator(resource,
|
||||||
partition_ids_list[resource_handle_idx],
|
partition_ids_list[resource_handle_idx],
|
||||||
feature_ids_list[resource_handle_idx],
|
feature_ids_list[resource_handle_idx],
|
||||||
gradients_list[resource_handle_idx],
|
gradients_list[resource_handle_idx],
|
||||||
@ -517,20 +512,18 @@ class StatsAccumulatorTensorAddOp : public OpKernel {
|
|||||||
resource_handle_list[resource_handle_idx]
|
resource_handle_list[resource_handle_idx]
|
||||||
.flat<ResourceHandle>()(0);
|
.flat<ResourceHandle>()(0);
|
||||||
|
|
||||||
StatsAccumulatorTensorResource* accumulator_resource;
|
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, handle,
|
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
|
||||||
&accumulator_resource));
|
mutex_lock l(*resource->mutex());
|
||||||
mutex_lock l(*accumulator_resource->mutex());
|
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
|
||||||
|
|
||||||
// If the stamp is invalid we drop the update.
|
// If the stamp is invalid we drop the update.
|
||||||
if (!accumulator_resource->is_stamp_valid(stamp_token)) {
|
if (!resource->is_stamp_valid(stamp_token)) {
|
||||||
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
|
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
|
||||||
<< "Passed stamp token: " << stamp_token << " "
|
<< "Passed stamp token: " << stamp_token << " "
|
||||||
<< "Current token: " << accumulator_resource->stamp();
|
<< "Current token: " << resource->stamp();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
AddToTensorAccumulator(accumulator_resource,
|
AddToTensorAccumulator(resource,
|
||||||
partition_ids_list[resource_handle_idx],
|
partition_ids_list[resource_handle_idx],
|
||||||
feature_ids_list[resource_handle_idx],
|
feature_ids_list[resource_handle_idx],
|
||||||
gradients_list[resource_handle_idx],
|
gradients_list[resource_handle_idx],
|
||||||
@ -549,11 +542,10 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
StatsAccumulatorScalarResource* accumulator_resource;
|
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&accumulator_resource));
|
&resource));
|
||||||
mutex_lock l(*accumulator_resource->mutex());
|
mutex_lock l(*resource->mutex());
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
|
||||||
|
|
||||||
const Tensor* stamp_token_t;
|
const Tensor* stamp_token_t;
|
||||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||||
@ -562,7 +554,7 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
|
|||||||
// If the stamp is invalid we restart the PS. It shouldn't happen since
|
// If the stamp is invalid we restart the PS. It shouldn't happen since
|
||||||
// only Chief should call this function and chief is guaranteed to be in
|
// only Chief should call this function and chief is guaranteed to be in
|
||||||
// a consistent state.
|
// a consistent state.
|
||||||
CHECK(accumulator_resource->is_stamp_valid(stamp_token));
|
CHECK(resource->is_stamp_valid(stamp_token));
|
||||||
|
|
||||||
const Tensor* next_stamp_token_t;
|
const Tensor* next_stamp_token_t;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
@ -570,15 +562,15 @@ class StatsAccumulatorScalarFlushOp : public OpKernel {
|
|||||||
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
|
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
|
||||||
CHECK(stamp_token != next_stamp_token);
|
CHECK(stamp_token != next_stamp_token);
|
||||||
|
|
||||||
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
|
SerializeScalarAccumulatorToOutput(*resource, context);
|
||||||
Tensor* num_updates_t = nullptr;
|
Tensor* num_updates_t = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output("num_updates", TensorShape({}),
|
context->allocate_output("num_updates", TensorShape({}),
|
||||||
&num_updates_t));
|
&num_updates_t));
|
||||||
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
|
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||||
|
|
||||||
accumulator_resource->Clear();
|
resource->Clear();
|
||||||
accumulator_resource->set_stamp(next_stamp_token);
|
resource->set_stamp(next_stamp_token);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -591,11 +583,10 @@ class StatsAccumulatorTensorFlushOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
StatsAccumulatorTensorResource* accumulator_resource;
|
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&accumulator_resource));
|
&resource));
|
||||||
mutex_lock l(*accumulator_resource->mutex());
|
mutex_lock l(*resource->mutex());
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
|
||||||
|
|
||||||
const Tensor* stamp_token_t;
|
const Tensor* stamp_token_t;
|
||||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||||
@ -609,16 +600,16 @@ class StatsAccumulatorTensorFlushOp : public OpKernel {
|
|||||||
// If the stamp is invalid we restart the PS. It shouldn't happen since
|
// If the stamp is invalid we restart the PS. It shouldn't happen since
|
||||||
// only Chief should call this function and chief is guaranteed to be in
|
// only Chief should call this function and chief is guaranteed to be in
|
||||||
// a consistent state.
|
// a consistent state.
|
||||||
CHECK(accumulator_resource->is_stamp_valid(stamp_token));
|
CHECK(resource->is_stamp_valid(stamp_token));
|
||||||
CHECK(stamp_token != next_stamp_token);
|
CHECK(stamp_token != next_stamp_token);
|
||||||
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
|
SerializeTensorAccumulatorToOutput(*resource, context);
|
||||||
Tensor* num_updates_t = nullptr;
|
Tensor* num_updates_t = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output("num_updates", TensorShape({}),
|
context->allocate_output("num_updates", TensorShape({}),
|
||||||
&num_updates_t));
|
&num_updates_t));
|
||||||
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
|
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||||
accumulator_resource->Clear();
|
resource->Clear();
|
||||||
accumulator_resource->set_stamp(next_stamp_token);
|
resource->set_stamp(next_stamp_token);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -631,22 +622,21 @@ class StatsAccumulatorScalarDeserializeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
StatsAccumulatorScalarResource* accumulator_resource;
|
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&accumulator_resource));
|
&resource));
|
||||||
mutex_lock l(*accumulator_resource->mutex());
|
mutex_lock l(*resource->mutex());
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
|
||||||
|
|
||||||
// Check the stamp token.
|
// Check the stamp token.
|
||||||
const Tensor* stamp_token_t;
|
const Tensor* stamp_token_t;
|
||||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||||
accumulator_resource->Clear();
|
resource->Clear();
|
||||||
accumulator_resource->set_stamp(stamp_token);
|
resource->set_stamp(stamp_token);
|
||||||
AddToScalarAccumulator(accumulator_resource, context);
|
AddToScalarAccumulator(resource, context);
|
||||||
const Tensor* num_updates_t;
|
const Tensor* num_updates_t;
|
||||||
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
|
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
|
||||||
accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
|
resource->set_num_updates(num_updates_t->scalar<int64>()());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -660,22 +650,21 @@ class StatsAccumulatorTensorDeserializeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
StatsAccumulatorTensorResource* accumulator_resource;
|
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&accumulator_resource));
|
&resource));
|
||||||
mutex_lock l(*accumulator_resource->mutex());
|
mutex_lock l(*resource->mutex());
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
|
||||||
|
|
||||||
// Check the stamp token.
|
// Check the stamp token.
|
||||||
const Tensor* stamp_token_t;
|
const Tensor* stamp_token_t;
|
||||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||||
accumulator_resource->Clear();
|
resource->Clear();
|
||||||
accumulator_resource->set_stamp(stamp_token);
|
resource->set_stamp(stamp_token);
|
||||||
AddToTensorAccumulator(accumulator_resource, context);
|
AddToTensorAccumulator(resource, context);
|
||||||
const Tensor* num_updates_t;
|
const Tensor* num_updates_t;
|
||||||
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
|
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
|
||||||
accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
|
resource->set_num_updates(num_updates_t->scalar<int64>()());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -689,23 +678,22 @@ class StatsAccumulatorScalarSerializeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
StatsAccumulatorScalarResource* accumulator_resource;
|
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&accumulator_resource));
|
&resource));
|
||||||
mutex_lock l(*accumulator_resource->mutex());
|
mutex_lock l(*resource->mutex());
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
SerializeScalarAccumulatorToOutput(*resource, context);
|
||||||
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
|
|
||||||
Tensor* stamp_token_t = nullptr;
|
Tensor* stamp_token_t = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output("stamp_token", TensorShape({}),
|
context->allocate_output("stamp_token", TensorShape({}),
|
||||||
&stamp_token_t));
|
&stamp_token_t));
|
||||||
stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
|
stamp_token_t->scalar<int64>()() = resource->stamp();
|
||||||
|
|
||||||
Tensor* num_updates_t = nullptr;
|
Tensor* num_updates_t = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output("num_updates", TensorShape({}),
|
context->allocate_output("num_updates", TensorShape({}),
|
||||||
&num_updates_t));
|
&num_updates_t));
|
||||||
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
|
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -719,23 +707,22 @@ class StatsAccumulatorTensorSerializeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
StatsAccumulatorTensorResource* accumulator_resource;
|
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&accumulator_resource));
|
&resource));
|
||||||
mutex_lock l(*accumulator_resource->mutex());
|
mutex_lock l(*resource->mutex());
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
SerializeTensorAccumulatorToOutput(*resource, context);
|
||||||
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
|
|
||||||
Tensor* stamp_token_t = nullptr;
|
Tensor* stamp_token_t = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output("stamp_token", TensorShape({}),
|
context->allocate_output("stamp_token", TensorShape({}),
|
||||||
&stamp_token_t));
|
&stamp_token_t));
|
||||||
stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
|
stamp_token_t->scalar<int64>()() = resource->stamp();
|
||||||
|
|
||||||
Tensor* num_updates_t = nullptr;
|
Tensor* num_updates_t = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output("num_updates", TensorShape({}),
|
context->allocate_output("num_updates", TensorShape({}),
|
||||||
&num_updates_t));
|
&num_updates_t));
|
||||||
num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
|
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -751,12 +738,11 @@ class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
TensorShape gradient_shape = TensorShape({});
|
TensorShape gradient_shape = TensorShape({});
|
||||||
TensorShape hessian_shape = TensorShape({});
|
TensorShape hessian_shape = TensorShape({});
|
||||||
StatsAccumulatorScalarResource* accumulator_resource =
|
core::RefCountPtr<StatsAccumulatorScalarResource> resource(
|
||||||
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
|
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape));
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
|
||||||
// Check the stamp token.
|
// Check the stamp token.
|
||||||
AddToScalarAccumulator(accumulator_resource, context);
|
AddToScalarAccumulator(resource, context);
|
||||||
SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
|
SerializeScalarAccumulatorToOutput(*resource, context);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -780,12 +766,11 @@ class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
|
|||||||
TensorShape hessians_shape = hessians_t->shape();
|
TensorShape hessians_shape = hessians_t->shape();
|
||||||
hessians_shape.RemoveDim(0);
|
hessians_shape.RemoveDim(0);
|
||||||
|
|
||||||
StatsAccumulatorTensorResource* accumulator_resource =
|
core::RefCountPtr<StatsAccumulatorTensorResource> resource(
|
||||||
new StatsAccumulatorTensorResource(gradients_shape, hessians_shape);
|
new StatsAccumulatorTensorResource(gradients_shape, hessians_shape));
|
||||||
core::ScopedUnref unref_me(accumulator_resource);
|
|
||||||
// Check the stamp token.
|
// Check the stamp token.
|
||||||
AddToTensorAccumulator(accumulator_resource, context);
|
AddToTensorAccumulator(resource, context);
|
||||||
SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
|
SerializeTensorAccumulatorToOutput(*resource, context);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
|
using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
|
||||||
@ -31,6 +32,8 @@ namespace {
|
|||||||
|
|
||||||
using boosted_trees::learner::LearnerConfig;
|
using boosted_trees::learner::LearnerConfig;
|
||||||
using boosted_trees::learner::LearningRateConfig;
|
using boosted_trees::learner::LearningRateConfig;
|
||||||
|
using boosted_trees::models::DecisionTreeEnsembleResource;
|
||||||
|
using boosted_trees::trees::DecisionTreeConfig;
|
||||||
using boosted_trees::trees::Leaf;
|
using boosted_trees::trees::Leaf;
|
||||||
using boosted_trees::trees::TreeNode;
|
using boosted_trees::trees::TreeNode;
|
||||||
using boosted_trees::trees::TreeNodeMetadata;
|
using boosted_trees::trees::TreeNodeMetadata;
|
||||||
@ -193,10 +196,9 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
// Get decision tree ensemble.
|
// Get decision tree ensemble.
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
mutex_lock l(*ensemble_resource->get_mutex());
|
mutex_lock l(*ensemble_resource->get_mutex());
|
||||||
|
|
||||||
// Get the stamp token.
|
// Get the stamp token.
|
||||||
@ -255,8 +257,8 @@ class CenterTreeEnsembleBiasOp : public OpKernel {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Helper method to retrieve the bias from the tree ensemble.
|
// Helper method to retrieve the bias from the tree ensemble.
|
||||||
boosted_trees::trees::Leaf* RetrieveBias(
|
Leaf* RetrieveBias(
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource,
|
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||||
int64 logits_dimension) {
|
int64 logits_dimension) {
|
||||||
const int32 num_trees = ensemble_resource->num_trees();
|
const int32 num_trees = ensemble_resource->num_trees();
|
||||||
if (num_trees <= 0) {
|
if (num_trees <= 0) {
|
||||||
@ -319,10 +321,9 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
// Get decision tree ensemble.
|
// Get decision tree ensemble.
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
mutex_lock l(*ensemble_resource->get_mutex());
|
mutex_lock l(*ensemble_resource->get_mutex());
|
||||||
|
|
||||||
// Get the stamp token.
|
// Get the stamp token.
|
||||||
@ -400,10 +401,9 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
// Update and retrieve the growable tree.
|
// Update and retrieve the growable tree.
|
||||||
// If the tree is fully built and dropout was applied, it also adjusts the
|
// If the tree is fully built and dropout was applied, it also adjusts the
|
||||||
// weights of dropped and the last tree.
|
// weights of dropped and the last tree.
|
||||||
boosted_trees::trees::DecisionTreeConfig* const tree_config =
|
DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(
|
||||||
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
|
ensemble_resource, learning_rate, dropout_seed, max_tree_depth,
|
||||||
dropout_seed, max_tree_depth,
|
weak_learner_type);
|
||||||
weak_learner_type);
|
|
||||||
// Split tree nodes.
|
// Split tree nodes.
|
||||||
switch (weak_learner_type) {
|
switch (weak_learner_type) {
|
||||||
case LearnerConfig::NORMAL_DECISION_TREE: {
|
case LearnerConfig::NORMAL_DECISION_TREE: {
|
||||||
@ -559,8 +559,7 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTreeWeightsIfDropout(
|
void UpdateTreeWeightsIfDropout(
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* const
|
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||||
ensemble_resource,
|
|
||||||
const uint64 dropout_seed) {
|
const uint64 dropout_seed) {
|
||||||
// It is possible that the tree was built with dropout. If it is the case,
|
// It is possible that the tree was built with dropout. If it is the case,
|
||||||
// we need to adjust the tree weight, or bail out.
|
// we need to adjust the tree weight, or bail out.
|
||||||
@ -609,9 +608,8 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
|
|
||||||
// Helper method to update the growable tree which is by definition the last
|
// Helper method to update the growable tree which is by definition the last
|
||||||
// tree in the ensemble.
|
// tree in the ensemble.
|
||||||
boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
|
DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* const
|
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||||
ensemble_resource,
|
|
||||||
const float learning_rate, const uint64 dropout_seed,
|
const float learning_rate, const uint64 dropout_seed,
|
||||||
const int32 max_tree_depth, const int32 weak_learner_type) {
|
const int32 max_tree_depth, const int32 weak_learner_type) {
|
||||||
const auto num_trees = ensemble_resource->num_trees();
|
const auto num_trees = ensemble_resource->num_trees();
|
||||||
@ -719,8 +717,8 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
// leaf children given the split candidate.
|
// leaf children given the split candidate.
|
||||||
void SplitTreeNode(
|
void SplitTreeNode(
|
||||||
const int32 node_id, SplitCandidate* split,
|
const int32 node_id, SplitCandidate* split,
|
||||||
boosted_trees::trees::DecisionTreeConfig* tree_config,
|
DecisionTreeConfig* tree_config,
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
|
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
|
||||||
// No-op if we have no real node.
|
// No-op if we have no real node.
|
||||||
CHECK(node_id < tree_config->nodes_size())
|
CHECK(node_id < tree_config->nodes_size())
|
||||||
<< "Invalid node " << node_id << " to split.";
|
<< "Invalid node " << node_id << " to split.";
|
||||||
@ -761,14 +759,13 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
(*tree_config->mutable_nodes(node_id)) =
|
(*tree_config->mutable_nodes(node_id)) =
|
||||||
*split->split_info.mutable_split_node();
|
*split->split_info.mutable_split_node();
|
||||||
if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
|
if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
|
||||||
ensemble_resource->MaybeAddUsedHandler(split->handler_id);
|
resource->MaybeAddUsedHandler(split->handler_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SplitTreeLayer(
|
void SplitTreeLayer(
|
||||||
SplitCandidate* split,
|
SplitCandidate* split, DecisionTreeConfig* tree_config,
|
||||||
boosted_trees::trees::DecisionTreeConfig* tree_config,
|
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
|
|
||||||
int depth = 0;
|
int depth = 0;
|
||||||
while (depth < tree_config->nodes_size() &&
|
while (depth < tree_config->nodes_size() &&
|
||||||
tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
|
tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
|
||||||
@ -903,10 +900,9 @@ class TreeEnsembleStatsOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
// Get decision tree ensemble.
|
// Get decision tree ensemble.
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
|
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||||
|
|
||||||
// Get the stamp token.
|
// Get the stamp token.
|
||||||
|
@ -95,7 +95,7 @@ class ZeroVarInitializer : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
|
OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
|
||||||
ctx, HandleFromInput(ctx, 0), &variable,
|
ctx, HandleFromInput(ctx, 0), &variable,
|
||||||
[this, ctx](Var** var_ptr) {
|
[this, ctx](Var** var_ptr) {
|
||||||
@ -117,7 +117,6 @@ class ZeroVarInitializer : public OpKernel {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
core::ScopedUnref scoped(variable);
|
|
||||||
mutex_lock ml(*variable->mu());
|
mutex_lock ml(*variable->mu());
|
||||||
|
|
||||||
OP_REQUIRES(ctx, !variable->is_initialized,
|
OP_REQUIRES(ctx, !variable->is_initialized,
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
|
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
|
||||||
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
|
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
|
||||||
#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
|
#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
|
||||||
@ -24,6 +25,7 @@
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -76,11 +78,10 @@ class TreeSerializeOp : public OpKernel {
|
|||||||
explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
DecisionTreeResource* decision_tree_resource;
|
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&decision_tree_resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
Tensor* output_config_t = nullptr;
|
Tensor* output_config_t = nullptr;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
||||||
@ -100,12 +101,11 @@ class TreeDeserializeOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
DecisionTreeResource* decision_tree_resource;
|
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||||
auto handle = HandleFromInput(context, 0);
|
auto handle = HandleFromInput(context, 0);
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
LookupResource(context, handle, &decision_tree_resource));
|
LookupResource(context, handle, &decision_tree_resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
|
|
||||||
const Tensor* tree_config_t;
|
const Tensor* tree_config_t;
|
||||||
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
||||||
@ -131,11 +131,10 @@ class TreeSizeOp : public OpKernel {
|
|||||||
explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
DecisionTreeResource* decision_tree_resource;
|
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&decision_tree_resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
Tensor* output_t = nullptr;
|
Tensor* output_t = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, TensorShape(), &output_t));
|
context->allocate_output(0, TensorShape(), &output_t));
|
||||||
@ -144,7 +143,7 @@ class TreeSizeOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void TraverseTree(const DecisionTreeResource* tree_resource,
|
void TraverseTree(DecisionTreeResource* tree_resource,
|
||||||
const std::unique_ptr<TensorDataSet>& data, int32 start,
|
const std::unique_ptr<TensorDataSet>& data, int32 start,
|
||||||
int32 end,
|
int32 end,
|
||||||
const std::function<void(int32, int32)>& set_leaf_id,
|
const std::function<void(int32, int32)>& set_leaf_id,
|
||||||
@ -182,11 +181,10 @@ class TreePredictionsV4Op : public OpKernel {
|
|||||||
data_set->set_input_tensors(input_data, sparse_input_indices,
|
data_set->set_input_tensors(input_data, sparse_input_indices,
|
||||||
sparse_input_values, sparse_input_shape);
|
sparse_input_values, sparse_input_shape);
|
||||||
|
|
||||||
DecisionTreeResource* decision_tree_resource;
|
core::RefCountPtr<DecisionTreeResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
|
|
||||||
const int num_data = data_set->NumItems();
|
const int num_data = data_set->NumItems();
|
||||||
const int32 num_outputs = param_proto_.num_outputs();
|
const int32 num_outputs = param_proto_.num_outputs();
|
||||||
@ -205,15 +203,16 @@ class TreePredictionsV4Op : public OpKernel {
|
|||||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||||
int num_threads = worker_threads->num_threads;
|
int num_threads = worker_threads->num_threads;
|
||||||
const int64 costPerTraverse = 500;
|
const int64 costPerTraverse = 500;
|
||||||
auto traverse = [this, &out, &data_set, decision_tree_resource, num_data,
|
auto traverse = [this, &out, &data_set, &resource, num_data, &tree_paths](
|
||||||
&tree_paths](int64 start, int64 end) {
|
int64 start, int64 end) {
|
||||||
CHECK(start <= end);
|
CHECK(start <= end);
|
||||||
CHECK(end <= num_data);
|
CHECK(end <= num_data);
|
||||||
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
|
|
||||||
|
TraverseTree(resource.get(), data_set, static_cast<int32>(start),
|
||||||
static_cast<int32>(end),
|
static_cast<int32>(end),
|
||||||
std::bind(&TreePredictionsV4Op::set_output_value, this,
|
std::bind(&TreePredictionsV4Op::set_output_value, this,
|
||||||
std::placeholders::_1, std::placeholders::_2,
|
std::placeholders::_1, std::placeholders::_2,
|
||||||
decision_tree_resource, &out),
|
resource.get(), &out),
|
||||||
param_proto_.inference_tree_paths() ? &tree_paths : nullptr);
|
param_proto_.inference_tree_paths() ? &tree_paths : nullptr);
|
||||||
};
|
};
|
||||||
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
||||||
@ -283,11 +282,10 @@ class TraverseTreeV4Op : public OpKernel {
|
|||||||
data_set->set_input_tensors(input_data, sparse_input_indices,
|
data_set->set_input_tensors(input_data, sparse_input_indices,
|
||||||
sparse_input_values, sparse_input_shape);
|
sparse_input_values, sparse_input_shape);
|
||||||
|
|
||||||
DecisionTreeResource* decision_tree_resource;
|
core::RefCountPtr<DecisionTreeResource> resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
|
|
||||||
const int num_data = data_set->NumItems();
|
const int num_data = data_set->NumItems();
|
||||||
|
|
||||||
@ -304,11 +302,11 @@ class TraverseTreeV4Op : public OpKernel {
|
|||||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||||
int num_threads = worker_threads->num_threads;
|
int num_threads = worker_threads->num_threads;
|
||||||
const int64 costPerTraverse = 500;
|
const int64 costPerTraverse = 500;
|
||||||
auto traverse = [&set_leaf_ids, &data_set, decision_tree_resource,
|
auto traverse = [&set_leaf_ids, &data_set, &resource, num_data](int64 start,
|
||||||
num_data](int64 start, int64 end) {
|
int64 end) {
|
||||||
CHECK(start <= end);
|
CHECK(start <= end);
|
||||||
CHECK(end <= num_data);
|
CHECK(end <= num_data);
|
||||||
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
|
TraverseTree(resource.get(), data_set, static_cast<int32>(start),
|
||||||
static_cast<int32>(end), set_leaf_ids, nullptr);
|
static_cast<int32>(end), set_leaf_ids, nullptr);
|
||||||
};
|
};
|
||||||
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
||||||
@ -336,11 +334,10 @@ class UpdateModelV4Op : public OpKernel {
|
|||||||
const Tensor& input_labels = context->input(2);
|
const Tensor& input_labels = context->input(2);
|
||||||
const Tensor& input_weights = context->input(3);
|
const Tensor& input_weights = context->input(3);
|
||||||
|
|
||||||
DecisionTreeResource* decision_tree_resource;
|
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&decision_tree_resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
|
|
||||||
const int num_data = input_labels.shape().dim_size(0);
|
const int num_data = input_labels.shape().dim_size(0);
|
||||||
const int32 label_dim =
|
const int32 label_dim =
|
||||||
@ -356,9 +353,10 @@ class UpdateModelV4Op : public OpKernel {
|
|||||||
UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
|
UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target,
|
void UpdateModel(
|
||||||
int32 start, int32 end,
|
const Tensor& leaf_ids, const TensorInputTarget& target, int32 start,
|
||||||
DecisionTreeResource* decision_tree_resource) {
|
int32 end,
|
||||||
|
const core::RefCountPtr<DecisionTreeResource>& decision_tree_resource) {
|
||||||
const auto leaves = leaf_ids.unaligned_flat<int32>();
|
const auto leaves = leaf_ids.unaligned_flat<int32>();
|
||||||
for (int i = start; i < end; ++i) {
|
for (int i = start; i < end; ++i) {
|
||||||
model_op_->UpdateModel(
|
model_op_->UpdateModel(
|
||||||
@ -384,11 +382,10 @@ class FeatureUsageCountsOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
DecisionTreeResource* decision_tree_resource;
|
core::RefCountPtr<DecisionTreeResource> decision_tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&decision_tree_resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
|
|
||||||
const auto& tree = decision_tree_resource->decision_tree();
|
const auto& tree = decision_tree_resource->decision_tree();
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
@ -87,11 +88,10 @@ class FertileStatsSerializeOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
FertileStatsResource* fertile_stats_resource;
|
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&fertile_stats_resource));
|
&fertile_stats_resource));
|
||||||
mutex_lock l(*fertile_stats_resource->get_mutex());
|
mutex_lock l(*fertile_stats_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(fertile_stats_resource);
|
|
||||||
Tensor* output_config_t = nullptr;
|
Tensor* output_config_t = nullptr;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
||||||
@ -116,11 +116,10 @@ class FertileStatsDeserializeOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
FertileStatsResource* fertile_stats_resource;
|
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&fertile_stats_resource));
|
&fertile_stats_resource));
|
||||||
mutex_lock l(*fertile_stats_resource->get_mutex());
|
mutex_lock l(*fertile_stats_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(fertile_stats_resource);
|
|
||||||
|
|
||||||
const Tensor* stats_config_t;
|
const Tensor* stats_config_t;
|
||||||
OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
|
OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
|
||||||
@ -145,7 +144,7 @@ class FertileStatsDeserializeOp : public OpKernel {
|
|||||||
// acquired, put it in a waiting queue to come back to later and try the next
|
// acquired, put it in a waiting queue to come back to later and try the next
|
||||||
// one. Once all leaf_ids have been visited, cycle through the waiting ids
|
// one. Once all leaf_ids have been visited, cycle through the waiting ids
|
||||||
// until they're gone.
|
// until they're gone.
|
||||||
void UpdateStats(FertileStatsResource* fertile_stats_resource,
|
void UpdateStats(const core::RefCountPtr<FertileStatsResource>& resource,
|
||||||
const std::unique_ptr<TensorDataSet>& data,
|
const std::unique_ptr<TensorDataSet>& data,
|
||||||
const TensorInputTarget& target, int num_targets,
|
const TensorInputTarget& target, int num_targets,
|
||||||
const Tensor& leaf_ids_tensor,
|
const Tensor& leaf_ids_tensor,
|
||||||
@ -183,8 +182,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool is_finished;
|
bool is_finished;
|
||||||
fertile_stats_resource->AddExampleToStatsAndInitialize(
|
resource->AddExampleToStatsAndInitialize(data, &target, {example_id},
|
||||||
data, &target, {example_id}, leaf_id, &is_finished);
|
leaf_id, &is_finished);
|
||||||
leaf_lock->unlock();
|
leaf_lock->unlock();
|
||||||
if (is_finished) {
|
if (is_finished) {
|
||||||
set_lock->lock();
|
set_lock->lock();
|
||||||
@ -196,8 +195,8 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
|
|||||||
|
|
||||||
// Update leaves from start through end in the leaf_examples iterator.
|
// Update leaves from start through end in the leaf_examples iterator.
|
||||||
void UpdateStatsCollated(
|
void UpdateStatsCollated(
|
||||||
FertileStatsResource* fertile_stats_resource,
|
const core::RefCountPtr<FertileStatsResource>& fertile_stats_resource,
|
||||||
DecisionTreeResource* tree_resource,
|
const core::RefCountPtr<DecisionTreeResource>& tree_resource,
|
||||||
const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
|
const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
|
||||||
int num_targets,
|
int num_targets,
|
||||||
const std::unordered_map<int32, std::vector<int>>& leaf_examples,
|
const std::unordered_map<int32, std::vector<int>>& leaf_examples,
|
||||||
@ -251,18 +250,15 @@ class ProcessInputOp : public OpKernel {
|
|||||||
data_set->set_input_tensors(input_data, sparse_input_indices,
|
data_set->set_input_tensors(input_data, sparse_input_indices,
|
||||||
sparse_input_values, sparse_input_shape);
|
sparse_input_values, sparse_input_shape);
|
||||||
|
|
||||||
FertileStatsResource* fertile_stats_resource;
|
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
|
||||||
&fertile_stats_resource));
|
&fertile_stats_resource));
|
||||||
DecisionTreeResource* tree_resource;
|
core::RefCountPtr<DecisionTreeResource> tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&tree_resource));
|
&tree_resource));
|
||||||
mutex_lock l1(*fertile_stats_resource->get_mutex());
|
mutex_lock l1(*fertile_stats_resource->get_mutex());
|
||||||
mutex_lock l2(*tree_resource->get_mutex());
|
mutex_lock l2(*tree_resource->get_mutex());
|
||||||
|
|
||||||
core::ScopedUnref unref_stats(fertile_stats_resource);
|
|
||||||
core::ScopedUnref unref_tree(tree_resource);
|
|
||||||
|
|
||||||
const int32 num_data = data_set->NumItems();
|
const int32 num_data = data_set->NumItems();
|
||||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||||
int num_threads = worker_threads->num_threads;
|
int num_threads = worker_threads->num_threads;
|
||||||
@ -308,7 +304,7 @@ class ProcessInputOp : public OpKernel {
|
|||||||
// if it really matters that much.
|
// if it really matters that much.
|
||||||
const int64 costPerUpdate = 1000;
|
const int64 costPerUpdate = 1000;
|
||||||
auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set,
|
auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set,
|
||||||
fertile_stats_resource, &locks, &set_lock, &ready_to_split,
|
&fertile_stats_resource, &locks, &set_lock, &ready_to_split,
|
||||||
num_data](int64 start, int64 end) {
|
num_data](int64 start, int64 end) {
|
||||||
CHECK(start <= end);
|
CHECK(start <= end);
|
||||||
CHECK(end <= num_data);
|
CHECK(end <= num_data);
|
||||||
@ -317,8 +313,8 @@ class ProcessInputOp : public OpKernel {
|
|||||||
static_cast<int32>(end), &ready_to_split);
|
static_cast<int32>(end), &ready_to_split);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto update_collated = [&target, &num_targets, fertile_stats_resource,
|
auto update_collated = [&target, &num_targets, &fertile_stats_resource,
|
||||||
tree_resource, &leaf_examples, &set_lock,
|
&tree_resource, &leaf_examples, &set_lock,
|
||||||
&ready_to_split, &data_set,
|
&ready_to_split, &data_set,
|
||||||
num_leaves](int64 start, int64 end) {
|
num_leaves](int64 start, int64 end) {
|
||||||
CHECK(start <= end);
|
CHECK(start <= end);
|
||||||
@ -362,18 +358,15 @@ class GrowTreeOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
FertileStatsResource* fertile_stats_resource;
|
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
|
||||||
&fertile_stats_resource));
|
&fertile_stats_resource));
|
||||||
DecisionTreeResource* tree_resource;
|
core::RefCountPtr<DecisionTreeResource> tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&tree_resource));
|
&tree_resource));
|
||||||
mutex_lock l1(*fertile_stats_resource->get_mutex());
|
mutex_lock l1(*fertile_stats_resource->get_mutex());
|
||||||
mutex_lock l2(*tree_resource->get_mutex());
|
mutex_lock l2(*tree_resource->get_mutex());
|
||||||
|
|
||||||
core::ScopedUnref unref_stats(fertile_stats_resource);
|
|
||||||
core::ScopedUnref unref_tree(tree_resource);
|
|
||||||
|
|
||||||
const Tensor& finished_nodes = context->input(2);
|
const Tensor& finished_nodes = context->input(2);
|
||||||
|
|
||||||
const auto finished = finished_nodes.unaligned_flat<int32>();
|
const auto finished = finished_nodes.unaligned_flat<int32>();
|
||||||
@ -463,19 +456,16 @@ class FinalizeTreeOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
DecisionTreeResource* tree_resource;
|
core::RefCountPtr<DecisionTreeResource> tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&tree_resource));
|
&tree_resource));
|
||||||
FertileStatsResource* fertile_stats_resource;
|
core::RefCountPtr<FertileStatsResource> fertile_stats_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
|
||||||
&fertile_stats_resource));
|
&fertile_stats_resource));
|
||||||
|
|
||||||
mutex_lock l1(*fertile_stats_resource->get_mutex());
|
mutex_lock l1(*fertile_stats_resource->get_mutex());
|
||||||
mutex_lock l2(*tree_resource->get_mutex());
|
mutex_lock l2(*tree_resource->get_mutex());
|
||||||
|
|
||||||
core::ScopedUnref unref_me(tree_resource);
|
|
||||||
core::ScopedUnref unref_stats(fertile_stats_resource);
|
|
||||||
|
|
||||||
// TODO(thomaswc): Add threads
|
// TODO(thomaswc): Add threads
|
||||||
int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size();
|
int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size();
|
||||||
for (int i = 0; i < num_nodes; i++) {
|
for (int i = 0; i < num_nodes; i++) {
|
||||||
|
@ -270,12 +270,19 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
|
|||||||
template <typename T, bool use_dynamic_cast = false>
|
template <typename T, bool use_dynamic_cast = false>
|
||||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
|
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
|
||||||
|
|
||||||
|
// Looks up a resource pointed by a given resource handle.
|
||||||
|
//
|
||||||
|
// Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
|
||||||
|
// requiring the caller to explicitly call `Unref()`.
|
||||||
|
template <typename T>
|
||||||
|
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||||
|
core::RefCountPtr<T>* value);
|
||||||
|
|
||||||
// Looks up multiple resources pointed by a sequence of resource handles. If
|
// Looks up multiple resources pointed by a sequence of resource handles. If
|
||||||
// p[i] is uninitialized then values[i] is unmodified.
|
// p[i] is uninitialized then values[i] is unmodified.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status LookupResources(
|
Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
|
||||||
OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
|
std::vector<core::RefCountPtr<T>>* values);
|
||||||
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
|
|
||||||
|
|
||||||
// Looks up or creates a resource.
|
// Looks up or creates a resource.
|
||||||
//
|
//
|
||||||
@ -283,10 +290,19 @@ Status LookupResources(
|
|||||||
// must call its `Unref()` method when it has finished using it. If the
|
// must call its `Unref()` method when it has finished using it. If the
|
||||||
// `creator` is invoked, its reference on the created resource is transferred
|
// `creator` is invoked, its reference on the created resource is transferred
|
||||||
// to `ctx->resource_mgr()`.
|
// to `ctx->resource_mgr()`.
|
||||||
|
//
|
||||||
|
// Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
|
||||||
|
// requiring the caller to explicitly call `Unref()`.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||||
T** value, std::function<Status(T**)> creator);
|
T** value, std::function<Status(T**)> creator);
|
||||||
|
|
||||||
|
// Looks up or creates a resource.
|
||||||
|
template <typename T>
|
||||||
|
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||||
|
core::RefCountPtr<T>* value,
|
||||||
|
std::function<Status(T**)> creator);
|
||||||
|
|
||||||
// Destroys a resource pointed by a given resource handle.
|
// Destroys a resource pointed by a given resource handle.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
|
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
|
||||||
@ -587,9 +603,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status LookupResources(
|
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||||
OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p,
|
core::RefCountPtr<T>* value) {
|
||||||
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) {
|
T* raw_ptr = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
|
||||||
|
value->reset(raw_ptr);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status LookupResources(OpKernelContext* ctx,
|
||||||
|
absl::Span<ResourceHandle const* const> p,
|
||||||
|
std::vector<core::RefCountPtr<T>>* values) {
|
||||||
std::vector<std::pair<const string*, const string*>> containers_and_names(
|
std::vector<std::pair<const string*, const string*>> containers_and_names(
|
||||||
p.size());
|
p.size());
|
||||||
for (size_t i = 0; i < p.size(); ++i) {
|
for (size_t i = 0; i < p.size(); ++i) {
|
||||||
@ -607,6 +633,17 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
|||||||
creator);
|
creator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||||
|
core::RefCountPtr<T>* value,
|
||||||
|
std::function<Status(T**)> creator) {
|
||||||
|
T* raw_ptr = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
|
||||||
|
value->reset(raw_ptr);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
|
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
|
||||||
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
|
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
@ -266,15 +267,14 @@ TEST(ResourceHandleTest, CRUD) {
|
|||||||
TF_EXPECT_OK(CreateResource(&ctx, p, r));
|
TF_EXPECT_OK(CreateResource(&ctx, p, r));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
StubResource* r = nullptr;
|
core::RefCountPtr<StubResource> r;
|
||||||
TF_ASSERT_OK(LookupResource(&ctx, p, &r));
|
TF_ASSERT_OK(LookupResource(&ctx, p, &r));
|
||||||
ASSERT_TRUE(r != nullptr);
|
ASSERT_TRUE(r != nullptr);
|
||||||
EXPECT_EQ(r->value_, 42);
|
EXPECT_EQ(r->value_, 42);
|
||||||
r->Unref();
|
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
TF_EXPECT_OK(DeleteResource<StubResource>(&ctx, p));
|
TF_EXPECT_OK(DeleteResource<StubResource>(&ctx, p));
|
||||||
StubResource* unused = nullptr;
|
core::RefCountPtr<StubResource> unused;
|
||||||
EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok());
|
EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -338,13 +338,12 @@ TEST(ResourceHandleTest, DeleteUsingResourceHandle) {
|
|||||||
StubResource* r = new StubResource;
|
StubResource* r = new StubResource;
|
||||||
TF_EXPECT_OK(CreateResource(&ctx, p, r));
|
TF_EXPECT_OK(CreateResource(&ctx, p, r));
|
||||||
|
|
||||||
StubResource* lookup_r = nullptr;
|
core::RefCountPtr<StubResource> lookup_r;
|
||||||
TF_EXPECT_OK(LookupResource<StubResource>(&ctx, p, &lookup_r));
|
TF_EXPECT_OK(LookupResource<StubResource>(&ctx, p, &lookup_r));
|
||||||
EXPECT_EQ(lookup_r, r);
|
EXPECT_EQ(lookup_r.get(), r);
|
||||||
|
|
||||||
TF_EXPECT_OK(DeleteResource(&ctx, p));
|
TF_EXPECT_OK(DeleteResource(&ctx, p));
|
||||||
EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true);
|
EXPECT_NE(LookupResource<StubResource>(&ctx, p, &lookup_r).ok(), true);
|
||||||
r->Unref();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -166,6 +166,7 @@ tf_kernel_library(
|
|||||||
":variable_ops",
|
":variable_ops",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -5022,6 +5023,7 @@ STATE_DEPS = [
|
|||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
] + if_sycl(["//tensorflow/core:sycl_runtime"])
|
] + if_sycl(["//tensorflow/core:sycl_runtime"])
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
@ -7589,6 +7591,7 @@ tf_kernel_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/lib/db:sqlite",
|
"//tensorflow/core/lib/db:sqlite",
|
||||||
"//tensorflow/core/summary:schema",
|
"//tensorflow/core/summary:schema",
|
||||||
|
@ -83,6 +83,7 @@ tf_kernel_library(
|
|||||||
":resources",
|
":resources",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -106,6 +107,7 @@ tf_kernel_library(
|
|||||||
":tree_helper",
|
":tree_helper",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
@ -117,6 +119,7 @@ tf_kernel_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
|
"//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -51,12 +51,10 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
BoostedTreesEnsembleResource* resource;
|
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
|
||||||
// Get the resource.
|
// Get the resource.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&resource));
|
&resource));
|
||||||
// Release the reference to the resource once we're done using it.
|
|
||||||
core::ScopedUnref unref_me(resource);
|
|
||||||
|
|
||||||
// Get the inputs.
|
// Get the inputs.
|
||||||
OpInputList bucketized_features_list;
|
OpInputList bucketized_features_list;
|
||||||
@ -198,12 +196,10 @@ class BoostedTreesPredictOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
BoostedTreesEnsembleResource* resource;
|
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
|
||||||
// Get the resource.
|
// Get the resource.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&resource));
|
&resource));
|
||||||
// Release the reference to the resource once we're done using it.
|
|
||||||
core::ScopedUnref unref_me(resource);
|
|
||||||
|
|
||||||
// Get the inputs.
|
// Get the inputs.
|
||||||
OpInputList bucketized_features_list;
|
OpInputList bucketized_features_list;
|
||||||
@ -302,12 +298,10 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
BoostedTreesEnsembleResource* resource;
|
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
|
||||||
// Get the resource.
|
// Get the resource.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&resource));
|
&resource));
|
||||||
// Release the reference to the resource once we're done using it.
|
|
||||||
core::ScopedUnref unref_me(resource);
|
|
||||||
|
|
||||||
// Get the inputs.
|
// Get the inputs.
|
||||||
OpInputList bucketized_features_list;
|
OpInputList bucketized_features_list;
|
||||||
|
@ -27,6 +27,7 @@
|
|||||||
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
|
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
|
||||||
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
|
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -224,12 +225,11 @@ class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
|
|||||||
ResourceHandle handle;
|
ResourceHandle handle;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
HandleFromInput(context, kResourceHandleName, &handle));
|
HandleFromInput(context, kResourceHandleName, &handle));
|
||||||
QuantileStreamResource* stream_resource;
|
core::RefCountPtr<QuantileStreamResource> stream_resource;
|
||||||
// Create a reference to the underlying resource using the handle.
|
// Create a reference to the underlying resource using the handle.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
|
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*stream_resource->mutex());
|
mutex_lock l(*stream_resource->mutex());
|
||||||
core::ScopedUnref unref_me(stream_resource);
|
|
||||||
|
|
||||||
OpInputList summaries_list;
|
OpInputList summaries_list;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
@ -281,13 +281,12 @@ class BoostedTreesQuantileStreamResourceDeserializeOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
QuantileStreamResource* streams_resource;
|
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||||
// Create a reference to the underlying resource using the handle.
|
// Create a reference to the underlying resource using the handle.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&streams_resource));
|
&streams_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*streams_resource->mutex());
|
mutex_lock l(*streams_resource->mutex());
|
||||||
core::ScopedUnref unref_me(streams_resource);
|
|
||||||
|
|
||||||
OpInputList bucket_boundaries_list;
|
OpInputList bucket_boundaries_list;
|
||||||
OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
|
OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
|
||||||
@ -336,12 +335,11 @@ class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
|
|||||||
ResourceHandle handle;
|
ResourceHandle handle;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
HandleFromInput(context, kResourceHandleName, &handle));
|
HandleFromInput(context, kResourceHandleName, &handle));
|
||||||
QuantileStreamResource* stream_resource;
|
core::RefCountPtr<QuantileStreamResource> stream_resource;
|
||||||
// Create a reference to the underlying resource using the handle.
|
// Create a reference to the underlying resource using the handle.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
|
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*stream_resource->mutex());
|
mutex_lock l(*stream_resource->mutex());
|
||||||
core::ScopedUnref unref_me(stream_resource);
|
|
||||||
|
|
||||||
const Tensor* num_buckets_t;
|
const Tensor* num_buckets_t;
|
||||||
OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
|
OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
|
||||||
@ -391,12 +389,11 @@ class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
|
|||||||
ResourceHandle handle;
|
ResourceHandle handle;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
HandleFromInput(context, kResourceHandleName, &handle));
|
HandleFromInput(context, kResourceHandleName, &handle));
|
||||||
QuantileStreamResource* stream_resource;
|
core::RefCountPtr<QuantileStreamResource> stream_resource;
|
||||||
// Create a reference to the underlying resource using the handle.
|
// Create a reference to the underlying resource using the handle.
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
|
OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
|
||||||
// Remove the reference at the end of this scope.
|
// Remove the reference at the end of this scope.
|
||||||
mutex_lock l(*stream_resource->mutex());
|
mutex_lock l(*stream_resource->mutex());
|
||||||
core::ScopedUnref unref_me(stream_resource);
|
|
||||||
|
|
||||||
const int64 num_streams = stream_resource->num_streams();
|
const int64 num_streams = stream_resource->num_streams();
|
||||||
CHECK_EQ(num_features_, num_streams);
|
CHECK_EQ(num_features_, num_streams);
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -78,11 +79,10 @@ class BoostedTreesGetEnsembleStatesOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
// Looks up the resource.
|
// Looks up the resource.
|
||||||
BoostedTreesEnsembleResource* tree_ensemble_resource;
|
core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&tree_ensemble_resource));
|
&tree_ensemble_resource));
|
||||||
tf_shared_lock l(*tree_ensemble_resource->get_mutex());
|
tf_shared_lock l(*tree_ensemble_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(tree_ensemble_resource);
|
|
||||||
|
|
||||||
// Sets the outputs.
|
// Sets the outputs.
|
||||||
const int num_trees = tree_ensemble_resource->num_trees();
|
const int num_trees = tree_ensemble_resource->num_trees();
|
||||||
@ -141,11 +141,10 @@ class BoostedTreesSerializeEnsembleOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
BoostedTreesEnsembleResource* tree_ensemble_resource;
|
core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&tree_ensemble_resource));
|
&tree_ensemble_resource));
|
||||||
tf_shared_lock l(*tree_ensemble_resource->get_mutex());
|
tf_shared_lock l(*tree_ensemble_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(tree_ensemble_resource);
|
|
||||||
Tensor* output_stamp_token_t = nullptr;
|
Tensor* output_stamp_token_t = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||||
&output_stamp_token_t));
|
&output_stamp_token_t));
|
||||||
@ -169,11 +168,10 @@ class BoostedTreesDeserializeEnsembleOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
BoostedTreesEnsembleResource* tree_ensemble_resource;
|
core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&tree_ensemble_resource));
|
&tree_ensemble_resource));
|
||||||
mutex_lock l(*tree_ensemble_resource->get_mutex());
|
mutex_lock l(*tree_ensemble_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(tree_ensemble_resource);
|
|
||||||
|
|
||||||
// Get the stamp token.
|
// Get the stamp token.
|
||||||
const Tensor* stamp_token_t;
|
const Tensor* stamp_token_t;
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
||||||
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -55,10 +56,9 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
// Get decision tree ensemble.
|
// Get decision tree ensemble.
|
||||||
BoostedTreesEnsembleResource* ensemble_resource;
|
core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
mutex_lock l(*ensemble_resource->get_mutex());
|
mutex_lock l(*ensemble_resource->get_mutex());
|
||||||
// Increase the ensemble stamp.
|
// Increase the ensemble stamp.
|
||||||
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
|
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
|
||||||
@ -176,19 +176,19 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
|
int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
|
||||||
BoostedTreesEnsembleResource* const ensemble_resource) {
|
const core::RefCountPtr<BoostedTreesEnsembleResource>& resource) {
|
||||||
int32 num_trees = ensemble_resource->num_trees();
|
int32 num_trees = resource->num_trees();
|
||||||
int32 current_tree = num_trees - 1;
|
int32 current_tree = num_trees - 1;
|
||||||
|
|
||||||
// Increment global attempt stats.
|
// Increment global attempt stats.
|
||||||
ensemble_resource->UpdateGrowingMetadata();
|
resource->UpdateGrowingMetadata();
|
||||||
|
|
||||||
// Note we don't set tree weight to be equal to learning rate, since we
|
// Note we don't set tree weight to be equal to learning rate, since we
|
||||||
// apply learning rate to leaf weights instead, when doing layer-by-layer
|
// apply learning rate to leaf weights instead, when doing layer-by-layer
|
||||||
// boosting.
|
// boosting.
|
||||||
if (num_trees <= 0) {
|
if (num_trees <= 0) {
|
||||||
// Create a new tree with a no-op leaf.
|
// Create a new tree with a no-op leaf.
|
||||||
current_tree = ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
|
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
|
||||||
}
|
}
|
||||||
return current_tree;
|
return current_tree;
|
||||||
}
|
}
|
||||||
@ -250,10 +250,9 @@ class BoostedTreesCenterBiasOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
// Get decision tree ensemble.
|
// Get decision tree ensemble.
|
||||||
BoostedTreesEnsembleResource* ensemble_resource;
|
core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&ensemble_resource));
|
&ensemble_resource));
|
||||||
core::ScopedUnref unref_me(ensemble_resource);
|
|
||||||
mutex_lock l(*ensemble_resource->get_mutex());
|
mutex_lock l(*ensemble_resource->get_mutex());
|
||||||
// Increase the ensemble stamp.
|
// Increase the ensemble stamp.
|
||||||
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
|
ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/kernels/variable_ops.h"
|
#include "tensorflow/core/kernels/variable_ops.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -65,11 +66,9 @@ class ResourceCountUpToOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
context,
|
&variable));
|
||||||
LookupResource<Var>(context, HandleFromInput(context, 0), &variable));
|
|
||||||
core::ScopedUnref s(variable);
|
|
||||||
mutex_lock l(*variable->mu());
|
mutex_lock l(*variable->mu());
|
||||||
Tensor before_increment = *variable->tensor();
|
Tensor before_increment = *variable->tensor();
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
|
@ -342,6 +342,7 @@ tf_kernel_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/kernels:summary_interface",
|
"//tensorflow/core/kernels:summary_interface",
|
||||||
],
|
],
|
||||||
@ -380,6 +381,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/kernels/data:dataset_utils",
|
"//tensorflow/core/kernels/data:dataset_utils",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
#include "tensorflow/core/kernels/data/stats_utils.h"
|
#include "tensorflow/core/kernels/data/stats_utils.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -83,17 +84,16 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
DatasetBase** output) override {
|
DatasetBase** output) override {
|
||||||
StatsAggregatorResource* stats_aggregator_resource;
|
core::RefCountPtr<StatsAggregatorResource> resource;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
|
OP_REQUIRES_OK(ctx,
|
||||||
&stats_aggregator_resource));
|
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
|
||||||
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
|
|
||||||
string tag;
|
string tag;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
|
||||||
string prefix;
|
string prefix;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix));
|
||||||
|
|
||||||
*output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource,
|
*output =
|
||||||
tag, prefix);
|
new Dataset(ctx, input, ctx->input(1), resource.get(), tag, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -101,12 +101,12 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
const Tensor& resource_handle,
|
const Tensor& resource_handle,
|
||||||
StatsAggregatorResource* stats_aggregator_resource,
|
StatsAggregatorResource* resource, const string& tag,
|
||||||
const string& tag, const string& prefix)
|
const string& prefix)
|
||||||
: DatasetBase(DatasetContext(ctx)),
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
input_(input),
|
input_(input),
|
||||||
resource_handle_(resource_handle),
|
resource_handle_(resource_handle),
|
||||||
stats_aggregator_resource_(stats_aggregator_resource),
|
stats_aggregator_resource_(resource),
|
||||||
tag_(tag),
|
tag_(tag),
|
||||||
prefix_(prefix) {
|
prefix_(prefix) {
|
||||||
input_->Ref();
|
input_->Ref();
|
||||||
@ -169,13 +169,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
StatsAggregatorResource* stats_aggregator_resource =
|
StatsAggregatorResource* resource =
|
||||||
dataset()->stats_aggregator_resource_;
|
dataset()->stats_aggregator_resource_;
|
||||||
IteratorContext::Params params(ctx);
|
IteratorContext::Params params(ctx);
|
||||||
params.stats_aggregator = std::shared_ptr<StatsAggregator>(
|
params.stats_aggregator = std::shared_ptr<StatsAggregator>(
|
||||||
new StatsAggregatorWithTagAndPrefix(
|
new StatsAggregatorWithTagAndPrefix(resource->stats_aggregator(),
|
||||||
stats_aggregator_resource->stats_aggregator(), dataset()->tag_,
|
dataset()->tag_,
|
||||||
dataset()->prefix_));
|
dataset()->prefix_));
|
||||||
IteratorContext iter_ctx(std::move(params));
|
IteratorContext iter_ctx(std::move(params));
|
||||||
return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
|
return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
|
||||||
}
|
}
|
||||||
|
@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||||
#include "tensorflow/core/framework/summary.pb.h"
|
#include "tensorflow/core/framework/summary.pb.h"
|
||||||
#include "tensorflow/core/kernels/summary_interface.h"
|
#include "tensorflow/core/kernels/summary_interface.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
@ -258,10 +258,9 @@ class StatsAggregatorSummaryOp : public OpKernel {
|
|||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||||
|
|
||||||
StatsAggregatorResource* resource;
|
core::RefCountPtr<StatsAggregatorResource> resource;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||||
core::ScopedUnref unref_iterator(resource);
|
|
||||||
|
|
||||||
Tensor* summary_t;
|
Tensor* summary_t;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
|
||||||
@ -281,21 +280,19 @@ class StatsAggregatorSetSummaryWriterOp : public OpKernel {
|
|||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||||
|
|
||||||
StatsAggregatorResource* resource;
|
core::RefCountPtr<StatsAggregatorResource> resource;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||||
core::ScopedUnref unref_iterator(resource);
|
|
||||||
|
|
||||||
const Tensor& summary_resource_handle_t = ctx->input(1);
|
const Tensor& summary_resource_handle_t = ctx->input(1);
|
||||||
OP_REQUIRES(ctx,
|
OP_REQUIRES(ctx,
|
||||||
TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
|
TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
|
||||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||||
SummaryWriterInterface* sumamry_resource;
|
core::RefCountPtr<SummaryWriterInterface> sumamry_resource;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource));
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource));
|
||||||
core::ScopedUnref unref_sumamry_resource(sumamry_resource);
|
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
resource->stats_aggregator()->SetSummaryWriter(sumamry_resource));
|
resource->stats_aggregator()->SetSummaryWriter(sumamry_resource.get()));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
@ -127,12 +128,10 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
DatasetBase** output) override {
|
DatasetBase** output) override {
|
||||||
ThreadPoolResource* threadpool_resource;
|
core::RefCountPtr<ThreadPoolResource> threadpool_resource;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
|
||||||
&threadpool_resource));
|
&threadpool_resource));
|
||||||
core::ScopedUnref unref_iterator(threadpool_resource);
|
*output = new Dataset(ctx, input, ctx->input(1), threadpool_resource.get());
|
||||||
|
|
||||||
*output = new Dataset(ctx, input, ctx->input(1), threadpool_resource);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
@ -595,10 +596,9 @@ int64 AnonymousIteratorHandleOp::current_id_(0);
|
|||||||
void MakeIteratorOp::Compute(OpKernelContext* ctx) {
|
void MakeIteratorOp::Compute(OpKernelContext* ctx) {
|
||||||
DatasetBase* dataset;
|
DatasetBase* dataset;
|
||||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||||
IteratorResource* iterator_resource;
|
core::RefCountPtr<IteratorResource> iterator_resource;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
|
||||||
core::ScopedUnref unref(iterator_resource);
|
|
||||||
OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset));
|
OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
@ -492,10 +493,9 @@ class MultiDeviceIteratorInitOp : public OpKernel {
|
|||||||
|
|
||||||
DatasetBase* dataset;
|
DatasetBase* dataset;
|
||||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||||
MultiDeviceIterator* resource;
|
core::RefCountPtr<MultiDeviceIterator> resource;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
|
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
|
||||||
core::ScopedUnref unref(resource);
|
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
IteratorContext::Params params(ctx);
|
IteratorContext::Params params(ctx);
|
||||||
@ -535,7 +535,7 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
|
|||||||
ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
|
ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
|
||||||
int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
|
int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
|
||||||
|
|
||||||
MultiDeviceIterator* iterator;
|
core::RefCountPtr<MultiDeviceIterator> iterator;
|
||||||
OP_REQUIRES_OK_ASYNC(
|
OP_REQUIRES_OK_ASYNC(
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
|
||||||
|
|
||||||
@ -557,7 +557,6 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
|
|||||||
std::placeholders::_1, std::move(done));
|
std::placeholders::_1, std::move(done));
|
||||||
|
|
||||||
iterator->GetNextFromShard(ctx, shard_num, incarnation_id, callback);
|
iterator->GetNextFromShard(ctx, shard_num, incarnation_id, callback);
|
||||||
iterator->Unref();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -577,10 +576,9 @@ class MultiDeviceIteratorToStringHandleOp : public OpKernel {
|
|||||||
|
|
||||||
// Validate that the handle corresponds to a real resource, and
|
// Validate that the handle corresponds to a real resource, and
|
||||||
// that it is an MultiDeviceIterator.
|
// that it is an MultiDeviceIterator.
|
||||||
MultiDeviceIterator* resource;
|
core::RefCountPtr<MultiDeviceIterator> resource;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||||
resource->Unref();
|
|
||||||
|
|
||||||
Tensor* string_handle_t;
|
Tensor* string_handle_t;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
@ -629,9 +627,8 @@ class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
|
|||||||
|
|
||||||
// Validate that the handle corresponds to a real resource, and
|
// Validate that the handle corresponds to a real resource, and
|
||||||
// that it is an MultiDeviceIterator.
|
// that it is an MultiDeviceIterator.
|
||||||
MultiDeviceIterator* resource;
|
core::RefCountPtr<MultiDeviceIterator> resource;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
|
||||||
core::ScopedUnref unref_iterator(resource);
|
|
||||||
if (!output_types_.empty()) {
|
if (!output_types_.empty()) {
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
VerifyTypesMatch(output_types_, resource->output_types()));
|
VerifyTypesMatch(output_types_, resource->output_types()));
|
||||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/guarded_philox_random.h"
|
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||||
@ -371,10 +372,9 @@ class RandomBinomialOp : public OpKernel {
|
|||||||
"Input probs should have length 1 or shape[0], got shape: ",
|
"Input probs should have length 1 or shape[0], got shape: ",
|
||||||
probs_tensor.shape().DebugString()));
|
probs_tensor.shape().DebugString()));
|
||||||
}
|
}
|
||||||
Var* var = nullptr;
|
core::RefCountPtr<Var> var;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
|
||||||
|
|
||||||
ScopedUnlockUnrefVar var_guard(var);
|
|
||||||
Tensor* var_tensor = var->tensor();
|
Tensor* var_tensor = var->tensor();
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE,
|
ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE,
|
||||||
@ -404,7 +404,6 @@ class RandomBinomialOp : public OpKernel {
|
|||||||
auto philox = GetPhiloxRandomFromMem(var_data);
|
auto philox = GetPhiloxRandomFromMem(var_data);
|
||||||
UpdateMemWithPhiloxRandom(
|
UpdateMemWithPhiloxRandom(
|
||||||
philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data);
|
philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data);
|
||||||
var_guard.Release();
|
|
||||||
|
|
||||||
auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
|
auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
|
||||||
binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
|
binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
|
||||||
|
@ -129,7 +129,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ReadVariableOp::Compute(OpKernelContext* ctx) {
|
void ReadVariableOp::Compute(OpKernelContext* ctx) {
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
const ResourceHandle& handle = HandleFromInput(ctx, 0);
|
const ResourceHandle& handle = HandleFromInput(ctx, 0);
|
||||||
const auto status = LookupResource(ctx, handle, &variable);
|
const auto status = LookupResource(ctx, handle, &variable);
|
||||||
OP_REQUIRES(ctx, status.ok(),
|
OP_REQUIRES(ctx, status.ok(),
|
||||||
@ -139,7 +139,6 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
|
|||||||
". This could mean that the variable was uninitialized. ",
|
". This could mean that the variable was uninitialized. ",
|
||||||
status.ToString()));
|
status.ToString()));
|
||||||
|
|
||||||
core::ScopedUnref s(variable);
|
|
||||||
{
|
{
|
||||||
tf_shared_lock ml(*variable->mu());
|
tf_shared_lock ml(*variable->mu());
|
||||||
// We're acquiring a reference to the underlying buffer while
|
// We're acquiring a reference to the underlying buffer while
|
||||||
@ -175,8 +174,7 @@ ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ReadVariablesOp::Compute(OpKernelContext* ctx) {
|
void ReadVariablesOp::Compute(OpKernelContext* ctx) {
|
||||||
std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
|
std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
|
||||||
dtypes_.size());
|
|
||||||
std::vector<const ResourceHandle*> handles(dtypes_.size());
|
std::vector<const ResourceHandle*> handles(dtypes_.size());
|
||||||
for (size_t i = 0; i < dtypes_.size(); ++i) {
|
for (size_t i = 0; i < dtypes_.size(); ++i) {
|
||||||
handles[i] = &HandleFromInput(ctx, i);
|
handles[i] = &HandleFromInput(ctx, i);
|
||||||
@ -265,10 +263,9 @@ class VariableShapeOp : public OpKernel {
|
|||||||
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
|
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
|
||||||
core::ScopedUnref s(variable);
|
|
||||||
variable->mu()->lock_shared();
|
variable->mu()->lock_shared();
|
||||||
TensorShape shape = variable->tensor()->shape();
|
TensorShape shape = variable->tensor()->shape();
|
||||||
variable->mu()->unlock_shared();
|
variable->mu()->unlock_shared();
|
||||||
@ -343,7 +340,7 @@ class AssignVariableOp : public OpKernel {
|
|||||||
"Variable and value dtypes don't match; respectively, ",
|
"Variable and value dtypes don't match; respectively, ",
|
||||||
DataTypeString(dtype_), " and ",
|
DataTypeString(dtype_), " and ",
|
||||||
DataTypeString(context->input(1).dtype())));
|
DataTypeString(context->input(1).dtype())));
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
const Tensor& value = context->input(1);
|
const Tensor& value = context->input(1);
|
||||||
// Note: every resource-variable-manipulating op assumes copy-on-write
|
// Note: every resource-variable-manipulating op assumes copy-on-write
|
||||||
// semantics, and creates a copy of the variable's Tensor if its refcount is
|
// semantics, and creates a copy of the variable's Tensor if its refcount is
|
||||||
@ -361,7 +358,6 @@ class AssignVariableOp : public OpKernel {
|
|||||||
(*ptr)->is_initialized = true;
|
(*ptr)->is_initialized = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
core::ScopedUnref s(variable);
|
|
||||||
mutex_lock ml(*variable->mu());
|
mutex_lock ml(*variable->mu());
|
||||||
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
|
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -404,7 +400,7 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& value = context->input(1);
|
const Tensor& value = context->input(1);
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
|
OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
|
||||||
context, HandleFromInput(context, 0), &variable,
|
context, HandleFromInput(context, 0), &variable,
|
||||||
[](Var** ptr) {
|
[](Var** ptr) {
|
||||||
@ -412,7 +408,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
|
|||||||
*ptr = new Var(DT_VARIANT);
|
*ptr = new Var(DT_VARIANT);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
core::ScopedUnref s(variable);
|
|
||||||
|
|
||||||
// For purposes of forwarding DT_VARIANT, we want the least
|
// For purposes of forwarding DT_VARIANT, we want the least
|
||||||
// restrictive attr; we already know the input is on host.
|
// restrictive attr; we already know the input is on host.
|
||||||
@ -500,10 +495,9 @@ class AssignUpdateVariableOp : public OpKernel {
|
|||||||
explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
|
explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&variable));
|
&variable));
|
||||||
core::ScopedUnref s(variable);
|
|
||||||
|
|
||||||
const Tensor& value = context->input(1);
|
const Tensor& value = context->input(1);
|
||||||
// TODO(apassos): We could possibly avoid the copy done by
|
// TODO(apassos): We could possibly avoid the copy done by
|
||||||
@ -568,13 +562,12 @@ class VarIsInitializedOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, TensorShape({}), &output));
|
context->allocate_output(0, TensorShape({}), &output));
|
||||||
auto output_tensor = output->tensor<bool, 0>();
|
auto output_tensor = output->tensor<bool, 0>();
|
||||||
Var* variable = nullptr;
|
core::RefCountPtr<Var> variable;
|
||||||
Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
|
Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
output_tensor() = false;
|
output_tensor() = false;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
core::ScopedUnref su(variable);
|
|
||||||
mutex_lock ml(*variable->mu());
|
mutex_lock ml(*variable->mu());
|
||||||
output_tensor() = variable->is_initialized;
|
output_tensor() = variable->is_initialized;
|
||||||
}
|
}
|
||||||
@ -623,10 +616,9 @@ class ResourceGatherOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* c) override {
|
void Compute(OpKernelContext* c) override {
|
||||||
Var* v = nullptr;
|
core::RefCountPtr<Var> v;
|
||||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||||
core::ScopedUnref su(v);
|
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
|
||||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
|
||||||
// NOTE: We hold the lock for the whole gather operation instead
|
// NOTE: We hold the lock for the whole gather operation instead
|
||||||
// of increasing the reference count of v->tensor() to avoid a
|
// of increasing the reference count of v->tensor() to avoid a
|
||||||
// situation where a write to the same variable will see a
|
// situation where a write to the same variable will see a
|
||||||
@ -765,10 +757,9 @@ class ResourceGatherNdOp : public OpKernel {
|
|||||||
explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
|
explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* c) override {
|
void Compute(OpKernelContext* c) override {
|
||||||
Var* v = nullptr;
|
core::RefCountPtr<Var> v;
|
||||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||||
core::ScopedUnref su(v);
|
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
|
||||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
|
||||||
// NOTE: We hold the lock for the whole gather operation instead
|
// NOTE: We hold the lock for the whole gather operation instead
|
||||||
// of increasing the reference count of v->tensor() to avoid a
|
// of increasing the reference count of v->tensor() to avoid a
|
||||||
// situation where a write to the same variable will see a
|
// situation where a write to the same variable will see a
|
||||||
@ -821,10 +812,9 @@ class ResourceScatterUpdateOp : public OpKernel {
|
|||||||
explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {}
|
explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* c) override {
|
void Compute(OpKernelContext* c) override {
|
||||||
Var* v = nullptr;
|
core::RefCountPtr<Var> v;
|
||||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||||
core::ScopedUnref unref_v(v);
|
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
|
||||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
|
||||||
tf_shared_lock ml(*v->mu());
|
tf_shared_lock ml(*v->mu());
|
||||||
Tensor* params = v->tensor();
|
Tensor* params = v->tensor();
|
||||||
const Tensor& indices = c->input(1);
|
const Tensor& indices = c->input(1);
|
||||||
|
@ -243,10 +243,9 @@ class ScatterNdUpdateOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* c) override {
|
void Compute(OpKernelContext* c) override {
|
||||||
if (dtype_ == DT_RESOURCE) {
|
if (dtype_ == DT_RESOURCE) {
|
||||||
Var* v;
|
core::RefCountPtr<Var> v;
|
||||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||||
core::ScopedUnref scoped_unref(v);
|
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
|
||||||
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
|
|
||||||
mutex_lock m(*v->mu());
|
mutex_lock m(*v->mu());
|
||||||
DoCompute(c);
|
DoCompute(c);
|
||||||
} else if (use_exclusive_lock_) {
|
} else if (use_exclusive_lock_) {
|
||||||
@ -271,9 +270,8 @@ class ScatterNdUpdateOp : public OpKernel {
|
|||||||
TensorShape params_shape;
|
TensorShape params_shape;
|
||||||
|
|
||||||
if (dtype_ == DT_RESOURCE) {
|
if (dtype_ == DT_RESOURCE) {
|
||||||
Var* v;
|
core::RefCountPtr<Var> v;
|
||||||
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
|
||||||
core::ScopedUnref scoped_unref(v);
|
|
||||||
Tensor* t = v->tensor();
|
Tensor* t = v->tensor();
|
||||||
params = *t;
|
params = *t;
|
||||||
params_shape = params.shape();
|
params_shape = params.shape();
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
// See docs in ../ops/array_ops.cc.
|
// See docs in ../ops/array_ops.cc.
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
@ -323,12 +324,11 @@ class StridedSliceAssignOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (context->input_dtype(0) == DT_RESOURCE) {
|
if (context->input_dtype(0) == DT_RESOURCE) {
|
||||||
Var* v;
|
core::RefCountPtr<Var> v;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
context, LookupResource(context, HandleFromInput(context, 0), &v));
|
context, LookupResource(context, HandleFromInput(context, 0), &v));
|
||||||
core::ScopedUnref scoped_unref(v);
|
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
EnsureSparseVariableAccess<Device, T>(context, v));
|
EnsureSparseVariableAccess<Device, T>(context, v.get()));
|
||||||
mutex_lock ml(*v->mu());
|
mutex_lock ml(*v->mu());
|
||||||
old_lhs = v->tensor();
|
old_lhs = v->tensor();
|
||||||
OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
|
OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/framework/summary.pb.h"
|
#include "tensorflow/core/framework/summary.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/db/sqlite.h"
|
#include "tensorflow/core/lib/db/sqlite.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/summary/schema.h"
|
#include "tensorflow/core/summary/schema.h"
|
||||||
@ -45,7 +46,7 @@ class CreateSummaryFileWriterOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
|
OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
|
||||||
const string filename_suffix = tmp->scalar<string>()();
|
const string filename_suffix = tmp->scalar<string>()();
|
||||||
|
|
||||||
SummaryWriterInterface* s = nullptr;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
|
OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
|
||||||
ctx, HandleFromInput(ctx, 0), &s,
|
ctx, HandleFromInput(ctx, 0), &s,
|
||||||
[max_queue, flush_millis, logdir, filename_suffix,
|
[max_queue, flush_millis, logdir, filename_suffix,
|
||||||
@ -54,7 +55,6 @@ class CreateSummaryFileWriterOp : public OpKernel {
|
|||||||
max_queue, flush_millis, logdir,
|
max_queue, flush_millis, logdir,
|
||||||
filename_suffix, ctx->env(), s);
|
filename_suffix, ctx->env(), s);
|
||||||
}));
|
}));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
|
||||||
@ -75,7 +75,7 @@ class CreateSummaryDbWriterOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
|
OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
|
||||||
const string user_name = tmp->scalar<string>()();
|
const string user_name = tmp->scalar<string>()();
|
||||||
|
|
||||||
SummaryWriterInterface* s = nullptr;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx,
|
ctx,
|
||||||
LookupOrCreateResource<SummaryWriterInterface>(
|
LookupOrCreateResource<SummaryWriterInterface>(
|
||||||
@ -91,7 +91,6 @@ class CreateSummaryDbWriterOp : public OpKernel {
|
|||||||
db, experiment_name, run_name, user_name, ctx->env(), s));
|
db, experiment_name, run_name, user_name, ctx->env(), s));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
|
||||||
@ -102,9 +101,8 @@ class FlushSummaryWriterOp : public OpKernel {
|
|||||||
explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
SummaryWriterInterface* s;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
OP_REQUIRES_OK(ctx, s->Flush());
|
OP_REQUIRES_OK(ctx, s->Flush());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -128,9 +126,8 @@ class WriteSummaryOp : public OpKernel {
|
|||||||
explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
SummaryWriterInterface* s;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
const Tensor* tmp;
|
const Tensor* tmp;
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||||
const int64 step = tmp->scalar<int64>()();
|
const int64 step = tmp->scalar<int64>()();
|
||||||
@ -153,9 +150,8 @@ class WriteRawProtoSummaryOp : public OpKernel {
|
|||||||
explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
SummaryWriterInterface* s;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
const Tensor* tmp;
|
const Tensor* tmp;
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
|
||||||
@ -190,9 +186,8 @@ class ImportEventOp : public OpKernel {
|
|||||||
explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
SummaryWriterInterface* s;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
const Tensor* t;
|
const Tensor* t;
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("event", &t));
|
OP_REQUIRES_OK(ctx, ctx->input("event", &t));
|
||||||
std::unique_ptr<Event> event{new Event};
|
std::unique_ptr<Event> event{new Event};
|
||||||
@ -211,9 +206,8 @@ class WriteScalarSummaryOp : public OpKernel {
|
|||||||
explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
SummaryWriterInterface* s;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
const Tensor* tmp;
|
const Tensor* tmp;
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||||
const int64 step = tmp->scalar<int64>()();
|
const int64 step = tmp->scalar<int64>()();
|
||||||
@ -234,9 +228,8 @@ class WriteHistogramSummaryOp : public OpKernel {
|
|||||||
explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
SummaryWriterInterface* s;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
const Tensor* tmp;
|
const Tensor* tmp;
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||||
const int64 step = tmp->scalar<int64>()();
|
const int64 step = tmp->scalar<int64>()();
|
||||||
@ -263,9 +256,8 @@ class WriteImageSummaryOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
SummaryWriterInterface* s;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
const Tensor* tmp;
|
const Tensor* tmp;
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||||
const int64 step = tmp->scalar<int64>()();
|
const int64 step = tmp->scalar<int64>()();
|
||||||
@ -299,9 +291,8 @@ class WriteAudioSummaryOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
SummaryWriterInterface* s;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
const Tensor* tmp;
|
const Tensor* tmp;
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||||
const int64 step = tmp->scalar<int64>()();
|
const int64 step = tmp->scalar<int64>()();
|
||||||
@ -328,9 +319,8 @@ class WriteGraphSummaryOp : public OpKernel {
|
|||||||
explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
SummaryWriterInterface* s;
|
core::RefCountPtr<SummaryWriterInterface> s;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||||
core::ScopedUnref unref(s);
|
|
||||||
const Tensor* t;
|
const Tensor* t;
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("step", &t));
|
OP_REQUIRES_OK(ctx, ctx->input("step", &t));
|
||||||
const int64 step = t->scalar<int64>()();
|
const int64 step = t->scalar<int64>()();
|
||||||
|
@ -26,6 +26,7 @@ tf_kernel_library(
|
|||||||
":resources",
|
":resources",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -37,6 +38,7 @@ tf_kernel_library(
|
|||||||
":resources",
|
":resources",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
@ -29,12 +30,10 @@ class TensorForestTreePredictOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
TensorForestTreeResource* decision_tree_resource;
|
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&decision_tree_resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
|
|
||||||
const Tensor* dense_features_t = nullptr;
|
const Tensor* dense_features_t = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->input("dense_features", &dense_features_t));
|
context->input("dense_features", &dense_features_t));
|
||||||
@ -60,7 +59,7 @@ class TensorForestTreePredictOp : public OpKernel {
|
|||||||
// We will need to run it on a number of trees of diff depth
|
// We will need to run it on a number of trees of diff depth
|
||||||
// and see the num of cpu cycles
|
// and see the num of cpu cycles
|
||||||
const int64 cost_per_traverse = 500;
|
const int64 cost_per_traverse = 500;
|
||||||
auto traverse = [this, &out, &dense_features, decision_tree_resource,
|
auto traverse = [this, &out, &dense_features, &decision_tree_resource,
|
||||||
batch_size](int64 start, int64 end) {
|
batch_size](int64 start, int64 end) {
|
||||||
DCHECK_LE(start, end) << "Start exceeding End";
|
DCHECK_LE(start, end) << "Start exceeding End";
|
||||||
DCHECK_LE(end, batch_size) << "End exceeding batch size";
|
DCHECK_LE(end, batch_size) << "End exceeding batch size";
|
||||||
@ -74,9 +73,10 @@ class TensorForestTreePredictOp : public OpKernel {
|
|||||||
traverse);
|
traverse);
|
||||||
};
|
};
|
||||||
|
|
||||||
void set_output_value(const int32 example_id, const int32 leaf_id,
|
void set_output_value(
|
||||||
const TensorForestTreeResource* decision_tree_resource,
|
const int32 example_id, const int32 leaf_id,
|
||||||
TTypes<float>::Matrix* out) const {
|
const core::RefCountPtr<TensorForestTreeResource>& decision_tree_resource,
|
||||||
|
TTypes<float>::Matrix* out) const {
|
||||||
for (int j = 0; j < logits_dimension_; ++j) {
|
for (int j = 0; j < logits_dimension_; ++j) {
|
||||||
const float logit = decision_tree_resource->get_prediction(leaf_id, j);
|
const float logit = decision_tree_resource->get_prediction(leaf_id, j);
|
||||||
(*out)(example_id, j) = logit;
|
(*out)(example_id, j) = logit;
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||||||
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -55,11 +56,10 @@ class TensorForestTreeSerializeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
TensorForestTreeResource* decision_tree_resource;
|
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&decision_tree_resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
Tensor* output_config_t = nullptr;
|
Tensor* output_config_t = nullptr;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
||||||
@ -74,13 +74,11 @@ class TensorForestTreeDeserializeOp : public OpKernel {
|
|||||||
explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context)
|
explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context)
|
||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
TensorForestTreeResource* decision_tree_resource;
|
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&decision_tree_resource));
|
||||||
|
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
|
|
||||||
const Tensor* tree_config_t;
|
const Tensor* tree_config_t;
|
||||||
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
||||||
|
|
||||||
@ -102,11 +100,10 @@ class TensorForestTreeSizeOp : public OpKernel {
|
|||||||
: OpKernel(context) {}
|
: OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
TensorForestTreeResource* decision_tree_resource;
|
core::RefCountPtr<TensorForestTreeResource> decision_tree_resource;
|
||||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
&decision_tree_resource));
|
&decision_tree_resource));
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
|
||||||
Tensor* output_t = nullptr;
|
Tensor* output_t = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, TensorShape(), &output_t));
|
context->allocate_output(0, TensorShape(), &output_t));
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
#include "tensorflow/core/kernels/dense_update_functor.h"
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
#include "tensorflow/core/kernels/variable_ops.h"
|
#include "tensorflow/core/kernels/variable_ops.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -167,10 +168,8 @@ VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
|
|||||||
std::sort(acquire_order.begin(), acquire_order.end(),
|
std::sort(acquire_order.begin(), acquire_order.end(),
|
||||||
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
||||||
|
|
||||||
std::unique_ptr<std::vector<mutex_lock>> locks =
|
auto locks = absl::make_unique<std::vector<mutex_lock>>();
|
||||||
absl::make_unique<std::vector<mutex_lock>>();
|
auto shared_locks = absl::make_unique<std::vector<tf_shared_lock>>();
|
||||||
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks =
|
|
||||||
absl::make_unique<std::vector<tf_shared_lock>>();
|
|
||||||
locks->reserve(acquire_order.size());
|
locks->reserve(acquire_order.size());
|
||||||
|
|
||||||
for (auto input : acquire_order) {
|
for (auto input : acquire_order) {
|
||||||
@ -241,11 +240,10 @@ template <typename Device, typename T>
|
|||||||
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
|
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
|
||||||
bool lock_held, bool sparse, Tensor* out) {
|
bool lock_held, bool sparse, Tensor* out) {
|
||||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||||
Var* var;
|
core::RefCountPtr<Var> var;
|
||||||
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
|
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
|
||||||
core::ScopedUnref unref_var(var);
|
|
||||||
if (sparse) {
|
if (sparse) {
|
||||||
TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var));
|
TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
|
||||||
*out = *var->tensor();
|
*out = *var->tensor();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user