From 9a133d73ae4b4664d22bd1aa6d654fec13c52ee1 Mon Sep 17 00:00:00 2001
From: Mihai Maruseac <mihaimaruseac@google.com>
Date: Fri, 18 Sep 2020 16:23:20 -0700
Subject: [PATCH] Prevent segfault in `GetSessionHandle{,V2}`.

In eager mode, session state is null.

PiperOrigin-RevId: 332548597
Change-Id: If094812c2e094044220b9ba28f7d7601be042f38
---
 tensorflow/core/kernels/session_ops.cc | 8 +++++++-
 tensorflow/python/ops/raw_ops_test.py  | 8 ++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc
index 9e67fec3c20..ee81ad27632 100644
--- a/tensorflow/core/kernels/session_ops.cc
+++ b/tensorflow/core/kernels/session_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
 // See docs in ../ops/data_flow_ops.cc.
 
 #include <limits.h>
+
 #include <vector>
 
 #include "tensorflow/core/common_runtime/device.h"
@@ -27,6 +28,7 @@ limitations under the License.
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/mutex.h"
@@ -42,7 +44,11 @@ class GetSessionHandleOp : public OpKernel {
 
   void Compute(OpKernelContext* ctx) override {
     const Tensor& val = ctx->input(0);
-    int64 id = ctx->session_state()->GetNewId();
+    auto session_state = ctx->session_state();
+    OP_REQUIRES(ctx, session_state != nullptr,
+                errors::FailedPrecondition(
+                    "GetSessionHandle called on null session state"));
+    int64 id = session_state->GetNewId();
     TensorStore::TensorAndKey tk{val, id, requested_device()};
     OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
 
diff --git a/tensorflow/python/ops/raw_ops_test.py b/tensorflow/python/ops/raw_ops_test.py
index ee20d58d2f0..6706ef194b2 100644
--- a/tensorflow/python/ops/raw_ops_test.py
+++ b/tensorflow/python/ops/raw_ops_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gen_data_flow_ops
 from tensorflow.python.ops import gen_math_ops
 from tensorflow.python.ops import gen_string_ops
 from tensorflow.python.platform import test
@@ -79,6 +80,13 @@ class RawOpsTest(test.TestCase, parameterized.TestCase):
               pad_width=0,
               preserve_short_sequences=False))
 
+  def testGetSessionHandle(self):
+    if context.executing_eagerly():
+      with self.assertRaisesRegex(
+          errors.FailedPreconditionError,
+          "GetSessionHandle called on null session state"):
+        gen_data_flow_ops.GetSessionHandle(value=[1])
+
 
 if __name__ == "__main__":
   ops.enable_eager_execution()