diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 9f5723f4fa4..dc5df94e963 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1837,7 +1837,7 @@ absl::flat_hash_map>* GetWhitelistTable() { "ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV", "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", - "Tile", "Transpose", "InvertPermutation", "Unpack"}}}; + "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}}; // clang-format on return result; } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index bdaeeafd295..e072225566d 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -32,6 +32,7 @@ tf_kernel_library( "data_format_ops.cc", "depthtospace_op.cc", "dequantize_op.cc", + "device_index_op.cc", "diag_op.cc", "dynamic_slice_ops.cc", "dynamic_stitch_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc new file mode 100644 index 00000000000..ff058f92cd7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +class DeviceIndexOp : public XlaOpKernel { + public: + explicit DeviceIndexOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // When compiling we are not executing on any physical device, so we return + // a sentinel value (size of the list of devices). + ctx->SetOutput( + 0, xla::ConstantR0(ctx->builder(), device_names_.size())); + } + + private: + std::vector device_names_; +}; + +REGISTER_XLA_OP(Name("DeviceIndex"), DeviceIndexOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/api_def/base_api/api_def_DeviceIndex.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeviceIndex.pbtxt index 9a4e5abd110..87c146910ff 100644 --- a/tensorflow/core/api_def/base_api/api_def_DeviceIndex.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DeviceIndex.pbtxt @@ -2,4 +2,10 @@ op { graph_op_name: "DeviceIndex" visibility: HIDDEN summary: "Return the index of device the op runs." + description: <