Open source op definition for TPUReshardVariable op
PiperOrigin-RevId: 358268125 Change-Id: I88c2de3f0bdb237cb1b235c2691a684ab644e916
This commit is contained in:
parent
aed7a7b5e8
commit
014f02fea5
@ -0,0 +1,13 @@
|
||||
op {
|
||||
graph_op_name: "TPUReshardVariables"
|
||||
visibility: HIDDEN
|
||||
summary: "Op that reshards on-device TPU variables to specified state."
|
||||
description: <<END
|
||||
Op that reshards on-device TPU variables to specified state. Internal use only.
|
||||
|
||||
The sharding state is represented as the key of the compilation that generated
|
||||
the sharding/unsharding programs along with the main program. new_format_key
|
||||
specifies the desired state, and format_state_var is the current state of the
|
||||
variables.
|
||||
END
|
||||
}
|
@ -14,6 +14,7 @@ cc_library(
|
||||
":tpu_compile_op",
|
||||
":tpu_execute_op",
|
||||
":tpu_partitioned_ops",
|
||||
":tpu_reshard_variables_op",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -88,3 +89,16 @@ cc_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_reshard_variables_op",
|
||||
srcs = [
|
||||
"tpu_reshard_variables_op.cc",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
33
tensorflow/core/tpu/ops/tpu_reshard_variables_op.cc
Normal file
33
tensorflow/core/tpu/ops/tpu_reshard_variables_op.cc
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("TPUReshardVariables")
|
||||
.Attr("N: int >= 0")
|
||||
.Input("vars: N * resource")
|
||||
.Input("new_format_key: string")
|
||||
.Input("format_state_var: resource")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
return ::tensorflow::shape_inference::UnknownShape(c);
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user