From 9f2714e03b6c8fa58075a18a20aec044f11a5bf1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Jun 2019 15:45:31 -0700 Subject: [PATCH] Make pfor runnable under xla compilation. PiperOrigin-RevId: 251943109 --- tensorflow/python/ops/parallel_for/BUILD | 17 ++++++- .../ops/parallel_for/control_flow_ops.py | 8 +++- tensorflow/python/ops/parallel_for/pfor.py | 6 +-- .../parallel_for/xla_control_flow_ops_test.py | 48 +++++++++++++++++++ 4 files changed, 74 insertions(+), 5 deletions(-) create mode 100644 tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index 71d13774efe..e07e76492c8 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -113,15 +113,30 @@ cuda_py_test( "//tensorflow/python:gradients", "//tensorflow/python:logging_ops", "//tensorflow/python:parsing_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:session", "//tensorflow/python:tensor_array_grad", - "//tensorflow/python:random_ops", "//tensorflow/python:util", ], tags = ["no_rocm"], xla_enable_strict_auto_jit = True, ) +cuda_py_test( + name = "xla_control_flow_ops_test", + srcs = ["xla_control_flow_ops_test.py"], + additional_deps = [ + ":control_flow_ops", + ":test_util", + "//tensorflow/python/compiler/xla:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + ], + tags = ["no_rocm"], + xla_enable_strict_auto_jit = True, + xla_enabled = True, +) + cuda_py_test( name = "array_test", srcs = ["array_test.py"], diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py index 89df51a42e4..a198489f9ac 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -159,7 +159,13 @@ def pfor(loop_fn, iters, parallel_iterations=None): """ def f(): return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations) - if context.executing_eagerly(): + control_flow_context = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + # Note that we wrap into a tf.function if in eager execution mode or under + # XLA compilation. The latter is so that we don't compile operations like + # tf.placeholder that are created by the loop body. + if (context.executing_eagerly() or + (control_flow_context is not None and + control_flow_context.IsXLAContext())): f = function.defun(f) return f() diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index b63ffd80bed..94985f97300 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -2264,10 +2264,10 @@ def _convert_cast(pfor_input): @RegisterPForWithArgs("Xlogy", math_ops.xlogy) @RegisterPForWithArgs("Zeta", math_ops.zeta) def _convert_cwise(pfor_input, op_type, op_func): - # Note that ops handled here do not have attributes except "T" and "Tout", and - # hence don't need extra arguments passed to the cwise_op call below. + # Note that ops handled here do not have attributes except those listed below + # and hence don't need extra arguments passed to the cwise_op call below. for attr in pfor_input.op.node_def.attr.keys(): - assert attr in [u"T", u"Tout"], (op_type, attr) + assert attr in [u"T", u"Tout", u"_xla_compile_id"], (op_type, attr) pfor_input.expanddim_inputs_for_broadcast() return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) diff --git a/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py new file mode 100644 index 00000000000..6d2bad23b94 --- /dev/null +++ b/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py @@ -0,0 +1,48 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""XLA tests for pfor.""" +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.compiler.xla import xla +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops +from tensorflow.python.ops.parallel_for.test_util import PForTestCase +from tensorflow.python.platform import test + + +@test_util.run_all_in_graph_and_eager_modes +class PForTest(PForTestCase): + + def test_xla(self): + + def compute(x): + return math_ops.reduce_mean(x, axis=0, keepdims=True) + + def vectorized_compute(x): + return pfor_control_flow_ops.vectorized_map(compute, x) + + result = xla.compile(vectorized_compute, + inputs=[array_ops.ones((10, 5, 3))]) + self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) + + +if __name__ == "__main__": + test.main()