Replace pyjwt with authlib in `org.matrix.login.jwt` (#13011)
This commit is contained in:
parent
e12ff697a4
commit
7d99414edf
|
@ -0,0 +1 @@
|
||||||
|
Replaced usage of PyJWT with methods from Authlib in `org.matrix.login.jwt`. Contributed by Hannes Lerchl.
|
25
docs/jwt.md
25
docs/jwt.md
|
@ -37,19 +37,19 @@ As with other login types, there are additional fields (e.g. `device_id` and
|
||||||
## Preparing Synapse
|
## Preparing Synapse
|
||||||
|
|
||||||
The JSON Web Token integration in Synapse uses the
|
The JSON Web Token integration in Synapse uses the
|
||||||
[`PyJWT`](https://pypi.org/project/pyjwt/) library, which must be installed
|
[`Authlib`](https://docs.authlib.org/en/latest/index.html) library, which must be installed
|
||||||
as follows:
|
as follows:
|
||||||
|
|
||||||
* The relevant libraries are included in the Docker images and Debian packages
|
* The relevant libraries are included in the Docker images and Debian packages
|
||||||
provided by `matrix.org` so no further action is needed.
|
provided by `matrix.org` so no further action is needed.
|
||||||
|
|
||||||
* If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
|
* If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
|
||||||
install synapse[pyjwt]` to install the necessary dependencies.
|
install synapse[jwt]` to install the necessary dependencies.
|
||||||
|
|
||||||
* For other installation mechanisms, see the documentation provided by the
|
* For other installation mechanisms, see the documentation provided by the
|
||||||
maintainer.
|
maintainer.
|
||||||
|
|
||||||
To enable the JSON web token integration, you should then add an `jwt_config` section
|
To enable the JSON web token integration, you should then add a `jwt_config` section
|
||||||
to your configuration file (or uncomment the `enabled: true` line in the
|
to your configuration file (or uncomment the `enabled: true` line in the
|
||||||
existing section). See [sample_config.yaml](./sample_config.yaml) for some
|
existing section). See [sample_config.yaml](./sample_config.yaml) for some
|
||||||
sample settings.
|
sample settings.
|
||||||
|
@ -57,7 +57,7 @@ sample settings.
|
||||||
## How to test JWT as a developer
|
## How to test JWT as a developer
|
||||||
|
|
||||||
Although JSON Web Tokens are typically generated from an external server, the
|
Although JSON Web Tokens are typically generated from an external server, the
|
||||||
examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
|
example below uses a locally generated JWT.
|
||||||
|
|
||||||
1. Configure Synapse with JWT logins, note that this example uses a pre-shared
|
1. Configure Synapse with JWT logins, note that this example uses a pre-shared
|
||||||
secret and an algorithm of HS256:
|
secret and an algorithm of HS256:
|
||||||
|
@ -70,10 +70,21 @@ examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
|
||||||
```
|
```
|
||||||
2. Generate a JSON web token:
|
2. Generate a JSON web token:
|
||||||
|
|
||||||
```bash
|
You can use the following short Python snippet to generate a JWT
|
||||||
$ pyjwt --key=my-secret-token --alg=HS256 encode sub=test-user
|
protected by an HMAC.
|
||||||
eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc
|
Take care that the `secret` and the algorithm given in the `header` match
|
||||||
|
the entries from `jwt_config` above.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from authlib.jose import jwt
|
||||||
|
|
||||||
|
header = {"alg": "HS256"}
|
||||||
|
payload = {"sub": "user1", "aud": ["audience"]}
|
||||||
|
secret = "my-secret-token"
|
||||||
|
result = jwt.encode(header, payload, secret)
|
||||||
|
print(result.decode("ascii"))
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Query for the login types and ensure `org.matrix.login.jwt` is there:
|
3. Query for the login types and ensure `org.matrix.login.jwt` is there:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -2946,8 +2946,10 @@ Additional sub-options for this setting include:
|
||||||
tokens. Defaults to false.
|
tokens. Defaults to false.
|
||||||
* `secret`: This is either the private shared secret or the public key used to
|
* `secret`: This is either the private shared secret or the public key used to
|
||||||
decode the contents of the JSON web token. Required if `enabled` is set to true.
|
decode the contents of the JSON web token. Required if `enabled` is set to true.
|
||||||
* `algorithm`: The algorithm used to sign the JSON web token. Supported algorithms are listed at
|
* `algorithm`: The algorithm used to sign (or HMAC) the JSON web token.
|
||||||
https://pyjwt.readthedocs.io/en/latest/algorithms.html Required if `enabled` is set to true.
|
Supported algorithms are listed
|
||||||
|
[here (section JWS)](https://docs.authlib.org/en/latest/specs/rfc7518.html).
|
||||||
|
Required if `enabled` is set to true.
|
||||||
* `subject_claim`: Name of the claim containing a unique identifier for the user.
|
* `subject_claim`: Name of the claim containing a unique identifier for the user.
|
||||||
Optional, defaults to `sub`.
|
Optional, defaults to `sub`.
|
||||||
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
|
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
|
||||||
|
|
|
@ -815,7 +815,7 @@ python-versions = ">=3.5"
|
||||||
name = "pyjwt"
|
name = "pyjwt"
|
||||||
version = "2.4.0"
|
version = "2.4.0"
|
||||||
description = "JSON Web Token implementation in Python"
|
description = "JSON Web Token implementation in Python"
|
||||||
category = "main"
|
category = "dev"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
|
@ -1546,9 +1546,9 @@ docs = ["sphinx", "repoze.sphinx.autointerface"]
|
||||||
test = ["zope.i18nmessageid", "zope.testing", "zope.testrunner"]
|
test = ["zope.i18nmessageid", "zope.testing", "zope.testrunner"]
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "pyjwt", "txredisapi", "hiredis", "Pympler"]
|
all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler"]
|
||||||
cache_memory = ["Pympler"]
|
cache_memory = ["Pympler"]
|
||||||
jwt = ["pyjwt"]
|
jwt = ["authlib"]
|
||||||
matrix-synapse-ldap3 = ["matrix-synapse-ldap3"]
|
matrix-synapse-ldap3 = ["matrix-synapse-ldap3"]
|
||||||
oidc = ["authlib"]
|
oidc = ["authlib"]
|
||||||
opentracing = ["jaeger-client", "opentracing"]
|
opentracing = ["jaeger-client", "opentracing"]
|
||||||
|
@ -1563,7 +1563,7 @@ url_preview = ["lxml"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.7.1"
|
python-versions = "^3.7.1"
|
||||||
content-hash = "37bd4bccfdb5a869635f2135a85bea4a0729af7375a27de153b4fd9a4aebc195"
|
content-hash = "73882e279e0379482f2fc7414cb71addfd408ca48ad508ff8a02b0cb544762af"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
attrs = [
|
attrs = [
|
||||||
|
|
|
@ -175,7 +175,6 @@ lxml = { version = ">=4.2.0", optional = true }
|
||||||
sentry-sdk = { version = ">=0.7.2", optional = true }
|
sentry-sdk = { version = ">=0.7.2", optional = true }
|
||||||
opentracing = { version = ">=2.2.0", optional = true }
|
opentracing = { version = ">=2.2.0", optional = true }
|
||||||
jaeger-client = { version = ">=4.0.0", optional = true }
|
jaeger-client = { version = ">=4.0.0", optional = true }
|
||||||
pyjwt = { version = ">=1.6.4", optional = true }
|
|
||||||
txredisapi = { version = ">=1.4.7", optional = true }
|
txredisapi = { version = ">=1.4.7", optional = true }
|
||||||
hiredis = { version = "*", optional = true }
|
hiredis = { version = "*", optional = true }
|
||||||
Pympler = { version = "*", optional = true }
|
Pympler = { version = "*", optional = true }
|
||||||
|
@ -196,7 +195,7 @@ systemd = ["systemd-python"]
|
||||||
url_preview = ["lxml"]
|
url_preview = ["lxml"]
|
||||||
sentry = ["sentry-sdk"]
|
sentry = ["sentry-sdk"]
|
||||||
opentracing = ["jaeger-client", "opentracing"]
|
opentracing = ["jaeger-client", "opentracing"]
|
||||||
jwt = ["pyjwt"]
|
jwt = ["authlib"]
|
||||||
# hiredis is not a *strict* dependency, but it makes things much faster.
|
# hiredis is not a *strict* dependency, but it makes things much faster.
|
||||||
# (if it is not installed, we fall back to slow code.)
|
# (if it is not installed, we fall back to slow code.)
|
||||||
redis = ["txredisapi", "hiredis"]
|
redis = ["txredisapi", "hiredis"]
|
||||||
|
@ -222,7 +221,7 @@ all = [
|
||||||
"psycopg2", "psycopg2cffi", "psycopg2cffi-compat",
|
"psycopg2", "psycopg2cffi", "psycopg2cffi-compat",
|
||||||
# saml2
|
# saml2
|
||||||
"pysaml2",
|
"pysaml2",
|
||||||
# oidc
|
# oidc and jwt
|
||||||
"authlib",
|
"authlib",
|
||||||
# url_preview
|
# url_preview
|
||||||
"lxml",
|
"lxml",
|
||||||
|
@ -230,8 +229,6 @@ all = [
|
||||||
"sentry-sdk",
|
"sentry-sdk",
|
||||||
# opentracing
|
# opentracing
|
||||||
"jaeger-client", "opentracing",
|
"jaeger-client", "opentracing",
|
||||||
# jwt
|
|
||||||
"pyjwt",
|
|
||||||
# redis
|
# redis
|
||||||
"txredisapi", "hiredis",
|
"txredisapi", "hiredis",
|
||||||
# cache_memory
|
# cache_memory
|
||||||
|
|
|
@ -18,10 +18,10 @@ from synapse.types import JsonDict
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
|
||||||
MISSING_JWT = """Missing jwt library. This is required for jwt login.
|
MISSING_AUTHLIB = """Missing authlib library. This is required for jwt login.
|
||||||
|
|
||||||
Install by running:
|
Install by running:
|
||||||
pip install pyjwt
|
pip install synapse[jwt]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,11 +43,11 @@ class JWTConfig(Config):
|
||||||
self.jwt_audiences = jwt_config.get("audiences")
|
self.jwt_audiences = jwt_config.get("audiences")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jwt
|
from authlib.jose import JsonWebToken
|
||||||
|
|
||||||
jwt # To stop unused lint.
|
JsonWebToken # To stop unused lint.
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ConfigError(MISSING_JWT)
|
raise ConfigError(MISSING_AUTHLIB)
|
||||||
else:
|
else:
|
||||||
self.jwt_enabled = False
|
self.jwt_enabled = False
|
||||||
self.jwt_secret = None
|
self.jwt_secret = None
|
||||||
|
|
|
@ -420,17 +420,31 @@ class LoginRestServlet(RestServlet):
|
||||||
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
|
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
import jwt
|
from authlib.jose import JsonWebToken, JWTClaims
|
||||||
|
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
|
||||||
|
|
||||||
|
jwt = JsonWebToken([self.jwt_algorithm])
|
||||||
|
claim_options = {}
|
||||||
|
if self.jwt_issuer is not None:
|
||||||
|
claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
|
||||||
|
if self.jwt_audiences is not None:
|
||||||
|
claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
claims = jwt.decode(
|
||||||
token,
|
token,
|
||||||
self.jwt_secret,
|
key=self.jwt_secret,
|
||||||
algorithms=[self.jwt_algorithm],
|
claims_cls=JWTClaims,
|
||||||
issuer=self.jwt_issuer,
|
claims_options=claim_options,
|
||||||
audience=self.jwt_audiences,
|
|
||||||
)
|
)
|
||||||
except jwt.PyJWTError as e:
|
except BadSignatureError:
|
||||||
|
# We handle this case separately to provide a better error message
|
||||||
|
raise LoginError(
|
||||||
|
403,
|
||||||
|
"JWT validation failed: Signature verification failed",
|
||||||
|
errcode=Codes.FORBIDDEN,
|
||||||
|
)
|
||||||
|
except JoseError as e:
|
||||||
# A JWT error occurred, return some info back to the client.
|
# A JWT error occurred, return some info back to the client.
|
||||||
raise LoginError(
|
raise LoginError(
|
||||||
403,
|
403,
|
||||||
|
@ -438,7 +452,23 @@ class LoginRestServlet(RestServlet):
|
||||||
errcode=Codes.FORBIDDEN,
|
errcode=Codes.FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = payload.get(self.jwt_subject_claim, None)
|
try:
|
||||||
|
claims.validate(leeway=120) # allows 2 min of clock skew
|
||||||
|
|
||||||
|
# Enforce the old behavior which is rolled out in productive
|
||||||
|
# servers: if the JWT contains an 'aud' claim but none is
|
||||||
|
# configured, the login attempt will fail
|
||||||
|
if claims.get("aud") is not None:
|
||||||
|
if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
|
||||||
|
raise InvalidClaimError("aud")
|
||||||
|
except JoseError as e:
|
||||||
|
raise LoginError(
|
||||||
|
403,
|
||||||
|
"JWT validation failed: %s" % (str(e),),
|
||||||
|
errcode=Codes.FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
user = claims.get(self.jwt_subject_claim, None)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ from tests.test_utils.html_parsers import TestHtmlParser
|
||||||
from tests.unittest import HomeserverTestCase, override_config, skip_unless
|
from tests.unittest import HomeserverTestCase, override_config, skip_unless
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jwt
|
from authlib.jose import jwk, jwt
|
||||||
|
|
||||||
HAS_JWT = True
|
HAS_JWT = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -841,7 +841,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertIn(b"SSO account deactivated", channel.result["body"])
|
self.assertIn(b"SSO account deactivated", channel.result["body"])
|
||||||
|
|
||||||
|
|
||||||
@skip_unless(HAS_JWT, "requires jwt")
|
@skip_unless(HAS_JWT, "requires authlib")
|
||||||
class JWTTestCase(unittest.HomeserverTestCase):
|
class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
|
@ -866,11 +866,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
|
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
|
||||||
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
|
header = {"alg": self.jwt_algorithm}
|
||||||
result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
|
result: bytes = jwt.encode(header, payload, secret)
|
||||||
if isinstance(result, bytes):
|
|
||||||
return result.decode("ascii")
|
return result.decode("ascii")
|
||||||
return result
|
|
||||||
|
|
||||||
def jwt_login(self, *args: Any) -> FakeChannel:
|
def jwt_login(self, *args: Any) -> FakeChannel:
|
||||||
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
||||||
|
@ -902,7 +900,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body["error"], "JWT validation failed: Signature has expired"
|
channel.json_body["error"],
|
||||||
|
"JWT validation failed: expired_token: The token is expired",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_jwt_not_before(self) -> None:
|
def test_login_jwt_not_before(self) -> None:
|
||||||
|
@ -912,7 +911,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body["error"],
|
channel.json_body["error"],
|
||||||
"JWT validation failed: The token is not yet valid (nbf)",
|
"JWT validation failed: invalid_token: The token is not valid yet",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_no_sub(self) -> None:
|
def test_login_no_sub(self) -> None:
|
||||||
|
@ -934,7 +933,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body["error"], "JWT validation failed: Invalid issuer"
|
channel.json_body["error"],
|
||||||
|
'JWT validation failed: invalid_claim: Invalid claim "iss"',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Not providing an issuer.
|
# Not providing an issuer.
|
||||||
|
@ -943,7 +943,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body["error"],
|
channel.json_body["error"],
|
||||||
'JWT validation failed: Token is missing the "iss" claim',
|
'JWT validation failed: missing_claim: Missing "iss" claim',
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_iss_no_config(self) -> None:
|
def test_login_iss_no_config(self) -> None:
|
||||||
|
@ -965,7 +965,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body["error"], "JWT validation failed: Invalid audience"
|
channel.json_body["error"],
|
||||||
|
'JWT validation failed: invalid_claim: Invalid claim "aud"',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Not providing an audience.
|
# Not providing an audience.
|
||||||
|
@ -974,7 +975,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body["error"],
|
channel.json_body["error"],
|
||||||
'JWT validation failed: Token is missing the "aud" claim',
|
'JWT validation failed: missing_claim: Missing "aud" claim',
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_aud_no_config(self) -> None:
|
def test_login_aud_no_config(self) -> None:
|
||||||
|
@ -983,7 +984,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body["error"], "JWT validation failed: Invalid audience"
|
channel.json_body["error"],
|
||||||
|
'JWT validation failed: invalid_claim: Invalid claim "aud"',
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_default_sub(self) -> None:
|
def test_login_default_sub(self) -> None:
|
||||||
|
@ -1010,7 +1012,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
|
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
|
||||||
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
|
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
|
||||||
# signed by the private key.
|
# signed by the private key.
|
||||||
@skip_unless(HAS_JWT, "requires jwt")
|
@skip_unless(HAS_JWT, "requires authlib")
|
||||||
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
|
@ -1071,11 +1073,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
|
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
|
||||||
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
|
header = {"alg": "RS256"}
|
||||||
result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
|
if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
|
||||||
if isinstance(result, bytes):
|
secret = jwk.dumps(secret, kty="RSA")
|
||||||
|
result: bytes = jwt.encode(header, payload, secret)
|
||||||
return result.decode("ascii")
|
return result.decode("ascii")
|
||||||
return result
|
|
||||||
|
|
||||||
def jwt_login(self, *args: Any) -> FakeChannel:
|
def jwt_login(self, *args: Any) -> FakeChannel:
|
||||||
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
||||||
|
|
Loading…
Reference in New Issue