Extend ConditionalAccumulator with SUM functionality.

Previously take_grad represents the average gradients being aggregated.
However this does not cover other use cases such as summing quantiles,
or summing probability distributions from parallel workers. This change
extends the functionality.

PiperOrigin-RevId: 211824519
This commit is contained in:
Zhenyu Tan 2018-09-06 10:01:46 -07:00 committed by TensorFlower Gardener
parent bfff3425e0
commit d17016a8df
16 changed files with 207 additions and 35 deletions

View File

@ -51,9 +51,11 @@ class ConditionalAccumulator
// dtype: The datatype of the gradients to be accumulated. // dtype: The datatype of the gradients to be accumulated.
// shape: The shape of the accumulated gradients. // shape: The shape of the accumulated gradients.
// name: A name to use for the ConditionalAccumulator. // name: A name to use for the ConditionalAccumulator.
// reduction_type: The reduction type, i.e., MEAN or SUM
ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
const string& name) const string& name, const string& reduction_type)
: TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {} : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
reduction_type) {}
~ConditionalAccumulator() override{}; ~ConditionalAccumulator() override{};
protected: protected:

View File

@ -14,12 +14,17 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/conditional_accumulator_base.h" #include "tensorflow/core/kernels/conditional_accumulator_base.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow { namespace tensorflow {
ConditionalAccumulatorBase::ConditionalAccumulatorBase( ConditionalAccumulatorBase::ConditionalAccumulatorBase(
const DataType& dtype, const PartialTensorShape& shape, const string& name) const DataType& dtype, const PartialTensorShape& shape, const string& name,
: dtype_(dtype), shape_(shape), name_(name) { const string& reduction_type)
: dtype_(dtype),
shape_(shape),
name_(name),
reduction_type_(reduction_type) {
counter_ = 0; counter_ = 0;
current_global_step_ = 0; current_global_step_ = 0;
} }
@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
current_global_step_++; current_global_step_++;
// Average the accumulated gradient // Average the accumulated gradient
DivideAccumGradByCounter(ctx); if (reduction_type_ == "MEAN") {
DivideAccumGradByCounter(ctx);
}
// Set output for accumulated gradient tensor // Set output for accumulated gradient tensor
bool successful_set_output = SetOutput(ctx); bool successful_set_output = SetOutput(ctx);

View File

@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
// name: A name to use for the ConditionalAccumulator. // name: A name to use for the ConditionalAccumulator.
ConditionalAccumulatorBase(const DataType& dtype, ConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape, const PartialTensorShape& shape,
const string& name); const string& name, const string& reduction_type);
typedef AsyncOpKernel::DoneCallback DoneCallback; typedef AsyncOpKernel::DoneCallback DoneCallback;
@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
const DataType dtype_; const DataType dtype_;
const PartialTensorShape shape_; const PartialTensorShape shape_;
const string name_; const string name_;
const string reduction_type_;
mutex mu_; mutex mu_;
int counter_ GUARDED_BY(mu_); int counter_ GUARDED_BY(mu_);
int64 current_global_step_ GUARDED_BY(mu_); int64 current_global_step_ GUARDED_BY(mu_);

View File

@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
&accumulator_handle_, nullptr)); &accumulator_handle_, nullptr));
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
OP_REQUIRES_OK(context,
context->GetAttr("reduction_type", &reduction_type_));
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
DataType dtype_; DataType dtype_;
PartialTensorShape shape_; PartialTensorShape shape_;
ContainerInfo cinfo_; ContainerInfo cinfo_;
string reduction_type_;
private: private:
Status SetAccumulatorHandle(OpKernelContext* ctx) Status SetAccumulatorHandle(OpKernelContext* ctx)

View File

@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override { Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) { return [this](ConditionalAccumulatorBase** ret) {
ConditionalAccumulator<Device, T>* accumulator = ConditionalAccumulator<Device, T>* accumulator =
new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name()); new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
reduction_type_);
*ret = accumulator; *ret = accumulator;
return Status::OK(); return Status::OK();
}; };

View File

