cond_v2: make FakeParam output tensor with correct size and device.
Allocating tensors of the expected size is necessary for adding them to TensorLists in the case of cond_v2 nested in while_v2. PiperOrigin-RevId: 221637330
This commit is contained in:
parent
a302885171
commit
299469c1eb
@ -526,21 +526,40 @@ REGISTER_KERNEL_BUILDER(Name("For")
|
||||
.HostMemory("delta"),
|
||||
ForOp);
|
||||
|
||||
// FakeParamOp allocates a tensor with a shape conforming to the expected
|
||||
// output. This is necessary if the value will be stored in a while_loop's
|
||||
// TensorList. The output is otherwise not expected to be consumed by anything
|
||||
// else.
|
||||
class FakeParamOp : public OpKernel {
|
||||
public:
|
||||
explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
|
||||
DataType dtype;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
|
||||
|
||||
// Set shape to the specified shape, setting unknown dimensions to empty.
|
||||
// If the specified shape is unknown, leave as an empty shape.
|
||||
TensorShape shape;
|
||||
PartialTensorShape partial_shape;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shape", &partial_shape));
|
||||
if (!partial_shape.unknown_rank()) {
|
||||
for (int64 d : partial_shape.dim_sizes()) {
|
||||
shape.AddDim(d == -1 ? 0 : d);
|
||||
}
|
||||
}
|
||||
|
||||
// Create a persistent tensor that we can repeatedly return to save memory.
|
||||
// TODO(b/119612758): add optimization to prevent sending this across
|
||||
// devices on each Compute() call.
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
dtype, shape, &value_handle_, nullptr));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// We must produce something (only Switch and Recvs are allowed to output
|
||||
// dead tensors). This output is not expected to be consumed by anything.
|
||||
Tensor output_tensor(dtype_, TensorShape({}));
|
||||
context->set_output(0, output_tensor);
|
||||
context->set_output(0, *value_handle_.AccessTensor(context));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
PersistentTensor value_handle_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
|
||||
|
@ -711,6 +711,34 @@ class ControlFlowTest(test.TestCase):
|
||||
self.assertAllEqual(980.0, r.eval(feed_dict={c: 1}))
|
||||
self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
|
||||
|
||||
def testCondGradMultiDevice(self):
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2},
|
||||
allow_soft_placement=True)
|
||||
with self.cached_session(use_gpu=True, config=config) as sess:
|
||||
pred = array_ops.placeholder(dtypes.bool, [])
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
y = array_ops.placeholder(dtypes.float32)
|
||||
|
||||
with ops.device("/cpu:0"):
|
||||
z = control_flow_ops.cond(pred, lambda: x * y * 2.0, lambda: 2.0)
|
||||
|
||||
with ops.device("/cpu:1"):
|
||||
grad = gradients_impl.gradients(z, x)[0]
|
||||
|
||||
self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0)
|
||||
self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
|
||||
|
||||
with ops.device("/cpu:0"):
|
||||
grad_grad = gradients_impl.gradients(grad, x)[0]
|
||||
|
||||
# v1 control flow gets None second derivative for some reason.
|
||||
if not control_flow_ops.ENABLE_COND_V2:
|
||||
self.assertIsNone(grad_grad)
|
||||
return
|
||||
|
||||
self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0)
|
||||
self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
|
||||
|
||||
def testNestedCond_Simple(self):
|
||||
with self.cached_session():
|
||||
x = constant_op.constant(0., name="X")
|
||||
@ -1657,6 +1685,35 @@ class ControlFlowTest(test.TestCase):
|
||||
r = control_flow_ops.while_loop(c, b, [n])
|
||||
self.assertAllEqual(10, r.eval())
|
||||
|
||||
def testWhileCondGradMultiDevice(self):
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2},
|
||||
allow_soft_placement=True)
|
||||
with self.cached_session(use_gpu=True, config=config) as sess:
|
||||
pred = array_ops.placeholder(dtypes.bool, [])
|
||||
x_init = constant_op.constant(1.0)
|
||||
|
||||
with ops.device("/cpu:0"):
|
||||
z = control_flow_ops.while_loop(
|
||||
lambda i, _: i < 3,
|
||||
lambda i, x: (i + 1, control_flow_ops.cond(
|
||||
pred, lambda: x * 2.0, lambda: 10.0)),
|
||||
[0, x_init])
|
||||
|
||||
with ops.device("/cpu:1"):
|
||||
grad = gradients_impl.gradients(z, x_init)[0]
|
||||
|
||||
self.assertEqual(sess.run(grad, {pred: True}), 8.0)
|
||||
self.assertEqual(sess.run(grad, {pred: False}), 0.0)
|
||||
|
||||
if not control_flow_ops.ENABLE_WHILE_V2:
|
||||
return
|
||||
|
||||
with ops.device("/cpu:0"):
|
||||
grad_grad = gradients_impl.gradients(grad, x_init)[0]
|
||||
|
||||
self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
|
||||
self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0)
|
||||
|
||||
# NOTE: It is ok to have parallel_iterations > 1
|
||||
@test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
|
||||
def testWhileUpdateVariable_1(self):
|
||||
|
Loading…
Reference in New Issue
Block a user