cond_v2: make FakeParam output tensor with correct size and device.

Allocating tensors of the expected size is necessary for adding them
to TensorLists in the case of cond_v2 nested in while_v2.

PiperOrigin-RevId: 221637330
This commit is contained in:
Skye Wanderman-Milne 2018-11-15 09:34:56 -08:00 committed by TensorFlower Gardener
parent a302885171
commit 299469c1eb
2 changed files with 82 additions and 6 deletions

View File

@ -526,21 +526,40 @@ REGISTER_KERNEL_BUILDER(Name("For")
.HostMemory("delta"),
ForOp);
// FakeParamOp allocates a tensor with a shape conforming to the expected
// output. This is necessary if the value will be stored in a while_loop's
// TensorList. The output is otherwise not expected to be consumed by anything
// else.
class FakeParamOp : public OpKernel {
public:
explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
DataType dtype;
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
// Set shape to the specified shape, setting unknown dimensions to empty.
// If the specified shape is unknown, leave as an empty shape.
TensorShape shape;
PartialTensorShape partial_shape;
OP_REQUIRES_OK(context, context->GetAttr("shape", &partial_shape));
if (!partial_shape.unknown_rank()) {
for (int64 d : partial_shape.dim_sizes()) {
shape.AddDim(d == -1 ? 0 : d);
}
}
// Create a persistent tensor that we can repeatedly return to save memory.
// TODO(b/119612758): add optimization to prevent sending this across
// devices on each Compute() call.
OP_REQUIRES_OK(context, context->allocate_persistent(
dtype, shape, &value_handle_, nullptr));
}
void Compute(OpKernelContext* context) override {
// We must produce something (only Switch and Recvs are allowed to output
// dead tensors). This output is not expected to be consumed by anything.
Tensor output_tensor(dtype_, TensorShape({}));
context->set_output(0, output_tensor);
context->set_output(0, *value_handle_.AccessTensor(context));
}
private:
DataType dtype_;
PersistentTensor value_handle_;
};
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);

View File

@ -711,6 +711,34 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(980.0, r.eval(feed_dict={c: 1}))
self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
def testCondGradMultiDevice(self):
config = config_pb2.ConfigProto(device_count={"CPU": 2},
allow_soft_placement=True)
with self.cached_session(use_gpu=True, config=config) as sess:
pred = array_ops.placeholder(dtypes.bool, [])
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
with ops.device("/cpu:0"):
z = control_flow_ops.cond(pred, lambda: x * y * 2.0, lambda: 2.0)
with ops.device("/cpu:1"):
grad = gradients_impl.gradients(z, x)[0]
self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0)
self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
with ops.device("/cpu:0"):
grad_grad = gradients_impl.gradients(grad, x)[0]
# v1 control flow gets None second derivative for some reason.
if not control_flow_ops.ENABLE_COND_V2:
self.assertIsNone(grad_grad)
return
self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0)
self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
def testNestedCond_Simple(self):
with self.cached_session():
x = constant_op.constant(0., name="X")
@ -1657,6 +1685,35 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertAllEqual(10, r.eval())
def testWhileCondGradMultiDevice(self):
config = config_pb2.ConfigProto(device_count={"CPU": 2},
allow_soft_placement=True)
with self.cached_session(use_gpu=True, config=config) as sess:
pred = array_ops.placeholder(dtypes.bool, [])
x_init = constant_op.constant(1.0)
with ops.device("/cpu:0"):
z = control_flow_ops.while_loop(
lambda i, _: i < 3,
lambda i, x: (i + 1, control_flow_ops.cond(
pred, lambda: x * 2.0, lambda: 10.0)),
[0, x_init])
with ops.device("/cpu:1"):
grad = gradients_impl.gradients(z, x_init)[0]
self.assertEqual(sess.run(grad, {pred: True}), 8.0)
self.assertEqual(sess.run(grad, {pred: False}), 0.0)
if not control_flow_ops.ENABLE_WHILE_V2:
return
with ops.device("/cpu:0"):
grad_grad = gradients_impl.gradients(grad, x_init)[0]
self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0)
# NOTE: It is ok to have parallel_iterations > 1
@test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_1(self):