address comments on commit e4c76ebc7790368227f2aa4695c177cff364e463
Add XLA-only merge that can merge all types. This prevents insertion of H2D and D2H copies when XLA-GPU clusters have int32 outputs. This merge is only used the merge the outputs from the XlaRun and the the PartitionedCall node.
This commit is contained in:
parent
791bf78c29
commit
d628b34f78
@ -631,7 +631,7 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx)
|
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx)
|
||||||
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
|
: OpKernel(ctx) {}
|
||||||
|
|
||||||
void XlaMergeOp::Compute(OpKernelContext* ctx) {
|
void XlaMergeOp::Compute(OpKernelContext* ctx) {
|
||||||
VLOG(3) << "XlaMergeOp " << def().name();
|
VLOG(3) << "XlaMergeOp " << def().name();
|
||||||
|
@ -180,9 +180,6 @@ class XlaMergeOp : public OpKernel {
|
|||||||
explicit XlaMergeOp(OpKernelConstruction* ctx);
|
explicit XlaMergeOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override;
|
void Compute(OpKernelContext* ctx) override;
|
||||||
|
|
||||||
private:
|
|
||||||
const XlaPlatformInfo platform_info_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -106,7 +106,11 @@ REGISTER_OP("_XlaMerge")
|
|||||||
})
|
})
|
||||||
.Doc(R"(XLA Merge Op. For use by the XLA JIT only.
|
.Doc(R"(XLA Merge Op. For use by the XLA JIT only.
|
||||||
|
|
||||||
Merges the outputs from the TensorFlow fallback execution and the _XlaRun node.
|
Merges the outputs from the PartitionedCall node and the _XlaRun node.
|
||||||
|
Unlike the TensorFlow merge op, _XlaMerge supports merging inputs of all types.
|
||||||
|
This prevents the need for copy operations, in particluar when an XLA cluster
|
||||||
|
has int32 outputs. The _XlaMerge up does not have a value_index output that
|
||||||
|
identifies the chosen input.
|
||||||
)");
|
)");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -106,7 +106,6 @@ constexpr std::array<DataType, 16> kAllXlaCpuTypes = {
|
|||||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
||||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
|
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
|
||||||
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
|
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
|
||||||
REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_CPU, XlaMergeOp, kAllXlaCpuTypes);
|
|
||||||
|
|
||||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
|
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
|
||||||
|
|
||||||
|
@ -72,11 +72,6 @@ class XlaAssignVariableOp : public OpKernel {
|
|||||||
#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
|
#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
|
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
|
||||||
|
|
||||||
#define REGISTER_XLA_MERGE_KERNEL(DEVICE, KERNEL, TYPES) \
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("_XlaMerge") \
|
|
||||||
.Device(DEVICE), \
|
|
||||||
KERNEL);
|
|
||||||
|
|
||||||
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
|
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||||
|
@ -155,7 +155,6 @@ constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
|
|||||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
|
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
|
||||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
|
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
|
||||||
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
|
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
|
||||||
REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_GPU, XlaMergeOp, kAllXlaGpuTypes);
|
|
||||||
|
|
||||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
|
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
|
||||||
|
|
||||||
|
@ -99,7 +99,6 @@ REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
|
|||||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
|
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
|
||||||
kExecAllTypes);
|
kExecAllTypes);
|
||||||
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
|
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
|
||||||
REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_INTERPRETER, XlaMergeOp, kExecAllTypes);
|
|
||||||
|
|
||||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
|
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
|
||||||
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
|
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
|
||||||
|
Loading…
Reference in New Issue
Block a user