[Resubmit] Only trigger resharding when necessary.
PiperOrigin-RevId: 332109807 Change-Id: I63bc5f88c32cdf91b1d7e05b56829ed53bc557d4
This commit is contained in:
parent
1593b96ae7
commit
8d9f388840
@ -537,9 +537,10 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
// Build a constant default key to specify that the unformatting should
|
||||
// transform the variables to the original format.
|
||||
builder.setInsertionPointAfter(while_op);
|
||||
tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {2});
|
||||
tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
|
||||
default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
|
||||
default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
|
||||
default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
|
||||
auto default_state_key = builder.create<TF::ConstOp>(
|
||||
while_op.getLoc(),
|
||||
tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
|
||||
|
@ -504,9 +504,10 @@ Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
|
||||
"TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
|
||||
AddNodeAttr("dtype", DT_STRING, &default_sharding);
|
||||
|
||||
Tensor t(DT_STRING, {2});
|
||||
Tensor t(DT_STRING, {3});
|
||||
t.vec<tstring>()(0) = kDefaultShardingValue;
|
||||
t.vec<tstring>()(1) = kDefaultShardingValue;
|
||||
t.vec<tstring>()(2) = kDefaultShardingValue;
|
||||
t.AsProtoTensorContent(
|
||||
(*default_sharding.mutable_attr())["value"].mutable_tensor());
|
||||
|
||||
|
@ -124,6 +124,9 @@ struct CompiledSubgraph : public core::RefCounted {
|
||||
// Compilation cache proto key to identify the cache entry.
|
||||
std::vector<std::string> proto_key;
|
||||
|
||||
// Fingerprints of sharding programs if there is any.
|
||||
std::vector<std::string> sharding_key;
|
||||
|
||||
// The number of 'external' client-held references to the entry.
|
||||
int external_references = 0;
|
||||
|
||||
|
@ -58,9 +58,9 @@ class RecvAtHostOp : public AsyncOpKernel {
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx,
|
||||
TensorShapeUtils::IsVector(input.shape()) &&
|
||||
input.shape().dim_size(0) == 2,
|
||||
input.shape().dim_size(0) == 3,
|
||||
errors::InvalidArgument("Input shape ", input.shape().DebugString(),
|
||||
" is not a vector of length 2."),
|
||||
" is not a vector of length 3."),
|
||||
done);
|
||||
const string rendezvous_key_base = input.vec<tstring>()(1);
|
||||
OP_REQUIRES_ASYNC(
|
||||
@ -164,10 +164,10 @@ class SendFromHostOp : public OpKernel {
|
||||
const Tensor& key_input = ctx->input(ctx->num_inputs() - 1);
|
||||
OP_REQUIRES(ctx,
|
||||
TensorShapeUtils::IsVector(key_input.shape()) &&
|
||||
key_input.shape().dim_size(0) == 2,
|
||||
key_input.shape().dim_size(0) == 3,
|
||||
errors::InvalidArgument("Key input shape ",
|
||||
key_input.shape().DebugString(),
|
||||
" is not a vector of length 2."));
|
||||
" is not a vector of length 3."));
|
||||
const string rendezvous_key_base = key_input.vec<tstring>()(1);
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->rendezvous() != nullptr,
|
||||
|
@ -362,14 +362,15 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<std::string>* proto_key,
|
||||
std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
|
||||
std::vector<CompiledSubgraph*> removed_entries;
|
||||
auto status = CompileIfKeyAbsentHelper(
|
||||
subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
|
||||
may_modify_variables, &removed_entries, hlo_metadatas, compile_function);
|
||||
sharding_key, may_modify_variables, &removed_entries, hlo_metadatas,
|
||||
compile_function);
|
||||
for (auto entry : removed_entries) {
|
||||
UnloadAndDestroy(entry);
|
||||
}
|
||||
@ -399,7 +400,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<std::string>* proto_key,
|
||||
std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
std::vector<CompiledSubgraph*>* removed_entries,
|
||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||
@ -497,6 +498,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
*uid = entry->uid;
|
||||
// Let the caller know the keys for each of the cached protos.
|
||||
*proto_key = entry->proto_key;
|
||||
*sharding_key = entry->sharding_key;
|
||||
*may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
|
||||
*hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
|
||||
|
||||
|
@ -109,6 +109,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<std::string>* proto_key,
|
||||
std::vector<std::string>* sharding_key,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
||||
@ -197,6 +198,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<std::string>* proto_key,
|
||||
std::vector<std::string>* sharding_key,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
std::vector<CompiledSubgraph*>* removed_entries,
|
||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||
|
@ -657,10 +657,11 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
|
||||
|
||||
int64 uid;
|
||||
std::vector<std::string> proto_key;
|
||||
std::vector<std::string> sharding_key;
|
||||
std::vector<bool> may_modify_variables;
|
||||
absl::Span<const xla::HloProto* const> hlo_metadatas;
|
||||
Status status = cache->CompileIfKeyAbsent(
|
||||
key, ctx->session_metadata(), ref_holder, &uid, &proto_key,
|
||||
key, ctx->session_metadata(), ref_holder, &uid, &proto_key, &sharding_key,
|
||||
&may_modify_variables, &hlo_metadatas,
|
||||
[&](TpuProgramGroupInterface* tpu_program_group) {
|
||||
VLOG(1) << "Cloud TPU: Compiling TPU program";
|
||||
@ -778,13 +779,21 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
|
||||
|
||||
if (status.ok()) {
|
||||
for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
|
||||
Tensor output(DT_STRING, TensorShape({2}));
|
||||
Tensor output(DT_STRING, TensorShape({3}));
|
||||
if (proto_key.size() == 1) {
|
||||
output.vec<tstring>()(0) = proto_key[0];
|
||||
} else {
|
||||
output.vec<tstring>()(0) = proto_key[i];
|
||||
}
|
||||
output.vec<tstring>()(1) = rendezvous_key_base;
|
||||
if (sharding_key.empty()) {
|
||||
output.vec<tstring>()(2) = "";
|
||||
} else if (sharding_key.size() == 1) {
|
||||
output.vec<tstring>()(2) = sharding_key[0];
|
||||
} else {
|
||||
TF_RET_CHECK(sharding_key.size() == num_cores_with_compiled_programs);
|
||||
output.vec<tstring>()(2) = sharding_key[i];
|
||||
}
|
||||
ctx->set_output(i + 1, output);
|
||||
}
|
||||
if (!use_mlir_) {
|
||||
@ -805,9 +814,10 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
|
||||
} else {
|
||||
// Return error in the invalid case.
|
||||
for (int i = 0; i < num_computations_; ++i) {
|
||||
Tensor output(DT_STRING, TensorShape({2}));
|
||||
Tensor output(DT_STRING, TensorShape({3}));
|
||||
output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
|
||||
output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
|
||||
output.vec<tstring>()(2) = "<<NO SHARDing KEY AS COMPILATION FAILED>>";
|
||||
ctx->set_output(i + 1, output);
|
||||
}
|
||||
if (!use_mlir_) {
|
||||
|
@ -72,9 +72,9 @@ Status GetComputationCacheEntry(
|
||||
TF_RETURN_IF_ERROR(context->input("key", &key));
|
||||
profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
|
||||
if (!TensorShapeUtils::IsVector(key->shape()) ||
|
||||
key->shape().dim_size(0) != 2) {
|
||||
key->shape().dim_size(0) != 3) {
|
||||
return errors::InvalidArgument(
|
||||
"Key argument to TPUExecute must be a 2-element vector");
|
||||
"Key argument to TPUExecute must be a 3-element vector");
|
||||
}
|
||||
|
||||
ResourceMgr* rmgr = GetTPUConfigResourceMgr();
|
||||
|
@ -40,7 +40,7 @@ REGISTER_OP("_TPUCompileMlir")
|
||||
c->set_output(0, c->Scalar());
|
||||
// Programs.
|
||||
for (int i = 0; i < num_computations; ++i) {
|
||||
c->set_output(i + 1, c->Vector(2));
|
||||
c->set_output(i + 1, c->Vector(3));
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
@ -64,7 +64,7 @@ REGISTER_OP("_TPUCompileMlirPlaceholderProgramKey")
|
||||
.SetIsStateful()
|
||||
.Output("program: string")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(2));
|
||||
c->set_output(0, c->Vector(3));
|
||||
return Status::OK();
|
||||
})
|
||||
.SetIsStateful()
|
||||
@ -100,7 +100,7 @@ REGISTER_OP("TPUCompile")
|
||||
c->set_output(0, c->Scalar());
|
||||
// Programs.
|
||||
for (int i = 0; i < num_computations; ++i) {
|
||||
c->set_output(i + 1, c->Vector(2));
|
||||
c->set_output(i + 1, c->Vector(3));
|
||||
}
|
||||
// May modify variables.
|
||||
for (int i = 0; i < num_computations; ++i) {
|
||||
|
@ -30,7 +30,7 @@ REGISTER_OP("TPUExecute")
|
||||
shape_inference::ShapeHandle key;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
|
||||
shape_inference::DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->UnknownShape());
|
||||
}
|
||||
@ -50,7 +50,7 @@ REGISTER_OP("TPUExecuteAndUpdateVariables")
|
||||
shape_inference::ShapeHandle key;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
|
||||
shape_inference::DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->UnknownShape());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user