Register VariableShapeOp for TPU.

PiperOrigin-RevId: 356370389
Change-Id: I23679fdbc05f02a6d3a6203a77ed7a374cdcb4fb
This commit is contained in:
Xiao Yu 2021-02-08 15:45:06 -08:00 committed by TensorFlower Gardener
parent 21ec8483b2
commit 324ab95fb0
4 changed files with 34 additions and 21 deletions

View File

@ -117,6 +117,18 @@ class XlaAssignVariableOp : public OpKernel {
.TypeConstraint<int64>("out_type") \
.TypeConstraint("T", TYPES), \
ShapeNOp<int64>); \
REGISTER_KERNEL_BUILDER(Name("VariableShape") \
.Device(DEVICE) \
.TypeConstraint<int32>("out_type") \
.HostMemory("output") \
.HostMemory("input"), \
VariableShapeOp<int32>); \
REGISTER_KERNEL_BUILDER(Name("VariableShape") \
.Device(DEVICE) \
.TypeConstraint<int64>("out_type") \
.HostMemory("output") \
.HostMemory("input"), \
VariableShapeOp<int64>); \
REGISTER_KERNEL_BUILDER(Name("Size") \
.Device(DEVICE) \
.HostMemory("output") \

View File

@ -59,7 +59,7 @@ class VariableShapeOp : public XlaOpKernel {
private:
DataType out_dtype_;
};
REGISTER_XLA_OP(Name("VariableShape").IsMetadataOp(), VariableShapeOp);
REGISTER_XLA_OP(Name("VariableShape").CompilationOnly(), VariableShapeOp);
class ReadVariableOp : public XlaOpKernel {
public:

View File

@ -294,26 +294,6 @@ REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>
class VariableShapeOp : public OpKernel {
public:
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* ctx) override {
core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
variable->mu()->lock_shared();
TensorShape shape = variable->tensor()->shape();
variable->mu()->unlock_shared();
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
for (int i = 0; i < shape.dims(); ++i) {
output->flat<T>()(i) = shape.dim_size(i);
}
}
};
REGISTER_KERNEL_BUILDER(
Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
VariableShapeOp<int32>);

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_var.h"
namespace tensorflow {
@ -66,6 +67,26 @@ class DestroyResourceOp : public OpKernel {
bool ignore_lookup_error_;
};
template <typename T>
class VariableShapeOp : public OpKernel {
public:
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* ctx) override {
core::RefCountPtr<Var> variable;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
variable->mu()->lock_shared();
TensorShape shape = variable->tensor()->shape();
variable->mu()->unlock_shared();
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
for (int i = 0; i < shape.dims(); ++i) {
output->flat<T>()(i) = shape.dim_size(i);
}
}
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_