address comments on commit e4c76ebc7790368227f2aa4695c177cff364e463

Add XLA-only merge that can merge all types.

This prevents insertion of H2D and D2H copies when XLA-GPU clusters
have int32 outputs. This merge is only used the merge the outputs
from the XlaRun and the the PartitionedCall node.
This commit is contained in:
Bas Aarts 2019-10-08 08:38:05 -07:00
parent 791bf78c29
commit d628b34f78
7 changed files with 6 additions and 13 deletions

View File

@ -631,7 +631,7 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
} }
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} : OpKernel(ctx) {}
void XlaMergeOp::Compute(OpKernelContext* ctx) { void XlaMergeOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaMergeOp " << def().name(); VLOG(3) << "XlaMergeOp " << def().name();

View File

@ -180,9 +180,6 @@ class XlaMergeOp : public OpKernel {
explicit XlaMergeOp(OpKernelConstruction* ctx); explicit XlaMergeOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override; void Compute(OpKernelContext* ctx) override;
private:
const XlaPlatformInfo platform_info_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -106,7 +106,11 @@ REGISTER_OP("_XlaMerge")
}) })
.Doc(R"(XLA Merge Op. For use by the XLA JIT only. .Doc(R"(XLA Merge Op. For use by the XLA JIT only.
Merges the outputs from the TensorFlow fallback execution and the _XlaRun node. Merges the outputs from the PartitionedCall node and the _XlaRun node.
Unlike the TensorFlow merge op, _XlaMerge supports merging inputs of all types.
This prevents the need for copy operations, in particluar when an XLA cluster
has int32 outputs. The _XlaMerge up does not have a value_index output that
identifies the chosen input.
)"); )");
} // namespace tensorflow } // namespace tensorflow

View File

@ -106,7 +106,6 @@ constexpr std::array<DataType, 16> kAllXlaCpuTypes = {
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes); REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_CPU, XlaMergeOp, kAllXlaCpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);

View File

@ -72,11 +72,6 @@ class XlaAssignVariableOp : public OpKernel {
#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
#define REGISTER_XLA_MERGE_KERNEL(DEVICE, KERNEL, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_XlaMerge") \
.Device(DEVICE), \
KERNEL);
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \

View File

@ -155,7 +155,6 @@ constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes); REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_GPU, XlaMergeOp, kAllXlaGpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);

View File

@ -99,7 +99,6 @@ REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp, REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
kExecAllTypes); kExecAllTypes);
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes); REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_INTERPRETER, XlaMergeOp, kExecAllTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter); REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);