Fix experimental_compile=True for graph mode
Previously the attribute only worked in eager mode, and was a no-op otherwise. Note that this also resolves the problem of a function with experimental_compile=True not being compiled when called from experimental_compile=False context. PiperOrigin-RevId: 286682281 Change-Id: Ifbc6efa2c82ae13f5d124ec6aaf440e1639a42c3
This commit is contained in:
parent
8a7097eb9b
commit
f3c0a3e0e4
@ -21,7 +21,7 @@ namespace tensorflow {
|
|||||||
|
|
||||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||||
const NodeDef& node_def) const {
|
const NodeDef& node_def) const {
|
||||||
return CanCreateXlaKernel(flr, node_def);
|
return CanCreateXlaKernel(node_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
||||||
|
@ -98,12 +98,14 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
|||||||
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||||
Init({fdef});
|
Init({fdef});
|
||||||
XlaKernelCreator xla_kernel_creator;
|
XlaKernelCreator xla_kernel_creator;
|
||||||
|
NodeDef callsite =
|
||||||
Status status = xla_kernel_creator.CreateKernel(
|
ToNodeDef(R"pb(
|
||||||
flr_, ToNodeDef(R"pb(
|
|
||||||
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
||||||
)pb"),
|
)pb");
|
||||||
&kernel_);
|
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||||
|
|
||||||
|
// Note: need to set attribute on the created node.
|
||||||
|
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
|
||||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||||
|
|
||||||
EXPECT_EQ("XTimesY", kernel_->name());
|
EXPECT_EQ("XTimesY", kernel_->name());
|
||||||
|
@ -23,7 +23,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
@ -68,40 +70,10 @@ class SinglePassSearch {
|
|||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
bool CanCreateXlaKernel(const NodeDef& node_def) {
|
||||||
const NodeDef& node_def) {
|
|
||||||
const FunctionDef* function_def =
|
|
||||||
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
|
|
||||||
if (function_def == nullptr) {
|
|
||||||
// The node def is not calling a function. Individual ops can be
|
|
||||||
// run directly using on-demand mode, no need to create XlaLaunch
|
|
||||||
// kernel for them.
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
||||||
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
||||||
if (it != node_def.attr().end()) {
|
return it != node_def.attr().end() && it->second.b();
|
||||||
return it->second.b();
|
|
||||||
}
|
|
||||||
|
|
||||||
// kXlaMustCompileAttr is not set on node_def, check if it is set on
|
|
||||||
// FunctionDef.
|
|
||||||
bool xla_compile = false;
|
|
||||||
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
|
|
||||||
node_def, kXlaMustCompileAttr, &xla_compile);
|
|
||||||
if (!status.ok() || !xla_compile) {
|
|
||||||
if (VLOG_IS_ON(3)) {
|
|
||||||
if (!status.ok()) {
|
|
||||||
VLOG(3) << "No " << kXlaMustCompileAttr << " attr defined for "
|
|
||||||
<< node_def.op() << ". status=" << status.ToString();
|
|
||||||
} else {
|
|
||||||
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||||
@ -118,8 +90,11 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
|||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
// If node_def is not instantiable, e.g., the function does not exist,
|
// If node_def is not instantiable, e.g., the function does not exist,
|
||||||
// simply bail out.
|
// simply bail out.
|
||||||
|
NameAttrList function;
|
||||||
|
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
|
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
||||||
*fbody = flr->GetFunctionBody(handle);
|
*fbody = flr->GetFunctionBody(handle);
|
||||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||||
@ -149,7 +124,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
|||||||
|
|
||||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||||
std::unique_ptr<OpKernel>* kernel) {
|
std::unique_ptr<OpKernel>* kernel) {
|
||||||
if (!CanCreateXlaKernel(*flr, node_def)) {
|
if (!CanCreateXlaKernel(node_def)) {
|
||||||
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
|
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,9 +216,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
|||||||
|
|
||||||
// Create the kernel.
|
// Create the kernel.
|
||||||
NameAttrList function;
|
NameAttrList function;
|
||||||
function.set_name(node_def.op());
|
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||||
*(function.mutable_attr()) = node_def.attr();
|
|
||||||
|
|
||||||
Device* dev = flr->device();
|
Device* dev = flr->device();
|
||||||
Status s;
|
Status s;
|
||||||
OpKernelConstruction construction(
|
OpKernelConstruction construction(
|
||||||
|
@ -24,11 +24,9 @@ namespace tensorflow {
|
|||||||
class FunctionLibraryRuntime;
|
class FunctionLibraryRuntime;
|
||||||
class OpKernel;
|
class OpKernel;
|
||||||
|
|
||||||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
|
||||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
// set.
|
||||||
// with the kXlaCompileAttr set.
|
bool CanCreateXlaKernel(const NodeDef& node_def);
|
||||||
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
|
||||||
const NodeDef& node_def);
|
|
||||||
|
|
||||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||||
|
@ -2533,6 +2533,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"//tensorflow/core/util:port",
|
"//tensorflow/core/util:port",
|
||||||
"//tensorflow/core/util:stats_calculator_portable",
|
"//tensorflow/core/util:stats_calculator_portable",
|
||||||
|
"//tensorflow/compiler/jit:common",
|
||||||
] + if_static(
|
] + if_static(
|
||||||
extra_deps = ["@com_google_protobuf//:protobuf"],
|
extra_deps = ["@com_google_protobuf//:protobuf"],
|
||||||
otherwise = ["@com_google_protobuf//:protobuf_headers"],
|
otherwise = ["@com_google_protobuf//:protobuf_headers"],
|
||||||
|
@ -1313,10 +1313,12 @@ Status DirectSession::CreateExecutors(
|
|||||||
options_.config.experimental().has_session_metadata()
|
options_.config.experimental().has_session_metadata()
|
||||||
? &options_.config.experimental().session_metadata()
|
? &options_.config.experimental().session_metadata()
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
const CustomKernelCreator* custom_kernel_creator =
|
||||||
|
GetDefaultCustomKernelCreator();
|
||||||
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
|
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
|
||||||
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
|
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
|
||||||
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
|
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
|
||||||
nullptr, nullptr, session_metadata));
|
nullptr, custom_kernel_creator, session_metadata));
|
||||||
|
|
||||||
GraphOptimizer optimizer(optimizer_opts);
|
GraphOptimizer optimizer(optimizer_opts);
|
||||||
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
|
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
@ -97,6 +99,11 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
|||||||
inp_mtypes->clear();
|
inp_mtypes->clear();
|
||||||
out_mtypes->clear();
|
out_mtypes->clear();
|
||||||
|
|
||||||
|
bool has_xla_compile = [&] {
|
||||||
|
const auto& it = ndef.attr().find(kXlaMustCompileAttr);
|
||||||
|
return it != ndef.attr().end() && it->second.b();
|
||||||
|
}();
|
||||||
|
|
||||||
// For functions (which have no KernelDef) and their gradients, we can only
|
// For functions (which have no KernelDef) and their gradients, we can only
|
||||||
// best-effort derive the memory type from the data type. For now, we assume
|
// best-effort derive the memory type from the data type. For now, we assume
|
||||||
// int32 is always on host memory and other types are always on device memory.
|
// int32 is always on host memory and other types are always on device memory.
|
||||||
@ -104,7 +111,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
|||||||
// to derive the correct input/output memory types. We should also split
|
// to derive the correct input/output memory types. We should also split
|
||||||
// host-memory and non host-memory arguments into separate type lists.
|
// host-memory and non host-memory arguments into separate type lists.
|
||||||
if (!status.ok() || IsFunctionCallOp(ndef.op())) {
|
if (!status.ok() || IsFunctionCallOp(ndef.op())) {
|
||||||
if (device_type.type_string() == "TPU") {
|
if (device_type.type_string() == "TPU" || has_xla_compile) {
|
||||||
// Here we assume that if tf.function() is called within
|
// Here we assume that if tf.function() is called within
|
||||||
// "with tf.device('/device:TPU:0')", the whole function will be compiled
|
// "with tf.device('/device:TPU:0')", the whole function will be compiled
|
||||||
// and executed on TPU. This is true today, but when we implement auto
|
// and executed on TPU. This is true today, but when we implement auto
|
||||||
|
@ -91,3 +91,20 @@ cuda_py_test(
|
|||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "experimental_compile_test",
|
||||||
|
srcs = ["experimental_compile_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:resource_variable_ops",
|
||||||
|
],
|
||||||
|
python_version = "PY3",
|
||||||
|
tags = [
|
||||||
|
"no_mac",
|
||||||
|
"no_windows",
|
||||||
|
],
|
||||||
|
xla_enabled = True,
|
||||||
|
)
|
||||||
|
113
tensorflow/python/compiler/xla/experimental_compile_test.py
Normal file
113
tensorflow/python/compiler/xla/experimental_compile_test.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ExperimentalCompileTest(test.TestCase):
|
||||||
|
|
||||||
|
def testBasic(self):
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
|
||||||
|
def fn(x, a):
|
||||||
|
return x + a
|
||||||
|
|
||||||
|
xla_func = def_function.function(fn, experimental_compile=True)
|
||||||
|
inputs = array_ops.placeholder(dtypes.float32, [5])
|
||||||
|
# XLA support is not yet enabled for TF ROCm
|
||||||
|
if not test.is_built_with_rocm():
|
||||||
|
x = xla_func(inputs, 1)
|
||||||
|
with session.Session(graph=g) as sess:
|
||||||
|
y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
|
||||||
|
self.assertTrue(x.graph.as_graph_def().library.function[0]
|
||||||
|
.attr["_XlaMustCompile"].b)
|
||||||
|
self.assertAllClose([2, 3, 3, 4, 4], y)
|
||||||
|
|
||||||
|
def testDerivative(self):
|
||||||
|
# XLA support is not yet enabled for TF ROCm
|
||||||
|
if test.is_built_with_rocm():
|
||||||
|
return
|
||||||
|
|
||||||
|
def fn(x, a):
|
||||||
|
return 2 * x + a
|
||||||
|
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
xla_func = def_function.function(fn, experimental_compile=True)
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
inputs = array_ops.placeholder(dtypes.float32, [5])
|
||||||
|
tape.watch(inputs)
|
||||||
|
outputs = xla_func(inputs, 1)
|
||||||
|
grads = tape.gradient(outputs, inputs)
|
||||||
|
|
||||||
|
with session.Session(graph=g) as sess:
|
||||||
|
grads_tensor = sess.run(grads, feed_dict={inputs: [1, 2, 2, 3, 3]})
|
||||||
|
self.assertAllClose([2, 2, 2, 2, 2], grads_tensor)
|
||||||
|
(forward, backward) = xla_func.get_concrete_function(
|
||||||
|
inputs, 1)._delayed_rewrite_functions.forward_backward()
|
||||||
|
|
||||||
|
# Check that the must-compile attribute gets correctly propagated to the
|
||||||
|
# created derivatives.
|
||||||
|
self.assertTrue(forward.definition.attr["_XlaMustCompile"])
|
||||||
|
self.assertTrue(backward.function_def.attr["_XlaMustCompile"])
|
||||||
|
|
||||||
|
def testBasicInt32(self):
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
|
||||||
|
def fn(x, a):
|
||||||
|
return x + a
|
||||||
|
|
||||||
|
xla_func = def_function.function(fn, experimental_compile=True)
|
||||||
|
inputs = array_ops.placeholder(dtypes.int32, [5])
|
||||||
|
# XLA support is not yet enabled for TF ROCm
|
||||||
|
if not test.is_built_with_rocm():
|
||||||
|
x = xla_func(inputs, 1)
|
||||||
|
with session.Session(graph=g) as sess:
|
||||||
|
y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
|
||||||
|
self.assertTrue(x.graph.as_graph_def().library.function[0]
|
||||||
|
.attr["_XlaMustCompile"].b)
|
||||||
|
self.assertAllClose([2, 3, 3, 4, 4], y)
|
||||||
|
|
||||||
|
# Checking that we crash on an unsupported operation lets us test that the XLA
|
||||||
|
# compiler was actually invoked.
|
||||||
|
def testUnsupportedOps(self):
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
return array_ops.unique(x).y # Unique is not supported by XLA
|
||||||
|
|
||||||
|
xla_func = def_function.function(fn, experimental_compile=True)
|
||||||
|
inputs = array_ops.placeholder(dtypes.float32, [5])
|
||||||
|
x = xla_func(inputs)
|
||||||
|
# XLA support is not yet enabled for TF ROCm
|
||||||
|
if not test.is_built_with_rocm():
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
"not compilable"):
|
||||||
|
with session.Session(graph=g) as sess:
|
||||||
|
sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -45,6 +46,18 @@ class DefFunctionTest(test.TestCase):
|
|||||||
# XLA support is not yet enabled for TF ROCm
|
# XLA support is not yet enabled for TF ROCm
|
||||||
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
|
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
|
||||||
|
|
||||||
|
def testBasicInt32(self):
|
||||||
|
|
||||||
|
def fn(x, a):
|
||||||
|
return x + a
|
||||||
|
|
||||||
|
xla_func = def_function.function(fn, experimental_compile=True)
|
||||||
|
|
||||||
|
inputs = constant_op.constant([1, 2, 2, 3, 3], dtype=dtypes.int32)
|
||||||
|
if not test.is_built_with_rocm():
|
||||||
|
# XLA support is not yet enabled for TF ROCm
|
||||||
|
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
|
||||||
|
|
||||||
def testDerivative(self):
|
def testDerivative(self):
|
||||||
if test.is_built_with_rocm():
|
if test.is_built_with_rocm():
|
||||||
return
|
return
|
||||||
|
@ -1163,17 +1163,19 @@ def partitioned_call(args,
|
|||||||
graph = ops.get_default_graph()
|
graph = ops.get_default_graph()
|
||||||
f.add_to_graph(graph)
|
f.add_to_graph(graph)
|
||||||
op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
|
op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
|
||||||
op = graph.create_op(
|
|
||||||
op_name,
|
# Propagate the attribute indicating the need to compile from function to the
|
||||||
args,
|
# call itself.
|
||||||
tout,
|
xla_compile_attr = "_XlaMustCompile"
|
||||||
name=op_name,
|
op_attrs = {
|
||||||
attrs={
|
"Tin": tin_attr,
|
||||||
"Tin": tin_attr,
|
"Tout": tout_attr,
|
||||||
"Tout": tout_attr,
|
"f": func_attr,
|
||||||
"f": func_attr,
|
"config_proto": config_proto,
|
||||||
"config_proto": config_proto,
|
"executor_type": executor_type_attr,
|
||||||
"executor_type": executor_type_attr,
|
}
|
||||||
})
|
if xla_compile_attr in f.definition.attr:
|
||||||
|
op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr]
|
||||||
|
op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
|
||||||
outputs = op.outputs
|
outputs = op.outputs
|
||||||
return outputs if outputs else op
|
return outputs if outputs else op
|
||||||
|
Loading…
x
Reference in New Issue
Block a user