From 324ab95fb097dd6044fa5346ee24fc610017ae85 Mon Sep 17 00:00:00 2001 From: Xiao Yu Date: Mon, 8 Feb 2021 15:45:06 -0800 Subject: [PATCH] Register VariableShapeOp for TPU. PiperOrigin-RevId: 356370389 Change-Id: I23679fdbc05f02a6d3a6203a77ed7a374cdcb4fb --- tensorflow/compiler/jit/xla_device_ops.h | 12 +++++++++++ .../compiler/tf2xla/kernels/variable_ops.cc | 2 +- .../core/kernels/resource_variable_ops.cc | 20 ------------------ .../core/kernels/resource_variable_ops.h | 21 +++++++++++++++++++ 4 files changed, 34 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 17e4226405a..ba8973d2afa 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -117,6 +117,18 @@ class XlaAssignVariableOp : public OpKernel { .TypeConstraint("out_type") \ .TypeConstraint("T", TYPES), \ ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("VariableShape") \ + .Device(DEVICE) \ + .TypeConstraint("out_type") \ + .HostMemory("output") \ + .HostMemory("input"), \ + VariableShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("VariableShape") \ + .Device(DEVICE) \ + .TypeConstraint("out_type") \ + .HostMemory("output") \ + .HostMemory("input"), \ + VariableShapeOp); \ REGISTER_KERNEL_BUILDER(Name("Size") \ .Device(DEVICE) \ .HostMemory("output") \ diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 60424f85840..4344643abfd 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -59,7 +59,7 @@ class VariableShapeOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("VariableShape").IsMetadataOp(), VariableShapeOp); +REGISTER_XLA_OP(Name("VariableShape").CompilationOnly(), VariableShapeOp); class ReadVariableOp : public XlaOpKernel { public: diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 6ae665a9b88..90463d2f86b 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -294,26 +294,6 @@ REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp") #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -template -class VariableShapeOp : public OpKernel { - public: - explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {} - - void Compute(OpKernelContext* ctx) override { - core::RefCountPtr variable; - OP_REQUIRES_OK(ctx, - LookupResource(ctx, HandleFromInput(ctx, 0), &variable)); - variable->mu()->lock_shared(); - TensorShape shape = variable->tensor()->shape(); - variable->mu()->unlock_shared(); - Tensor* output; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output)); - for (int i = 0; i < shape.dims(); ++i) { - output->flat()(i) = shape.dim_size(i); - } - } -}; - REGISTER_KERNEL_BUILDER( Name("VariableShape").Device(DEVICE_CPU).TypeConstraint("out_type"), VariableShapeOp); diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h index 1bb70b537c1..1821eb9092a 100644 --- a/tensorflow/core/kernels/resource_variable_ops.h +++ b/tensorflow/core/kernels/resource_variable_ops.h @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_var.h" namespace tensorflow { @@ -66,6 +67,26 @@ class DestroyResourceOp : public OpKernel { bool ignore_lookup_error_; }; +template +class VariableShapeOp : public OpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* ctx) override { + core::RefCountPtr variable; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &variable)); + variable->mu()->lock_shared(); + TensorShape shape = variable->tensor()->shape(); + variable->mu()->unlock_shared(); + Tensor* output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output)); + for (int i = 0; i < shape.dims(); ++i) { + output->flat()(i) = shape.dim_size(i); + } + } +}; + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_