@ -50,10 +50,10 @@ class SparseConditionalAccumulator
public: public:
SparseConditionalAccumulator(const DataType& dtype, SparseConditionalAccumulator(const DataType& dtype,
const PartialTensorShape& shape, const PartialTensorShape& shape,
const string& name) const string& name, const string& reduction_type)
: TypedConditionalAccumulatorBase< : TypedConditionalAccumulatorBase<
std::tuple<const Tensor*, const Tensor*, const Tensor*>>( std::tuple<const Tensor*, const Tensor*, const Tensor*>>(
dtype, shape, name) { dtype, shape, name, reduction_type) {
accum_idx_vec_ = nullptr; accum_idx_vec_ = nullptr;
count_element_ = nullptr; count_element_ = nullptr;
accum_val_ = nullptr; accum_val_ = nullptr;

View File

@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override { Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) { return [this](ConditionalAccumulatorBase** ret) {
SparseConditionalAccumulator<Device, T>* accumulator = SparseConditionalAccumulator<Device, T>* accumulator =
new SparseConditionalAccumulator<Device, T>(dtype_, shape_, new SparseConditionalAccumulator<Device, T>(
cinfo_.name()); dtype_, shape_, cinfo_.name(), reduction_type_);
*ret = accumulator; *ret = accumulator;
return Status::OK(); return Status::OK();
}; };

View File

@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
public: public:
TypedConditionalAccumulatorBase(const DataType& dtype, TypedConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape, const PartialTensorShape& shape,
const string& name) const string& name,
: ConditionalAccumulatorBase(dtype, shape, name) {} const string& reduction_type)
: ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
/** /**
* Attempts to add a gradient to the accumulator. An ApplyGrad attempt is * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is

View File

@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator")
.Attr("shape: shape") .Attr("shape: shape")
.Attr("container: string = ''") .Attr("container: string = ''")
.Attr("shared_name: string = ''") .Attr("shared_name: string = ''")
.Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful() .SetIsStateful()
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2)); c->set_output(0, c->Vector(2));
@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator")
.Attr("shape: shape") .Attr("shape: shape")
.Attr("container: string = ''") .Attr("container: string = ''")
.Attr("shared_name: string = ''") .Attr("shared_name: string = ''")
.Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful() .SetIsStateful()
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2)); c->set_output(0, c->Vector(2));

View File

@ -42,14 +42,22 @@ class ConditionalAccumulatorTest(test.TestCase):
with ops.Graph().as_default(): with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q") q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
self.assertProtoEquals(""" self.assertProtoEquals(
"""
name:'Q' op:'ConditionalAccumulator' name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { unknown_rank: true} } } attr { key: 'shape' value { shape { unknown_rank: true} } }
attr { key: 'container' value { s: '' } } attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } } attr { key: 'shared_name' value { s: '' } }
attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def) """, q.accumulator_ref.op.node_def)
def testConstructorWithInvalidArg(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError):
data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", reduction_type="Invalid")
def testConstructorWithShape(self): def testConstructorWithShape(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator( q = data_flow_ops.ConditionalAccumulator(
@ -57,7 +65,8 @@ class ConditionalAccumulatorTest(test.TestCase):
name="Q", name="Q",
shape=tensor_shape.TensorShape([1, 5, 2, 8])) shape=tensor_shape.TensorShape([1, 5, 2, 8]))
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
self.assertProtoEquals(""" self.assertProtoEquals(
"""
name:'Q' op:'ConditionalAccumulator' name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { dim {size: 1 } attr { key: 'shape' value { shape { dim {size: 1 }
@ -67,6 +76,7 @@ class ConditionalAccumulatorTest(test.TestCase):
} } } } } }
attr { key: 'container' value { s: '' } } attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } } attr { key: 'shared_name' value { s: '' } }
attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def) """, q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self): def testAccumulatorSizeEmpty(self):
@ -237,12 +247,11 @@ class ConditionalAccumulatorTest(test.TestCase):
extract_t.op.run() extract_t.op.run()
self.assertEqual(q.num_accumulated().eval(), 0) self.assertEqual(q.num_accumulated().eval(), 0)
def testAccumulatorTakeGrad(self): def testAccumulatorTakeGradMean(self):
with self.test_session(): with self.test_session():
q = data_flow_ops.ConditionalAccumulator( q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0] elems = [10.0, 20.0]
elems_ave = sum(elems) / len(elems)
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems] accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
takeg_t = q.take_grad(1) takeg_t = q.take_grad(1)
@ -251,7 +260,7 @@ class ConditionalAccumulatorTest(test.TestCase):
accum_op.run() accum_op.run()
val = takeg_t.eval() val = takeg_t.eval()
self.assertEqual(elems_ave, val) self.assertEqual(15.0, val)
accum_ops = [q.apply_grad((x,), local_step=1) for x in elems] accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
takeg_t = q.take_grad(constant_op.constant(1)) takeg_t = q.take_grad(constant_op.constant(1))
@ -260,7 +269,42 @@ class ConditionalAccumulatorTest(test.TestCase):
accum_op.run() accum_op.run()
val = takeg_t.eval() val = takeg_t.eval()
self.assertEqual(elems_ave, val) self.assertEqual(15.0, val)
def testAccumulatorTakeGradSum(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32,
name="Q",
shape=tensor_shape.TensorShape([1]),
reduction_type="SUM")
elems = [10.0, 20.0]
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
takeg_t = q.take_grad(1)
for accum_op in accum_ops:
accum_op.run()
val = takeg_t.eval()
self.assertEqual(30.0, val)
accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
takeg_t = q.take_grad(constant_op.constant(1))
for accum_op in accum_ops:
accum_op.run()
val = takeg_t.eval()
self.assertEqual(30.0, val)
def testAccumulatorTakeGradInvalidReductionType(self):
with self.assertRaises(ValueError):
data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32,
name="Q",
shape=tensor_shape.TensorShape([1]),
reduction_type="Invalid")
def testAccumulatorInvalidTakeGrad(self): def testAccumulatorInvalidTakeGrad(self):
with self.test_session(): with self.test_session():
@ -277,7 +321,7 @@ class ConditionalAccumulatorTest(test.TestCase):
with self.assertRaises(errors_impl.InvalidArgumentError): with self.assertRaises(errors_impl.InvalidArgumentError):
takeg_t.eval() takeg_t.eval()
def testAccumulatorRepeatedTakeGrad(self): def testAccumulatorRepeatedTakeGradMean(self):
with self.test_session(): with self.test_session():
q = data_flow_ops.ConditionalAccumulator( q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@ -304,6 +348,36 @@ class ConditionalAccumulatorTest(test.TestCase):
val = takeg_t.eval() val = takeg_t.eval()
self.assertEqual(elems_ave + 0.0, val) self.assertEqual(elems_ave + 0.0, val)
def testAccumulatorRepeatedTakeGradSum(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32,
name="Q",
shape=tensor_shape.TensorShape([1]),
reduction_type="SUM")
elems = [10.0, 20.0]
elems_sum = 30.0
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
takeg_t = q.take_grad(1)
for accum_op in accum_ops:
accum_op.run()
val = takeg_t.eval()
self.assertEqual(elems_sum, val)
elems = [20.0, 30.0]
elems_sum = 50.0
accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
takeg_t = q.take_grad(1)
for accum_op in accum_ops:
accum_op.run()
val = takeg_t.eval()
self.assertEqual(elems_sum, val)
def testAccumulatorIncrementGlobalStep(self): def testAccumulatorIncrementGlobalStep(self):
with self.test_session(): with self.test_session():
q = data_flow_ops.ConditionalAccumulator( q = data_flow_ops.ConditionalAccumulator(

View File

@ -61,14 +61,22 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q") dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
self.assertProtoEquals(""" self.assertProtoEquals(
"""
name:'Q' op:'SparseConditionalAccumulator' name:'Q' op:'SparseConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { unknown_rank: true} } } attr { key: 'shape' value { shape { unknown_rank: true} } }
attr { key: 'container' value { s: '' } } attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } } attr { key: 'shared_name' value { s: '' } }
attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def) """, q.accumulator_ref.op.node_def)
def testConstructorWithInvalidArg(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError):
data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", reduction_type="Invalid")
def testConstructorWithShape(self): def testConstructorWithShape(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(
@ -76,7 +84,8 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
name="Q", name="Q",
shape=tensor_shape.TensorShape([1, 5, 2, 8])) shape=tensor_shape.TensorShape([1, 5, 2, 8]))
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
self.assertProtoEquals(""" self.assertProtoEquals(
"""
name:'Q' op:'SparseConditionalAccumulator' name:'Q' op:'SparseConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { dim {size: 1 } attr { key: 'shape' value { shape { dim {size: 1 }
@ -86,6 +95,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
} } } } } }
attr { key: 'container' value { s: '' } } attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } } attr { key: 'shared_name' value { s: '' } }
attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def) """, q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self): def testAccumulatorSizeEmpty(self):
@ -164,7 +174,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
result = sess.run(accums[i].take_indexed_slices_grad(1)) result = sess.run(accums[i].take_indexed_slices_grad(1))
self._assertEqual_indexedslices(expected_tensors[i], result) self._assertEqual_indexedslices(expected_tensors[i], result)
def testAccumulatorTakeGrad(self): def testAccumulatorTakeGradMean(self):
with self.test_session() as sess: with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=()) dtypes_lib.float32, name="Q", shape=())
@ -180,9 +190,34 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
takeg_t = q.take_indexed_slices_grad(1) takeg_t = q.take_indexed_slices_grad(1)
val = sess.run(takeg_t) val = sess.run(takeg_t)
self.assertAllEqual(val.indices, [0, 1, 2]) self.assertAllEqual([0, 1, 2], val.indices)
self.assertAllEqual(val.values, [[0.5, 0.5], [0, 2], [3, 0]]) self.assertAllEqual([[0.5, 0.5], [0, 2], [3, 0]], val.values)
self.assertAllEqual(val.dense_shape, [-1, 2]) self.assertAllEqual([-1, 2], val.dense_shape)
def testAccumulatorTakeGradSum(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM")
grad_indexed_slices = ops.IndexedSlices(
indices=[0, 1], values=np.array([[1, 0], [0, 2]]).astype(np.float32))
accum_op = q.apply_indexed_slices_grad(grad_indexed_slices)
accum_op.run()
accum_op = q.apply_grad([0, 2],
np.array([[0, 1], [3, 0]]).astype(np.float32),
[3, 2])
accum_op.run()
takeg_t = q.take_indexed_slices_grad(1)
val = sess.run(takeg_t)
self.assertAllEqual([0, 1, 2], val.indices)
self.assertAllEqual([[1, 1], [0, 2], [3, 0]], val.values)
self.assertAllEqual([-1, 2], val.dense_shape)
def testAccumulatorTakeGradInvalidReductionType(self):
with self.assertRaises(ValueError):
data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid")
def testAccumulatorRepeatedTakeGrad(self): def testAccumulatorRepeatedTakeGrad(self):
with self.test_session() as sess: with self.test_session() as sess:
@ -222,7 +257,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual(val.values, [[5, 5], [0, 20], [30, 0]]) self.assertAllEqual(val.values, [[5, 5], [0, 20], [30, 0]])
self.assertAllEqual(val.dense_shape, [-1, 2]) self.assertAllEqual(val.dense_shape, [-1, 2])
def testParallelApplyGrad(self): def testParallelApplyGradMean(self):
with self.test_session() as sess: with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
@ -253,6 +288,40 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32), np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
val, sess) val, sess)
def testParallelApplyGradSum(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32,
name="Q",
shape=tensor_shape.TensorShape([2, 2]),
reduction_type="SUM")
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
accum_ops = []
for x in elems:
x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32))
accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
takeg_t = q.take_indexed_slices_grad(1)
def apply_indexed_slices_grad(accum_op):
sess.run(accum_op)
threads = [
self.checkedThread(target=apply_indexed_slices_grad, args=(o,))
for o in accum_ops
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
val = sess.run(takeg_t)
expected_val = 550.0
self._assertEqual_nparray(
np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
val, sess)
def testParallelTakeGrad(self): def testParallelTakeGrad(self):
with self.test_session() as sess: with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(

View File

@ -1229,7 +1229,8 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
dtype, dtype,
shape=None, shape=None,
shared_name=None, shared_name=None,
name="conditional_accumulator"): name="conditional_accumulator",
reduction_type="MEAN"):
"""Creates a new ConditionalAccumulator. """Creates a new ConditionalAccumulator.
Args: Args:
@ -1238,9 +1239,14 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
shared_name: Optional. If non-empty, this accumulator will be shared under shared_name: Optional. If non-empty, this accumulator will be shared under
the given name across multiple sessions. the given name across multiple sessions.
name: Optional name for the accumulator. name: Optional name for the accumulator.
reduction_type: Reduction type to use when taking the gradient.
""" """
accumulator_ref = gen_data_flow_ops.conditional_accumulator( accumulator_ref = gen_data_flow_ops.conditional_accumulator(
dtype=dtype, shape=shape, shared_name=shared_name, name=name) dtype=dtype,
shape=shape,
shared_name=shared_name,
name=name,
reduction_type=reduction_type)
super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
def apply_grad(self, grad, local_step=0, name=None): def apply_grad(self, grad, local_step=0, name=None):
@ -1312,15 +1318,21 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
shared_name: Optional. If non-empty, this accumulator will be shared under shared_name: Optional. If non-empty, this accumulator will be shared under
the given name across multiple sessions. the given name across multiple sessions.
name: Optional name for the accumulator. name: Optional name for the accumulator.
reduction_type: Reduction type to use when taking the gradient.
""" """
def __init__(self, def __init__(self,
dtype, dtype,
shape=None, shape=None,
shared_name=None, shared_name=None,
name="sparse_conditional_accumulator"): name="sparse_conditional_accumulator",
reduction_type="MEAN"):
accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator( accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
dtype=dtype, shape=shape, shared_name=shared_name, name=name) dtype=dtype,
shape=shape,
shared_name=shared_name,
name=name,
reduction_type=reduction_type)
super(SparseConditionalAccumulator, self).__init__(dtype, shape, super(SparseConditionalAccumulator, self).__init__(dtype, shape,
accumulator_ref) accumulator_ref)

View File

@ -17,7 +17,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], " argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
} }
member_method { member_method {
name: "apply_grad" name: "apply_grad"

View File

@ -17,7 +17,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], " argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
} }
member_method { member_method {
name: "apply_grad" name: "apply_grad"

View File

@ -17,7 +17,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], " argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
} }
member_method { member_method {
name: "apply_grad" name: "apply_grad"

View File

@ -17,7 +17,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], " argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
} }
member_method { member_method {
name: "apply_grad" name: "apply_grad"