Make pfor runnable under xla compilation.
PiperOrigin-RevId: 251943109
This commit is contained in:
parent
4f9230e2bd
commit
9f2714e03b
@ -113,15 +113,30 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:logging_ops",
|
"//tensorflow/python:logging_ops",
|
||||||
"//tensorflow/python:parsing_ops",
|
"//tensorflow/python:parsing_ops",
|
||||||
|
"//tensorflow/python:random_ops",
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
"//tensorflow/python:tensor_array_grad",
|
"//tensorflow/python:tensor_array_grad",
|
||||||
"//tensorflow/python:random_ops",
|
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
tags = ["no_rocm"],
|
tags = ["no_rocm"],
|
||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "xla_control_flow_ops_test",
|
||||||
|
srcs = ["xla_control_flow_ops_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":control_flow_ops",
|
||||||
|
":test_util",
|
||||||
|
"//tensorflow/python/compiler/xla:xla",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
],
|
||||||
|
tags = ["no_rocm"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
|
xla_enabled = True,
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "array_test",
|
name = "array_test",
|
||||||
srcs = ["array_test.py"],
|
srcs = ["array_test.py"],
|
||||||
|
@ -159,7 +159,13 @@ def pfor(loop_fn, iters, parallel_iterations=None):
|
|||||||
"""
|
"""
|
||||||
def f():
|
def f():
|
||||||
return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations)
|
return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations)
|
||||||
if context.executing_eagerly():
|
control_flow_context = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
|
||||||
|
# Note that we wrap into a tf.function if in eager execution mode or under
|
||||||
|
# XLA compilation. The latter is so that we don't compile operations like
|
||||||
|
# tf.placeholder that are created by the loop body.
|
||||||
|
if (context.executing_eagerly() or
|
||||||
|
(control_flow_context is not None and
|
||||||
|
control_flow_context.IsXLAContext())):
|
||||||
f = function.defun(f)
|
f = function.defun(f)
|
||||||
return f()
|
return f()
|
||||||
|
|
||||||
|
@ -2264,10 +2264,10 @@ def _convert_cast(pfor_input):
|
|||||||
@RegisterPForWithArgs("Xlogy", math_ops.xlogy)
|
@RegisterPForWithArgs("Xlogy", math_ops.xlogy)
|
||||||
@RegisterPForWithArgs("Zeta", math_ops.zeta)
|
@RegisterPForWithArgs("Zeta", math_ops.zeta)
|
||||||
def _convert_cwise(pfor_input, op_type, op_func):
|
def _convert_cwise(pfor_input, op_type, op_func):
|
||||||
# Note that ops handled here do not have attributes except "T" and "Tout", and
|
# Note that ops handled here do not have attributes except those listed below
|
||||||
# hence don't need extra arguments passed to the cwise_op call below.
|
# and hence don't need extra arguments passed to the cwise_op call below.
|
||||||
for attr in pfor_input.op.node_def.attr.keys():
|
for attr in pfor_input.op.node_def.attr.keys():
|
||||||
assert attr in [u"T", u"Tout"], (op_type, attr)
|
assert attr in [u"T", u"Tout", u"_xla_compile_id"], (op_type, attr)
|
||||||
pfor_input.expanddim_inputs_for_broadcast()
|
pfor_input.expanddim_inputs_for_broadcast()
|
||||||
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
|
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
|
||||||
|
|
||||||
|
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""XLA tests for pfor."""
|
||||||
|
# pylint: disable=g-direct-tensorflow-import
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.compiler.xla import xla
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
|
||||||
|
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class PForTest(PForTestCase):
|
||||||
|
|
||||||
|
def test_xla(self):
|
||||||
|
|
||||||
|
def compute(x):
|
||||||
|
return math_ops.reduce_mean(x, axis=0, keepdims=True)
|
||||||
|
|
||||||
|
def vectorized_compute(x):
|
||||||
|
return pfor_control_flow_ops.vectorized_map(compute, x)
|
||||||
|
|
||||||
|
result = xla.compile(vectorized_compute,
|
||||||
|
inputs=[array_ops.ones((10, 5, 3))])
|
||||||
|
self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
Loading…
Reference in New Issue
Block a user