[TF/XLA Bridge] Alias pass-through parameters when the input graph contains no TF reference variables
Previously, expensive copies were required for pass-through parameters. Removing those copies is not safe in the presence of TF reference variables in the graph, so we only remove them for cases when the graph does not contain TF reference variables. PiperOrigin-RevId: 271241769
This commit is contained in:
parent
b7ce325a0d
commit
6a20128c5b
@ -472,6 +472,11 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
|||||||
/*resources=*/cluster_info.resource_inputs,
|
/*resources=*/cluster_info.resource_inputs,
|
||||||
/*must_compile=*/requires_compilation,
|
/*must_compile=*/requires_compilation,
|
||||||
cluster_info.function);
|
cluster_info.function);
|
||||||
|
|
||||||
|
bool has_ref_attr;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetNodeAttr(n->attrs(), kXlaHasReferenceVarsAttr, &has_ref_attr));
|
||||||
|
xla_compile.operation.node()->AddAttr(kXlaHasReferenceVarsAttr, has_ref_attr);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
|
CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
|
||||||
|
|
||||||
|
@ -149,8 +149,10 @@ TEST_F(BuildXlaOpsTest, ControlDepsPreserved) {
|
|||||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
|
||||||
Node* call;
|
Node* call;
|
||||||
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
||||||
|
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||||
call->set_requested_device(kXlaDeviceName);
|
call->set_requested_device(kXlaDeviceName);
|
||||||
Node* write_op = MakeWrite(root, "write");
|
Node* write_op = MakeWrite(root, "write");
|
||||||
|
write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||||
root.graph()->AddControlEdge(call, write_op);
|
root.graph()->AddControlEdge(call, write_op);
|
||||||
|
|
||||||
std::unique_ptr<Graph> graph;
|
std::unique_ptr<Graph> graph;
|
||||||
@ -191,8 +193,10 @@ TEST_F(BuildXlaOpsTest, OnNonXlaDevice) {
|
|||||||
Node* call;
|
Node* call;
|
||||||
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
||||||
TF_ASSERT_OK(root.DoShapeInference(call));
|
TF_ASSERT_OK(root.DoShapeInference(call));
|
||||||
|
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||||
|
|
||||||
Node* write_op = MakeWrite(root, Output(call), "write_result");
|
Node* write_op = MakeWrite(root, Output(call), "write_result");
|
||||||
|
write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||||
|
|
||||||
auto xla_compile = NodeWith(Op("_XlaCompile"), Attr("must_compile", false));
|
auto xla_compile = NodeWith(Op("_XlaCompile"), Attr("must_compile", false));
|
||||||
auto predicated_compilation_key =
|
auto predicated_compilation_key =
|
||||||
@ -226,8 +230,10 @@ TEST_F(BuildXlaOpsTest, OnXlaDevice) {
|
|||||||
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
||||||
call->set_requested_device(kXlaDeviceName);
|
call->set_requested_device(kXlaDeviceName);
|
||||||
TF_ASSERT_OK(root.DoShapeInference(call));
|
TF_ASSERT_OK(root.DoShapeInference(call));
|
||||||
|
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||||
|
|
||||||
Node* write_op = MakeWrite(root, Output(call), "write_result");
|
Node* write_op = MakeWrite(root, Output(call), "write_result");
|
||||||
|
write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||||
|
|
||||||
std::unique_ptr<Graph> graph;
|
std::unique_ptr<Graph> graph;
|
||||||
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
|
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
|
||||||
@ -250,6 +256,7 @@ TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) {
|
|||||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
|
||||||
Node* call;
|
Node* call;
|
||||||
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
||||||
|
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||||
|
|
||||||
std::unique_ptr<Graph> graph;
|
std::unique_ptr<Graph> graph;
|
||||||
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
|
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
|
||||||
@ -278,6 +285,7 @@ TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) {
|
|||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
MakeXlaCompiledKernel(root.graph(), "cluster_int32", "C", &call));
|
MakeXlaCompiledKernel(root.graph(), "cluster_int32", "C", &call));
|
||||||
call->set_requested_device(kXlaDeviceName);
|
call->set_requested_device(kXlaDeviceName);
|
||||||
|
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||||
|
|
||||||
auto var =
|
auto var =
|
||||||
ops::VarHandleOp(root.WithOpName("var"), DT_INT32, TensorShape({}));
|
ops::VarHandleOp(root.WithOpName("var"), DT_INT32, TensorShape({}));
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
@ -61,6 +62,7 @@ const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
|
|||||||
const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
|
const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
|
||||||
const char* const kXlaHostTransferSequencerAttr =
|
const char* const kXlaHostTransferSequencerAttr =
|
||||||
"_xla_host_transfer_sequencer";
|
"_xla_host_transfer_sequencer";
|
||||||
|
const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars";
|
||||||
|
|
||||||
void SortControlInputs(GraphDef* gdef) {
|
void SortControlInputs(GraphDef* gdef) {
|
||||||
int64 num_nodes = gdef->node_size();
|
int64 num_nodes = gdef->node_size();
|
||||||
@ -1311,6 +1313,14 @@ Status EncapsulateSubgraphsPass::Run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
*options.graph = std::move(graph_out);
|
*options.graph = std::move(graph_out);
|
||||||
|
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> ref_related_nodes,
|
||||||
|
GetNodesRelatedToRefVariables(**options.graph, flr));
|
||||||
|
for (Node* node : (*options.graph)->nodes()) {
|
||||||
|
bool has_ref_vars = ref_related_nodes.contains(node);
|
||||||
|
node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars);
|
||||||
|
VLOG(3) << "Has ref vars = " << has_ref_vars
|
||||||
|
<< ", node: " << node->def().SerializeAsString();
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,6 +91,9 @@ extern const char* const kXlaNumConstantArgsAttr;
|
|||||||
// Name of the attribute containing the number of resource variable arguments.
|
// Name of the attribute containing the number of resource variable arguments.
|
||||||
extern const char* const kXlaNumResourceArgsAttr;
|
extern const char* const kXlaNumResourceArgsAttr;
|
||||||
|
|
||||||
|
// Name of the attribute defining whether the cluster has reference variables.
|
||||||
|
extern const char* const kXlaHasReferenceVarsAttr;
|
||||||
|
|
||||||
// Sorts each node's control inputs by their names. This guarantees that for two
|
// Sorts each node's control inputs by their names. This guarantees that for two
|
||||||
// structually equivalent GraphDefs, we get the same traversal ordering on
|
// structually equivalent GraphDefs, we get the same traversal ordering on
|
||||||
// node's control input fields.
|
// node's control input fields.
|
||||||
|
@ -2581,5 +2581,79 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
|
|||||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CreateSubgraphTouchingRefVar(const Scope& s) {
|
||||||
|
Output variable =
|
||||||
|
ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
|
||||||
|
Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
|
||||||
|
Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
|
||||||
|
Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
|
||||||
|
|
||||||
|
Output constant =
|
||||||
|
ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
|
||||||
|
s.graph()->AddControlEdge(constant.node(), variable.node());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(EncapsulateSubgraphsTest, RefVariablesMarked) {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
CreateSubgraphTouchingRefVar(root);
|
||||||
|
|
||||||
|
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
SessionOptions session_options;
|
||||||
|
session_options.env = Env::Default();
|
||||||
|
GraphOptimizationPassOptions options;
|
||||||
|
options.session_options = &session_options;
|
||||||
|
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||||
|
options.flib_def = &library;
|
||||||
|
options.graph = &graph;
|
||||||
|
|
||||||
|
EncapsulateSubgraphsPass pass;
|
||||||
|
TF_ASSERT_OK(pass.Run(options));
|
||||||
|
|
||||||
|
for (const Node* node : graph->nodes()) {
|
||||||
|
bool has_ref_var;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
|
||||||
|
EXPECT_TRUE(node->IsSink() || node->IsSource() || has_ref_var)
|
||||||
|
<< "All nodes apart from source and sink can access reference variable";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CreateSubgraphNotTouchingRefVar(const Scope& s) {
|
||||||
|
Output constant =
|
||||||
|
ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
|
||||||
|
Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
|
||||||
|
Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
CreateSubgraphNotTouchingRefVar(root);
|
||||||
|
|
||||||
|
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
// TODO(cheshire): reduce boilerplate for creating
|
||||||
|
// GraphOptimizationPassOptions here and elsewhere, probably using a macro.
|
||||||
|
SessionOptions session_options;
|
||||||
|
session_options.env = Env::Default();
|
||||||
|
GraphOptimizationPassOptions options;
|
||||||
|
options.session_options = &session_options;
|
||||||
|
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||||
|
options.flib_def = &library;
|
||||||
|
options.graph = &graph;
|
||||||
|
|
||||||
|
EncapsulateSubgraphsPass pass;
|
||||||
|
TF_ASSERT_OK(pass.Run(options));
|
||||||
|
|
||||||
|
for (const Node* node : graph->nodes()) {
|
||||||
|
bool has_ref_var;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
|
||||||
|
EXPECT_FALSE(has_ref_var) << "The graph does not have reference variables";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -9,11 +9,13 @@ XLA_OPS_DEPS = [
|
|||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"//tensorflow/compiler/jit:common",
|
"//tensorflow/compiler/jit:common",
|
||||||
|
"//tensorflow/compiler/jit:compilation_passes",
|
||||||
"//tensorflow/compiler/jit:flags",
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/compiler/jit:xla_activity_listener",
|
"//tensorflow/compiler/jit:xla_activity_listener",
|
||||||
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
||||||
"//tensorflow/compiler/jit:xla_compilation_cache",
|
"//tensorflow/compiler/jit:xla_compilation_cache",
|
||||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||||
|
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||||
"//tensorflow/compiler/jit:xla_launch_util",
|
"//tensorflow/compiler/jit:xla_launch_util",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
|||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
|
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
@ -268,7 +270,7 @@ static Status BuildCompilationCache(OpKernelContext* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Status CompileToLocalExecutable(
|
static Status CompileToLocalExecutable(
|
||||||
OpKernelContext* ctx, const NameAttrList& function,
|
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||||
const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
|
const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
|
||||||
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
|
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
|
||||||
std::map<int, OptionalTensor>* variables,
|
std::map<int, OptionalTensor>* variables,
|
||||||
@ -313,8 +315,9 @@ static Status CompileToLocalExecutable(
|
|||||||
options.shape_representation_fn =
|
options.shape_representation_fn =
|
||||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||||
}
|
}
|
||||||
// TODO(b/138728225): Set options.alias_passthrough_params for clusters
|
// If reference variables are not present in the graph, we can safely alias
|
||||||
// without ref variables.
|
// passthrough parameters without performing a copy.
|
||||||
|
options.alias_passthrough_params = !has_ref_vars;
|
||||||
|
|
||||||
std::map<int, Tensor> constant_args;
|
std::map<int, Tensor> constant_args;
|
||||||
for (int i : constants) {
|
for (int i : constants) {
|
||||||
@ -351,8 +354,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
Status s = CompileToLocalExecutable(
|
Status s = CompileToLocalExecutable(
|
||||||
ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false,
|
ctx, function_, /*has_ref_vars=*/true, platform_info_, resources_,
|
||||||
&client, &variables, &kernel, &executable);
|
constants_, /*lazy=*/false, &client, &variables, &kernel, &executable);
|
||||||
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
||||||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
||||||
// Suggest auto jit if the failure was with GPU or CPU.
|
// Suggest auto jit if the failure was with GPU or CPU.
|
||||||
@ -451,6 +454,14 @@ bool MustCompileAttr(OpKernelConstruction* ctx) {
|
|||||||
ctx->GetAttr("must_compile", &must_compile));
|
ctx->GetAttr("must_compile", &must_compile));
|
||||||
return must_compile;
|
return must_compile;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HasRefVars(OpKernelConstruction* ctx) {
|
||||||
|
bool has_ref_vars;
|
||||||
|
OP_REQUIRES_OK_RETURN(ctx, false,
|
||||||
|
ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars));
|
||||||
|
return has_ref_vars;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
||||||
@ -467,7 +478,8 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
|
|||||||
resources_(ResourcesVector(ctx)),
|
resources_(ResourcesVector(ctx)),
|
||||||
function_(FunctionAttr(ctx)),
|
function_(FunctionAttr(ctx)),
|
||||||
platform_info_(PlatformInfoFromContext(ctx)),
|
platform_info_(PlatformInfoFromContext(ctx)),
|
||||||
must_compile_(MustCompileAttr(ctx)) {}
|
must_compile_(MustCompileAttr(ctx)),
|
||||||
|
has_ref_vars_(HasRefVars(ctx)) {}
|
||||||
|
|
||||||
void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||||
VLOG(3) << "XlaCompileOp " << def().name()
|
VLOG(3) << "XlaCompileOp " << def().name()
|
||||||
@ -488,7 +500,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||||||
executable = nullptr;
|
executable = nullptr;
|
||||||
} else {
|
} else {
|
||||||
Status status = CompileToLocalExecutable(
|
Status status = CompileToLocalExecutable(
|
||||||
ctx, function_, platform_info_, resources_, constants_,
|
ctx, function_, has_ref_vars_, platform_info_, resources_, constants_,
|
||||||
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
|
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
|
||||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||||
OP_REQUIRES_OK(ctx, status);
|
OP_REQUIRES_OK(ctx, status);
|
||||||
|
@ -153,6 +153,9 @@ class XlaCompileOp : public OpKernel {
|
|||||||
|
|
||||||
const bool must_compile_;
|
const bool must_compile_;
|
||||||
|
|
||||||
|
// Whether the graph has TF reference variables.
|
||||||
|
const bool has_ref_vars_;
|
||||||
|
|
||||||
// cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
|
// cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
|
||||||
// error when compiling the cluster this _XlaCompile is supposed to compile.
|
// error when compiling the cluster this _XlaCompile is supposed to compile.
|
||||||
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
|
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
|
||||||
|
Loading…
x
Reference in New Issue
Block a user