Prevent segfault in GetSessionHandle{,V2}
.
In eager mode, session state is null. PiperOrigin-RevId: 332548597 Change-Id: If094812c2e094044220b9ba28f7d7601be042f38
This commit is contained in:
parent
73b291b6ac
commit
9a133d73ae
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// See docs in ../ops/data_flow_ops.cc.
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -27,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
@ -42,7 +44,11 @@ class GetSessionHandleOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& val = ctx->input(0);
|
||||
int64 id = ctx->session_state()->GetNewId();
|
||||
auto session_state = ctx->session_state();
|
||||
OP_REQUIRES(ctx, session_state != nullptr,
|
||||
errors::FailedPrecondition(
|
||||
"GetSessionHandle called on null session state"));
|
||||
int64 id = session_state->GetNewId();
|
||||
TensorStore::TensorAndKey tk{val, id, requested_device()};
|
||||
OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
|
||||
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_data_flow_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_string_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -79,6 +80,13 @@ class RawOpsTest(test.TestCase, parameterized.TestCase):
|
||||
pad_width=0,
|
||||
preserve_short_sequences=False))
|
||||
|
||||
def testGetSessionHandle(self):
|
||||
if context.executing_eagerly():
|
||||
with self.assertRaisesRegex(
|
||||
errors.FailedPreconditionError,
|
||||
"GetSessionHandle called on null session state"):
|
||||
gen_data_flow_ops.GetSessionHandle(value=[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user