[Resubmit] Only trigger resharding when necessary.

PiperOrigin-RevId: 332109807
Change-Id: I63bc5f88c32cdf91b1d7e05b56829ed53bc557d4
This commit is contained in:
Yunxing Dai 2020-09-16 16:06:30 -07:00 committed by TensorFlower Gardener
parent 1593b96ae7
commit 8d9f388840
10 changed files with 38 additions and 19 deletions

View File

@ -537,9 +537,10 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
// Build a constant default key to specify that the unformatting should
// transform the variables to the original format.
builder.setInsertionPointAfter(while_op);
tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {2});
tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
auto default_state_key = builder.create<TF::ConstOp>(
while_op.getLoc(),
tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());

View File

@ -504,9 +504,10 @@ Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
"TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
AddNodeAttr("dtype", DT_STRING, &default_sharding);
Tensor t(DT_STRING, {2});
Tensor t(DT_STRING, {3});
t.vec<tstring>()(0) = kDefaultShardingValue;
t.vec<tstring>()(1) = kDefaultShardingValue;
t.vec<tstring>()(2) = kDefaultShardingValue;
t.AsProtoTensorContent(
(*default_sharding.mutable_attr())["value"].mutable_tensor());

View File

@ -124,6 +124,9 @@ struct CompiledSubgraph : public core::RefCounted {
// Compilation cache proto key to identify the cache entry.
std::vector<std::string> proto_key;
// Fingerprints of sharding programs if there is any.
std::vector<std::string> sharding_key;
// The number of 'external' client-held references to the entry.
int external_references = 0;

View File

@ -58,9 +58,9 @@ class RecvAtHostOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(
ctx,
TensorShapeUtils::IsVector(input.shape()) &&
input.shape().dim_size(0) == 2,
input.shape().dim_size(0) == 3,
errors::InvalidArgument("Input shape ", input.shape().DebugString(),
" is not a vector of length 2."),
" is not a vector of length 3."),
done);
const string rendezvous_key_base = input.vec<tstring>()(1);
OP_REQUIRES_ASYNC(
@ -164,10 +164,10 @@ class SendFromHostOp : public OpKernel {
const Tensor& key_input = ctx->input(ctx->num_inputs() - 1);
OP_REQUIRES(ctx,
TensorShapeUtils::IsVector(key_input.shape()) &&
key_input.shape().dim_size(0) == 2,
key_input.shape().dim_size(0) == 3,
errors::InvalidArgument("Key input shape ",
key_input.shape().DebugString(),
" is not a vector of length 2."));
" is not a vector of length 3."));
const string rendezvous_key_base = key_input.vec<tstring>()(1);
OP_REQUIRES(
ctx, ctx->rendezvous() != nullptr,

View File

@ -362,14 +362,15 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
std::vector<bool>* may_modify_variables,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
std::vector<CompiledSubgraph*> removed_entries;
auto status = CompileIfKeyAbsentHelper(
subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
may_modify_variables, &removed_entries, hlo_metadatas, compile_function);
sharding_key, may_modify_variables, &removed_entries, hlo_metadatas,
compile_function);
for (auto entry : removed_entries) {
UnloadAndDestroy(entry);
}
@ -399,7 +400,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
std::vector<bool>* may_modify_variables,
std::vector<CompiledSubgraph*>* removed_entries,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
@ -497,6 +498,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
*uid = entry->uid;
// Let the caller know the keys for each of the cached protos.
*proto_key = entry->proto_key;
*sharding_key = entry->sharding_key;
*may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
*hlo_metadatas = entry->tpu_program_group->hlo_metadatas();

View File

@ -109,6 +109,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
std::vector<std::string>* sharding_key,
std::vector<bool>* may_modify_variables,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
@ -197,6 +198,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
std::vector<std::string>* sharding_key,
std::vector<bool>* may_modify_variables,
std::vector<CompiledSubgraph*>* removed_entries,
absl::Span<const xla::HloProto* const>* hlo_metadatas,

