diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 17e4226405a..ba8973d2afa 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -117,6 +117,18 @@ class XlaAssignVariableOp : public OpKernel {
                               .TypeConstraint<int64>("out_type")               \
                               .TypeConstraint("T", TYPES),                     \
                           ShapeNOp<int64>);                                    \
+  REGISTER_KERNEL_BUILDER(Name("VariableShape")                                \
+                              .Device(DEVICE)                                  \
+                              .TypeConstraint<int32>("out_type")               \
+                              .HostMemory("output")                            \
+                              .HostMemory("input"),                            \
+                          VariableShapeOp<int32>);                             \
+  REGISTER_KERNEL_BUILDER(Name("VariableShape")                                \
+                              .Device(DEVICE)                                  \
+                              .TypeConstraint<int64>("out_type")               \
+                              .HostMemory("output")                            \
+                              .HostMemory("input"),                            \
+                          VariableShapeOp<int64>);                             \
   REGISTER_KERNEL_BUILDER(Name("Size")                                         \
                               .Device(DEVICE)                                  \
                               .HostMemory("output")                            \
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index 60424f85840..4344643abfd 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -59,7 +59,7 @@ class VariableShapeOp : public XlaOpKernel {
  private:
   DataType out_dtype_;
 };
-REGISTER_XLA_OP(Name("VariableShape").IsMetadataOp(), VariableShapeOp);
+REGISTER_XLA_OP(Name("VariableShape").CompilationOnly(), VariableShapeOp);
 
 class ReadVariableOp : public XlaOpKernel {
  public:
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 6ae665a9b88..90463d2f86b 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -294,26 +294,6 @@ REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
 
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
-template <typename T>
-class VariableShapeOp : public OpKernel {
- public:
-  explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
-
-  void Compute(OpKernelContext* ctx) override {
-    core::RefCountPtr<Var> variable;
-    OP_REQUIRES_OK(ctx,
-                   LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
-    variable->mu()->lock_shared();
-    TensorShape shape = variable->tensor()->shape();
-    variable->mu()->unlock_shared();
-    Tensor* output;
-    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
-    for (int i = 0; i < shape.dims(); ++i) {
-      output->flat<T>()(i) = shape.dim_size(i);
-    }
-  }
-};
-
 REGISTER_KERNEL_BUILDER(
     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
     VariableShapeOp<int32>);
diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h
index 1bb70b537c1..1821eb9092a 100644
--- a/tensorflow/core/kernels/resource_variable_ops.h
+++ b/tensorflow/core/kernels/resource_variable_ops.h
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/resource_var.h"
 
 namespace tensorflow {
 
@@ -66,6 +67,26 @@ class DestroyResourceOp : public OpKernel {
   bool ignore_lookup_error_;
 };
 
+template <typename T>
+class VariableShapeOp : public OpKernel {
+ public:
+  explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    core::RefCountPtr<Var> variable;
+    OP_REQUIRES_OK(ctx,
+                   LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
+    variable->mu()->lock_shared();
+    TensorShape shape = variable->tensor()->shape();
+    variable->mu()->unlock_shared();
+    Tensor* output;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
+    for (int i = 0; i < shape.dims(); ++i) {
+      output->flat<T>()(i) = shape.dim_size(i);
+    }
+  }
+};
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_