[Resubmit] Only trigger resharding when necessary.
PiperOrigin-RevId: 332109807 Change-Id: I63bc5f88c32cdf91b1d7e05b56829ed53bc557d4
This commit is contained in:
parent
1593b96ae7
commit
8d9f388840
@ -537,9 +537,10 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
|||||||
// Build a constant default key to specify that the unformatting should
|
// Build a constant default key to specify that the unformatting should
|
||||||
// transform the variables to the original format.
|
// transform the variables to the original format.
|
||||||
builder.setInsertionPointAfter(while_op);
|
builder.setInsertionPointAfter(while_op);
|
||||||
tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {2});
|
tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
|
||||||
default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
|
default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
|
||||||
default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
|
default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
|
||||||
|
default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
|
||||||
auto default_state_key = builder.create<TF::ConstOp>(
|
auto default_state_key = builder.create<TF::ConstOp>(
|
||||||
while_op.getLoc(),
|
while_op.getLoc(),
|
||||||
tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
|
tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
|
||||||
|
@ -504,9 +504,10 @@ Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
|
|||||||
"TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
|
"TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
|
||||||
AddNodeAttr("dtype", DT_STRING, &default_sharding);
|
AddNodeAttr("dtype", DT_STRING, &default_sharding);
|
||||||
|
|
||||||
Tensor t(DT_STRING, {2});
|
Tensor t(DT_STRING, {3});
|
||||||
t.vec<tstring>()(0) = kDefaultShardingValue;
|
t.vec<tstring>()(0) = kDefaultShardingValue;
|
||||||
t.vec<tstring>()(1) = kDefaultShardingValue;
|
t.vec<tstring>()(1) = kDefaultShardingValue;
|
||||||
|
t.vec<tstring>()(2) = kDefaultShardingValue;
|
||||||
t.AsProtoTensorContent(
|
t.AsProtoTensorContent(
|
||||||
(*default_sharding.mutable_attr())["value"].mutable_tensor());
|
(*default_sharding.mutable_attr())["value"].mutable_tensor());
|
||||||
|
|
||||||
|
@ -124,6 +124,9 @@ struct CompiledSubgraph : public core::RefCounted {
|
|||||||
// Compilation cache proto key to identify the cache entry.
|
// Compilation cache proto key to identify the cache entry.
|
||||||
std::vector<std::string> proto_key;
|
std::vector<std::string> proto_key;
|
||||||
|
|
||||||
|
// Fingerprints of sharding programs if there is any.
|
||||||
|
std::vector<std::string> sharding_key;
|
||||||
|
|
||||||
// The number of 'external' client-held references to the entry.
|
// The number of 'external' client-held references to the entry.
|
||||||
int external_references = 0;
|
int external_references = 0;
|
||||||
|
|
||||||
|
@ -58,9 +58,9 @@ class RecvAtHostOp : public AsyncOpKernel {
|
|||||||
OP_REQUIRES_ASYNC(
|
OP_REQUIRES_ASYNC(
|
||||||
ctx,
|
ctx,
|
||||||
TensorShapeUtils::IsVector(input.shape()) &&
|
TensorShapeUtils::IsVector(input.shape()) &&
|
||||||
input.shape().dim_size(0) == 2,
|
input.shape().dim_size(0) == 3,
|
||||||
errors::InvalidArgument("Input shape ", input.shape().DebugString(),
|
errors::InvalidArgument("Input shape ", input.shape().DebugString(),
|
||||||
" is not a vector of length 2."),
|
" is not a vector of length 3."),
|
||||||
done);
|
done);
|
||||||
const string rendezvous_key_base = input.vec<tstring>()(1);
|
const string rendezvous_key_base = input.vec<tstring>()(1);
|
||||||
OP_REQUIRES_ASYNC(
|
OP_REQUIRES_ASYNC(
|
||||||
@ -164,10 +164,10 @@ class SendFromHostOp : public OpKernel {
|
|||||||
const Tensor& key_input = ctx->input(ctx->num_inputs() - 1);
|
const Tensor& key_input = ctx->input(ctx->num_inputs() - 1);
|
||||||
OP_REQUIRES(ctx,
|
OP_REQUIRES(ctx,
|
||||||
TensorShapeUtils::IsVector(key_input.shape()) &&
|
TensorShapeUtils::IsVector(key_input.shape()) &&
|
||||||
key_input.shape().dim_size(0) == 2,
|
key_input.shape().dim_size(0) == 3,
|
||||||
errors::InvalidArgument("Key input shape ",
|
errors::InvalidArgument("Key input shape ",
|
||||||
key_input.shape().DebugString(),
|
key_input.shape().DebugString(),
|
||||||
" is not a vector of length 2."));
|
" is not a vector of length 3."));
|
||||||
const string rendezvous_key_base = key_input.vec<tstring>()(1);
|
const string rendezvous_key_base = key_input.vec<tstring>()(1);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, ctx->rendezvous() != nullptr,
|
ctx, ctx->rendezvous() != nullptr,
|
||||||
|
@ -362,14 +362,15 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
|
|||||||
const TpuCompilationCacheKey& subgraph_key,
|
const TpuCompilationCacheKey& subgraph_key,
|
||||||
const SessionMetadata* session_metadata,
|
const SessionMetadata* session_metadata,
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
std::vector<std::string>* proto_key,
|
std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
|
||||||
std::vector<bool>* may_modify_variables,
|
std::vector<bool>* may_modify_variables,
|
||||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
|
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
|
||||||
std::vector<CompiledSubgraph*> removed_entries;
|
std::vector<CompiledSubgraph*> removed_entries;
|
||||||
auto status = CompileIfKeyAbsentHelper(
|
auto status = CompileIfKeyAbsentHelper(
|
||||||
subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
|
subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
|
||||||
may_modify_variables, &removed_entries, hlo_metadatas, compile_function);
|
sharding_key, may_modify_variables, &removed_entries, hlo_metadatas,
|
||||||
|
compile_function);
|
||||||
for (auto entry : removed_entries) {
|
for (auto entry : removed_entries) {
|
||||||
UnloadAndDestroy(entry);
|
UnloadAndDestroy(entry);
|
||||||
}
|
}
|
||||||
@ -399,7 +400,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
|||||||
const TpuCompilationCacheKey& subgraph_key,
|
const TpuCompilationCacheKey& subgraph_key,
|
||||||
const SessionMetadata* session_metadata,
|
const SessionMetadata* session_metadata,
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
std::vector<std::string>* proto_key,
|
std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
|
||||||
std::vector<bool>* may_modify_variables,
|
std::vector<bool>* may_modify_variables,
|
||||||
std::vector<CompiledSubgraph*>* removed_entries,
|
std::vector<CompiledSubgraph*>* removed_entries,
|
||||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||||
@ -497,6 +498,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
|||||||
*uid = entry->uid;
|
*uid = entry->uid;
|
||||||
// Let the caller know the keys for each of the cached protos.
|
// Let the caller know the keys for each of the cached protos.
|
||||||
*proto_key = entry->proto_key;
|
*proto_key = entry->proto_key;
|
||||||
|
*sharding_key = entry->sharding_key;
|
||||||
*may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
|
*may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
|
||||||
*hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
|
*hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
|
||||||
|
|
||||||
|
@ -109,6 +109,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
const SessionMetadata* session_metadata,
|
const SessionMetadata* session_metadata,
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
std::vector<std::string>* proto_key,
|
std::vector<std::string>* proto_key,
|
||||||
|
std::vector<std::string>* sharding_key,
|
||||||
std::vector<bool>* may_modify_variables,
|
std::vector<bool>* may_modify_variables,
|
||||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
||||||
@ -197,6 +198,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
const SessionMetadata* session_metadata,
|
const SessionMetadata* session_metadata,
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
std::vector<std::string>* proto_key,
|
std::vector<std::string>* proto_key,
|
||||||
|
std::vector<std::string>* sharding_key,
|
||||||
std::vector<bool>* may_modify_variables,
|
std::vector<bool>* may_modify_variables,
|
||||||
std::vector<CompiledSubgraph*>* removed_entries,
|
std::vector<CompiledSubgraph*>* removed_entries,
|
||||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||||
|
@ -657,10 +657,11 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
int64 uid;
|
int64 uid;
|
||||||
std::vector<std::string> proto_key;
|
std::vector<std::string> proto_key;
|
||||||
|
std::vector<std::string> sharding_key;
|
||||||
std::vector<bool> may_modify_variables;
|
std::vector<bool> may_modify_variables;
|
||||||
absl::Span<const xla::HloProto* const> hlo_metadatas;
|
absl::Span<const xla::HloProto* const> hlo_metadatas;
|
||||||
Status status = cache->CompileIfKeyAbsent(
|
Status status = cache->CompileIfKeyAbsent(
|
||||||
key, ctx->session_metadata(), ref_holder, &uid, &proto_key,
|
key, ctx->session_metadata(), ref_holder, &uid, &proto_key, &sharding_key,
|
||||||
&may_modify_variables, &hlo_metadatas,
|
&may_modify_variables, &hlo_metadatas,
|
||||||
[&](TpuProgramGroupInterface* tpu_program_group) {
|
[&](TpuProgramGroupInterface* tpu_program_group) {
|
||||||
VLOG(1) << "Cloud TPU: Compiling TPU program";
|
VLOG(1) << "Cloud TPU: Compiling TPU program";
|
||||||
@ -778,13 +779,21 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
|
for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
|
||||||
Tensor output(DT_STRING, TensorShape({2}));
|
Tensor output(DT_STRING, TensorShape({3}));
|
||||||
if (proto_key.size() == 1) {
|
if (proto_key.size() == 1) {
|
||||||
output.vec<tstring>()(0) = proto_key[0];
|
output.vec<tstring>()(0) = proto_key[0];
|
||||||
} else {
|
} else {
|
||||||
output.vec<tstring>()(0) = proto_key[i];
|
output.vec<tstring>()(0) = proto_key[i];
|
||||||
}
|
}
|
||||||
output.vec<tstring>()(1) = rendezvous_key_base;
|
output.vec<tstring>()(1) = rendezvous_key_base;
|
||||||
|
if (sharding_key.empty()) {
|
||||||
|
output.vec<tstring>()(2) = "";
|
||||||
|
} else if (sharding_key.size() == 1) {
|
||||||
|
output.vec<tstring>()(2) = sharding_key[0];
|
||||||
|
} else {
|
||||||
|
TF_RET_CHECK(sharding_key.size() == num_cores_with_compiled_programs);
|
||||||
|
output.vec<tstring>()(2) = sharding_key[i];
|
||||||
|
}
|
||||||
ctx->set_output(i + 1, output);
|
ctx->set_output(i + 1, output);
|
||||||
}
|
}
|
||||||
if (!use_mlir_) {
|
if (!use_mlir_) {
|
||||||
@ -805,9 +814,10 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
|
|||||||
} else {
|
} else {
|
||||||
// Return error in the invalid case.
|
// Return error in the invalid case.
|
||||||
for (int i = 0; i < num_computations_; ++i) {
|
for (int i = 0; i < num_computations_; ++i) {
|
||||||
Tensor output(DT_STRING, TensorShape({2}));
|
Tensor output(DT_STRING, TensorShape({3}));
|
||||||
output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
|
output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
|
||||||
output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
|
output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
|
||||||
|
output.vec<tstring>()(2) = "<<NO SHARDing KEY AS COMPILATION FAILED>>";
|
||||||
ctx->set_output(i + 1, output);
|
ctx->set_output(i + 1, output);
|
||||||
}
|
}
|
||||||
if (!use_mlir_) {
|
if (!use_mlir_) {
|
||||||
|
@ -72,9 +72,9 @@ Status GetComputationCacheEntry(
|
|||||||
TF_RETURN_IF_ERROR(context->input("key", &key));
|
TF_RETURN_IF_ERROR(context->input("key", &key));
|
||||||
profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
|
profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
|
||||||
if (!TensorShapeUtils::IsVector(key->shape()) ||
|
if (!TensorShapeUtils::IsVector(key->shape()) ||
|
||||||
key->shape().dim_size(0) != 2) {
|
key->shape().dim_size(0) != 3) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Key argument to TPUExecute must be a 2-element vector");
|
"Key argument to TPUExecute must be a 3-element vector");
|
||||||
}
|
}
|
||||||
|
|
||||||
ResourceMgr* rmgr = GetTPUConfigResourceMgr();
|
ResourceMgr* rmgr = GetTPUConfigResourceMgr();
|
||||||
|
@ -40,7 +40,7 @@ REGISTER_OP("_TPUCompileMlir")
|
|||||||
c->set_output(0, c->Scalar());
|
c->set_output(0, c->Scalar());
|
||||||
// Programs.
|
// Programs.
|
||||||
for (int i = 0; i < num_computations; ++i) {
|
for (int i = 0; i < num_computations; ++i) {
|
||||||
c->set_output(i + 1, c->Vector(2));
|
c->set_output(i + 1, c->Vector(3));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
@ -64,7 +64,7 @@ REGISTER_OP("_TPUCompileMlirPlaceholderProgramKey")
|
|||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.Output("program: string")
|
.Output("program: string")
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
c->set_output(0, c->Vector(2));
|
c->set_output(0, c->Vector(3));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
@ -100,7 +100,7 @@ REGISTER_OP("TPUCompile")
|
|||||||
c->set_output(0, c->Scalar());
|
c->set_output(0, c->Scalar());
|
||||||
// Programs.
|
// Programs.
|
||||||
for (int i = 0; i < num_computations; ++i) {
|
for (int i = 0; i < num_computations; ++i) {
|
||||||
c->set_output(i + 1, c->Vector(2));
|
c->set_output(i + 1, c->Vector(3));
|
||||||
}
|
}
|
||||||
// May modify variables.
|
// May modify variables.
|
||||||
for (int i = 0; i < num_computations; ++i) {
|
for (int i = 0; i < num_computations; ++i) {
|
||||||
|
@ -30,7 +30,7 @@ REGISTER_OP("TPUExecute")
|
|||||||
shape_inference::ShapeHandle key;
|
shape_inference::ShapeHandle key;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
|
||||||
shape_inference::DimensionHandle unused;
|
shape_inference::DimensionHandle unused;
|
||||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
|
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
|
||||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||||
c->set_output(i, c->UnknownShape());
|
c->set_output(i, c->UnknownShape());
|
||||||
}
|
}
|
||||||
@ -50,7 +50,7 @@ REGISTER_OP("TPUExecuteAndUpdateVariables")
|
|||||||
shape_inference::ShapeHandle key;
|
shape_inference::ShapeHandle key;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
|
||||||
shape_inference::DimensionHandle unused;
|
shape_inference::DimensionHandle unused;
|
||||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
|
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
|
||||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||||
c->set_output(i, c->UnknownShape());
|
c->set_output(i, c->UnknownShape());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user