Add gradient function for XlaClusterOutput
PiperOrigin-RevId: 223049594
This commit is contained in:
parent
891e56199d
commit
654efdeacc
@ -736,7 +736,10 @@ tf_custom_op_py_library(
|
||||
visibility = [
|
||||
":friends",
|
||||
],
|
||||
deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit/ops:xla_ops_grad",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops_wrapper_py",
|
||||
],
|
||||
)
|
||||
|
||||
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
||||
|
@ -195,8 +195,11 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
|
||||
e->dst()->type_string() != kXlaClusterOutput) {
|
||||
return errors::InvalidArgument(
|
||||
"Undeclared output of XLA computation. A common cause of this error "
|
||||
"is variable initializers that depend on the XLA computation. Edge: ",
|
||||
"Undeclared output of XLA computation. Some common causes of this "
|
||||
"error are: 1) variable initializers that depend on the XLA "
|
||||
"computation; 2) gradient computations that depend on the XLA "
|
||||
"computation, which can be mitigated by moving gradient computations "
|
||||
"inside XLA computation. Offending edge: ",
|
||||
e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
|
||||
e->dst_input());
|
||||
}
|
||||
|
@ -18,3 +18,9 @@ tf_gen_op_wrapper_py(
|
||||
out = "xla_ops.py",
|
||||
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "xla_ops_grad",
|
||||
srcs = ["xla_ops_grad.py"],
|
||||
deps = ["//tensorflow/python:framework_ops"],
|
||||
)
|
||||
|
29
tensorflow/compiler/jit/ops/xla_ops_grad.py
Normal file
29
tensorflow/compiler/jit/ops/xla_ops_grad.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Gradients for XLA ops."""
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
@ops.RegisterGradient("XlaClusterOutput")
|
||||
def _XlaClusterOutputGrad(_, grad):
|
||||
del grad # unused
|
||||
raise RuntimeError("Gradient computation of graph in xla.compile() is "
|
||||
"prohibited because it can cause performance degradation."
|
||||
"Please move gradient computation inside xla.compile().")
|
@ -9,6 +9,7 @@ package_group(
|
||||
"//tensorflow/compiler/jit/...",
|
||||
"//tensorflow/compiler/tests/...",
|
||||
"//tensorflow/compiler/tf2xla/...",
|
||||
"//tensorflow/contrib/compiler/...",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -58,6 +58,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:xla_ops_py",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops_grad",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
|
@ -23,6 +23,7 @@ import contextlib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.jit.ops import xla_ops
|
||||
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.framework import ops
|
||||
|
Loading…
Reference in New Issue
Block a user