diff --git a/tensorflow/core/api_def/base_api/api_def_TPUReshardVariables.pbtxt b/tensorflow/core/api_def/base_api/api_def_TPUReshardVariables.pbtxt new file mode 100644 index 00000000000..80a40fe1ed3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_TPUReshardVariables.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "TPUReshardVariables" + visibility: HIDDEN + summary: "Op that reshards on-device TPU variables to specified state." + description: <<END +Op that reshards on-device TPU variables to specified state. Internal use only. + +The sharding state is represented as the key of the compilation that generated +the sharding/unsharding programs along with the main program. new_format_key +specifies the desired state, and format_state_var is the current state of the +variables. +END +} diff --git a/tensorflow/core/tpu/ops/BUILD b/tensorflow/core/tpu/ops/BUILD index a85b599ab31..e36d46472d0 100644 --- a/tensorflow/core/tpu/ops/BUILD +++ b/tensorflow/core/tpu/ops/BUILD @@ -14,6 +14,7 @@ cc_library( ":tpu_compile_op", ":tpu_execute_op", ":tpu_partitioned_ops", + ":tpu_reshard_variables_op", ], alwayslink = 1, ) @@ -88,3 +89,16 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "tpu_reshard_variables_op", + srcs = [ + "tpu_reshard_variables_op.cc", + ], + linkstatic = 1, + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/core/tpu/ops/tpu_reshard_variables_op.cc b/tensorflow/core/tpu/ops/tpu_reshard_variables_op.cc new file mode 100644 index 00000000000..fe35bf781b6 --- /dev/null +++ b/tensorflow/core/tpu/ops/tpu_reshard_variables_op.cc @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +REGISTER_OP("TPUReshardVariables") + .Attr("N: int >= 0") + .Input("vars: N * resource") + .Input("new_format_key: string") + .Input("format_state_var: resource") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + return ::tensorflow::shape_inference::UnknownShape(c); + }); + +} // namespace tensorflow