Set flow to a value during TensorArray creation,
Re-enable tensor_array_ops_test in msan. PiperOrigin-RevId: 157841785
This commit is contained in:
parent
edcc5cc13b
commit
a56d59a84b
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -101,7 +102,7 @@ Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) {
|
||||
class TensorArrayCreationOp : public OpKernel {
|
||||
public:
|
||||
explicit TensorArrayCreationOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
: OpKernel(context), device_type_(context->device_type()) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor tensor_array_output_handle;
|
||||
@ -133,6 +134,12 @@ class TensorArrayCreationOp : public OpKernel {
|
||||
// Create the flow output.
|
||||
Tensor* flow;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &flow));
|
||||
if (device_type_ == DEVICE_CPU) {
|
||||
// Value doesn't matter, but this makes msan not complaint about
|
||||
// copying an uninitialized value. To do this on GPU would require
|
||||
// a kernel launch or a host->device memcpy, so we avoid that.
|
||||
flow->flat<float>()(0) = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -140,6 +147,9 @@ class TensorArrayCreationOp : public OpKernel {
|
||||
virtual Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm,
|
||||
Tensor* tensor_array_output_handle,
|
||||
TensorArray** output_tensor_array) = 0;
|
||||
|
||||
private:
|
||||
const DeviceType device_type_;
|
||||
};
|
||||
|
||||
// A per-run local tensor array. The tensor array uses a "per-step" resource
|
||||
|
@ -1939,7 +1939,6 @@ cuda_py_test(
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
flaky = 1, # create_local_cluster sometimes times out.
|
||||
tags = ["nomsan"], # b/38390993
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
Loading…
x
Reference in New Issue
Block a user