Additional colocation options and bugfixes for TensorArray
* colocate_with is now set properly when a TensorArray is passed through a while_loop * added a new argument, "colocate_with_first_write" (default: True; this is the current behavior). If False, the TensorArray is simply placed on the device from the context it's constructed in, and no colocation constraints are added. PiperOrigin-RevId: 157643133
This commit is contained in:
parent
03fc7022b4
commit
ccdb30763a
@ -53,6 +53,16 @@ def _make_converter(tf_dtype):
|
||||
|
||||
class TensorArrayTest(test.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TensorArrayTest, cls).setUpClass()
|
||||
cls._workers, _ = test.create_local_cluster(num_workers=3, num_ps=0)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(TensorArrayTest, cls).tearDownClass()
|
||||
session_lib.Session.reset(cls._workers[0].target)
|
||||
|
||||
def testTensorArrayWriteRead(self):
|
||||
with self.test_session(use_gpu=True) as session:
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
@ -1225,8 +1235,7 @@ class TensorArrayTest(test.TestCase):
|
||||
ta = ta.split([1.0, 2.0], [1, 1])
|
||||
flows.append(ta.flow)
|
||||
|
||||
workers, _ = test.create_local_cluster(num_workers=3, num_ps=0)
|
||||
session = session_lib.Session(workers[0].target)
|
||||
session = session_lib.Session(self._workers[0].target)
|
||||
|
||||
run_options = config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||
@ -1250,13 +1259,12 @@ class TensorArrayTest(test.TestCase):
|
||||
|
||||
def _body(i, ta_i):
|
||||
with ops.device("/job:worker/task:1/cpu:0"):
|
||||
return i + 1, ta_i.write(i, 0.0)
|
||||
return i + 1, ta_i.write(i, constant_op.constant(0.0))
|
||||
|
||||
_, ta_out = control_flow_ops.while_loop(
|
||||
lambda i, ta: i < 2, _body, loop_vars=[0, ta])
|
||||
|
||||
workers, _ = test.create_local_cluster(num_workers=3, num_ps=0)
|
||||
session = session_lib.Session(workers[0].target)
|
||||
session = session_lib.Session(self._workers[0].target)
|
||||
|
||||
run_options = config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||
@ -1274,6 +1282,36 @@ class TensorArrayTest(test.TestCase):
|
||||
self.assertFalse(
|
||||
[s for s in dev_stats[d] if "/TensorArray" in s.node_name])
|
||||
|
||||
def testTensorArrayDisabledColocateWithFirstWriteCall(self):
|
||||
with ops.device("/job:worker/task:0/cpu:0"):
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, size=2, colocate_with_first_write_call=False)
|
||||
|
||||
def _body(i, ta_i):
|
||||
with ops.device("/job:worker/task:1/cpu:0"):
|
||||
return i + 1, ta_i.write(i, constant_op.constant(0.0))
|
||||
|
||||
_, ta_out = control_flow_ops.while_loop(
|
||||
lambda i, ta: i < 2, _body, loop_vars=[0, ta])
|
||||
|
||||
session = session_lib.Session(self._workers[0].target)
|
||||
|
||||
run_options = config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
|
||||
session.run(ta_out.flow, options=run_options, run_metadata=run_metadata)
|
||||
self.assertTrue(run_metadata.HasField("step_stats"))
|
||||
dev_stats = {d.device: list(d.node_stats)
|
||||
for d in run_metadata.step_stats.dev_stats}
|
||||
for d in dev_stats:
|
||||
if "/task:0/" in d and "cpu" in d: # Skip any GPU node stats
|
||||
self.assertTrue(
|
||||
[s for s in dev_stats[d] if "/TensorArray" in s.node_name])
|
||||
else:
|
||||
self.assertFalse(
|
||||
[s for s in dev_stats[d] if "/TensorArray" in s.node_name])
|
||||
|
||||
def testTensorArrayIdentity(self):
|
||||
with self.test_session(use_gpu=True) as session:
|
||||
ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2,
|
||||
|
@ -437,10 +437,14 @@ def _convert_tensorarray_to_flow(tensor_or_tensor_array):
|
||||
|
||||
|
||||
def _make_tensor_array(ta, t_or_flow):
|
||||
# pylint: disable=protected-access
|
||||
new_ta = tensor_array_ops.TensorArray(
|
||||
dtype=ta.dtype, handle=ta.handle, flow=t_or_flow,
|
||||
infer_shape=ta._infer_shape)
|
||||
new_ta._element_shape = ta._element_shape # pylint: disable=protected-access
|
||||
infer_shape=ta._infer_shape,
|
||||
colocate_with_first_write_call=ta._colocate_with_first_write_call)
|
||||
new_ta._colocate_with = ta._colocate_with
|
||||
new_ta._element_shape = ta._element_shape
|
||||
# pylint: enable=protected-access
|
||||
return new_ta
|
||||
|
||||
|
||||
|
@ -99,9 +99,9 @@ def _TensorArrayReadGrad(op, grad):
|
||||
flow = op.inputs[2]
|
||||
dtype = op.get_attr("dtype")
|
||||
grad_source = _GetGradSource(grad)
|
||||
g = tensor_array_ops.TensorArray(
|
||||
dtype=dtype, handle=handle, flow=flow).grad(
|
||||
source=grad_source, flow=flow)
|
||||
g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow,
|
||||
colocate_with_first_write_call=False)
|
||||
.grad(source=grad_source, flow=flow))
|
||||
w_g = g.write(index, grad)
|
||||
return [None, None, w_g.flow]
|
||||
|
||||
@ -125,9 +125,9 @@ def _TensorArrayWriteGrad(op, flow):
|
||||
index = op.inputs[1]
|
||||
dtype = op.get_attr("T")
|
||||
grad_source = _GetGradSource(flow)
|
||||
g = tensor_array_ops.TensorArray(
|
||||
dtype=dtype, handle=handle, flow=flow).grad(
|
||||
source=grad_source, flow=flow)
|
||||
g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow,
|
||||
colocate_with_first_write_call=False)
|
||||
.grad(source=grad_source, flow=flow))
|
||||
grad = g.read(index)
|
||||
return [None, None, grad, flow]
|
||||
|
||||
@ -156,9 +156,9 @@ def _TensorArrayGatherGrad(op, grad):
|
||||
flow = op.inputs[2]
|
||||
dtype = op.get_attr("dtype")
|
||||
grad_source = _GetGradSource(grad)
|
||||
g = tensor_array_ops.TensorArray(
|
||||
dtype=dtype, handle=handle, flow=flow).grad(
|
||||
source=grad_source, flow=flow)
|
||||
g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow,
|
||||
colocate_with_first_write_call=False)
|
||||
.grad(source=grad_source, flow=flow))
|
||||
u_g = g.scatter(indices, grad)
|
||||
return [None, None, u_g.flow]
|
||||
|
||||
@ -180,9 +180,9 @@ def _TensorArrayScatterGrad(op, flow):
|
||||
indices = op.inputs[1]
|
||||
dtype = op.get_attr("T")
|
||||
grad_source = _GetGradSource(flow)
|
||||
g = tensor_array_ops.TensorArray(
|
||||
dtype=dtype, handle=handle, flow=flow).grad(
|
||||
source=grad_source, flow=flow)
|
||||
g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow,
|
||||
colocate_with_first_write_call=False)
|
||||
.grad(source=grad_source, flow=flow))
|
||||
grad = g.gather(indices)
|
||||
return [None, None, grad, flow]
|
||||
|
||||
@ -211,9 +211,9 @@ def _TensorArrayConcatGrad(op, grad, unused_lengths_grad):
|
||||
lengths = op.outputs[1]
|
||||
dtype = op.get_attr("dtype")
|
||||
grad_source = _GetGradSource(grad)
|
||||
g = tensor_array_ops.TensorArray(
|
||||
dtype=dtype, handle=handle, flow=flow).grad(
|
||||
source=grad_source, flow=flow)
|
||||
g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow,
|
||||
colocate_with_first_write_call=False)
|
||||
.grad(source=grad_source, flow=flow))
|
||||
u_g = g.split(grad, lengths=lengths)
|
||||
# handle, flow_in
|
||||
return [None, u_g.flow]
|
||||
@ -235,9 +235,9 @@ def _TensorArraySplitGrad(op, flow):
|
||||
handle = op.inputs[0]
|
||||
dtype = op.get_attr("T")
|
||||
grad_source = _GetGradSource(flow)
|
||||
g = tensor_array_ops.TensorArray(
|
||||
dtype=dtype, handle=handle, flow=flow).grad(
|
||||
source=grad_source, flow=flow)
|
||||
g = (tensor_array_ops.TensorArray(dtype=dtype, handle=handle, flow=flow,
|
||||
colocate_with_first_write_call=False)
|
||||
.grad(source=grad_source, flow=flow))
|
||||
grad = g.concat()
|
||||
# handle, value, lengths, flow_in
|
||||
return [None, grad, None, flow]
|
||||
|
@ -54,6 +54,7 @@ class TensorArray(object):
|
||||
flow=None,
|
||||
infer_shape=True,
|
||||
element_shape=None,
|
||||
colocate_with_first_write_call=True,
|
||||
name=None):
|
||||
"""Construct a new TensorArray or wrap an existing TensorArray handle.
|
||||
|
||||
@ -85,6 +86,11 @@ class TensorArray(object):
|
||||
element_shape: (optional, default: None) A `TensorShape` object specifying
|
||||
the shape constraints of each of the elements of the TensorArray.
|
||||
Need not be fully defined.
|
||||
colocate_with_first_write_call: If `True`, the TensorArray will be
|
||||
colocated on the same device as the the Tensor used on its first write
|
||||
(write operations include `write`, `unstack`, and `split`). If `False`,
|
||||
the TensorArray will be placed on the device determined by the
|
||||
device context available during its initialization.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Raises:
|
||||
@ -120,7 +126,11 @@ class TensorArray(object):
|
||||
# Used to keep track of what tensors the TensorArray should be
|
||||
# colocated with. We choose to colocate the TensorArray with the
|
||||
# first tensor written to it.
|
||||
self._colocate_with = []
|
||||
self._colocate_with_first_write_call = colocate_with_first_write_call
|
||||
if colocate_with_first_write_call:
|
||||
self._colocate_with = []
|
||||
else:
|
||||
self._colocate_with = None
|
||||
|
||||
# Record the current static shape for the array elements. The element
|
||||
# shape is defined either by `element_shape` or the shape of the tensor
|
||||
@ -142,8 +152,8 @@ class TensorArray(object):
|
||||
# Construct the TensorArray with an empty device. The first
|
||||
# write into the TensorArray from a Tensor with a set device
|
||||
# will retroactively set the device value of this op.
|
||||
with ops.device(None), ops.colocate_with(None, ignore_existing=True):
|
||||
self._handle, self._flow = gen_data_flow_ops._tensor_array_v3(
|
||||
def create():
|
||||
return gen_data_flow_ops._tensor_array_v3(
|
||||
dtype=dtype,
|
||||
size=size,
|
||||
element_shape=element_shape,
|
||||
@ -151,6 +161,11 @@ class TensorArray(object):
|
||||
clear_after_read=clear_after_read,
|
||||
tensor_array_name=tensor_array_name,
|
||||
name=scope)
|
||||
if colocate_with_first_write_call:
|
||||
with ops.device(None), ops.colocate_with(None, ignore_existing=True):
|
||||
self._handle, self._flow = create()
|
||||
else:
|
||||
self._handle, self._flow = create()
|
||||
|
||||
@property
|
||||
def flow(self):
|
||||
@ -200,10 +215,13 @@ class TensorArray(object):
|
||||
If no internal colocation group is set, colocate with `value` and set
|
||||
the internal colocation group to be value.
|
||||
"""
|
||||
if not self._colocate_with:
|
||||
self._colocate_with.append(value)
|
||||
with ops.colocate_with(self._colocate_with[0]):
|
||||
if not self._colocate_with_first_write_call:
|
||||
yield
|
||||
else:
|
||||
if not self._colocate_with:
|
||||
self._colocate_with.append(value)
|
||||
with ops.colocate_with(self._colocate_with[0]):
|
||||
yield
|
||||
|
||||
def identity(self):
|
||||
"""Returns a TensorArray with the same content and properties.
|
||||
@ -214,8 +232,10 @@ class TensorArray(object):
|
||||
Use this object all for subsequent operations.
|
||||
"""
|
||||
flow = array_ops.identity(self._flow)
|
||||
ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow,
|
||||
infer_shape=self._infer_shape)
|
||||
ta = TensorArray(
|
||||
dtype=self._dtype, handle=self._handle, flow=flow,
|
||||
infer_shape=self._infer_shape,
|
||||
colocate_with_first_write_call=self._colocate_with_first_write_call)
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
return ta
|
||||
@ -237,7 +257,8 @@ class TensorArray(object):
|
||||
dtype=self._dtype,
|
||||
handle=g_handle,
|
||||
flow=flow,
|
||||
infer_shape=self._infer_shape)
|
||||
infer_shape=self._infer_shape,
|
||||
colocate_with_first_write_call=False)
|
||||
g._element_shape = self._element_shape
|
||||
return g
|
||||
|
||||
@ -286,7 +307,9 @@ class TensorArray(object):
|
||||
value=value,
|
||||
flow_in=self._flow,
|
||||
name=name)
|
||||
ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow_out)
|
||||
ta = TensorArray(
|
||||
dtype=self._dtype, handle=self._handle, flow=flow_out,
|
||||
colocate_with_first_write_call=self._colocate_with_first_write_call)
|
||||
ta._infer_shape = self._infer_shape
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
@ -416,7 +439,9 @@ class TensorArray(object):
|
||||
value=value,
|
||||
flow_in=self._flow,
|
||||
name=name)
|
||||
ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow_out)
|
||||
ta = TensorArray(
|
||||
dtype=self._dtype, handle=self._handle, flow=flow_out,
|
||||
colocate_with_first_write_call=self._colocate_with_first_write_call)
|
||||
ta._infer_shape = self._infer_shape
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
@ -456,7 +481,9 @@ class TensorArray(object):
|
||||
lengths=lengths_64,
|
||||
flow_in=self._flow,
|
||||
name=name)
|
||||
ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow_out)
|
||||
ta = TensorArray(
|
||||
dtype=self._dtype, handle=self._handle, flow=flow_out,
|
||||
colocate_with_first_write_call=self._colocate_with_first_write_call)
|
||||
ta._infer_shape = self._infer_shape
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
|
@ -16,7 +16,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'size\', \'dynamic_size\', \'clear_after_read\', \'tensor_array_name\', \'handle\', \'flow\', \'infer_shape\', \'element_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'dtype\', \'size\', \'dynamic_size\', \'clear_after_read\', \'tensor_array_name\', \'handle\', \'flow\', \'infer_shape\', \'element_shape\', \'colocate_with_first_write_call\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "close"
|
||||
|
Loading…
Reference in New Issue
Block a user