Boosted Trees: switching to version 2 cond (even when running in v1). This should fix retval issues and enable TF2.0 support
PiperOrigin-RevId: 253650966
This commit is contained in:
parent
f8eb295646
commit
524eecf61d
@ -0,0 +1,33 @@
|
||||
op {
|
||||
graph_op_name: "ResourceAccumulatorApplyGradient"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
The handle to a accumulator.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "local_step"
|
||||
description: <<END
|
||||
The local_step value at which the gradient was computed.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "gradient"
|
||||
description: <<END
|
||||
A tensor of the gradient to be accumulated.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The data type of accumulated gradients. Needs to correspond to the type
|
||||
of the accumulator.
|
||||
END
|
||||
}
|
||||
summary: "Applies a gradient to a given accumulator."
|
||||
description: <<END
|
||||
Does not add if local_step is lesser than the accumulator's global_step.
|
||||
END
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
op {
|
||||
graph_op_name: "ResourceAccumulatorNumAccumulated"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
The handle to an accumulator.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "num_accumulated"
|
||||
description: <<END
|
||||
The number of gradients aggregated in the given accumulator.
|
||||
END
|
||||
}
|
||||
summary: "Returns the number of gradients aggregated in the given accumulators."
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
op {
|
||||
graph_op_name: "ResourceAccumulatorSetGlobalStep"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
The handle to an accumulator.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "new_global_step"
|
||||
description: <<END
|
||||
The new global_step value to set.
|
||||
END
|
||||
}
|
||||
summary: "Updates the accumulator with a new value for global_step."
|
||||
description: <<END
|
||||
Logs warning if the accumulator's value is already higher than
|
||||
new_global_step.
|
||||
END
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
op {
|
||||
graph_op_name: "ResourceAccumulatorTakeGradient"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
The handle to an accumulator.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_required"
|
||||
description: <<END
|
||||
Number of gradients required before we return an aggregate.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "average"
|
||||
description: <<END
|
||||
The average of the accumulated gradients.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The data type of accumulated gradients. Needs to correspond to the type
|
||||
of the accumulator.
|
||||
END
|
||||
}
|
||||
summary: "Extracts the average gradient in the given ConditionalAccumulator."
|
||||
description: <<END
|
||||
The op blocks until sufficient (i.e., more than num_required)
|
||||
gradients have been accumulated. If the accumulator has already
|
||||
aggregated more than num_required gradients, it returns the average of
|
||||
the accumulated gradients. Also automatically increments the recorded
|
||||
global_step in the accumulator by 1, and resets the aggregate to 0.
|
||||
END
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
op {
|
||||
graph_op_name: "ResourceConditionalAccumulator"
|
||||
visibility: HIDDEN
|
||||
out_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
The handle to the accumulator.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the value being accumulated.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the values, can be [], in which case shape is unknown.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "container"
|
||||
description: <<END
|
||||
If non-empty, this accumulator is placed in the given container.
|
||||
Otherwise, a default container is used.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "shared_name"
|
||||
description: <<END
|
||||
If non-empty, this accumulator will be shared under the
|
||||
given name across multiple sessions.
|
||||
END
|
||||
}
|
||||
summary: "A conditional accumulator for aggregating gradients."
|
||||
description: <<END
|
||||
The accumulator accepts gradients marked with local_step greater or
|
||||
equal to the most recent global_step known to the accumulator. The
|
||||
average can be extracted from the accumulator, provided sufficient
|
||||
gradients have been accumulated. Extracting the average automatically
|
||||
resets the aggregate to 0, and increments the global_step recorded by
|
||||
the accumulator.
|
||||
This is a resource version of ConditionalAccumulator that will work in TF2.0
|
||||
with tf.cond version 2.
|
||||
END
|
||||
}
|
@ -30,10 +30,15 @@ class AccumulatorSetGlobalStepOp
|
||||
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
return {DT_STRING_REF, DT_INT64};
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
// Check signature
|
||||
OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT64}, {}));
|
||||
CheckSignature(ctx, accumulator);
|
||||
|
||||
// Get input new_global_step
|
||||
const Tensor* new_global_step_tensor;
|
||||
@ -56,6 +61,24 @@ class AccumulatorSetGlobalStepOp
|
||||
REGISTER_KERNEL_BUILDER(Name("AccumulatorSetGlobalStep").Device(DEVICE_CPU),
|
||||
AccumulatorSetGlobalStepOp);
|
||||
|
||||
class ResourceAccumulatorSetGlobalStepOp : public AccumulatorSetGlobalStepOp {
|
||||
public:
|
||||
explicit ResourceAccumulatorSetGlobalStepOp(OpKernelConstruction* context)
|
||||
: AccumulatorSetGlobalStepOp(context) {}
|
||||
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
return {DT_RESOURCE, DT_INT64};
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorSetGlobalStepOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ResourceAccumulatorSetGlobalStep").Device(DEVICE_CPU),
|
||||
ResourceAccumulatorSetGlobalStepOp);
|
||||
|
||||
/**
|
||||
* Defines a AccumulatorNumAccumulatedOp, which returns the number of gradients
|
||||
* that have been accumulated in the given ConditionalAccumulator, and emits it
|
||||
@ -68,10 +91,23 @@ class AccumulatorNumAccumulatedOp
|
||||
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
void CheckSignature(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
// Check input signature
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->MatchSignature(GetExpectedInputs(accumulator), {DT_INT32}));
|
||||
}
|
||||
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
return {DT_STRING_REF};
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
// Check signature
|
||||
OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_STRING_REF}, {DT_INT32}));
|
||||
CheckSignature(ctx, accumulator);
|
||||
|
||||
Tensor* Taccumulator_size = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, TensorShape({}), &Taccumulator_size));
|
||||
@ -86,4 +122,22 @@ class AccumulatorNumAccumulatedOp
|
||||
REGISTER_KERNEL_BUILDER(Name("AccumulatorNumAccumulated").Device(DEVICE_CPU),
|
||||
AccumulatorNumAccumulatedOp);
|
||||
|
||||
class ResourceAccumulatorNumAccumulatedOp : public AccumulatorNumAccumulatedOp {
|
||||
public:
|
||||
explicit ResourceAccumulatorNumAccumulatedOp(OpKernelConstruction* context)
|
||||
: AccumulatorNumAccumulatedOp(context) {}
|
||||
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
return {DT_RESOURCE};
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorNumAccumulatedOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ResourceAccumulatorNumAccumulated").Device(DEVICE_CPU),
|
||||
ResourceAccumulatorNumAccumulatedOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -60,7 +60,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
||||
if (!accumulator_handle_set_) {
|
||||
OP_REQUIRES_OK(ctx, SetAccumulatorHandle(ctx));
|
||||
}
|
||||
ctx->set_output_ref(0, &mu_, accumulator_handle_.AccessTensor(ctx));
|
||||
SetHandleToOutput(ctx);
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -73,6 +73,12 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void SetHandleToOutput(OpKernelContext* ctx)
|
||||
SHARED_LOCKS_REQUIRED(mu_) = 0;
|
||||
|
||||
virtual Status CheckSignature(OpKernelContext* ctx) = 0;
|
||||
|
||||
protected:
|
||||
typedef std::function<Status(ConditionalAccumulatorBase**)> Creator;
|
||||
|
||||
@ -84,6 +90,9 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
||||
PartialTensorShape shape_;
|
||||
ContainerInfo cinfo_;
|
||||
string reduction_type_;
|
||||
mutex mu_;
|
||||
PersistentTensor accumulator_handle_ GUARDED_BY(mu_);
|
||||
bool accumulator_handle_set_ GUARDED_BY(mu_);
|
||||
|
||||
private:
|
||||
Status SetAccumulatorHandle(OpKernelContext* ctx)
|
||||
@ -91,8 +100,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
||||
TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
|
||||
|
||||
// Check input signature
|
||||
DataTypeVector expected_inputs = {};
|
||||
TF_RETURN_IF_ERROR(ctx->MatchSignature(expected_inputs, {DT_STRING_REF}));
|
||||
TF_RETURN_IF_ERROR(CheckSignature(ctx));
|
||||
|
||||
Creator creator = GetCreator();
|
||||
ConditionalAccumulatorBase* accumulator;
|
||||
@ -112,11 +120,72 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
PersistentTensor accumulator_handle_ GUARDED_BY(mu_);
|
||||
bool accumulator_handle_set_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// ------------------Sync kernels ------------------------------------------
|
||||
|
||||
/**
|
||||
* General OpKernel for ConditionalAccumulatorBase-related ops.
|
||||
*/
|
||||
class ConditionalAccumulatorBaseSyncOpKernel : public OpKernel {
|
||||
public:
|
||||
explicit ConditionalAccumulatorBaseSyncOpKernel(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) final {
|
||||
ConditionalAccumulatorBase* accumulator;
|
||||
OP_REQUIRES_OK(ctx, GetResourceFromContext(ctx, "handle", &accumulator));
|
||||
Compute(ctx, accumulator);
|
||||
accumulator->Unref();
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void Compute(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) = 0;
|
||||
|
||||
virtual DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) = 0;
|
||||
|
||||
virtual void CheckSignature(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) {
|
||||
// Check input signature
|
||||
DataTypeVector expected_inputs = GetExpectedInputs(accumulator);
|
||||
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Defines a AccumulateGradientOp, the execution of which adds a gradient to the
|
||||
* given ConditionalAccumulator.
|
||||
*/
|
||||
class ConditionalAccumulatorBaseApplyGradientOp
|
||||
: public ConditionalAccumulatorBaseSyncOpKernel {
|
||||
public:
|
||||
explicit ConditionalAccumulatorBaseApplyGradientOp(
|
||||
OpKernelConstruction* context)
|
||||
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
void Compute(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
// Check input signature
|
||||
CheckSignature(ctx, accumulator);
|
||||
|
||||
// Get input local_step
|
||||
const Tensor* local_step_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("local_step", &local_step_tensor));
|
||||
if (!TensorShapeUtils::IsScalar(local_step_tensor->shape())) {
|
||||
ctx->CtxFailureWithWarning(errors::InvalidArgument(
|
||||
"Argument local_step must be scalar, but had bad shape ",
|
||||
local_step_tensor->shape().DebugString()));
|
||||
}
|
||||
|
||||
// Actually try to apply gradient now
|
||||
accumulator->TryApplyGrad(local_step_tensor->scalar<int64>()(), ctx);
|
||||
}
|
||||
};
|
||||
|
||||
// -------------------- Async kernels --------------------------------------
|
||||
/**
|
||||
* General OpKernel for ConditionalAccumulatorBase-related ops.
|
||||
*/
|
||||
@ -140,59 +209,18 @@ class ConditionalAccumulatorBaseAsyncOpKernel : public AsyncOpKernel {
|
||||
virtual void ComputeAsync(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator,
|
||||
DoneCallback callback) = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* General OpKernel for ConditionalAccumulatorBase-related ops.
|
||||
*/
|
||||
class ConditionalAccumulatorBaseSyncOpKernel : public OpKernel {
|
||||
public:
|
||||
explicit ConditionalAccumulatorBaseSyncOpKernel(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
virtual DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) = 0;
|
||||
|
||||
void Compute(OpKernelContext* ctx) final {
|
||||
ConditionalAccumulatorBase* accumulator;
|
||||
OP_REQUIRES_OK(ctx, GetResourceFromContext(ctx, "handle", &accumulator));
|
||||
Compute(ctx, accumulator);
|
||||
accumulator->Unref();
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void Compute(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* Defines a AccumulateGradientOp, the execution of which adds a gradient to the
|
||||
* given ConditionalAccumulator.
|
||||
*/
|
||||
class ConditionalAccumulatorBaseApplyGradientOp
|
||||
: public ConditionalAccumulatorBaseSyncOpKernel {
|
||||
public:
|
||||
explicit ConditionalAccumulatorBaseApplyGradientOp(
|
||||
OpKernelConstruction* context)
|
||||
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
virtual void CheckSignature(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) = 0;
|
||||
|
||||
void Compute(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
ConditionalAccumulatorBase* accumulator,
|
||||
DoneCallback callback) {
|
||||
// Check input signature
|
||||
CheckSignature(ctx, accumulator);
|
||||
|
||||
// Get input local_step
|
||||
const Tensor* local_step_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("local_step", &local_step_tensor));
|
||||
if (!TensorShapeUtils::IsScalar(local_step_tensor->shape())) {
|
||||
ctx->CtxFailureWithWarning(errors::InvalidArgument(
|
||||
"Argument local_step must be scalar, but had bad shape ",
|
||||
local_step_tensor->shape().DebugString()));
|
||||
}
|
||||
|
||||
// Actually try to apply gradient now
|
||||
accumulator->TryApplyGrad(local_step_tensor->scalar<int64>()(), ctx);
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ctx->MatchSignature(GetExpectedInputs(accumulator),
|
||||
{accumulator->dtype()}),
|
||||
callback);
|
||||
}
|
||||
};
|
||||
|
||||
@ -208,10 +236,6 @@ class ConditionalAccumulatorBaseTakeGradientOp
|
||||
: ConditionalAccumulatorBaseAsyncOpKernel(context) {}
|
||||
|
||||
protected:
|
||||
virtual void CheckSignature(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator,
|
||||
DoneCallback callback) = 0;
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator,
|
||||
DoneCallback callback) override {
|
||||
|
@ -41,6 +41,16 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
||||
};
|
||||
}
|
||||
|
||||
Status CheckSignature(OpKernelContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SetHandleToOutput(OpKernelContext* ctx)
|
||||
SHARED_LOCKS_REQUIRED(mu_) override {
|
||||
ctx->set_output_ref(0, &mu_, accumulator_handle_.AccessTensor(ctx));
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulatorOp);
|
||||
};
|
||||
|
||||
@ -50,6 +60,50 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
||||
.TypeConstraint<type>("dtype"), \
|
||||
ConditionalAccumulatorOp<dev##Device, type>)
|
||||
|
||||
// Resource conditional accumulator
|
||||
template <typename Device, typename T>
|
||||
class ResourceConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
||||
public:
|
||||
explicit ResourceConditionalAccumulatorOp(OpKernelConstruction* context)
|
||||
: ConditionalAccumulatorBaseOp(context) {}
|
||||
|
||||
protected:
|
||||
Creator GetCreator() const override {
|
||||
return [this](ConditionalAccumulatorBase** ret) {
|
||||
ConditionalAccumulator<Device, T>* accumulator =
|
||||
new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
|
||||
reduction_type_);
|
||||
*ret = accumulator;
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status CheckSignature(OpKernelContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_RESOURCE}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SetHandleToOutput(OpKernelContext* ctx)
|
||||
SHARED_LOCKS_REQUIRED(mu_) override {
|
||||
auto h = accumulator_handle_.AccessTensor(ctx)->template flat<string>();
|
||||
h(0) = cinfo_.container();
|
||||
h(1) = cinfo_.name();
|
||||
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
||||
ctx, 0, cinfo_.container(), cinfo_.name(),
|
||||
MakeTypeIndex<ConditionalAccumulatorBase>()));
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ResourceConditionalAccumulatorOp);
|
||||
};
|
||||
|
||||
#define REGISTER_RESOURCE_KERNELS(type, dev) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceConditionalAccumulator") \
|
||||
.Device(DEVICE_##dev) \
|
||||
.TypeConstraint<type>("dtype"), \
|
||||
ResourceConditionalAccumulatorOp<dev##Device, type>)
|
||||
|
||||
// End of Resource conditional accumulator
|
||||
|
||||
#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU)
|
||||
|
||||
TF_CALL_half(REGISTER_KERNELS_CPU);
|
||||
@ -59,6 +113,15 @@ TF_CALL_double(REGISTER_KERNELS_CPU);
|
||||
#undef REGISTER_KERNELS_CPU
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#define REGISTER_RESOURCE_KERNELS_CPU(type) REGISTER_RESOURCE_KERNELS(type, CPU)
|
||||
|
||||
TF_CALL_half(REGISTER_RESOURCE_KERNELS_CPU);
|
||||
TF_CALL_float(REGISTER_RESOURCE_KERNELS_CPU);
|
||||
TF_CALL_double(REGISTER_RESOURCE_KERNELS_CPU);
|
||||
|
||||
#undef REGISTER_KERNELS_CPU
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
/**
|
||||
* Defines a AccumulateGradientOp, the execution of which adds a gradient to the
|
||||
* given ConditionalAccumulator.
|
||||
@ -69,13 +132,12 @@ class AccumulatorApplyGradientOp
|
||||
explicit AccumulatorApplyGradientOp(OpKernelConstruction* context)
|
||||
: ConditionalAccumulatorBaseApplyGradientOp(context) {}
|
||||
|
||||
protected:
|
||||
void CheckSignature(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
// Check input signature
|
||||
DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64};
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
DataTypeVector expected_inputs;
|
||||
expected_inputs = {DT_STRING_REF, DT_INT64};
|
||||
expected_inputs.push_back(accumulator->dtype());
|
||||
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
|
||||
return expected_inputs;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -85,6 +147,28 @@ class AccumulatorApplyGradientOp
|
||||
REGISTER_KERNEL_BUILDER(Name("AccumulatorApplyGradient").Device(DEVICE_CPU),
|
||||
AccumulatorApplyGradientOp);
|
||||
|
||||
class ResourceAccumulatorApplyGradientOp
|
||||
: public ConditionalAccumulatorBaseApplyGradientOp {
|
||||
public:
|
||||
explicit ResourceAccumulatorApplyGradientOp(OpKernelConstruction* context)
|
||||
: ConditionalAccumulatorBaseApplyGradientOp(context) {}
|
||||
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
DataTypeVector expected_inputs;
|
||||
expected_inputs = {DT_RESOURCE, DT_INT64};
|
||||
expected_inputs.push_back(accumulator->dtype());
|
||||
return expected_inputs;
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorApplyGradientOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ResourceAccumulatorApplyGradient").Device(DEVICE_CPU),
|
||||
ResourceAccumulatorApplyGradientOp);
|
||||
|
||||
/**
|
||||
* Defines a ConditionalAccumulatorBaseTakeGradientOp, the execution of which
|
||||
* returns the average gradient accumulated by the given ConditionalAccumulator.
|
||||
@ -95,22 +179,34 @@ class AccumulatorTakeGradientOp
|
||||
explicit AccumulatorTakeGradientOp(OpKernelConstruction* context)
|
||||
: ConditionalAccumulatorBaseTakeGradientOp(context) {}
|
||||
|
||||
protected:
|
||||
void CheckSignature(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator,
|
||||
DoneCallback callback) override {
|
||||
// Check signature
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
ctx->MatchSignature({DT_STRING_REF, DT_INT32}, {accumulator->dtype()}),
|
||||
callback);
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
return {DT_STRING_REF, DT_INT32};
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorTakeGradientOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("AccumulatorTakeGradient").Device(DEVICE_CPU),
|
||||
AccumulatorTakeGradientOp);
|
||||
|
||||
class ResourceAccumulatorTakeGradientOp
|
||||
: public ConditionalAccumulatorBaseTakeGradientOp {
|
||||
public:
|
||||
explicit ResourceAccumulatorTakeGradientOp(OpKernelConstruction* context)
|
||||
: ConditionalAccumulatorBaseTakeGradientOp(context) {}
|
||||
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
return {DT_RESOURCE, DT_INT32};
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorTakeGradientOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ResourceAccumulatorTakeGradient").Device(DEVICE_CPU),
|
||||
ResourceAccumulatorTakeGradientOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -41,6 +41,18 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
||||
};
|
||||
}
|
||||
|
||||
// TODO(tanzheny): actually switch it to resource. You won't be able to use
|
||||
// it with cond2 otherwise.
|
||||
Status CheckSignature(OpKernelContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SetHandleToOutput(OpKernelContext* ctx)
|
||||
SHARED_LOCKS_REQUIRED(mu_) override {
|
||||
ctx->set_output_ref(0, &mu_, accumulator_handle_.AccessTensor(ctx));
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SparseConditionalAccumulatorOp);
|
||||
};
|
||||
|
||||
@ -70,13 +82,12 @@ class SparseAccumulatorApplyGradientOp
|
||||
: ConditionalAccumulatorBaseApplyGradientOp(context) {}
|
||||
|
||||
protected:
|
||||
void CheckSignature(OpKernelContext* ctx,
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
// Check input signature
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64, DT_INT64};
|
||||
expected_inputs.push_back(accumulator->dtype());
|
||||
expected_inputs.push_back(DT_INT64);
|
||||
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
|
||||
return expected_inputs;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -109,6 +120,11 @@ class SparseAccumulatorTakeGradientOp
|
||||
callback);
|
||||
}
|
||||
|
||||
DataTypeVector GetExpectedInputs(
|
||||
ConditionalAccumulatorBase* accumulator) override {
|
||||
return {DT_STRING_REF, DT_INT32};
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorTakeGradientOp);
|
||||
};
|
||||
|
@ -451,6 +451,61 @@ REGISTER_OP("AccumulatorTakeGradient")
|
||||
})
|
||||
.Attr("dtype: numbertype");
|
||||
|
||||
// -----------------V2 accumulators that use resource -------------------------
|
||||
|
||||
REGISTER_OP("ResourceAccumulatorNumAccumulated")
|
||||
.Input("handle: resource")
|
||||
.Output("num_accumulated: int32")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ResourceAccumulatorSetGlobalStep")
|
||||
.Input("handle: resource")
|
||||
.Input("new_global_step: int64")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("ResourceConditionalAccumulator")
|
||||
.Output("handle: resource")
|
||||
.Attr("dtype: numbertype")
|
||||
.Attr("shape: shape")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(2));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("ResourceAccumulatorApplyGradient")
|
||||
.Input("handle: resource")
|
||||
.Input("local_step: int64")
|
||||
.Input("gradient: dtype")
|
||||
.Attr("dtype: numbertype")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("ResourceAccumulatorTakeGradient")
|
||||
.Input("handle: resource")
|
||||
.Input("num_required: int32")
|
||||
.Output("average: dtype")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
// Shape of output is the shape of the accumulator referenced
|
||||
// by 'handle', but which is not available here, so we lose
|
||||
// shape information.
|
||||
return shape_inference::UnknownShape(c);
|
||||
})
|
||||
.Attr("dtype: numbertype");
|
||||
|
||||
// TODO(nponomareva): change these all to use resources.
|
||||
REGISTER_OP("SparseConditionalAccumulator")
|
||||
.Output("handle: Ref(string)")
|
||||
.Attr("dtype: numbertype")
|
||||
|
@ -24,6 +24,7 @@ import threading
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes as _dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -1215,6 +1216,11 @@ class ConditionalAccumulatorBase(object):
|
||||
"""
|
||||
if name is None:
|
||||
name = "%s_NumAccumulated" % self._name
|
||||
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
return gen_data_flow_ops.resource_accumulator_num_accumulated(
|
||||
self._accumulator_ref, name=name)
|
||||
|
||||
return gen_data_flow_ops.accumulator_num_accumulated(
|
||||
self._accumulator_ref, name=name)
|
||||
|
||||
@ -1231,6 +1237,12 @@ class ConditionalAccumulatorBase(object):
|
||||
Returns:
|
||||
Operation that sets the accumulator's time step.
|
||||
"""
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
return gen_data_flow_ops.resource_accumulator_set_global_step(
|
||||
self._accumulator_ref,
|
||||
math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
|
||||
name=name)
|
||||
|
||||
return gen_data_flow_ops.accumulator_set_global_step(
|
||||
self._accumulator_ref,
|
||||
math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
|
||||
@ -1264,12 +1276,23 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
name: Optional name for the accumulator.
|
||||
reduction_type: Reduction type to use when taking the gradient.
|
||||
"""
|
||||
accumulator_ref = gen_data_flow_ops.conditional_accumulator(
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
shared_name=shared_name,
|
||||
name=name,
|
||||
reduction_type=reduction_type)
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
accumulator_ref = gen_data_flow_ops.resource_conditional_accumulator(
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
shared_name=shared_name,
|
||||
name=name,
|
||||
reduction_type=reduction_type)
|
||||
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
|
||||
handle=accumulator_ref, handle_device=context.context().device_name)
|
||||
else:
|
||||
accumulator_ref = gen_data_flow_ops.conditional_accumulator(
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
shared_name=shared_name,
|
||||
name=name,
|
||||
reduction_type=reduction_type)
|
||||
|
||||
super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
|
||||
|
||||
def apply_grad(self, grad, local_step=0, name=None):
|
||||
@ -1292,6 +1315,13 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
grad = ops.convert_to_tensor(grad, self._dtype)
|
||||
grad.get_shape().assert_is_compatible_with(self._shape)
|
||||
local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
|
||||
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
return gen_data_flow_ops.resource_accumulator_apply_gradient(
|
||||
self._accumulator_ref,
|
||||
local_step=local_step,
|
||||
gradient=grad,
|
||||
name=name)
|
||||
return gen_data_flow_ops.accumulator_apply_gradient(
|
||||
self._accumulator_ref, local_step=local_step, gradient=grad, name=name)
|
||||
|
||||
@ -1317,8 +1347,12 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
Raises:
|
||||
InvalidArgumentError: If num_required < 1
|
||||
"""
|
||||
out = gen_data_flow_ops.accumulator_take_gradient(
|
||||
self._accumulator_ref, num_required, dtype=self._dtype, name=name)
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
out = gen_data_flow_ops.resource_accumulator_take_gradient(
|
||||
self._accumulator_ref, num_required, dtype=self._dtype, name=name)
|
||||
else:
|
||||
out = gen_data_flow_ops.accumulator_take_gradient(
|
||||
self._accumulator_ref, num_required, dtype=self._dtype, name=name)
|
||||
out.set_shape(self._shape)
|
||||
return out
|
||||
|
||||
|
@ -3016,6 +3016,22 @@ tf_module {
|
||||
name: "ResizeNearestNeighborGrad"
|
||||
argspec: "args=[\'grads\', \'size\', \'align_corners\', \'half_pixel_centers\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceAccumulatorApplyGradient"
|
||||
argspec: "args=[\'handle\', \'local_step\', \'gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceAccumulatorNumAccumulated"
|
||||
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceAccumulatorSetGlobalStep"
|
||||
argspec: "args=[\'handle\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceAccumulatorTakeGradient"
|
||||
argspec: "args=[\'handle\', \'num_required\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceApplyAdaMax"
|
||||
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
@ -3084,6 +3100,10 @@ tf_module {
|
||||
name: "ResourceApplyRMSProp"
|
||||
argspec: "args=[\'var\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceConditionalAccumulator"
|
||||
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceCountUpTo"
|
||||
argspec: "args=[\'resource\', \'limit\', \'T\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -3016,6 +3016,22 @@ tf_module {
|
||||
name: "ResizeNearestNeighborGrad"
|
||||
argspec: "args=[\'grads\', \'size\', \'align_corners\', \'half_pixel_centers\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceAccumulatorApplyGradient"
|
||||
argspec: "args=[\'handle\', \'local_step\', \'gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceAccumulatorNumAccumulated"
|
||||
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceAccumulatorSetGlobalStep"
|
||||
argspec: "args=[\'handle\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceAccumulatorTakeGradient"
|
||||
argspec: "args=[\'handle\', \'num_required\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceApplyAdaMax"
|
||||
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
@ -3084,6 +3100,10 @@ tf_module {
|
||||
name: "ResourceApplyRMSProp"
|
||||
argspec: "args=[\'var\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceConditionalAccumulator"
|
||||
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceCountUpTo"
|
||||
argspec: "args=[\'resource\', \'limit\', \'T\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user