Register VariableShapeOp for TPU.
PiperOrigin-RevId: 356370389 Change-Id: I23679fdbc05f02a6d3a6203a77ed7a374cdcb4fb
This commit is contained in:
parent
21ec8483b2
commit
324ab95fb0
@ -117,6 +117,18 @@ class XlaAssignVariableOp : public OpKernel {
|
|||||||
.TypeConstraint<int64>("out_type") \
|
.TypeConstraint<int64>("out_type") \
|
||||||
.TypeConstraint("T", TYPES), \
|
.TypeConstraint("T", TYPES), \
|
||||||
ShapeNOp<int64>); \
|
ShapeNOp<int64>); \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("VariableShape") \
|
||||||
|
.Device(DEVICE) \
|
||||||
|
.TypeConstraint<int32>("out_type") \
|
||||||
|
.HostMemory("output") \
|
||||||
|
.HostMemory("input"), \
|
||||||
|
VariableShapeOp<int32>); \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("VariableShape") \
|
||||||
|
.Device(DEVICE) \
|
||||||
|
.TypeConstraint<int64>("out_type") \
|
||||||
|
.HostMemory("output") \
|
||||||
|
.HostMemory("input"), \
|
||||||
|
VariableShapeOp<int64>); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("Size") \
|
REGISTER_KERNEL_BUILDER(Name("Size") \
|
||||||
.Device(DEVICE) \
|
.Device(DEVICE) \
|
||||||
.HostMemory("output") \
|
.HostMemory("output") \
|
||||||
|
@ -59,7 +59,7 @@ class VariableShapeOp : public XlaOpKernel {
|
|||||||
private:
|
private:
|
||||||
DataType out_dtype_;
|
DataType out_dtype_;
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("VariableShape").IsMetadataOp(), VariableShapeOp);
|
REGISTER_XLA_OP(Name("VariableShape").CompilationOnly(), VariableShapeOp);
|
||||||
|
|
||||||
class ReadVariableOp : public XlaOpKernel {
|
class ReadVariableOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
|
@ -294,26 +294,6 @@ REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
|
|||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class VariableShapeOp : public OpKernel {
|
|
||||||
public:
|
|
||||||
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
|
||||||
core::RefCountPtr<Var> variable;
|
|
||||||
OP_REQUIRES_OK(ctx,
|
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
|
|
||||||
variable->mu()->lock_shared();
|
|
||||||
TensorShape shape = variable->tensor()->shape();
|
|
||||||
variable->mu()->unlock_shared();
|
|
||||||
Tensor* output;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
|
|
||||||
for (int i = 0; i < shape.dims(); ++i) {
|
|
||||||
output->flat<T>()(i) = shape.dim_size(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
|
Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
|
||||||
VariableShapeOp<int32>);
|
VariableShapeOp<int32>);
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/resource_var.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -66,6 +67,26 @@ class DestroyResourceOp : public OpKernel {
|
|||||||
bool ignore_lookup_error_;
|
bool ignore_lookup_error_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class VariableShapeOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
core::RefCountPtr<Var> variable;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
|
||||||
|
variable->mu()->lock_shared();
|
||||||
|
TensorShape shape = variable->tensor()->shape();
|
||||||
|
variable->mu()->unlock_shared();
|
||||||
|
Tensor* output;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
|
||||||
|
for (int i = 0; i < shape.dims(); ++i) {
|
||||||
|
output->flat<T>()(i) = shape.dim_size(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_
|
#endif // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user