pfor: add converter for StatefulPartitionedCall and PartitionedCall
PiperOrigin-RevId: 256210960
This commit is contained in:
parent
758db76393
commit
fa33109764
@ -30,6 +30,7 @@ from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
@ -1331,5 +1332,90 @@ class ParsingTest(PForTestCase):
|
||||
self.run_and_assert_equal(pfor, manual)
|
||||
|
||||
|
||||
class PartitionedCallTest(PForTestCase):
|
||||
|
||||
def test_simple(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
return math_ops.square(x) + 1
|
||||
|
||||
z = random_ops.random_uniform([4])
|
||||
|
||||
def loop_fn(i):
|
||||
return f(array_ops.gather(z, i))
|
||||
|
||||
self._test_loop_fn(loop_fn, 4)
|
||||
|
||||
def test_nested_calls(self):
|
||||
|
||||
@def_function.function
|
||||
def inner(x):
|
||||
return math_ops.square(x)
|
||||
|
||||
@def_function.function
|
||||
def outer(y):
|
||||
return math_ops.reduce_sum(inner(y)) + 2
|
||||
|
||||
z = random_ops.random_uniform([4, 2])
|
||||
|
||||
def loop_fn(i):
|
||||
return outer(array_ops.gather(z, i))
|
||||
|
||||
self._test_loop_fn(loop_fn, 4)
|
||||
|
||||
def test_nested_definition(self):
|
||||
|
||||
@def_function.function
|
||||
def outer(y):
|
||||
@def_function.function
|
||||
def inner(x):
|
||||
return math_ops.square(x) + 1
|
||||
|
||||
return math_ops.reduce_sum(inner(y)) + 2
|
||||
|
||||
z = random_ops.random_uniform([4, 2])
|
||||
|
||||
def loop_fn(i):
|
||||
return outer(array_ops.gather(z, i))
|
||||
|
||||
self._test_loop_fn(loop_fn, 4)
|
||||
|
||||
def test_gradients(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
return math_ops.square(x) + 1
|
||||
|
||||
z = random_ops.random_uniform([4, 2])
|
||||
|
||||
def loop_fn(i):
|
||||
z_i = array_ops.gather(z, i)
|
||||
with backprop.GradientTape() as g:
|
||||
g.watch(z_i)
|
||||
out = f(z_i)
|
||||
return out, g.gradient(out, z_i)
|
||||
|
||||
self._test_loop_fn(loop_fn, 4, [dtypes.float32] * 2)
|
||||
|
||||
def test_stateful_with_gradients(self):
|
||||
|
||||
z = random_ops.random_uniform([4, 2])
|
||||
v = variables.Variable(z[0])
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
return math_ops.square(x) + v + 1
|
||||
|
||||
def loop_fn(i):
|
||||
z_i = array_ops.gather(z, i)
|
||||
with backprop.GradientTape() as g:
|
||||
g.watch(z_i)
|
||||
out = f(z_i)
|
||||
return out, g.gradient(out, z_i)
|
||||
|
||||
self._test_loop_fn(loop_fn, 4, [dtypes.float32] * 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -22,9 +22,11 @@ from __future__ import print_function
|
||||
import collections
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -1190,6 +1192,8 @@ class PFor(object):
|
||||
return converted_t.t is not t
|
||||
|
||||
def _add_conversion(self, old_output, new_output):
|
||||
assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output
|
||||
assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output
|
||||
self._conversion_map[old_output] = new_output
|
||||
|
||||
def _convert_helper(self, op_or_tensor):
|
||||
@ -1308,9 +1312,10 @@ class PFor(object):
|
||||
# output as unstacked.
|
||||
new_outputs = [wrap(self._unwrap_or_tile(output), False)]
|
||||
# None of the inputs and control inputs were converted.
|
||||
elif (not is_inside_loop or
|
||||
(not is_stateful and not some_input_converted and
|
||||
not some_control_input_converted)):
|
||||
elif ((not is_inside_loop or
|
||||
(not is_stateful and not some_input_converted and
|
||||
not some_control_input_converted)) and
|
||||
y.graph == ops.get_default_graph()):
|
||||
if y == y_op:
|
||||
assert not isinstance(y_op, WhileOp)
|
||||
new_outputs = y_op
|
||||
@ -1358,6 +1363,8 @@ class PFor(object):
|
||||
assert isinstance(new_outputs, ops.Operation)
|
||||
self._add_conversion(y_op, new_outputs)
|
||||
else:
|
||||
assert len(y_op.outputs) == len(new_outputs), (
|
||||
y_op, y_op.outputs, new_outputs)
|
||||
for old_output, new_output in zip(y_op.outputs, new_outputs):
|
||||
assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
|
||||
self._add_conversion(old_output, new_output)
|
||||
@ -1379,6 +1386,10 @@ class PFor(object):
|
||||
def pfor_ops(self):
|
||||
return self._pfor_ops
|
||||
|
||||
@property
|
||||
def pfor_config(self):
|
||||
return self._pfor_config
|
||||
|
||||
@property
|
||||
def all_indices_partitioned(self):
|
||||
"""all_indices_partitioned property.
|
||||
@ -1523,7 +1534,7 @@ def _convert_fused_batch_norm(pfor_input):
|
||||
y = _unflatten_first_dim(y, n)
|
||||
mean = pfor_input.unstacked_input(3)
|
||||
zeros = array_ops.zeros_like(mean)
|
||||
return [wrap(y, True), wrap(zeros, False), wrap(zeros, False)]
|
||||
return [wrap(y, True)] + [wrap(zeros, False)] * 5
|
||||
|
||||
pfor_input.stack_inputs()
|
||||
data_format = pfor_input.get_attr("data_format")
|
||||
@ -3031,3 +3042,44 @@ def _convert_parse_single_example(pfor_input):
|
||||
sparse_types=sparse_types,
|
||||
dense_shapes=dense_shapes)
|
||||
return [wrap(t, True, True) for t in nest.flatten(output)]
|
||||
|
||||
|
||||
# functional_ops
|
||||
|
||||
|
||||
@RegisterPFor("StatefulPartitionedCall")
|
||||
@RegisterPFor("PartitionedCall")
|
||||
def _convert_partitioned_call(pfor_input):
|
||||
func_name = pfor_input.get_attr("f").name
|
||||
func = pfor_input.op.graph._get_function(compat.as_bytes(func_name))
|
||||
assert isinstance(func.graph, func_graph.FuncGraph), (
|
||||
"Could not find FuncGraph object for %s. Got func %s" % (func_name, func))
|
||||
pfor = pfor_input.pfor
|
||||
converter = PFor(loop_var=pfor.loop_var,
|
||||
loop_len=pfor.loop_len_vector[0],
|
||||
pfor_ops=func.graph.get_operations(),
|
||||
all_indices=pfor.all_indices,
|
||||
all_indices_partitioned=pfor.all_indices_partitioned,
|
||||
pfor_config=pfor.pfor_config)
|
||||
|
||||
# TODO(agarwal): consider caching this function definition.
|
||||
@def_function.function
|
||||
def f(*args):
|
||||
assert all(isinstance(arg, WrappedTensor) for arg in args), args
|
||||
assert len(args) == len(func.graph.inputs), (args, func.graph.inputs)
|
||||
# Map inputs to function arguments.
|
||||
for inp, arg in zip(func.graph.inputs, args):
|
||||
converter._add_conversion(inp, arg)
|
||||
# Convert output tensors.
|
||||
return tuple([converter._convert_helper(x).t
|
||||
for x in func._func_graph_outputs])
|
||||
|
||||
call_outputs = f(*pfor_input.inputs)
|
||||
assert len(call_outputs) == len(func._func_graph_outputs)
|
||||
outputs = []
|
||||
for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs):
|
||||
func_output = converter._convert_helper(output_tensor)
|
||||
outputs.append(wrap(call_output,
|
||||
func_output.is_stacked,
|
||||
func_output.is_sparse_stacked))
|
||||
return outputs
|
||||
|
Loading…
x
Reference in New Issue
Block a user