Add gradient function for XlaClusterOutput

PiperOrigin-RevId: 223049594
This commit is contained in:
Yanan Cao 2018-11-27 13:40:54 -08:00 committed by TensorFlower Gardener
parent 891e56199d
commit 654efdeacc
7 changed files with 47 additions and 3 deletions

View File

@ -736,7 +736,10 @@ tf_custom_op_py_library(
visibility = [
":friends",
],
deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
deps = [
"//tensorflow/compiler/jit/ops:xla_ops_grad",
"//tensorflow/compiler/jit/ops:xla_ops_wrapper_py",
],
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.

View File

@ -195,8 +195,11 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
e->dst()->type_string() != kXlaClusterOutput) {
return errors::InvalidArgument(
"Undeclared output of XLA computation. A common cause of this error "
"is variable initializers that depend on the XLA computation. Edge: ",
"Undeclared output of XLA computation. Some common causes of this "
"error are: 1) variable initializers that depend on the XLA "
"computation; 2) gradient computations that depend on the XLA "
"computation, which can be mitigated by moving gradient computations "
"inside XLA computation. Offending edge: ",
e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
e->dst_input());
}

View File

@ -18,3 +18,9 @@ tf_gen_op_wrapper_py(
out = "xla_ops.py",
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
)
py_library(
name = "xla_ops_grad",
srcs = ["xla_ops_grad.py"],
deps = ["//tensorflow/python:framework_ops"],
)

View File

@ -0,0 +1,29 @@
"""Gradients for XLA ops."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
@ops.RegisterGradient("XlaClusterOutput")
def _XlaClusterOutputGrad(_, grad):
del grad # unused
raise RuntimeError("Gradient computation of graph in xla.compile() is "
"prohibited because it can cause performance degradation."
"Please move gradient computation inside xla.compile().")

View File

@ -9,6 +9,7 @@ package_group(
"//tensorflow/compiler/jit/...",
"//tensorflow/compiler/tests/...",
"//tensorflow/compiler/tf2xla/...",
"//tensorflow/contrib/compiler/...",
],
)

View File

@ -58,6 +58,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/compiler/jit:xla_ops_py",
"//tensorflow/compiler/jit/ops:xla_ops_grad",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",

View File

@ -23,6 +23,7 @@ import contextlib
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.jit.ops import xla_ops
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops