From 9a133d73ae4b4664d22bd1aa6d654fec13c52ee1 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac <mihaimaruseac@google.com> Date: Fri, 18 Sep 2020 16:23:20 -0700 Subject: [PATCH] Prevent segfault in `GetSessionHandle{,V2}`. In eager mode, session state is null. PiperOrigin-RevId: 332548597 Change-Id: If094812c2e094044220b9ba28f7d7601be042f38 --- tensorflow/core/kernels/session_ops.cc | 8 +++++++- tensorflow/python/ops/raw_ops_test.py | 8 ++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc index 9e67fec3c20..ee81ad27632 100644 --- a/tensorflow/core/kernels/session_ops.cc +++ b/tensorflow/core/kernels/session_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // See docs in ../ops/data_flow_ops.cc. #include <limits.h> + #include <vector> #include "tensorflow/core/common_runtime/device.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -42,7 +44,11 @@ class GetSessionHandleOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& val = ctx->input(0); - int64 id = ctx->session_state()->GetNewId(); + auto session_state = ctx->session_state(); + OP_REQUIRES(ctx, session_state != nullptr, + errors::FailedPrecondition( + "GetSessionHandle called on null session state")); + int64 id = session_state->GetNewId(); TensorStore::TensorAndKey tk{val, id, requested_device()}; OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk)); diff --git a/tensorflow/python/ops/raw_ops_test.py b/tensorflow/python/ops/raw_ops_test.py index ee20d58d2f0..6706ef194b2 100644 --- a/tensorflow/python/ops/raw_ops_test.py +++ b/tensorflow/python/ops/raw_ops_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_string_ops from tensorflow.python.platform import test @@ -79,6 +80,13 @@ class RawOpsTest(test.TestCase, parameterized.TestCase): pad_width=0, preserve_short_sequences=False)) + def testGetSessionHandle(self): + if context.executing_eagerly(): + with self.assertRaisesRegex( + errors.FailedPreconditionError, + "GetSessionHandle called on null session state"): + gen_data_flow_ops.GetSessionHandle(value=[1]) + if __name__ == "__main__": ops.enable_eager_execution()