View File

@ -657,10 +657,11 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
int64 uid;
std::vector<std::string> proto_key;
std::vector<std::string> sharding_key;
std::vector<bool> may_modify_variables;
absl::Span<const xla::HloProto* const> hlo_metadatas;
Status status = cache->CompileIfKeyAbsent(
key, ctx->session_metadata(), ref_holder, &uid, &proto_key,
key, ctx->session_metadata(), ref_holder, &uid, &proto_key, &sharding_key,
&may_modify_variables, &hlo_metadatas,
[&](TpuProgramGroupInterface* tpu_program_group) {
VLOG(1) << "Cloud TPU: Compiling TPU program";
@ -778,13 +779,21 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
if (status.ok()) {
for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
Tensor output(DT_STRING, TensorShape({2}));
Tensor output(DT_STRING, TensorShape({3}));
if (proto_key.size() == 1) {
output.vec<tstring>()(0) = proto_key[0];
} else {
output.vec<tstring>()(0) = proto_key[i];
}
output.vec<tstring>()(1) = rendezvous_key_base;
if (sharding_key.empty()) {
output.vec<tstring>()(2) = "";
} else if (sharding_key.size() == 1) {
output.vec<tstring>()(2) = sharding_key[0];
} else {
TF_RET_CHECK(sharding_key.size() == num_cores_with_compiled_programs);
output.vec<tstring>()(2) = sharding_key[i];
}
ctx->set_output(i + 1, output);
}
if (!use_mlir_) {
@ -805,9 +814,10 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
} else {
// Return error in the invalid case.
for (int i = 0; i < num_computations_; ++i) {
Tensor output(DT_STRING, TensorShape({2}));
Tensor output(DT_STRING, TensorShape({3}));
output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
output.vec<tstring>()(2) = "<<NO SHARDing KEY AS COMPILATION FAILED>>";
ctx->set_output(i + 1, output);
}
if (!use_mlir_) {

View File

@ -72,9 +72,9 @@ Status GetComputationCacheEntry(
TF_RETURN_IF_ERROR(context->input("key", &key));
profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
if (!TensorShapeUtils::IsVector(key->shape()) ||
key->shape().dim_size(0) != 2) {
key->shape().dim_size(0) != 3) {
return errors::InvalidArgument(
"Key argument to TPUExecute must be a 2-element vector");
"Key argument to TPUExecute must be a 3-element vector");
}
ResourceMgr* rmgr = GetTPUConfigResourceMgr();

View File

@ -40,7 +40,7 @@ REGISTER_OP("_TPUCompileMlir")
c->set_output(0, c->Scalar());
// Programs.
for (int i = 0; i < num_computations; ++i) {
c->set_output(i + 1, c->Vector(2));
c->set_output(i + 1, c->Vector(3));
}
return Status::OK();
})
@ -64,7 +64,7 @@ REGISTER_OP("_TPUCompileMlirPlaceholderProgramKey")
.SetIsStateful()
.Output("program: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Vector(2));
c->set_output(0, c->Vector(3));
return Status::OK();
})
.SetIsStateful()
@ -100,7 +100,7 @@ REGISTER_OP("TPUCompile")
c->set_output(0, c->Scalar());
// Programs.
for (int i = 0; i < num_computations; ++i) {
c->set_output(i + 1, c->Vector(2));
c->set_output(i + 1, c->Vector(3));
}
// May modify variables.
for (int i = 0; i < num_computations; ++i) {

View File

@ -30,7 +30,7 @@ REGISTER_OP("TPUExecute")
shape_inference::ShapeHandle key;
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
shape_inference::DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->UnknownShape());
}
@ -50,7 +50,7 @@ REGISTER_OP("TPUExecuteAndUpdateVariables")
shape_inference::ShapeHandle key;
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
shape_inference::DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->UnknownShape());
}