Register VariableShapeOp for TPU.
PiperOrigin-RevId: 356370389 Change-Id: I23679fdbc05f02a6d3a6203a77ed7a374cdcb4fb
This commit is contained in:
parent
21ec8483b2
commit
324ab95fb0
@ -117,6 +117,18 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
.TypeConstraint<int64>("out_type") \
|
||||
.TypeConstraint("T", TYPES), \
|
||||
ShapeNOp<int64>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("VariableShape") \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<int32>("out_type") \
|
||||
.HostMemory("output") \
|
||||
.HostMemory("input"), \
|
||||
VariableShapeOp<int32>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("VariableShape") \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<int64>("out_type") \
|
||||
.HostMemory("output") \
|
||||
.HostMemory("input"), \
|
||||
VariableShapeOp<int64>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Size") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("output") \
|
||||
|
@ -59,7 +59,7 @@ class VariableShapeOp : public XlaOpKernel {
|
||||
private:
|
||||
DataType out_dtype_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("VariableShape").IsMetadataOp(), VariableShapeOp);
|
||||
REGISTER_XLA_OP(Name("VariableShape").CompilationOnly(), VariableShapeOp);
|
||||
|
||||
class ReadVariableOp : public XlaOpKernel {
|
||||
public:
|
||||
|
@ -294,26 +294,6 @@ REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename T>
|
||||
class VariableShapeOp : public OpKernel {
|
||||
public:
|
||||
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
core::RefCountPtr<Var> variable;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
|
||||
variable->mu()->lock_shared();
|
||||
TensorShape shape = variable->tensor()->shape();
|
||||
variable->mu()->unlock_shared();
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
|
||||
for (int i = 0; i < shape.dims(); ++i) {
|
||||
output->flat<T>()(i) = shape.dim_size(i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
|
||||
VariableShapeOp<int32>);
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -66,6 +67,26 @@ class DestroyResourceOp : public OpKernel {
|
||||
bool ignore_lookup_error_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VariableShapeOp : public OpKernel {
|
||||
public:
|
||||
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
core::RefCountPtr<Var> variable;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
|
||||
variable->mu()->lock_shared();
|
||||
TensorShape shape = variable->tensor()->shape();
|
||||
variable->mu()->unlock_shared();
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
|
||||
for (int i = 0; i < shape.dims(); ++i) {
|
||||
output->flat<T>()(i) = shape.dim_size(i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user