Set flow to a value during TensorArray creation,

Re-enable tensor_array_ops_test in msan.

PiperOrigin-RevId: 157841785
This commit is contained in:
A. Unique TensorFlower 2017-06-02 09:51:27 -07:00 committed by TensorFlower Gardener
parent edcc5cc13b
commit a56d59a84b
2 changed files with 11 additions and 2 deletions

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
@ -101,7 +102,7 @@ Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) {
class TensorArrayCreationOp : public OpKernel {
public:
explicit TensorArrayCreationOp(OpKernelConstruction* context)
: OpKernel(context) {}
: OpKernel(context), device_type_(context->device_type()) {}
void Compute(OpKernelContext* ctx) override {
Tensor tensor_array_output_handle;
@ -133,6 +134,12 @@ class TensorArrayCreationOp : public OpKernel {
// Create the flow output.
Tensor* flow;
OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &flow));
if (device_type_ == DEVICE_CPU) {
// Value doesn't matter, but this makes msan not complaint about
// copying an uninitialized value. To do this on GPU would require
// a kernel launch or a host->device memcpy, so we avoid that.
flow->flat<float>()(0) = 0;
}
}
}
@ -140,6 +147,9 @@ class TensorArrayCreationOp : public OpKernel {
virtual Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm,
Tensor* tensor_array_output_handle,
TensorArray** output_tensor_array) = 0;
private:
const DeviceType device_type_;
};
// A per-run local tensor array. The tensor array uses a "per-step" resource

View File

@ -1939,7 +1939,6 @@ cuda_py_test(
"//tensorflow/python:variables",
],
flaky = 1, # create_local_cluster sometimes times out.
tags = ["nomsan"], # b/38390993
)
cuda_py_test(