Boosted Trees: switching to version 2 cond (even when running in v1). This should fix retval issues and enable TF2.0 support
PiperOrigin-RevId: 253650966
This commit is contained in:
parent
f8eb295646
commit
524eecf61d
@ -0,0 +1,33 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ResourceAccumulatorApplyGradient"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "handle"
|
||||||
|
description: <<END
|
||||||
|
The handle to a accumulator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "local_step"
|
||||||
|
description: <<END
|
||||||
|
The local_step value at which the gradient was computed.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "gradient"
|
||||||
|
description: <<END
|
||||||
|
A tensor of the gradient to be accumulated.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The data type of accumulated gradients. Needs to correspond to the type
|
||||||
|
of the accumulator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Applies a gradient to a given accumulator."
|
||||||
|
description: <<END
|
||||||
|
Does not add if local_step is lesser than the accumulator's global_step.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,17 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ResourceAccumulatorNumAccumulated"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "handle"
|
||||||
|
description: <<END
|
||||||
|
The handle to an accumulator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "num_accumulated"
|
||||||
|
description: <<END
|
||||||
|
The number of gradients aggregated in the given accumulator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Returns the number of gradients aggregated in the given accumulators."
|
||||||
|
}
|
@ -0,0 +1,21 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ResourceAccumulatorSetGlobalStep"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "handle"
|
||||||
|
description: <<END
|
||||||
|
The handle to an accumulator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "new_global_step"
|
||||||
|
description: <<END
|
||||||
|
The new global_step value to set.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Updates the accumulator with a new value for global_step."
|
||||||
|
description: <<END
|
||||||
|
Logs warning if the accumulator's value is already higher than
|
||||||
|
new_global_step.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ResourceAccumulatorTakeGradient"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "handle"
|
||||||
|
description: <<END
|
||||||
|
The handle to an accumulator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "num_required"
|
||||||
|
description: <<END
|
||||||
|
Number of gradients required before we return an aggregate.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "average"
|
||||||
|
description: <<END
|
||||||
|
The average of the accumulated gradients.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The data type of accumulated gradients. Needs to correspond to the type
|
||||||
|
of the accumulator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Extracts the average gradient in the given ConditionalAccumulator."
|
||||||
|
description: <<END
|
||||||
|
The op blocks until sufficient (i.e., more than num_required)
|
||||||
|
gradients have been accumulated. If the accumulator has already
|
||||||
|
aggregated more than num_required gradients, it returns the average of
|
||||||
|
the accumulated gradients. Also automatically increments the recorded
|
||||||
|
global_step in the accumulator by 1, and resets the aggregate to 0.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ResourceConditionalAccumulator"
|
||||||
|
visibility: HIDDEN
|
||||||
|
out_arg {
|
||||||
|
name: "handle"
|
||||||
|
description: <<END
|
||||||
|
The handle to the accumulator.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The type of the value being accumulated.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "shape"
|
||||||
|
description: <<END
|
||||||
|
The shape of the values, can be [], in which case shape is unknown.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "container"
|
||||||
|
description: <<END
|
||||||
|
If non-empty, this accumulator is placed in the given container.
|
||||||
|
Otherwise, a default container is used.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "shared_name"
|
||||||
|
description: <<END
|
||||||
|
If non-empty, this accumulator will be shared under the
|
||||||
|
given name across multiple sessions.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "A conditional accumulator for aggregating gradients."
|
||||||
|
description: <<END
|
||||||
|
The accumulator accepts gradients marked with local_step greater or
|
||||||
|
equal to the most recent global_step known to the accumulator. The
|
||||||
|
average can be extracted from the accumulator, provided sufficient
|
||||||
|
gradients have been accumulated. Extracting the average automatically
|
||||||
|
resets the aggregate to 0, and increments the global_step recorded by
|
||||||
|
the accumulator.
|
||||||
|
This is a resource version of ConditionalAccumulator that will work in TF2.0
|
||||||
|
with tf.cond version 2.
|
||||||
|
END
|
||||||
|
}
|
@ -30,10 +30,15 @@ class AccumulatorSetGlobalStepOp
|
|||||||
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
DataTypeVector GetExpectedInputs(
|
||||||
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
|
return {DT_STRING_REF, DT_INT64};
|
||||||
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx,
|
void Compute(OpKernelContext* ctx,
|
||||||
ConditionalAccumulatorBase* accumulator) override {
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
// Check signature
|
// Check signature
|
||||||
OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT64}, {}));
|
CheckSignature(ctx, accumulator);
|
||||||
|
|
||||||
// Get input new_global_step
|
// Get input new_global_step
|
||||||
const Tensor* new_global_step_tensor;
|
const Tensor* new_global_step_tensor;
|
||||||
@ -56,6 +61,24 @@ class AccumulatorSetGlobalStepOp
|
|||||||
REGISTER_KERNEL_BUILDER(Name("AccumulatorSetGlobalStep").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("AccumulatorSetGlobalStep").Device(DEVICE_CPU),
|
||||||
AccumulatorSetGlobalStepOp);
|
AccumulatorSetGlobalStepOp);
|
||||||
|
|
||||||
|
class ResourceAccumulatorSetGlobalStepOp : public AccumulatorSetGlobalStepOp {
|
||||||
|
public:
|
||||||
|
explicit ResourceAccumulatorSetGlobalStepOp(OpKernelConstruction* context)
|
||||||
|
: AccumulatorSetGlobalStepOp(context) {}
|
||||||
|
|
||||||
|
DataTypeVector GetExpectedInputs(
|
||||||
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
|
return {DT_RESOURCE, DT_INT64};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorSetGlobalStepOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("ResourceAccumulatorSetGlobalStep").Device(DEVICE_CPU),
|
||||||
|
ResourceAccumulatorSetGlobalStepOp);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Defines a AccumulatorNumAccumulatedOp, which returns the number of gradients
|
* Defines a AccumulatorNumAccumulatedOp, which returns the number of gradients
|
||||||
* that have been accumulated in the given ConditionalAccumulator, and emits it
|
* that have been accumulated in the given ConditionalAccumulator, and emits it
|
||||||
@ -68,10 +91,23 @@ class AccumulatorNumAccumulatedOp
|
|||||||
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
void CheckSignature(OpKernelContext* ctx,
|
||||||
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
|
// Check input signature
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->MatchSignature(GetExpectedInputs(accumulator), {DT_INT32}));
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypeVector GetExpectedInputs(
|
||||||
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
|
return {DT_STRING_REF};
|
||||||
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx,
|
void Compute(OpKernelContext* ctx,
|
||||||
ConditionalAccumulatorBase* accumulator) override {
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
// Check signature
|
// Check signature
|
||||||
OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_STRING_REF}, {DT_INT32}));
|
CheckSignature(ctx, accumulator);
|
||||||
|
|
||||||
Tensor* Taccumulator_size = nullptr;
|
Tensor* Taccumulator_size = nullptr;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, ctx->allocate_output(0, TensorShape({}), &Taccumulator_size));
|
ctx, ctx->allocate_output(0, TensorShape({}), &Taccumulator_size));
|
||||||
@ -86,4 +122,22 @@ class AccumulatorNumAccumulatedOp
|
|||||||
REGISTER_KERNEL_BUILDER(Name("AccumulatorNumAccumulated").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("AccumulatorNumAccumulated").Device(DEVICE_CPU),
|
||||||
AccumulatorNumAccumulatedOp);
|
AccumulatorNumAccumulatedOp);
|
||||||
|
|
||||||
|
class ResourceAccumulatorNumAccumulatedOp : public AccumulatorNumAccumulatedOp {
|
||||||
|
public:
|
||||||
|
explicit ResourceAccumulatorNumAccumulatedOp(OpKernelConstruction* context)
|
||||||
|
: AccumulatorNumAccumulatedOp(context) {}
|
||||||
|
|
||||||
|
DataTypeVector GetExpectedInputs(
|
||||||
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
|
return {DT_RESOURCE};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorNumAccumulatedOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("ResourceAccumulatorNumAccumulated").Device(DEVICE_CPU),
|
||||||
|
ResourceAccumulatorNumAccumulatedOp);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -60,7 +60,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
|||||||
if (!accumulator_handle_set_) {
|
if (!accumulator_handle_set_) {
|
||||||
OP_REQUIRES_OK(ctx, SetAccumulatorHandle(ctx));
|
OP_REQUIRES_OK(ctx, SetAccumulatorHandle(ctx));
|
||||||
}
|
}
|
||||||
ctx->set_output_ref(0, &mu_, accumulator_handle_.AccessTensor(ctx));
|
SetHandleToOutput(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -73,6 +73,12 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void SetHandleToOutput(OpKernelContext* ctx)
|
||||||
|
SHARED_LOCKS_REQUIRED(mu_) = 0;
|
||||||
|
|
||||||
|
virtual Status CheckSignature(OpKernelContext* ctx) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
typedef std::function<Status(ConditionalAccumulatorBase**)> Creator;
|
typedef std::function<Status(ConditionalAccumulatorBase**)> Creator;
|
||||||
|
|
||||||
@ -84,6 +90,9 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
|||||||
PartialTensorShape shape_;
|
PartialTensorShape shape_;
|
||||||
ContainerInfo cinfo_;
|
ContainerInfo cinfo_;
|
||||||
string reduction_type_;
|
string reduction_type_;
|
||||||
|
mutex mu_;
|
||||||
|
PersistentTensor accumulator_handle_ GUARDED_BY(mu_);
|
||||||
|
bool accumulator_handle_set_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status SetAccumulatorHandle(OpKernelContext* ctx)
|
Status SetAccumulatorHandle(OpKernelContext* ctx)
|
||||||
@ -91,8 +100,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
|||||||
TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
|
TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
|
||||||
|
|
||||||
// Check input signature
|
// Check input signature
|
||||||
DataTypeVector expected_inputs = {};
|
TF_RETURN_IF_ERROR(CheckSignature(ctx));
|
||||||
TF_RETURN_IF_ERROR(ctx->MatchSignature(expected_inputs, {DT_STRING_REF}));
|
|
||||||
|
|
||||||
Creator creator = GetCreator();
|
Creator creator = GetCreator();
|
||||||
ConditionalAccumulatorBase* accumulator;
|
ConditionalAccumulatorBase* accumulator;
|
||||||
@ -112,11 +120,72 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
mutex mu_;
|
|
||||||
PersistentTensor accumulator_handle_ GUARDED_BY(mu_);
|
|
||||||
bool accumulator_handle_set_ GUARDED_BY(mu_);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ------------------Sync kernels ------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* General OpKernel for ConditionalAccumulatorBase-related ops.
|
||||||
|
*/
|
||||||
|
class ConditionalAccumulatorBaseSyncOpKernel : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit ConditionalAccumulatorBaseSyncOpKernel(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) final {
|
||||||
|
ConditionalAccumulatorBase* accumulator;
|
||||||
|
OP_REQUIRES_OK(ctx, GetResourceFromContext(ctx, "handle", &accumulator));
|
||||||
|
Compute(ctx, accumulator);
|
||||||
|
accumulator->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void Compute(OpKernelContext* ctx,
|
||||||
|
ConditionalAccumulatorBase* accumulator) = 0;
|
||||||
|
|
||||||
|
virtual DataTypeVector GetExpectedInputs(
|
||||||
|
ConditionalAccumulatorBase* accumulator) = 0;
|
||||||
|
|
||||||
|
virtual void CheckSignature(OpKernelContext* ctx,
|
||||||
|
ConditionalAccumulatorBase* accumulator) {
|
||||||
|
// Check input signature
|
||||||
|
DataTypeVector expected_inputs = GetExpectedInputs(accumulator);
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defines a AccumulateGradientOp, the execution of which adds a gradient to the
|
||||||
|
* given ConditionalAccumulator.
|
||||||
|
*/
|
||||||
|
class ConditionalAccumulatorBaseApplyGradientOp
|
||||||
|
: public ConditionalAccumulatorBaseSyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit ConditionalAccumulatorBaseApplyGradientOp(
|
||||||
|
OpKernelConstruction* context)
|
||||||
|
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void Compute(OpKernelContext* ctx,
|
||||||
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
|
// Check input signature
|
||||||
|
CheckSignature(ctx, accumulator);
|
||||||
|
|
||||||
|
// Get input local_step
|
||||||
|
const Tensor* local_step_tensor;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input("local_step", &local_step_tensor));
|
||||||
|
if (!TensorShapeUtils::IsScalar(local_step_tensor->shape())) {
|
||||||
|
ctx->CtxFailureWithWarning(errors::InvalidArgument(
|
||||||
|
"Argument local_step must be scalar, but had bad shape ",
|
||||||
|
local_step_tensor->shape().DebugString()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Actually try to apply gradient now
|
||||||
|
accumulator->TryApplyGrad(local_step_tensor->scalar<int64>()(), ctx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// -------------------- Async kernels --------------------------------------
|
||||||
/**
|
/**
|
||||||
* General OpKernel for ConditionalAccumulatorBase-related ops.
|
* General OpKernel for ConditionalAccumulatorBase-related ops.
|
||||||
*/
|
*/
|
||||||
@ -140,59 +209,18 @@ class ConditionalAccumulatorBaseAsyncOpKernel : public AsyncOpKernel {
|
|||||||
virtual void ComputeAsync(OpKernelContext* ctx,
|
virtual void ComputeAsync(OpKernelContext* ctx,
|
||||||
ConditionalAccumulatorBase* accumulator,
|
ConditionalAccumulatorBase* accumulator,
|
||||||
DoneCallback callback) = 0;
|
DoneCallback callback) = 0;
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
virtual DataTypeVector GetExpectedInputs(
|
||||||
* General OpKernel for ConditionalAccumulatorBase-related ops.
|
|
||||||
*/
|
|
||||||
class ConditionalAccumulatorBaseSyncOpKernel : public OpKernel {
|
|
||||||
public:
|
|
||||||
explicit ConditionalAccumulatorBaseSyncOpKernel(OpKernelConstruction* context)
|
|
||||||
: OpKernel(context) {}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) final {
|
|
||||||
ConditionalAccumulatorBase* accumulator;
|
|
||||||
OP_REQUIRES_OK(ctx, GetResourceFromContext(ctx, "handle", &accumulator));
|
|
||||||
Compute(ctx, accumulator);
|
|
||||||
accumulator->Unref();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
virtual void Compute(OpKernelContext* ctx,
|
|
||||||
ConditionalAccumulatorBase* accumulator) = 0;
|
ConditionalAccumulatorBase* accumulator) = 0;
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Defines a AccumulateGradientOp, the execution of which adds a gradient to the
|
|
||||||
* given ConditionalAccumulator.
|
|
||||||
*/
|
|
||||||
class ConditionalAccumulatorBaseApplyGradientOp
|
|
||||||
: public ConditionalAccumulatorBaseSyncOpKernel {
|
|
||||||
public:
|
|
||||||
explicit ConditionalAccumulatorBaseApplyGradientOp(
|
|
||||||
OpKernelConstruction* context)
|
|
||||||
: ConditionalAccumulatorBaseSyncOpKernel(context) {}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
virtual void CheckSignature(OpKernelContext* ctx,
|
virtual void CheckSignature(OpKernelContext* ctx,
|
||||||
ConditionalAccumulatorBase* accumulator) = 0;
|
ConditionalAccumulatorBase* accumulator,
|
||||||
|
DoneCallback callback) {
|
||||||
void Compute(OpKernelContext* ctx,
|
|
||||||
ConditionalAccumulatorBase* accumulator) override {
|
|
||||||
// Check input signature
|
// Check input signature
|
||||||
CheckSignature(ctx, accumulator);
|
OP_REQUIRES_OK_ASYNC(ctx,
|
||||||
|
ctx->MatchSignature(GetExpectedInputs(accumulator),
|
||||||
// Get input local_step
|
{accumulator->dtype()}),
|
||||||
const Tensor* local_step_tensor;
|
callback);
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("local_step", &local_step_tensor));
|
|
||||||
if (!TensorShapeUtils::IsScalar(local_step_tensor->shape())) {
|
|
||||||
ctx->CtxFailureWithWarning(errors::InvalidArgument(
|
|
||||||
"Argument local_step must be scalar, but had bad shape ",
|
|
||||||
local_step_tensor->shape().DebugString()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Actually try to apply gradient now
|
|
||||||
accumulator->TryApplyGrad(local_step_tensor->scalar<int64>()(), ctx);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -208,10 +236,6 @@ class ConditionalAccumulatorBaseTakeGradientOp
|
|||||||
: ConditionalAccumulatorBaseAsyncOpKernel(context) {}
|
: ConditionalAccumulatorBaseAsyncOpKernel(context) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void CheckSignature(OpKernelContext* ctx,
|
|
||||||
ConditionalAccumulatorBase* accumulator,
|
|
||||||
DoneCallback callback) = 0;
|
|
||||||
|
|
||||||
void ComputeAsync(OpKernelContext* ctx,
|
void ComputeAsync(OpKernelContext* ctx,
|
||||||
ConditionalAccumulatorBase* accumulator,
|
ConditionalAccumulatorBase* accumulator,
|
||||||
DoneCallback callback) override {
|
DoneCallback callback) override {
|
||||||
|
@ -41,6 +41,16 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CheckSignature(OpKernelContext* ctx) override {
|
||||||
|
TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF}));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetHandleToOutput(OpKernelContext* ctx)
|
||||||
|
SHARED_LOCKS_REQUIRED(mu_) override {
|
||||||
|
ctx->set_output_ref(0, &mu_, accumulator_handle_.AccessTensor(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulatorOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulatorOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -50,6 +60,50 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
|||||||
.TypeConstraint<type>("dtype"), \
|
.TypeConstraint<type>("dtype"), \
|
||||||
ConditionalAccumulatorOp<dev##Device, type>)
|
ConditionalAccumulatorOp<dev##Device, type>)
|
||||||
|
|
||||||
|
// Resource conditional accumulator
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class ResourceConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
||||||
|
public:
|
||||||
|
explicit ResourceConditionalAccumulatorOp(OpKernelConstruction* context)
|
||||||
|
: ConditionalAccumulatorBaseOp(context) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Creator GetCreator() const override {
|
||||||
|
return [this](ConditionalAccumulatorBase** ret) {
|
||||||
|
ConditionalAccumulator<Device, T>* accumulator =
|
||||||
|
new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
|
||||||
|
reduction_type_);
|
||||||
|
*ret = accumulator;
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CheckSignature(OpKernelContext* ctx) override {
|
||||||
|
TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_RESOURCE}));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetHandleToOutput(OpKernelContext* ctx)
|
||||||
|
SHARED_LOCKS_REQUIRED(mu_) override {
|
||||||
|
auto h = accumulator_handle_.AccessTensor(ctx)->template flat<string>();
|
||||||
|
h(0) = cinfo_.container();
|
||||||
|
h(1) = cinfo_.name();
|
||||||
|
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
||||||
|
ctx, 0, cinfo_.container(), cinfo_.name(),
|
||||||
|
MakeTypeIndex<ConditionalAccumulatorBase>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(ResourceConditionalAccumulatorOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_RESOURCE_KERNELS(type, dev) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ResourceConditionalAccumulator") \
|
||||||
|
.Device(DEVICE_##dev) \
|
||||||
|
.TypeConstraint<type>("dtype"), \
|
||||||
|
ResourceConditionalAccumulatorOp<dev##Device, type>)
|
||||||
|
|
||||||
|
// End of Resource conditional accumulator
|
||||||
|
|
||||||
#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU)
|
#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU)
|
||||||
|
|
||||||
TF_CALL_half(REGISTER_KERNELS_CPU);
|
TF_CALL_half(REGISTER_KERNELS_CPU);
|
||||||
@ -59,6 +113,15 @@ TF_CALL_double(REGISTER_KERNELS_CPU);
|
|||||||
#undef REGISTER_KERNELS_CPU
|
#undef REGISTER_KERNELS_CPU
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
|
#define REGISTER_RESOURCE_KERNELS_CPU(type) REGISTER_RESOURCE_KERNELS(type, CPU)
|
||||||
|
|
||||||
|
TF_CALL_half(REGISTER_RESOURCE_KERNELS_CPU);
|
||||||
|
TF_CALL_float(REGISTER_RESOURCE_KERNELS_CPU);
|
||||||
|
TF_CALL_double(REGISTER_RESOURCE_KERNELS_CPU);
|
||||||
|
|
||||||
|
#undef REGISTER_KERNELS_CPU
|
||||||
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Defines a AccumulateGradientOp, the execution of which adds a gradient to the
|
* Defines a AccumulateGradientOp, the execution of which adds a gradient to the
|
||||||
* given ConditionalAccumulator.
|
* given ConditionalAccumulator.
|
||||||
@ -69,13 +132,12 @@ class AccumulatorApplyGradientOp
|
|||||||
explicit AccumulatorApplyGradientOp(OpKernelConstruction* context)
|
explicit AccumulatorApplyGradientOp(OpKernelConstruction* context)
|
||||||
: ConditionalAccumulatorBaseApplyGradientOp(context) {}
|
: ConditionalAccumulatorBaseApplyGradientOp(context) {}
|
||||||
|
|
||||||
protected:
|
DataTypeVector GetExpectedInputs(
|
||||||
void CheckSignature(OpKernelContext* ctx,
|
|
||||||
ConditionalAccumulatorBase* accumulator) override {
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
// Check input signature
|
DataTypeVector expected_inputs;
|
||||||
DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64};
|
expected_inputs = {DT_STRING_REF, DT_INT64};
|
||||||
expected_inputs.push_back(accumulator->dtype());
|
expected_inputs.push_back(accumulator->dtype());
|
||||||
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
|
return expected_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -85,6 +147,28 @@ class AccumulatorApplyGradientOp
|
|||||||
REGISTER_KERNEL_BUILDER(Name("AccumulatorApplyGradient").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("AccumulatorApplyGradient").Device(DEVICE_CPU),
|
||||||
AccumulatorApplyGradientOp);
|
AccumulatorApplyGradientOp);
|
||||||
|
|
||||||
|
class ResourceAccumulatorApplyGradientOp
|
||||||
|
: public ConditionalAccumulatorBaseApplyGradientOp {
|
||||||
|
public:
|
||||||
|
explicit ResourceAccumulatorApplyGradientOp(OpKernelConstruction* context)
|
||||||
|
: ConditionalAccumulatorBaseApplyGradientOp(context) {}
|
||||||
|
|
||||||
|
DataTypeVector GetExpectedInputs(
|
||||||
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
|
DataTypeVector expected_inputs;
|
||||||
|
expected_inputs = {DT_RESOURCE, DT_INT64};
|
||||||
|
expected_inputs.push_back(accumulator->dtype());
|
||||||
|
return expected_inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorApplyGradientOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("ResourceAccumulatorApplyGradient").Device(DEVICE_CPU),
|
||||||
|
ResourceAccumulatorApplyGradientOp);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Defines a ConditionalAccumulatorBaseTakeGradientOp, the execution of which
|
* Defines a ConditionalAccumulatorBaseTakeGradientOp, the execution of which
|
||||||
* returns the average gradient accumulated by the given ConditionalAccumulator.
|
* returns the average gradient accumulated by the given ConditionalAccumulator.
|
||||||
@ -95,22 +179,34 @@ class AccumulatorTakeGradientOp
|
|||||||
explicit AccumulatorTakeGradientOp(OpKernelConstruction* context)
|
explicit AccumulatorTakeGradientOp(OpKernelConstruction* context)
|
||||||
: ConditionalAccumulatorBaseTakeGradientOp(context) {}
|
: ConditionalAccumulatorBaseTakeGradientOp(context) {}
|
||||||
|
|
||||||
protected:
|
DataTypeVector GetExpectedInputs(
|
||||||
void CheckSignature(OpKernelContext* ctx,
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
ConditionalAccumulatorBase* accumulator,
|
return {DT_STRING_REF, DT_INT32};
|
||||||
DoneCallback callback) override {
|
|
||||||
// Check signature
|
|
||||||
OP_REQUIRES_OK_ASYNC(
|
|
||||||
ctx,
|
|
||||||
ctx->MatchSignature({DT_STRING_REF, DT_INT32}, {accumulator->dtype()}),
|
|
||||||
callback);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorTakeGradientOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorTakeGradientOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("AccumulatorTakeGradient").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("AccumulatorTakeGradient").Device(DEVICE_CPU),
|
||||||
AccumulatorTakeGradientOp);
|
AccumulatorTakeGradientOp);
|
||||||
|
|
||||||
|
class ResourceAccumulatorTakeGradientOp
|
||||||
|
: public ConditionalAccumulatorBaseTakeGradientOp {
|
||||||
|
public:
|
||||||
|
explicit ResourceAccumulatorTakeGradientOp(OpKernelConstruction* context)
|
||||||
|
: ConditionalAccumulatorBaseTakeGradientOp(context) {}
|
||||||
|
|
||||||
|
DataTypeVector GetExpectedInputs(
|
||||||
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
|
return {DT_RESOURCE, DT_INT32};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(ResourceAccumulatorTakeGradientOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("ResourceAccumulatorTakeGradient").Device(DEVICE_CPU),
|
||||||
|
ResourceAccumulatorTakeGradientOp);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -41,6 +41,18 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(tanzheny): actually switch it to resource. You won't be able to use
|
||||||
|
// it with cond2 otherwise.
|
||||||
|
Status CheckSignature(OpKernelContext* ctx) override {
|
||||||
|
TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF}));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetHandleToOutput(OpKernelContext* ctx)
|
||||||
|
SHARED_LOCKS_REQUIRED(mu_) override {
|
||||||
|
ctx->set_output_ref(0, &mu_, accumulator_handle_.AccessTensor(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(SparseConditionalAccumulatorOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(SparseConditionalAccumulatorOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -70,13 +82,12 @@ class SparseAccumulatorApplyGradientOp
|
|||||||
: ConditionalAccumulatorBaseApplyGradientOp(context) {}
|
: ConditionalAccumulatorBaseApplyGradientOp(context) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void CheckSignature(OpKernelContext* ctx,
|
DataTypeVector GetExpectedInputs(
|
||||||
ConditionalAccumulatorBase* accumulator) override {
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
// Check input signature
|
|
||||||
DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64, DT_INT64};
|
DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64, DT_INT64};
|
||||||
expected_inputs.push_back(accumulator->dtype());
|
expected_inputs.push_back(accumulator->dtype());
|
||||||
expected_inputs.push_back(DT_INT64);
|
expected_inputs.push_back(DT_INT64);
|
||||||
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
|
return expected_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -109,6 +120,11 @@ class SparseAccumulatorTakeGradientOp
|
|||||||
callback);
|
callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DataTypeVector GetExpectedInputs(
|
||||||
|
ConditionalAccumulatorBase* accumulator) override {
|
||||||
|
return {DT_STRING_REF, DT_INT32};
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorTakeGradientOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorTakeGradientOp);
|
||||||
};
|
};
|
||||||
|
@ -451,6 +451,61 @@ REGISTER_OP("AccumulatorTakeGradient")
|
|||||||
})
|
})
|
||||||
.Attr("dtype: numbertype");
|
.Attr("dtype: numbertype");
|
||||||
|
|
||||||
|
// -----------------V2 accumulators that use resource -------------------------
|
||||||
|
|
||||||
|
REGISTER_OP("ResourceAccumulatorNumAccumulated")
|
||||||
|
.Input("handle: resource")
|
||||||
|
.Output("num_accumulated: int32")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
|
REGISTER_OP("ResourceAccumulatorSetGlobalStep")
|
||||||
|
.Input("handle: resource")
|
||||||
|
.Input("new_global_step: int64")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("ResourceConditionalAccumulator")
|
||||||
|
.Output("handle: resource")
|
||||||
|
.Attr("dtype: numbertype")
|
||||||
|
.Attr("shape: shape")
|
||||||
|
.Attr("container: string = ''")
|
||||||
|
.Attr("shared_name: string = ''")
|
||||||
|
.Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
c->set_output(0, c->Vector(2));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("ResourceAccumulatorApplyGradient")
|
||||||
|
.Input("handle: resource")
|
||||||
|
.Input("local_step: int64")
|
||||||
|
.Input("gradient: dtype")
|
||||||
|
.Attr("dtype: numbertype")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("ResourceAccumulatorTakeGradient")
|
||||||
|
.Input("handle: resource")
|
||||||
|
.Input("num_required: int32")
|
||||||
|
.Output("average: dtype")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
// Shape of output is the shape of the accumulator referenced
|
||||||
|
// by 'handle', but which is not available here, so we lose
|
||||||
|
// shape information.
|
||||||
|
return shape_inference::UnknownShape(c);
|
||||||
|
})
|
||||||
|
.Attr("dtype: numbertype");
|
||||||
|
|
||||||
|
// TODO(nponomareva): change these all to use resources.
|
||||||
REGISTER_OP("SparseConditionalAccumulator")
|
REGISTER_OP("SparseConditionalAccumulator")
|
||||||
.Output("handle: Ref(string)")
|
.Output("handle: Ref(string)")
|
||||||
.Attr("dtype: numbertype")
|
.Attr("dtype: numbertype")
|
||||||
|
@ -24,6 +24,7 @@ import threading
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes as _dtypes
|
from tensorflow.python.framework import dtypes as _dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -1215,6 +1216,11 @@ class ConditionalAccumulatorBase(object):
|
|||||||
"""
|
"""
|
||||||
if name is None:
|
if name is None:
|
||||||
name = "%s_NumAccumulated" % self._name
|
name = "%s_NumAccumulated" % self._name
|
||||||
|
|
||||||
|
if compat.forward_compatible(2019, 7, 8):
|
||||||
|
return gen_data_flow_ops.resource_accumulator_num_accumulated(
|
||||||
|
self._accumulator_ref, name=name)
|
||||||
|
|
||||||
return gen_data_flow_ops.accumulator_num_accumulated(
|
return gen_data_flow_ops.accumulator_num_accumulated(
|
||||||
self._accumulator_ref, name=name)
|
self._accumulator_ref, name=name)
|
||||||
|
|
||||||
@ -1231,6 +1237,12 @@ class ConditionalAccumulatorBase(object):
|
|||||||
Returns:
|
Returns:
|
||||||
Operation that sets the accumulator's time step.
|
Operation that sets the accumulator's time step.
|
||||||
"""
|
"""
|
||||||
|
if compat.forward_compatible(2019, 7, 8):
|
||||||
|
return gen_data_flow_ops.resource_accumulator_set_global_step(
|
||||||
|
self._accumulator_ref,
|
||||||
|
math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
|
||||||
|
name=name)
|
||||||
|
|
||||||
return gen_data_flow_ops.accumulator_set_global_step(
|
return gen_data_flow_ops.accumulator_set_global_step(
|
||||||
self._accumulator_ref,
|
self._accumulator_ref,
|
||||||
math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
|
math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
|
||||||
@ -1264,12 +1276,23 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
|||||||
name: Optional name for the accumulator.
|
name: Optional name for the accumulator.
|
||||||
reduction_type: Reduction type to use when taking the gradient.
|
reduction_type: Reduction type to use when taking the gradient.
|
||||||
"""
|
"""
|
||||||
|
if compat.forward_compatible(2019, 7, 8):
|
||||||
|
accumulator_ref = gen_data_flow_ops.resource_conditional_accumulator(
|
||||||
|
dtype=dtype,
|
||||||
|
shape=shape,
|
||||||
|
shared_name=shared_name,
|
||||||
|
name=name,
|
||||||
|
reduction_type=reduction_type)
|
||||||
|
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
|
||||||
|
handle=accumulator_ref, handle_device=context.context().device_name)
|
||||||
|
else:
|
||||||
accumulator_ref = gen_data_flow_ops.conditional_accumulator(
|
accumulator_ref = gen_data_flow_ops.conditional_accumulator(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
shared_name=shared_name,
|
shared_name=shared_name,
|
||||||
name=name,
|
name=name,
|
||||||
reduction_type=reduction_type)
|
reduction_type=reduction_type)
|
||||||
|
|
||||||
super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
|
super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
|
||||||
|
|
||||||
def apply_grad(self, grad, local_step=0, name=None):
|
def apply_grad(self, grad, local_step=0, name=None):
|
||||||
@ -1292,6 +1315,13 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
|||||||
grad = ops.convert_to_tensor(grad, self._dtype)
|
grad = ops.convert_to_tensor(grad, self._dtype)
|
||||||
grad.get_shape().assert_is_compatible_with(self._shape)
|
grad.get_shape().assert_is_compatible_with(self._shape)
|
||||||
local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
|
local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
|
||||||
|
|
||||||
|
if compat.forward_compatible(2019, 7, 8):
|
||||||
|
return gen_data_flow_ops.resource_accumulator_apply_gradient(
|
||||||
|
self._accumulator_ref,
|
||||||
|
local_step=local_step,
|
||||||
|
gradient=grad,
|
||||||
|
name=name)
|
||||||
return gen_data_flow_ops.accumulator_apply_gradient(
|
return gen_data_flow_ops.accumulator_apply_gradient(
|
||||||
self._accumulator_ref, local_step=local_step, gradient=grad, name=name)
|
self._accumulator_ref, local_step=local_step, gradient=grad, name=name)
|
||||||
|
|
||||||
@ -1317,6 +1347,10 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
|||||||
Raises:
|
Raises:
|
||||||
InvalidArgumentError: If num_required < 1
|
InvalidArgumentError: If num_required < 1
|
||||||
"""
|
"""
|
||||||
|
if compat.forward_compatible(2019, 7, 8):
|
||||||
|
out = gen_data_flow_ops.resource_accumulator_take_gradient(
|
||||||
|
self._accumulator_ref, num_required, dtype=self._dtype, name=name)
|
||||||
|
else:
|
||||||
out = gen_data_flow_ops.accumulator_take_gradient(
|
out = gen_data_flow_ops.accumulator_take_gradient(
|
||||||
self._accumulator_ref, num_required, dtype=self._dtype, name=name)
|
self._accumulator_ref, num_required, dtype=self._dtype, name=name)
|
||||||
out.set_shape(self._shape)
|
out.set_shape(self._shape)
|
||||||
|
@ -3016,6 +3016,22 @@ tf_module {
|
|||||||
name: "ResizeNearestNeighborGrad"
|
name: "ResizeNearestNeighborGrad"
|
||||||
argspec: "args=[\'grads\', \'size\', \'align_corners\', \'half_pixel_centers\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
argspec: "args=[\'grads\', \'size\', \'align_corners\', \'half_pixel_centers\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceAccumulatorApplyGradient"
|
||||||
|
argspec: "args=[\'handle\', \'local_step\', \'gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceAccumulatorNumAccumulated"
|
||||||
|
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceAccumulatorSetGlobalStep"
|
||||||
|
argspec: "args=[\'handle\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceAccumulatorTakeGradient"
|
||||||
|
argspec: "args=[\'handle\', \'num_required\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ResourceApplyAdaMax"
|
name: "ResourceApplyAdaMax"
|
||||||
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
@ -3084,6 +3100,10 @@ tf_module {
|
|||||||
name: "ResourceApplyRMSProp"
|
name: "ResourceApplyRMSProp"
|
||||||
argspec: "args=[\'var\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'var\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceConditionalAccumulator"
|
||||||
|
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ResourceCountUpTo"
|
name: "ResourceCountUpTo"
|
||||||
argspec: "args=[\'resource\', \'limit\', \'T\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'resource\', \'limit\', \'T\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -3016,6 +3016,22 @@ tf_module {
|
|||||||
name: "ResizeNearestNeighborGrad"
|
name: "ResizeNearestNeighborGrad"
|
||||||
argspec: "args=[\'grads\', \'size\', \'align_corners\', \'half_pixel_centers\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
argspec: "args=[\'grads\', \'size\', \'align_corners\', \'half_pixel_centers\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceAccumulatorApplyGradient"
|
||||||
|
argspec: "args=[\'handle\', \'local_step\', \'gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceAccumulatorNumAccumulated"
|
||||||
|
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceAccumulatorSetGlobalStep"
|
||||||
|
argspec: "args=[\'handle\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceAccumulatorTakeGradient"
|
||||||
|
argspec: "args=[\'handle\', \'num_required\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ResourceApplyAdaMax"
|
name: "ResourceApplyAdaMax"
|
||||||
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'var\', \'m\', \'v\', \'beta1_power\', \'lr\', \'beta1\', \'beta2\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
@ -3084,6 +3100,10 @@ tf_module {
|
|||||||
name: "ResourceApplyRMSProp"
|
name: "ResourceApplyRMSProp"
|
||||||
argspec: "args=[\'var\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'var\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ResourceConditionalAccumulator"
|
||||||
|
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ResourceCountUpTo"
|
name: "ResourceCountUpTo"
|
||||||
argspec: "args=[\'resource\', \'limit\', \'T\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'resource\', \'limit\', \'T\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user