From 791bf78c29595e9b5e4410d5009cdc2007dc50f1 Mon Sep 17 00:00:00 2001 From: Bas Aarts Date: Fri, 4 Oct 2019 15:29:07 -0700 Subject: [PATCH] Add XLA-only merge that can merge all types. This prevents insertion of H2D and D2H copies when XLA-GPU clusters have int32 outputs. This merge is only used the merge the outputs from the XlaRun and the the PartitionedCall node. --- tensorflow/compiler/jit/build_xla_ops_pass.cc | 6 +++--- tensorflow/compiler/jit/kernels/xla_ops.cc | 19 ++++++++++++++++++- tensorflow/compiler/jit/kernels/xla_ops.h | 10 ++++++++++ tensorflow/compiler/jit/ops/xla_ops.cc | 14 ++++++++++++++ tensorflow/compiler/jit/xla_cpu_device.cc | 1 + tensorflow/compiler/jit/xla_device_ops.h | 5 +++++ tensorflow/compiler/jit/xla_gpu_device.cc | 1 + .../compiler/jit/xla_interpreter_device.cc | 1 + tensorflow/core/graph/graph.cc | 1 + tensorflow/core/graph/graph_constructor.cc | 2 +- tensorflow/core/graph/graph_partition.cc | 2 +- tensorflow/core/grappler/op_types.cc | 3 ++- .../tools/graph_transforms/transform_utils.cc | 2 +- 13 files changed, 59 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 32f2d1db813..00bfa5b6b82 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -135,9 +135,9 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node, new_output = check_numerics_op; } - ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx), - {Output(old_node, oidx), new_output}); - merged_output = merged_outputs[oidx] = merge_op.output; + ops::_XlaMerge xla_merge_op(s.WithOpName("merge_oidx_", oidx), + Output(old_node, oidx), new_output); + merged_output = merged_outputs[oidx] = xla_merge_op.output; } Node* dst = e->dst(); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 8b3a8905030..59771172b9a 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -630,6 +630,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { input_output_alias, closure.resource_var_snapshots())); } +XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) + : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} + +void XlaMergeOp::Compute(OpKernelContext* ctx) { + VLOG(3) << "XlaMergeOp " << def().name(); + int i=0; + if (ctx->has_input(i) || ctx->has_input(++i)) { + ctx->set_output(0, ctx->input(i)); + } +} + REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); REGISTER_KERNEL_BUILDER(Name("XlaLaunch") @@ -648,6 +659,12 @@ REGISTER_KERNEL_BUILDER(Name("_XlaCompile") XlaCompileOp); REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp); -REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp); +REGISTER_KERNEL_BUILDER(Name("_XlaRun") + .Device(DEVICE_GPU) + .HostMemory("key"), + XlaRunOp); + +REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp); +REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index 3848ac72aac..f812ab8da74 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -175,6 +175,16 @@ class XlaRunOp : public OpKernel { const XlaPlatformInfo platform_info_; }; +class XlaMergeOp : public OpKernel { + public: + explicit XlaMergeOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + const XlaPlatformInfo platform_info_; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 95d12e95fd9..0378d3432c5 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -95,4 +95,18 @@ Executes a TensorFlow function previously compiled into a LocalExecutable by an _XlaCompile op. )"); +REGISTER_OP("_XlaMerge") + .Input("partitioned_call: T") + .Input("xla_run: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"(XLA Merge Op. For use by the XLA JIT only. + +Merges the outputs from the TensorFlow fallback execution and the _XlaRun node. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 85c09a027d3..2d7fa276ad4 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -106,6 +106,7 @@ constexpr std::array kAllXlaCpuTypes = { REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes); +REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_CPU, XlaMergeOp, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 99e95314f64..6b9656dd788 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -72,6 +72,11 @@ class XlaAssignVariableOp : public OpKernel { #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); +#define REGISTER_XLA_MERGE_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaMerge") \ + .Device(DEVICE), \ + KERNEL); + #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ REGISTER_KERNEL_BUILDER( \ Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index cead23d816e..372e409d895 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -155,6 +155,7 @@ constexpr std::array kAllXlaGpuTypes = { REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes); +REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_GPU, XlaMergeOp, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index f720183e196..bc6dc730345 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -99,6 +99,7 @@ REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp, REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp, kExecAllTypes); REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes); +REGISTER_XLA_MERGE_KERNEL(DEVICE_XLA_INTERPRETER, XlaMergeOp, kExecAllTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes); REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter); diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 363fe19d335..e7acc56afb4 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -101,6 +101,7 @@ const std::unordered_map& Node::kNodeClassTable = {"_DeviceArg", NC_ARG}, {"_Retval", NC_RETVAL}, {"_DeviceRetval", NC_RETVAL}, + {"_XlaMerge", NC_MERGE}, }); #undef REF_CLASS diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 40d4fbdb1d8..97e8d150c86 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -56,7 +56,7 @@ namespace { static constexpr const bool kDoNotCheckDuplicates = true; inline bool IsMerge(const NodeDef& node_def) { - return node_def.op() == "Merge" || node_def.op() == "RefMerge"; + return node_def.op() == "Merge" || node_def.op() == "RefMerge" || node_def.op() == "_XlaMerge"; } inline bool IsNextIteration(const NodeDef& node_def) { diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index b295085b40d..01c64f2fb03 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -47,7 +47,7 @@ namespace tensorflow { namespace { inline bool IsMerge(const NodeDef& node_def) { - return node_def.op() == "Merge" || node_def.op() == "RefMerge"; + return node_def.op() == "Merge" || node_def.op() == "RefMerge" || node_def.op() == "_XlaMerge"; } inline bool IsNextIteration(const NodeDef& node_def) { diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index b3d53360802..79fb4f93160 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -153,6 +153,7 @@ bool IsControlFlow(const NodeDef& node) { node.op() == "Exit" || node.op() == "LoopCond" || node.op() == "Merge" || + node.op() == "_XlaMerge" || node.op() == "NextIteration" || node.op() == "Switch" || node.op() == "_SwitchN"; @@ -332,7 +333,7 @@ bool IsMean(const NodeDef& node) { return node.op() == "Mean"; } bool IsMerge(const NodeDef& node) { const auto& op = node.op(); - return op == "Merge" || op == "RefMerge"; + return op == "Merge" || op == "RefMerge" || op == "_XlaMerge"; } bool IsMin(const NodeDef& node) { return node.op() == "Min"; } diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index 6c5b80e3381..27c978ec581 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -26,7 +26,7 @@ namespace graph_transforms { namespace { inline bool IsMerge(const NodeDef& node_def) { - return node_def.op() == "Merge" || node_def.op() == "RefMerge"; + return node_def.op() == "Merge" || node_def.op() == "RefMerge" || node_def.op() == "_XlaMerge"; } void RecordMatchedNodes(const NodeMatch& match,