Add XLA-only merge that can merge all types.
This prevents insertion of H2D and D2H copies when XLA-GPU clusters have int32 outputs. This merge is only used the merge the outputs from the XlaRun and the the PartitionedCall node.
This commit is contained in:
parent
18f700fa7e
commit
791bf78c29
@ -135,9 +135,9 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
|
||||
new_output = check_numerics_op;
|
||||
}
|
||||
|
||||
ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx),
|
||||
{Output(old_node, oidx), new_output});
|
||||
merged_output = merged_outputs[oidx] = merge_op.output;
|
||||
ops::_XlaMerge xla_merge_op(s.WithOpName("merge_oidx_", oidx),
|
||||
Output(old_node, oidx), new_output);
|
||||
merged_output = merged_outputs[oidx] = xla_merge_op.output;
|
||||
}
|
||||
|
||||
Node* dst = e->dst();
|
||||
|
@ -630,6 +630,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
input_output_alias, closure.resource_var_snapshots()));
|
||||
}
|
||||
|
||||
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
|
||||
|
||||
void XlaMergeOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaMergeOp " << def().name();
|
||||
int i=0;
|
||||
if (ctx->has_input(i) || ctx->has_input(++i)) {
|
||||
ctx->set_output(0, ctx->input(i));
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
|
||||
@ -648,6 +659,12 @@ REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
|
||||
XlaCompileOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("key"),
|
||||
XlaRunOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -175,6 +175,16 @@ class XlaRunOp : public OpKernel {
|
||||
const XlaPlatformInfo platform_info_;
|
||||
};
|
||||
|
||||
class XlaMergeOp : public OpKernel {
|
||||
public:
|
||||
explicit XlaMergeOp(OpKernelConstruction* ctx);
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
const XlaPlatformInfo platform_info_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
|
||||
|
@ -95,4 +95,18 @@ Executes a TensorFlow function previously compiled into a LocalExecutable by an
|
||||
_XlaCompile op.
|
||||
)");
|
||||
|
||||
REGISTER_OP("_XlaMerge")
|
||||
.Input("partitioned_call: T")
|
||||
.Input("xla_run: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"(XLA Merge Op. For use by the XLA JIT only.
|
||||
|
||||
Merges the outputs from the TensorFlow fallback execution and the _XlaRun node.
|
||||
)");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -106,6 +106,7 @@ constexpr std::array<DataType, 16> kAllXlaCpuTypes = {
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
|
||||
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
|
||||
REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_CPU, XlaMergeOp, kAllXlaCpuTypes);
|
||||
|
||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
|
||||
|
||||
|
@ -72,6 +72,11 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
|
||||
|
||||
#define REGISTER_XLA_MERGE_KERNEL(DEVICE, KERNEL, TYPES) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaMerge") \
|
||||
.Device(DEVICE), \
|
||||
KERNEL);
|
||||
|
||||
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||
|
@ -155,6 +155,7 @@ constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
|
||||
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
|
||||
REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_GPU, XlaMergeOp, kAllXlaGpuTypes);
|
||||
|
||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
|
||||
|
||||
|
@ -99,6 +99,7 @@ REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
|
||||
kExecAllTypes);
|
||||
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
|
||||
REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_INTERPRETER, XlaMergeOp, kExecAllTypes);
|
||||
|
||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
|
||||
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
|
||||
|
@ -101,6 +101,7 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
|
||||
{"_DeviceArg", NC_ARG},
|
||||
{"_Retval", NC_RETVAL},
|
||||
{"_DeviceRetval", NC_RETVAL},
|
||||
{"_XlaMerge", NC_MERGE},
|
||||
});
|
||||
|
||||
#undef REF_CLASS
|
||||
|
@ -56,7 +56,7 @@ namespace {
|
||||
static constexpr const bool kDoNotCheckDuplicates = true;
|
||||
|
||||
inline bool IsMerge(const NodeDef& node_def) {
|
||||
return node_def.op() == "Merge" || node_def.op() == "RefMerge";
|
||||
return node_def.op() == "Merge" || node_def.op() == "RefMerge" || node_def.op() == "_XlaMerge";
|
||||
}
|
||||
|
||||
inline bool IsNextIteration(const NodeDef& node_def) {
|
||||
|
@ -47,7 +47,7 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
inline bool IsMerge(const NodeDef& node_def) {
|
||||
return node_def.op() == "Merge" || node_def.op() == "RefMerge";
|
||||
return node_def.op() == "Merge" || node_def.op() == "RefMerge" || node_def.op() == "_XlaMerge";
|
||||
}
|
||||
|
||||
inline bool IsNextIteration(const NodeDef& node_def) {
|
||||
|
@ -153,6 +153,7 @@ bool IsControlFlow(const NodeDef& node) {
|
||||
node.op() == "Exit" ||
|
||||
node.op() == "LoopCond" ||
|
||||
node.op() == "Merge" ||
|
||||
node.op() == "_XlaMerge" ||
|
||||
node.op() == "NextIteration" ||
|
||||
node.op() == "Switch" ||
|
||||
node.op() == "_SwitchN";
|
||||
@ -332,7 +333,7 @@ bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
|
||||
|
||||
bool IsMerge(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Merge" || op == "RefMerge";
|
||||
return op == "Merge" || op == "RefMerge" || op == "_XlaMerge";
|
||||
}
|
||||
|
||||
bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
|
||||
|
@ -26,7 +26,7 @@ namespace graph_transforms {
|
||||
|
||||
namespace {
|
||||
inline bool IsMerge(const NodeDef& node_def) {
|
||||
return node_def.op() == "Merge" || node_def.op() == "RefMerge";
|
||||
return node_def.op() == "Merge" || node_def.op() == "RefMerge" || node_def.op() == "_XlaMerge";
|
||||
}
|
||||
|
||||
void RecordMatchedNodes(const NodeMatch& match,
|
||||
|
Loading…
x
Reference in New Issue
Block a user