Mark variables initialized in XLA, so Tensorflow also knows they're initialized.
PiperOrigin-RevId: 276530886 Change-Id: I5c0d82dce44c81139efc766807fba1039f085de3
This commit is contained in:
parent
95a6f4571c
commit
f41d1939cf
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -367,6 +368,7 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
*variable_infos[i].var()->tensor() = output_tensor;
|
*variable_infos[i].var()->tensor() = output_tensor;
|
||||||
|
variable_infos[i].var()->is_initialized |= write.modified;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -540,6 +542,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
kernel->input_mapping, resource_var_snapshots, write.type,
|
kernel->input_mapping, resource_var_snapshots, write.type,
|
||||||
write.shape, buffer, allocator);
|
write.shape, buffer, allocator);
|
||||||
*variable_infos[i].var()->tensor() = output_tensor;
|
*variable_infos[i].var()->tensor() = output_tensor;
|
||||||
|
variable_infos[i].var()->is_initialized |= write.modified;
|
||||||
}
|
}
|
||||||
++output_num;
|
++output_num;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user