Merge pull request #33063 from bas-aarts:xla-merge
PiperOrigin-RevId: 276286841 Change-Id: I4f9cbc4d82cc963676b0b55ea023d4792ee1b0c7
This commit is contained in:
commit
c87a16e17a
@ -157,9 +157,12 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
|
|||||||
new_output = check_numerics_op;
|
new_output = check_numerics_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx),
|
ops::_XlaMerge xla_merge_op(s.WithOpName("merge_oidx_", oidx),
|
||||||
{Output(old_node, oidx), new_output});
|
Output(old_node, oidx), new_output);
|
||||||
merged_output = merged_outputs[oidx] = merge_op.output;
|
if (xla_merge_op.output.type() == DT_INT32) {
|
||||||
|
LOG(INFO) << "int32 output at index " << oidx;
|
||||||
|
}
|
||||||
|
merged_output = merged_outputs[oidx] = xla_merge_op.output;
|
||||||
}
|
}
|
||||||
|
|
||||||
Node* dst = e->dst();
|
Node* dst = e->dst();
|
||||||
|
@ -208,7 +208,7 @@ TEST_F(BuildXlaOpsTest, OnNonXlaDevice) {
|
|||||||
NodeWith(Op("PartitionedCall"),
|
NodeWith(Op("PartitionedCall"),
|
||||||
CtrlDeps(NodeWith(Op("Identity"),
|
CtrlDeps(NodeWith(Op("Identity"),
|
||||||
Inputs(Out(0, predicated_compilation_key)))));
|
Inputs(Out(0, predicated_compilation_key)))));
|
||||||
auto merge = NodeWith(Op("Merge"), Inputs(Out(tf_call), Out(xla_run)));
|
auto merge = NodeWith(Op("_XlaMerge"), Inputs(Out(tf_call), Out(xla_run)));
|
||||||
auto assign_var = NodeWith(Op("AssignVariableOp"), Inputs(_, Out(merge)));
|
auto assign_var = NodeWith(Op("AssignVariableOp"), Inputs(_, Out(merge)));
|
||||||
|
|
||||||
std::unique_ptr<Graph> graph;
|
std::unique_ptr<Graph> graph;
|
||||||
|
@ -533,6 +533,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||||||
compilation_successful.scalar<bool>()() = false;
|
compilation_successful.scalar<bool>()() = false;
|
||||||
ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
|
ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
|
||||||
ctx->set_output(1, compilation_successful);
|
ctx->set_output(1, compilation_successful);
|
||||||
|
LOG(INFO) << "Compilation bailout!";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -630,6 +631,16 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
|||||||
input_output_alias, closure.resource_var_snapshots()));
|
input_output_alias, closure.resource_var_snapshots()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void XlaMergeOp::Compute(OpKernelContext* ctx) {
|
||||||
|
VLOG(3) << "XlaMergeOp " << def().name();
|
||||||
|
int i = 0;
|
||||||
|
if (ctx->has_input(i) || ctx->has_input(++i)) {
|
||||||
|
ctx->set_output(0, ctx->input(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
|
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
|
REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
|
||||||
@ -648,6 +659,10 @@ REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
|
|||||||
XlaCompileOp);
|
XlaCompileOp);
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
|
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp);
|
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU).HostMemory("key"),
|
||||||
|
XlaRunOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -175,6 +175,13 @@ class XlaRunOp : public OpKernel {
|
|||||||
const XlaPlatformInfo platform_info_;
|
const XlaPlatformInfo platform_info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class XlaMergeOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit XlaMergeOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
|
#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
|
||||||
|
@ -95,4 +95,23 @@ Executes a TensorFlow function previously compiled into a LocalExecutable by an
|
|||||||
_XlaCompile op.
|
_XlaCompile op.
|
||||||
)");
|
)");
|
||||||
|
|
||||||
|
REGISTER_OP("_XlaMerge")
|
||||||
|
.Input("partitioned_call: T")
|
||||||
|
.Input("xla_run: T")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: type")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
c->set_output(0, c->input(0));
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"(XLA Merge Op. For use by the XLA JIT only.
|
||||||
|
|
||||||
|
Merges the outputs from the PartitionedCall node and the _XlaRun node.
|
||||||
|
Unlike the TensorFlow Merge op, which requires inputs of some types to be
|
||||||
|
placed on the host, the _XlaMerge op can merge inputs of all types when
|
||||||
|
placed on the device. This prevents the need for copy operations, in
|
||||||
|
particluar when an XLA cluster has int32 outputs. The _XlaMerge up does not
|
||||||
|
have a value_index output that identifies the chosen input.
|
||||||
|
)");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -103,6 +103,7 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
|
|||||||
{"_DeviceArg", NC_ARG},
|
{"_DeviceArg", NC_ARG},
|
||||||
{"_Retval", NC_RETVAL},
|
{"_Retval", NC_RETVAL},
|
||||||
{"_DeviceRetval", NC_RETVAL},
|
{"_DeviceRetval", NC_RETVAL},
|
||||||
|
{"_XlaMerge", NC_MERGE},
|
||||||
});
|
});
|
||||||
|
|
||||||
#undef REF_CLASS
|
#undef REF_CLASS
|
||||||
|
@ -56,7 +56,8 @@ namespace {
|
|||||||
static constexpr const bool kDoNotCheckDuplicates = true;
|
static constexpr const bool kDoNotCheckDuplicates = true;
|
||||||
|
|
||||||
inline bool IsMerge(const NodeDef& node_def) {
|
inline bool IsMerge(const NodeDef& node_def) {
|
||||||
return node_def.op() == "Merge" || node_def.op() == "RefMerge";
|
return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
|
||||||
|
node_def.op() == "_XlaMerge";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool IsNextIteration(const NodeDef& node_def) {
|
inline bool IsNextIteration(const NodeDef& node_def) {
|
||||||
|
@ -47,7 +47,8 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
inline bool IsMerge(const NodeDef& node_def) {
|
inline bool IsMerge(const NodeDef& node_def) {
|
||||||
return node_def.op() == "Merge" || node_def.op() == "RefMerge";
|
return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
|
||||||
|
node_def.op() == "_XlaMerge";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool IsNextIteration(const NodeDef& node_def) {
|
inline bool IsNextIteration(const NodeDef& node_def) {
|
||||||
|
@ -153,6 +153,7 @@ bool IsControlFlow(const NodeDef& node) {
|
|||||||
node.op() == "Exit" ||
|
node.op() == "Exit" ||
|
||||||
node.op() == "LoopCond" ||
|
node.op() == "LoopCond" ||
|
||||||
node.op() == "Merge" ||
|
node.op() == "Merge" ||
|
||||||
|
node.op() == "_XlaMerge" ||
|
||||||
node.op() == "NextIteration" ||
|
node.op() == "NextIteration" ||
|
||||||
node.op() == "Switch" ||
|
node.op() == "Switch" ||
|
||||||
node.op() == "_SwitchN";
|
node.op() == "_SwitchN";
|
||||||
@ -332,7 +333,7 @@ bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
|
|||||||
|
|
||||||
bool IsMerge(const NodeDef& node) {
|
bool IsMerge(const NodeDef& node) {
|
||||||
const auto& op = node.op();
|
const auto& op = node.op();
|
||||||
return op == "Merge" || op == "RefMerge";
|
return op == "Merge" || op == "RefMerge" || op == "_XlaMerge";
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
|
bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
|
||||||
|
@ -26,7 +26,8 @@ namespace graph_transforms {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
inline bool IsMerge(const NodeDef& node_def) {
|
inline bool IsMerge(const NodeDef& node_def) {
|
||||||
return node_def.op() == "Merge" || node_def.op() == "RefMerge";
|
return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
|
||||||
|
node_def.op() == "_XlaMerge";
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecordMatchedNodes(const NodeMatch& match,
|
void RecordMatchedNodes(const NodeMatch& match,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user