Extend ConditionalAccumulator with SUM functionality.
Previously take_grad represents the average gradients being aggregated. However this does not cover other use cases such as summing quantiles, or summing probability distributions from parallel workers. This change extends the functionality. PiperOrigin-RevId: 211824519
This commit is contained in:
parent
bfff3425e0
commit
d17016a8df
@ -51,9 +51,11 @@ class ConditionalAccumulator
|
||||
// dtype: The datatype of the gradients to be accumulated.
|
||||
// shape: The shape of the accumulated gradients.
|
||||
// name: A name to use for the ConditionalAccumulator.
|
||||
// reduction_type: The reduction type, i.e., MEAN or SUM
|
||||
ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
|
||||
const string& name)
|
||||
: TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {}
|
||||
const string& name, const string& reduction_type)
|
||||
: TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
|
||||
reduction_type) {}
|
||||
~ConditionalAccumulator() override{};
|
||||
|
||||
protected:
|
||||
|
@ -14,12 +14,17 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/conditional_accumulator_base.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
ConditionalAccumulatorBase::ConditionalAccumulatorBase(
|
||||
const DataType& dtype, const PartialTensorShape& shape, const string& name)
|
||||
: dtype_(dtype), shape_(shape), name_(name) {
|
||||
const DataType& dtype, const PartialTensorShape& shape, const string& name,
|
||||
const string& reduction_type)
|
||||
: dtype_(dtype),
|
||||
shape_(shape),
|
||||
name_(name),
|
||||
reduction_type_(reduction_type) {
|
||||
counter_ = 0;
|
||||
current_global_step_ = 0;
|
||||
}
|
||||
@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
|
||||
current_global_step_++;
|
||||
|
||||
// Average the accumulated gradient
|
||||
DivideAccumGradByCounter(ctx);
|
||||
if (reduction_type_ == "MEAN") {
|
||||
DivideAccumGradByCounter(ctx);
|
||||
}
|
||||
|
||||
// Set output for accumulated gradient tensor
|
||||
bool successful_set_output = SetOutput(ctx);
|
||||
|
@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
|
||||
// name: A name to use for the ConditionalAccumulator.
|
||||
ConditionalAccumulatorBase(const DataType& dtype,
|
||||
const PartialTensorShape& shape,
|
||||
const string& name);
|
||||
const string& name, const string& reduction_type);
|
||||
|
||||
typedef AsyncOpKernel::DoneCallback DoneCallback;
|
||||
|
||||
@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
|
||||
const DataType dtype_;
|
||||
const PartialTensorShape shape_;
|
||||
const string name_;
|
||||
const string reduction_type_;
|
||||
mutex mu_;
|
||||
int counter_ GUARDED_BY(mu_);
|
||||
int64 current_global_step_ GUARDED_BY(mu_);
|
||||
|
@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
||||
&accumulator_handle_, nullptr));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("reduction_type", &reduction_type_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
|
||||
DataType dtype_;
|
||||
PartialTensorShape shape_;
|
||||
ContainerInfo cinfo_;
|
||||
string reduction_type_;
|
||||
|
||||
private:
|
||||
Status SetAccumulatorHandle(OpKernelContext* ctx)
|
||||
|
@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
||||
Creator GetCreator() const override {
|
||||
return [this](ConditionalAccumulatorBase** ret) {
|
||||
ConditionalAccumulator<Device, T>* accumulator =
|
||||
new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name());
|
||||
new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
|
||||
reduction_type_);
|
||||
*ret = accumulator;
|
||||
return Status::OK();
|
||||
};
|
||||
|
@ -50,10 +50,10 @@ class SparseConditionalAccumulator
|
||||
public:
|
||||
SparseConditionalAccumulator(const DataType& dtype,
|
||||
const PartialTensorShape& shape,
|
||||
const string& name)
|
||||
const string& name, const string& reduction_type)
|
||||
: TypedConditionalAccumulatorBase<
|
||||
std::tuple<const Tensor*, const Tensor*, const Tensor*>>(
|
||||
dtype, shape, name) {
|
||||
dtype, shape, name, reduction_type) {
|
||||
accum_idx_vec_ = nullptr;
|
||||
count_element_ = nullptr;
|
||||
accum_val_ = nullptr;
|
||||
|
@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
||||
Creator GetCreator() const override {
|
||||
return [this](ConditionalAccumulatorBase** ret) {
|
||||
SparseConditionalAccumulator<Device, T>* accumulator =
|
||||
new SparseConditionalAccumulator<Device, T>(dtype_, shape_,
|
||||
cinfo_.name());
|
||||
new SparseConditionalAccumulator<Device, T>(
|
||||
dtype_, shape_, cinfo_.name(), reduction_type_);
|
||||
*ret = accumulator;
|
||||
return Status::OK();
|
||||
};
|
||||
|
@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
|
||||
public:
|
||||
TypedConditionalAccumulatorBase(const DataType& dtype,
|
||||
const PartialTensorShape& shape,
|
||||
const string& name)
|
||||
: ConditionalAccumulatorBase(dtype, shape, name) {}
|
||||
const string& name,
|
||||
const string& reduction_type)
|
||||
: ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
|
||||
|
||||
/**
|
||||
* Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
|
||||
|
@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator")
|
||||
.Attr("shape: shape")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(2));
|
||||
@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator")
|
||||
.Attr("shape: shape")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(2));
|
||||
|
@ -42,14 +42,22 @@ class ConditionalAccumulatorTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
|
||||
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEquals(
|
||||
"""
|
||||
name:'Q' op:'ConditionalAccumulator'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'shape' value { shape { unknown_rank: true} } }
|
||||
attr { key: 'container' value { s: '' } }
|
||||
attr { key: 'shared_name' value { s: '' } }
|
||||
attr { key: 'reduction_type' value {s: 'MEAN'} }
|
||||
""", q.accumulator_ref.op.node_def)
|
||||
|
||||
def testConstructorWithInvalidArg(self):
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaises(ValueError):
|
||||
data_flow_ops.ConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q", reduction_type="Invalid")
|
||||
|
||||
def testConstructorWithShape(self):
|
||||
with ops.Graph().as_default():
|
||||
q = data_flow_ops.ConditionalAccumulator(
|
||||
@ -57,7 +65,8 @@ class ConditionalAccumulatorTest(test.TestCase):
|
||||
name="Q",
|
||||
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
|
||||
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEquals(
|
||||
"""
|
||||
name:'Q' op:'ConditionalAccumulator'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'shape' value { shape { dim {size: 1 }
|
||||
@ -67,6 +76,7 @@ class ConditionalAccumulatorTest(test.TestCase):
|
||||
} } }
|
||||
attr { key: 'container' value { s: '' } }
|
||||
attr { key: 'shared_name' value { s: '' } }
|
||||
attr { key: 'reduction_type' value {s: 'MEAN'} }
|
||||
""", q.accumulator_ref.op.node_def)
|
||||
|
||||
def testAccumulatorSizeEmpty(self):
|
||||
@ -237,12 +247,11 @@ class ConditionalAccumulatorTest(test.TestCase):
|
||||
extract_t.op.run()
|
||||
self.assertEqual(q.num_accumulated().eval(), 0)
|
||||
|
||||
def testAccumulatorTakeGrad(self):
|
||||
def testAccumulatorTakeGradMean(self):
|
||||
with self.test_session():
|
||||
q = data_flow_ops.ConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
|
||||
elems = [10.0, 20.0]
|
||||
elems_ave = sum(elems) / len(elems)
|
||||
|
||||
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
|
||||
takeg_t = q.take_grad(1)
|
||||
@ -251,7 +260,7 @@ class ConditionalAccumulatorTest(test.TestCase):
|
||||
accum_op.run()
|
||||
|
||||
val = takeg_t.eval()
|
||||
self.assertEqual(elems_ave, val)
|
||||
self.assertEqual(15.0, val)
|
||||
|
||||
accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
|
||||
takeg_t = q.take_grad(constant_op.constant(1))
|
||||
@ -260,7 +269,42 @@ class ConditionalAccumulatorTest(test.TestCase):
|
||||
accum_op.run()
|
||||
|
||||
val = takeg_t.eval()
|
||||
self.assertEqual(elems_ave, val)
|
||||
self.assertEqual(15.0, val)
|
||||
|
||||
def testAccumulatorTakeGradSum(self):
|
||||
with self.test_session():
|
||||
q = data_flow_ops.ConditionalAccumulator(
|
||||
dtypes_lib.float32,
|
||||
name="Q",
|
||||
shape=tensor_shape.TensorShape([1]),
|
||||
reduction_type="SUM")
|
||||
elems = [10.0, 20.0]
|
||||
|
||||
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
|
||||
takeg_t = q.take_grad(1)
|
||||
|
||||
for accum_op in accum_ops:
|
||||
accum_op.run()
|
||||
|
||||
val = takeg_t.eval()
|
||||
self.assertEqual(30.0, val)
|
||||
|
||||
accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
|
||||
takeg_t = q.take_grad(constant_op.constant(1))
|
||||
|
||||
for accum_op in accum_ops:
|
||||
accum_op.run()
|
||||
|
||||
val = takeg_t.eval()
|
||||
self.assertEqual(30.0, val)
|
||||
|
||||
def testAccumulatorTakeGradInvalidReductionType(self):
|
||||
with self.assertRaises(ValueError):
|
||||
data_flow_ops.ConditionalAccumulator(
|
||||
dtypes_lib.float32,
|
||||
name="Q",
|
||||
shape=tensor_shape.TensorShape([1]),
|
||||
reduction_type="Invalid")
|
||||
|
||||
def testAccumulatorInvalidTakeGrad(self):
|
||||
with self.test_session():
|
||||
@ -277,7 +321,7 @@ class ConditionalAccumulatorTest(test.TestCase):
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
takeg_t.eval()
|
||||
|
||||
def testAccumulatorRepeatedTakeGrad(self):
|
||||
def testAccumulatorRepeatedTakeGradMean(self):
|
||||
with self.test_session():
|
||||
q = data_flow_ops.ConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
|
||||
@ -304,6 +348,36 @@ class ConditionalAccumulatorTest(test.TestCase):
|
||||
val = takeg_t.eval()
|
||||
self.assertEqual(elems_ave + 0.0, val)
|
||||
|
||||
def testAccumulatorRepeatedTakeGradSum(self):
|
||||
with self.test_session():
|
||||
q = data_flow_ops.ConditionalAccumulator(
|
||||
dtypes_lib.float32,
|
||||
name="Q",
|
||||
shape=tensor_shape.TensorShape([1]),
|
||||
reduction_type="SUM")
|
||||
|
||||
elems = [10.0, 20.0]
|
||||
elems_sum = 30.0
|
||||
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
|
||||
takeg_t = q.take_grad(1)
|
||||
|
||||
for accum_op in accum_ops:
|
||||
accum_op.run()
|
||||
|
||||
val = takeg_t.eval()
|
||||
self.assertEqual(elems_sum, val)
|
||||
|
||||
elems = [20.0, 30.0]
|
||||
elems_sum = 50.0
|
||||
accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
|
||||
takeg_t = q.take_grad(1)
|
||||
|
||||
for accum_op in accum_ops:
|
||||
accum_op.run()
|
||||
|
||||
val = takeg_t.eval()
|
||||
self.assertEqual(elems_sum, val)
|
||||
|
||||
def testAccumulatorIncrementGlobalStep(self):
|
||||
with self.test_session():
|
||||
q = data_flow_ops.ConditionalAccumulator(
|
||||
|
@ -61,14 +61,22 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
||||
q = data_flow_ops.SparseConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q")
|
||||
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEquals(
|
||||
"""
|
||||
name:'Q' op:'SparseConditionalAccumulator'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'shape' value { shape { unknown_rank: true} } }
|
||||
attr { key: 'container' value { s: '' } }
|
||||
attr { key: 'shared_name' value { s: '' } }
|
||||
attr { key: 'reduction_type' value {s: 'MEAN'} }
|
||||
""", q.accumulator_ref.op.node_def)
|
||||
|
||||
def testConstructorWithInvalidArg(self):
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaises(ValueError):
|
||||
data_flow_ops.SparseConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q", reduction_type="Invalid")
|
||||
|
||||
def testConstructorWithShape(self):
|
||||
with ops.Graph().as_default():
|
||||
q = data_flow_ops.SparseConditionalAccumulator(
|
||||
@ -76,7 +84,8 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
||||
name="Q",
|
||||
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
|
||||
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEquals(
|
||||
"""
|
||||
name:'Q' op:'SparseConditionalAccumulator'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'shape' value { shape { dim {size: 1 }
|
||||
@ -86,6 +95,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
||||
} } }
|
||||
attr { key: 'container' value { s: '' } }
|
||||
attr { key: 'shared_name' value { s: '' } }
|
||||
attr { key: 'reduction_type' value {s: 'MEAN'} }
|
||||
""", q.accumulator_ref.op.node_def)
|
||||
|
||||
def testAccumulatorSizeEmpty(self):
|
||||
@ -164,7 +174,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
||||
result = sess.run(accums[i].take_indexed_slices_grad(1))
|
||||
self._assertEqual_indexedslices(expected_tensors[i], result)
|
||||
|
||||
def testAccumulatorTakeGrad(self):
|
||||
def testAccumulatorTakeGradMean(self):
|
||||
with self.test_session() as sess:
|
||||
q = data_flow_ops.SparseConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q", shape=())
|
||||
@ -180,9 +190,34 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
||||
|
||||
takeg_t = q.take_indexed_slices_grad(1)
|
||||
val = sess.run(takeg_t)
|
||||
self.assertAllEqual(val.indices, [0, 1, 2])
|
||||
self.assertAllEqual(val.values, [[0.5, 0.5], [0, 2], [3, 0]])
|
||||
self.assertAllEqual(val.dense_shape, [-1, 2])
|
||||
self.assertAllEqual([0, 1, 2], val.indices)
|
||||
self.assertAllEqual([[0.5, 0.5], [0, 2], [3, 0]], val.values)
|
||||
self.assertAllEqual([-1, 2], val.dense_shape)
|
||||
|
||||
def testAccumulatorTakeGradSum(self):
|
||||
with self.test_session() as sess:
|
||||
q = data_flow_ops.SparseConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM")
|
||||
|
||||
grad_indexed_slices = ops.IndexedSlices(
|
||||
indices=[0, 1], values=np.array([[1, 0], [0, 2]]).astype(np.float32))
|
||||
accum_op = q.apply_indexed_slices_grad(grad_indexed_slices)
|
||||
accum_op.run()
|
||||
accum_op = q.apply_grad([0, 2],
|
||||
np.array([[0, 1], [3, 0]]).astype(np.float32),
|
||||
[3, 2])
|
||||
accum_op.run()
|
||||
|
||||
takeg_t = q.take_indexed_slices_grad(1)
|
||||
val = sess.run(takeg_t)
|
||||
self.assertAllEqual([0, 1, 2], val.indices)
|
||||
self.assertAllEqual([[1, 1], [0, 2], [3, 0]], val.values)
|
||||
self.assertAllEqual([-1, 2], val.dense_shape)
|
||||
|
||||
def testAccumulatorTakeGradInvalidReductionType(self):
|
||||
with self.assertRaises(ValueError):
|
||||
data_flow_ops.SparseConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid")
|
||||
|
||||
def testAccumulatorRepeatedTakeGrad(self):
|
||||
with self.test_session() as sess:
|
||||
@ -222,7 +257,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
||||
self.assertAllEqual(val.values, [[5, 5], [0, 20], [30, 0]])
|
||||
self.assertAllEqual(val.dense_shape, [-1, 2])
|
||||
|
||||
def testParallelApplyGrad(self):
|
||||
def testParallelApplyGradMean(self):
|
||||
with self.test_session() as sess:
|
||||
q = data_flow_ops.SparseConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
|
||||
@ -253,6 +288,40 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
||||
np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
|
||||
val, sess)
|
||||
|
||||
def testParallelApplyGradSum(self):
|
||||
with self.test_session() as sess:
|
||||
q = data_flow_ops.SparseConditionalAccumulator(
|
||||
dtypes_lib.float32,
|
||||
name="Q",
|
||||
shape=tensor_shape.TensorShape([2, 2]),
|
||||
reduction_type="SUM")
|
||||
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
|
||||
accum_ops = []
|
||||
for x in elems:
|
||||
x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32))
|
||||
accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
|
||||
takeg_t = q.take_indexed_slices_grad(1)
|
||||
|
||||
def apply_indexed_slices_grad(accum_op):
|
||||
sess.run(accum_op)
|
||||
|
||||
threads = [
|
||||
self.checkedThread(target=apply_indexed_slices_grad, args=(o,))
|
||||
for o in accum_ops
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
val = sess.run(takeg_t)
|
||||
|
||||
expected_val = 550.0
|
||||
self._assertEqual_nparray(
|
||||
np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
|
||||
val, sess)
|
||||
|
||||
def testParallelTakeGrad(self):
|
||||
with self.test_session() as sess:
|
||||
q = data_flow_ops.SparseConditionalAccumulator(
|
||||
|
@ -1229,7 +1229,8 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
dtype,
|
||||
shape=None,
|
||||
shared_name=None,
|
||||
name="conditional_accumulator"):
|
||||
name="conditional_accumulator",
|
||||
reduction_type="MEAN"):
|
||||
"""Creates a new ConditionalAccumulator.
|
||||
|
||||
Args:
|
||||
@ -1238,9 +1239,14 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
shared_name: Optional. If non-empty, this accumulator will be shared under
|
||||
the given name across multiple sessions.
|
||||
name: Optional name for the accumulator.
|
||||
reduction_type: Reduction type to use when taking the gradient.
|
||||
"""
|
||||
accumulator_ref = gen_data_flow_ops.conditional_accumulator(
|
||||
dtype=dtype, shape=shape, shared_name=shared_name, name=name)
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
shared_name=shared_name,
|
||||
name=name,
|
||||
reduction_type=reduction_type)
|
||||
super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
|
||||
|
||||
def apply_grad(self, grad, local_step=0, name=None):
|
||||
@ -1312,15 +1318,21 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
shared_name: Optional. If non-empty, this accumulator will be shared under
|
||||
the given name across multiple sessions.
|
||||
name: Optional name for the accumulator.
|
||||
reduction_type: Reduction type to use when taking the gradient.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dtype,
|
||||
shape=None,
|
||||
shared_name=None,
|
||||
name="sparse_conditional_accumulator"):
|
||||
name="sparse_conditional_accumulator",
|
||||
reduction_type="MEAN"):
|
||||
accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
|
||||
dtype=dtype, shape=shape, shared_name=shared_name, name=name)
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
shared_name=shared_name,
|
||||
name=name,
|
||||
reduction_type=reduction_type)
|
||||
super(SparseConditionalAccumulator, self).__init__(dtype, shape,
|
||||
accumulator_ref)
|
||||
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply_grad"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply_grad"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply_grad"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply_grad"
|
||||
|
Loading…
Reference in New Issue
Block a user