pfor: add converter for StatefulPartitionedCall and PartitionedCall

PiperOrigin-RevId: 256210960
This commit is contained in:
A. Unique TensorFlower 2019-07-02 11:59:51 -07:00 committed by TensorFlower Gardener
parent 758db76393
commit fa33109764
2 changed files with 142 additions and 4 deletions

View File

@ -30,6 +30,7 @@ from tensorflow.core.example import feature_pb2
from tensorflow.python.client import session
from tensorflow.python.compat import compat
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
@ -1331,5 +1332,90 @@ class ParsingTest(PForTestCase):
self.run_and_assert_equal(pfor, manual)
class PartitionedCallTest(PForTestCase):
def test_simple(self):
@def_function.function
def f(x):
return math_ops.square(x) + 1
z = random_ops.random_uniform([4])
def loop_fn(i):
return f(array_ops.gather(z, i))
self._test_loop_fn(loop_fn, 4)
def test_nested_calls(self):
@def_function.function
def inner(x):
return math_ops.square(x)
@def_function.function
def outer(y):
return math_ops.reduce_sum(inner(y)) + 2
z = random_ops.random_uniform([4, 2])
def loop_fn(i):
return outer(array_ops.gather(z, i))
self._test_loop_fn(loop_fn, 4)
def test_nested_definition(self):
@def_function.function
def outer(y):
@def_function.function
def inner(x):
return math_ops.square(x) + 1
return math_ops.reduce_sum(inner(y)) + 2
z = random_ops.random_uniform([4, 2])
def loop_fn(i):
return outer(array_ops.gather(z, i))
self._test_loop_fn(loop_fn, 4)
def test_gradients(self):
@def_function.function
def f(x):
return math_ops.square(x) + 1
z = random_ops.random_uniform([4, 2])
def loop_fn(i):
z_i = array_ops.gather(z, i)
with backprop.GradientTape() as g:
g.watch(z_i)
out = f(z_i)
return out, g.gradient(out, z_i)
self._test_loop_fn(loop_fn, 4, [dtypes.float32] * 2)
def test_stateful_with_gradients(self):
z = random_ops.random_uniform([4, 2])
v = variables.Variable(z[0])
@def_function.function
def f(x):
return math_ops.square(x) + v + 1
def loop_fn(i):
z_i = array_ops.gather(z, i)
with backprop.GradientTape() as g:
g.watch(z_i)
out = f(z_i)
return out, g.gradient(out, z_i)
self._test_loop_fn(loop_fn, 4, [dtypes.float32] * 2)
if __name__ == "__main__":
test.main()

View File

@ -22,9 +22,11 @@ from __future__ import print_function
import collections
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import execute
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
@ -1190,6 +1192,8 @@ class PFor(object):
return converted_t.t is not t
def _add_conversion(self, old_output, new_output):
assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output
assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output
self._conversion_map[old_output] = new_output
def _convert_helper(self, op_or_tensor):
@ -1308,9 +1312,10 @@ class PFor(object):
# output as unstacked.
new_outputs = [wrap(self._unwrap_or_tile(output), False)]
# None of the inputs and control inputs were converted.
elif (not is_inside_loop or
(not is_stateful and not some_input_converted and
not some_control_input_converted)):
elif ((not is_inside_loop or
(not is_stateful and not some_input_converted and
not some_control_input_converted)) and
y.graph == ops.get_default_graph()):
if y == y_op:
assert not isinstance(y_op, WhileOp)
new_outputs = y_op
@ -1358,6 +1363,8 @@ class PFor(object):
assert isinstance(new_outputs, ops.Operation)
self._add_conversion(y_op, new_outputs)
else:
assert len(y_op.outputs) == len(new_outputs), (
y_op, y_op.outputs, new_outputs)
for old_output, new_output in zip(y_op.outputs, new_outputs):
assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
self._add_conversion(old_output, new_output)
@ -1379,6 +1386,10 @@ class PFor(object):
def pfor_ops(self):
return self._pfor_ops
@property
def pfor_config(self):
return self._pfor_config
@property
def all_indices_partitioned(self):
"""all_indices_partitioned property.
@ -1523,7 +1534,7 @@ def _convert_fused_batch_norm(pfor_input):
y = _unflatten_first_dim(y, n)
mean = pfor_input.unstacked_input(3)
zeros = array_ops.zeros_like(mean)
return [wrap(y, True), wrap(zeros, False), wrap(zeros, False)]
return [wrap(y, True)] + [wrap(zeros, False)] * 5
pfor_input.stack_inputs()
data_format = pfor_input.get_attr("data_format")
@ -3031,3 +3042,44 @@ def _convert_parse_single_example(pfor_input):
sparse_types=sparse_types,
dense_shapes=dense_shapes)
return [wrap(t, True, True) for t in nest.flatten(output)]
# functional_ops
@RegisterPFor("StatefulPartitionedCall")
@RegisterPFor("PartitionedCall")
def _convert_partitioned_call(pfor_input):
func_name = pfor_input.get_attr("f").name
func = pfor_input.op.graph._get_function(compat.as_bytes(func_name))
assert isinstance(func.graph, func_graph.FuncGraph), (
"Could not find FuncGraph object for %s. Got func %s" % (func_name, func))
pfor = pfor_input.pfor
converter = PFor(loop_var=pfor.loop_var,
loop_len=pfor.loop_len_vector[0],
pfor_ops=func.graph.get_operations(),
all_indices=pfor.all_indices,
all_indices_partitioned=pfor.all_indices_partitioned,
pfor_config=pfor.pfor_config)
# TODO(agarwal): consider caching this function definition.
@def_function.function
def f(*args):
assert all(isinstance(arg, WrappedTensor) for arg in args), args
assert len(args) == len(func.graph.inputs), (args, func.graph.inputs)
# Map inputs to function arguments.
for inp, arg in zip(func.graph.inputs, args):
converter._add_conversion(inp, arg)
# Convert output tensors.
return tuple([converter._convert_helper(x).t
for x in func._func_graph_outputs])
call_outputs = f(*pfor_input.inputs)
assert len(call_outputs) == len(func._func_graph_outputs)
outputs = []
for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs):
func_output = converter._convert_helper(output_tensor)
outputs.append(wrap(call_output,
func_output.is_stacked,
func_output.is_sparse_stacked))
return outputs